├── requirements.txt ├── temporal ├── api │ ├── __init__.py │ ├── command │ │ ├── __init__.py │ │ └── v1.py │ ├── common │ │ ├── __init__.py │ │ └── v1.py │ ├── enums │ │ └── __init__.py │ ├── failure │ │ ├── __init__.py │ │ └── v1.py │ ├── filter │ │ ├── __init__.py │ │ └── v1.py │ ├── history │ │ └── __init__.py │ ├── query │ │ ├── __init__.py │ │ └── v1.py │ ├── version │ │ ├── __init__.py │ │ └── v1.py │ ├── workflow │ │ ├── __init__.py │ │ └── v1.py │ ├── errordetails │ │ ├── __init__.py │ │ └── v1.py │ ├── namespace │ │ ├── __init__.py │ │ └── v1.py │ ├── replication │ │ ├── __init__.py │ │ └── v1.py │ ├── taskqueue │ │ ├── __init__.py │ │ └── v1.py │ └── workflowservice │ │ └── __init__.py ├── __init__.py ├── constants.py ├── util.py ├── service_helpers.py ├── replay_interceptor.py ├── workerfactory.py ├── decisions.py ├── retry.py ├── async_activity.py ├── converter.py ├── conversions.py ├── exception_handling.py ├── errors.py ├── exceptions.py ├── activity_loop.py ├── activity_method.py ├── marker.py ├── activity.py ├── clock_decision_context.py └── worker.py ├── tests ├── unittests │ ├── __init__.py │ ├── test_data_converter.py │ ├── test_payload.py │ ├── test_failure.py │ └── test_workflow.py ├── test_workflow_return_list.py ├── test_workflow_return_value.py ├── test_sleep.py ├── __init__.py ├── test_start_workflow.py ├── interceptor_testing_utils.py ├── test_workflow_single_argument.py ├── test_timer.py ├── test_workflow_workflow_id_run_id.py ├── test_workflow_multi_argument.py ├── test_workflow_exception.py ├── test_current_time_millis.py ├── test_workflow_now.py ├── conftest.py ├── test_workflow_random_uuid.py ├── test_workflow_get_logger.py ├── test_workflow_new_random.py ├── test_activity_return_list.py ├── test_activity_return_value.py ├── test_cron.py ├── test_signal_arguments.py ├── test_activity_multi_argument.py ├── test_activity_single_argument.py ├── test_signal.py ├── test_start_activity.py ├── test_start_async_activity.py ├── test_workflow_get_version_single.py ├── test_workflow_long_history.py ├── test_workflow_untyped_activity_stub.py ├── test_await_till.py ├── test_activity_retry_maximum_attempts.py ├── test_activity_async_all_of.py ├── test_async_any_of_timer.py ├── test_activity_activity_attributes.py ├── test_start_workflow_start_parameters.py ├── test_start_workflow_workflow_method_parameters.py ├── test_activity_exception.py ├── test_do_not_complete_on_return_complete.py ├── test_query.py ├── test_activity_heartbeat.py ├── test_data_converter.py ├── test_activity_async_sync.py ├── test_activity_async_any_of.py ├── test_workflow_get_version_with_update.py ├── test_do_not_complete_on_return_complete_exceptionally.py ├── test_activity_method_activity_options.py ├── test_typed_data_converter.py ├── test_typed_data_converter_query_signal.py └── test_activity_method_activity_options_from_stub.py ├── test-utils └── java-test-client │ ├── .gitignore │ ├── settings.gradle │ ├── lib │ └── py4j0.10.8.1.jar │ ├── gradle │ └── wrapper │ │ ├── gradle-wrapper.jar │ │ └── gradle-wrapper.properties │ ├── src │ └── main │ │ └── java │ │ ├── JavaGateway.java │ │ └── GreetingWorkflow.java │ ├── build.gradle │ ├── gradlew.bat │ └── gradlew ├── MANIFEST.in ├── pytest.ini ├── dev-requirements.txt ├── .gitignore ├── install-deps.sh ├── LICENSE ├── setup.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /temporal/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unittests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/command/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/enums/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/failure/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/filter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/history/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/query/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/version/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/errordetails/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/namespace/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/replication/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/taskqueue/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/api/workflowservice/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /temporal/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | DEFAULT_VERSION = -1 3 | 4 | -------------------------------------------------------------------------------- /test-utils/java-test-client/.gitignore: -------------------------------------------------------------------------------- 1 | .gradle 2 | out 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude cadence/thrift/* 2 | exclude cadence/* 3 | exclude tests/* 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | worker_config: config for temporal worker 4 | -------------------------------------------------------------------------------- /test-utils/java-test-client/settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'java-test-client' 2 | 3 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | py4j==0.10.8.1 2 | pytest==5.2.0 3 | pytest-asyncio==0.10.0 4 | pytest-repeat==0.8.0 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | build 3 | cadence_client.egg-info/ 4 | /__init__.py 5 | temporal-api 6 | dependencies 7 | /temporal_python_sdk.egg-info/ 8 | -------------------------------------------------------------------------------- /test-utils/java-test-client/lib/py4j0.10.8.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/firdaus/temporal-python-sdk/HEAD/test-utils/java-test-client/lib/py4j0.10.8.1.jar -------------------------------------------------------------------------------- /test-utils/java-test-client/gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/firdaus/temporal-python-sdk/HEAD/test-utils/java-test-client/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /temporal/constants.py: -------------------------------------------------------------------------------- 1 | 2 | CODE_OK = 0x00 3 | CODE_ERROR = 0x01 4 | 5 | # This should be at least 60 seconds because Cadence will reply after 60 seconds when polling 6 | # if there is nothing pending 7 | DEFAULT_SOCKET_TIMEOUT_SECONDS = 120 8 | -------------------------------------------------------------------------------- /temporal/util.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable 3 | 4 | 5 | @dataclass 6 | class OpenRequestInfo: 7 | # BiConsumer 8 | completion_handle: Callable = None 9 | user_context: object = None 10 | 11 | -------------------------------------------------------------------------------- /tests/unittests/test_data_converter.py: -------------------------------------------------------------------------------- 1 | from temporal.api.common.v1 import Payloads 2 | from temporal.converter import DefaultDataConverter 3 | 4 | 5 | def test_no_payloads(): 6 | converter = DefaultDataConverter() 7 | assert converter.from_payloads(None) == [None] 8 | -------------------------------------------------------------------------------- /test-utils/java-test-client/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Sun May 19 01:25:10 MYT 2019 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.3-all.zip 7 | -------------------------------------------------------------------------------- /test-utils/java-test-client/src/main/java/JavaGateway.java: -------------------------------------------------------------------------------- 1 | import py4j.GatewayServer; 2 | 3 | public class JavaGateway { 4 | 5 | public static void main(String[] args) { 6 | GatewayServer gatewayServer = new GatewayServer(); 7 | gatewayServer.start(); 8 | System.out.println("Gateway Server Started"); 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /install-deps.sh: -------------------------------------------------------------------------------- 1 | mkdir -p dependencies 2 | cd dependencies 3 | if [ -d "python-betterproto" ]; then 4 | n=$(ls -l|wc -l|sed -e 's/^[[:space:]]*//') 5 | echo $n 6 | mv python-betterproto "python-betterproto-$n" 7 | fi 8 | git clone git@github.com:firdaus/python-betterproto.git 9 | cd python-betterproto 10 | git checkout temporal-python-changes 11 | pip install .[compiler] 12 | 13 | -------------------------------------------------------------------------------- /temporal/service_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | 4 | from grpclib.client import Channel 5 | 6 | from temporal.api.workflowservice.v1 import WorkflowServiceStub 7 | 8 | 9 | def create_workflow_service(host: str, port: int, timeout: float) -> WorkflowServiceStub: 10 | channel = Channel(host=host, port=port) 11 | return WorkflowServiceStub(channel, timeout=timeout) 12 | 13 | 14 | def get_identity(): 15 | return "%d@%s" % (os.getpid(), socket.gethostname()) 16 | 17 | -------------------------------------------------------------------------------- /test-utils/java-test-client/build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id 'java' 3 | id 'application' 4 | } 5 | 6 | group 'cadence-python' 7 | version '1.0-SNAPSHOT' 8 | 9 | sourceCompatibility = 1.8 10 | 11 | repositories { 12 | mavenCentral() 13 | 14 | flatDir { 15 | dirs 'lib' 16 | } 17 | } 18 | 19 | application { 20 | mainClassName = 'JavaGateway' 21 | } 22 | 23 | dependencies { 24 | compile name: 'py4j0.10.8.1' 25 | compile group: 'com.uber.cadence', name: 'cadence-client', version: '2.2.0' 26 | testCompile group: 'junit', name: 'junit', version: '4.12' 27 | } 28 | -------------------------------------------------------------------------------- /temporal/api/replication/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/replication/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from typing import List 6 | 7 | import betterproto 8 | 9 | 10 | @dataclass 11 | class ClusterReplicationConfig(betterproto.Message): 12 | cluster_name: str = betterproto.string_field(1) 13 | 14 | 15 | @dataclass 16 | class NamespaceReplicationConfig(betterproto.Message): 17 | active_cluster_name: str = betterproto.string_field(1) 18 | clusters: List["ClusterReplicationConfig"] = betterproto.message_field(2) 19 | -------------------------------------------------------------------------------- /temporal/api/version/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/version/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | 6 | import betterproto 7 | 8 | 9 | @dataclass 10 | class SupportedSDKVersions(betterproto.Message): 11 | """SupportedSDKVersions contains the support versions for SDK.""" 12 | 13 | go_sdk: str = betterproto.string_field(1) 14 | java_sdk: str = betterproto.string_field(2) 15 | 16 | 17 | @dataclass 18 | class WorkerVersionInfo(betterproto.Message): 19 | implementation: str = betterproto.string_field(1) 20 | feature_version: str = betterproto.string_field(2) 21 | -------------------------------------------------------------------------------- /test-utils/java-test-client/src/main/java/GreetingWorkflow.java: -------------------------------------------------------------------------------- 1 | import com.uber.cadence.client.WorkflowClient; 2 | import com.uber.cadence.client.WorkflowOptions; 3 | import com.uber.cadence.workflow.WorkflowMethod; 4 | 5 | public interface GreetingWorkflow { 6 | @WorkflowMethod(executionStartToCloseTimeoutSeconds = 60 * 5) 7 | String getGreeting(String name); 8 | 9 | static GreetingWorkflow getStub(String taskList) { 10 | WorkflowClient workflowClient = WorkflowClient.newInstance("test-domain"); 11 | 12 | return workflowClient.newWorkflowStub(GreetingWorkflow.class, 13 | new WorkflowOptions.Builder().setTaskList(taskList).build()); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /temporal/replay_interceptor.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Callable 3 | 4 | 5 | def get_replay_aware_interceptor(fn: Callable): 6 | def interceptor(*args, **kwargs): 7 | from .decision_loop import ITask 8 | task: ITask = ITask.current() 9 | if not task.decider.decision_context.is_replaying(): 10 | return fn(*args, **kwargs) 11 | 12 | return interceptor 13 | 14 | 15 | def make_replay_aware(target: object): 16 | # TODO: Consider using metaclasses instead 17 | if hasattr(target, "_cadence_python_intercepted"): 18 | return target 19 | for name, fn in inspect.getmembers(target): 20 | if inspect.ismethod(fn): 21 | setattr(target, name, get_replay_aware_interceptor(fn)) 22 | target._cadence_python_intercepted = True # type: ignore 23 | return target 24 | -------------------------------------------------------------------------------- /tests/test_workflow_return_list.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient 4 | 5 | TASK_QUEUE = "test_workflow_return_list_tq" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingWorkflow: 10 | @workflow_method(task_queue=TASK_QUEUE) 11 | async def get_greeting(self) -> None: 12 | raise NotImplementedError 13 | 14 | 15 | class GreetingWorkflowImpl(GreetingWorkflow): 16 | 17 | async def get_greeting(self): 18 | return ["a", "b"] 19 | 20 | 21 | @pytest.mark.asyncio 22 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 23 | async def test(worker): 24 | client = WorkflowClient.new_client(namespace=NAMESPACE) 25 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 26 | return_value = await greeting_workflow.get_greeting() 27 | 28 | assert return_value == ["a", "b"] 29 | -------------------------------------------------------------------------------- /temporal/workerfactory.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | from temporal.api.workflowservice.v1 import WorkflowServiceStub 5 | from temporal.worker import Worker, WorkerOptions 6 | from temporal.workflow import WorkflowClient 7 | 8 | 9 | @dataclass 10 | class WorkerFactoryOptions: 11 | pass 12 | 13 | 14 | @dataclass 15 | class WorkerFactory: 16 | client: WorkflowClient 17 | namespace: str = None 18 | options: WorkerFactoryOptions = None 19 | workers: List[Worker] = field(default_factory=list) 20 | 21 | def new_worker(self, task_queue: str, worker_options: WorkerOptions = None) -> Worker: 22 | worker = Worker(client=self.client, namespace=self.namespace, task_queue=task_queue, options=worker_options) 23 | self.workers.append(worker) 24 | return worker 25 | 26 | def start(self): 27 | for worker in self.workers: 28 | worker.start() 29 | -------------------------------------------------------------------------------- /tests/test_workflow_return_value.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient 4 | 5 | TASK_QUEUE = "test_workflow_return_value_tq" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingWorkflow: 10 | @workflow_method(task_queue=TASK_QUEUE) 11 | async def get_greeting(self) -> None: 12 | raise NotImplementedError 13 | 14 | 15 | class GreetingWorkflowImpl(GreetingWorkflow): 16 | 17 | async def get_greeting(self): 18 | return "from-workflow" 19 | 20 | 21 | @pytest.mark.asyncio 22 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 23 | async def test(worker): 24 | client = WorkflowClient.new_client(namespace=NAMESPACE) 25 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 26 | return_value = await greeting_workflow.get_greeting() 27 | 28 | assert return_value == "from-workflow" 29 | -------------------------------------------------------------------------------- /temporal/api/query/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/query/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | 6 | import betterproto 7 | 8 | from temporal.api.common import v1 as v1common 9 | from temporal.api.enums import v1 as v1enums 10 | 11 | 12 | @dataclass 13 | class WorkflowQuery(betterproto.Message): 14 | query_type: str = betterproto.string_field(1) 15 | query_args: v1common.Payloads = betterproto.message_field(2) 16 | 17 | 18 | @dataclass 19 | class WorkflowQueryResult(betterproto.Message): 20 | result_type: v1enums.QueryResultType = betterproto.enum_field(1) 21 | answer: v1common.Payloads = betterproto.message_field(2) 22 | error_message: str = betterproto.string_field(3) 23 | 24 | 25 | @dataclass 26 | class QueryRejected(betterproto.Message): 27 | status: v1enums.WorkflowExecutionStatus = betterproto.enum_field(1) 28 | -------------------------------------------------------------------------------- /temporal/api/filter/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/filter/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from datetime import datetime 6 | 7 | import betterproto 8 | 9 | from temporal.api.enums import v1 as v1enums 10 | 11 | 12 | @dataclass 13 | class WorkflowExecutionFilter(betterproto.Message): 14 | workflow_id: str = betterproto.string_field(1) 15 | run_id: str = betterproto.string_field(2) 16 | 17 | 18 | @dataclass 19 | class WorkflowTypeFilter(betterproto.Message): 20 | name: str = betterproto.string_field(1) 21 | 22 | 23 | @dataclass 24 | class StartTimeFilter(betterproto.Message): 25 | earliest_time: datetime = betterproto.message_field(1) 26 | latest_time: datetime = betterproto.message_field(2) 27 | 28 | 29 | @dataclass 30 | class StatusFilter(betterproto.Message): 31 | status: v1enums.WorkflowExecutionStatus = betterproto.enum_field(1) 32 | -------------------------------------------------------------------------------- /tests/test_sleep.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import time 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | 6 | TASK_QUEUE = "test_sleep" 7 | NAMESPACE = "default" 8 | 9 | 10 | class GreetingWorkflow: 11 | 12 | @workflow_method(task_queue=TASK_QUEUE) 13 | async def get_greeting(self) -> None: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingWorkflowImpl(GreetingWorkflow): 18 | 19 | async def get_greeting(self): 20 | await Workflow.sleep(5) 21 | 22 | 23 | @pytest.mark.asyncio 24 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 25 | async def test(worker): 26 | client = WorkflowClient.new_client(namespace=NAMESPACE) 27 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 28 | start_time = time.time() 29 | await greeting_workflow.get_greeting() 30 | end_time = time.time() 31 | assert end_time - start_time > 5 32 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from temporal.worker import Worker 5 | from temporal.workflow import WorkflowClient 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | workers_to_cleanup = [] 10 | 11 | 12 | def setup_module(): 13 | pass 14 | 15 | 16 | async def monitor_teardown(): 17 | from .conftest import loop 18 | while True: 19 | running_tasks = asyncio.all_tasks(loop) 20 | print(f"running_tasks={len(running_tasks)}") 21 | await asyncio.sleep(5) 22 | 23 | 24 | def teardown_module(): 25 | from .conftest import loop 26 | pending = asyncio.all_tasks(loop) 27 | print("Waiting for workers to cleanup....") 28 | loop.create_task(monitor_teardown()) 29 | loop.run_until_complete(asyncio.gather(*pending)) 30 | 31 | 32 | async def cleanup_worker(client: WorkflowClient, worker: Worker): 33 | workers_to_cleanup.append(worker) 34 | await worker.stop(background=False) 35 | client.close() 36 | -------------------------------------------------------------------------------- /tests/test_start_workflow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient 4 | 5 | TASK_QUEUE = "test_start_workflow_tq" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingWorkflow: 10 | @workflow_method(task_queue=TASK_QUEUE) 11 | async def get_greeting(self) -> None: 12 | raise NotImplementedError 13 | 14 | 15 | class GreetingWorkflowImpl(GreetingWorkflow): 16 | workflow_method_executed: bool = False 17 | 18 | async def get_greeting(self): 19 | GreetingWorkflowImpl.workflow_method_executed = True 20 | 21 | 22 | @pytest.mark.asyncio 23 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 24 | async def test(worker): 25 | client = WorkflowClient.new_client(namespace=NAMESPACE) 26 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 27 | await greeting_workflow.get_greeting() 28 | 29 | assert GreetingWorkflowImpl.workflow_method_executed 30 | -------------------------------------------------------------------------------- /tests/interceptor_testing_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | 4 | counter_filter_counter = 0 5 | 6 | 7 | class CounterFilter(logging.Filter): 8 | 9 | def filter(self, record): 10 | global counter_filter_counter 11 | counter_filter_counter += 1 12 | return True 13 | 14 | 15 | def reset_counter_filter_counter(): 16 | global counter_filter_counter 17 | counter_filter_counter = 0 18 | 19 | 20 | def get_counter_filter_counter(): 21 | return counter_filter_counter 22 | 23 | 24 | LOGGING = { 25 | 'version': 1, 26 | 'filters': { 27 | 'counter-filter': { 28 | '()': CounterFilter, 29 | } 30 | }, 31 | 'handlers': { 32 | 'console': { 33 | 'class': 'logging.StreamHandler', 34 | 'filters': ['counter-filter'] 35 | } 36 | }, 37 | 'loggers': { 38 | 'test-logger': { 39 | 'level': 'DEBUG', 40 | 'handlers': ['console'] 41 | } 42 | }, 43 | } 44 | -------------------------------------------------------------------------------- /tests/test_workflow_single_argument.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient 4 | 5 | TASK_QUEUE = "test_workflow_single_argument" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingWorkflow: 10 | @workflow_method(task_queue=TASK_QUEUE) 11 | async def get_greeting(self, arg1) -> None: 12 | raise NotImplementedError 13 | 14 | 15 | class GreetingWorkflowImpl(GreetingWorkflow): 16 | argument1: object = None 17 | 18 | async def get_greeting(self, arg1): 19 | GreetingWorkflowImpl.argument1 = arg1 20 | 21 | 22 | @pytest.mark.asyncio 23 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 24 | async def test(worker): 25 | client = WorkflowClient.new_client(namespace=NAMESPACE) 26 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 27 | await greeting_workflow.get_greeting("blah-blah-blah") 28 | 29 | assert GreetingWorkflowImpl.argument1 == "blah-blah-blah" 30 | -------------------------------------------------------------------------------- /tests/test_timer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import time 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | 6 | TASK_QUEUE = "test_timer" 7 | NAMESPACE = "default" 8 | 9 | 10 | class GreetingWorkflow: 11 | 12 | @workflow_method(task_queue=TASK_QUEUE) 13 | async def get_greeting(self) -> None: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingWorkflowImpl(GreetingWorkflow): 18 | 19 | async def get_greeting(self): 20 | future = Workflow.new_timer(5) 21 | await future 22 | 23 | 24 | @pytest.mark.asyncio 25 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 26 | async def test(worker): 27 | client = WorkflowClient.new_client(namespace=NAMESPACE) 28 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 29 | start_time = time.time() 30 | await greeting_workflow.get_greeting() 31 | end_time = time.time() 32 | assert end_time - start_time > 5 33 | client.close() 34 | -------------------------------------------------------------------------------- /tests/unittests/test_payload.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from temporal.converter import get_fn_args_type_hints 4 | 5 | 6 | def test_get_fn_args_type_hints_no_annotations(): 7 | def hello(a, b): 8 | pass 9 | hints = get_fn_args_type_hints(hello) 10 | assert len(hints) == 2 11 | assert hints[0] is None 12 | assert hints[1] is None 13 | 14 | 15 | def test_get_fn_args_type_hints_no_arguments(): 16 | def hello(): 17 | pass 18 | hints = get_fn_args_type_hints(hello) 19 | assert len(hints) == 0 20 | 21 | 22 | def test_get_fn_type_with_annotations(): 23 | def hello(a: str, b: Dict): 24 | pass 25 | hints = get_fn_args_type_hints(hello) 26 | assert len(hints) == 2 27 | assert hints[0] is str 28 | assert hints[1] is Dict 29 | 30 | 31 | def test_get_fn_type_method(): 32 | class Person: 33 | def hello(self, a: str, b: Dict): 34 | pass 35 | hints = get_fn_args_type_hints(Person().hello) 36 | assert len(hints) == 2 37 | assert hints[0] is str 38 | assert hints[1] is Dict 39 | 40 | -------------------------------------------------------------------------------- /tests/test_workflow_workflow_id_run_id.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 3 | 4 | TASK_QUEUE = "test_workflow_workflow_id_run_id_tq" 5 | NAMESPACE = "default" 6 | workflow_id = None 7 | run_id = None 8 | 9 | 10 | class GreetingWorkflow: 11 | 12 | @workflow_method(task_queue=TASK_QUEUE) 13 | async def get_greeting(self) -> None: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingWorkflowImpl(GreetingWorkflow): 18 | 19 | async def get_greeting(self): 20 | global workflow_id, run_id 21 | workflow_id = Workflow.get_workflow_id() 22 | run_id = Workflow.get_run_id() 23 | 24 | 25 | @pytest.mark.asyncio 26 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 27 | async def test(worker): 28 | client = WorkflowClient.new_client(namespace=NAMESPACE) 29 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 30 | await greeting_workflow.get_greeting() 31 | assert isinstance(workflow_id, str) 32 | assert isinstance(run_id, str) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mohammed Firdaus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_workflow_multi_argument.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient 4 | 5 | TASK_QUEUE = "test_workflow_multi_argument" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingWorkflow: 10 | @workflow_method(task_queue=TASK_QUEUE) 11 | async def get_greeting(self, arg1, arg2) -> None: 12 | raise NotImplementedError 13 | 14 | 15 | class GreetingWorkflowImpl(GreetingWorkflow): 16 | argument1: object = None 17 | argument2: object = None 18 | 19 | async def get_greeting(self, arg1, arg2): 20 | GreetingWorkflowImpl.argument1 = arg1 21 | GreetingWorkflowImpl.argument2 = arg2 22 | 23 | 24 | @pytest.mark.asyncio 25 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 26 | async def test(worker): 27 | client = WorkflowClient.new_client(namespace=NAMESPACE) 28 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 29 | await greeting_workflow.get_greeting("first-argument", "second-argument") 30 | 31 | assert GreetingWorkflowImpl.argument1 == "first-argument" 32 | assert GreetingWorkflowImpl.argument2 == "second-argument" 33 | -------------------------------------------------------------------------------- /tests/test_workflow_exception.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient 4 | 5 | TASK_QUEUE = "test_workflow_exception" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingException(Exception): 10 | 11 | def __init__(self, *args: object) -> None: 12 | super().__init__(*args) 13 | 14 | 15 | class GreetingWorkflow: 16 | @workflow_method(task_queue=TASK_QUEUE) 17 | async def get_greeting(self) -> None: 18 | raise NotImplementedError 19 | 20 | 21 | class GreetingWorkflowImpl(GreetingWorkflow): 22 | 23 | async def get_greeting(self): 24 | raise GreetingException("blah") 25 | 26 | 27 | @pytest.mark.asyncio 28 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 29 | async def test(worker): 30 | client = WorkflowClient.new_client(namespace=NAMESPACE) 31 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 32 | caught_exception = None 33 | try: 34 | await greeting_workflow.get_greeting() 35 | except GreetingException as e: 36 | caught_exception = e 37 | assert caught_exception 38 | assert isinstance(caught_exception, GreetingException) 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from pkg_resources import parse_requirements 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="temporal-python-sdk", 9 | version="1.0.19", 10 | author="Mohammed Firdaus", 11 | author_email="firdaus.halim@gmail.com", 12 | description="Unofficial Python SDK for the Temporal Workflow Engine", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/firdaus/temporal-python-sdk", 16 | packages=setuptools.find_packages(exclude=["cadence", "tests", "cadence.*", "tests.*"]), 17 | install_requires=[ 18 | "betterproto-for-temporal-python-sdk==1.2.5", 19 | "dataclasses-json==0.3.8", 20 | "grpcio==1.30.0", 21 | "grpclib==0.3.2", 22 | "h2==3.2.0", 23 | "more-itertools==7.0.0", 24 | "pytz==2020.1", 25 | "tblib==1.6.0" 26 | ], 27 | classifiers=[ 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3.7", 30 | "Programming Language :: Python :: 3.8", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | ], 34 | include_package_data=True 35 | ) 36 | -------------------------------------------------------------------------------- /temporal/decisions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | 5 | 6 | class DecisionState(Enum): 7 | CREATED = 1 8 | DECISION_SENT = 2 9 | CANCELED_BEFORE_INITIATED = 3 10 | INITIATED = 4 11 | STARTED = 5 12 | CANCELED_AFTER_INITIATED = 6 13 | CANCELED_AFTER_STARTED = 7 14 | CANCELLATION_DECISION_SENT = 8 15 | COMPLETED_AFTER_CANCELLATION_DECISION_SENT = 9 16 | COMPLETED = 10 17 | 18 | 19 | class DecisionTarget(Enum): 20 | ACTIVITY = 1 21 | CHILD_WORKFLOW = 2 22 | CANCEL_EXTERNAL_WORKFLOW = 3 23 | SIGNAL_EXTERNAL_WORKFLOW = 4 24 | TIMER = 5 25 | MARKER = 6 26 | 27 | # Probably won't end up using this since the Python version won't have something analagous to 28 | # CompleteWorkflowStateMachine 29 | SELF = 7 30 | 31 | 32 | @dataclass 33 | class DecisionId: 34 | decision_target: DecisionTarget 35 | decision_event_id: int 36 | 37 | def __str__(self): 38 | return f"{self.decision_target}:{self.decision_event_id}" 39 | 40 | def __hash__(self): 41 | return hash(self.__str__()) 42 | 43 | def __eq__(self, other: object): 44 | # TODO: unit test 45 | if not isinstance(other, DecisionId): 46 | return False 47 | else: 48 | return (self.decision_target == other.decision_target) and (self.decision_event_id == other.decision_event_id) 49 | -------------------------------------------------------------------------------- /tests/test_current_time_millis.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | import time 5 | 6 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 7 | 8 | TASK_QUEUE = "test_current_time_millis_tq" 9 | NAMESPACE = "default" 10 | 11 | a = [] 12 | b = [] 13 | c = [] 14 | 15 | 16 | class GreetingWorkflow: 17 | 18 | @workflow_method(task_queue=TASK_QUEUE) 19 | async def get_greeting(self) -> None: 20 | raise NotImplementedError 21 | 22 | 23 | class GreetingWorkflowImpl(GreetingWorkflow): 24 | 25 | async def get_greeting(self): 26 | global a, b, c 27 | a.append(Workflow.current_time_millis()) 28 | await Workflow.sleep(1) 29 | b.append(Workflow.current_time_millis()) 30 | await Workflow.sleep(1) 31 | c.append(Workflow.current_time_millis()) 32 | await Workflow.sleep(1) 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 37 | async def test(worker): 38 | global a, b, c 39 | a, b, c = [], [], [] 40 | client = WorkflowClient.new_client(namespace=NAMESPACE) 41 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 42 | await greeting_workflow.get_greeting() 43 | assert len(a) >= 4 and len(set(a)) == 1 44 | assert len(b) >= 3 and len(set(b)) == 1 45 | assert len(c) >= 2 and len(set(c)) == 1 46 | assert len(set(itertools.chain(a, b, c))) == 3 47 | -------------------------------------------------------------------------------- /tests/test_workflow_now.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from datetime import datetime 3 | 4 | import pytest 5 | 6 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 7 | 8 | TASK_QUEUE = "test_workflow_now_tq" 9 | NAMESPACE = "default" 10 | 11 | a = [] 12 | b = [] 13 | c = [] 14 | 15 | 16 | class GreetingWorkflow: 17 | 18 | @workflow_method(task_queue=TASK_QUEUE) 19 | async def get_greeting(self) -> None: 20 | raise NotImplementedError 21 | 22 | 23 | class GreetingWorkflowImpl(GreetingWorkflow): 24 | 25 | async def get_greeting(self): 26 | global a, b, c 27 | a.append(Workflow.now()) 28 | await Workflow.sleep(1) 29 | b.append(Workflow.now()) 30 | await Workflow.sleep(1) 31 | c.append(Workflow.now()) 32 | await Workflow.sleep(1) 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 37 | async def test(worker): 38 | global a, b, c 39 | a, b, c = [], [], [] 40 | client = WorkflowClient.new_client(namespace=NAMESPACE) 41 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 42 | await greeting_workflow.get_greeting() 43 | assert len(a) >= 4 and len(set(a)) == 1 44 | assert len(b) >= 3 and len(set(b)) == 1 45 | assert len(c) >= 2 and len(set(c)) == 1 46 | assert len(set(itertools.chain(a, b, c))) == 3 47 | for d in itertools.chain(a, b, c): 48 | assert isinstance(d, datetime) 49 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio.events import AbstractEventLoop 3 | from typing import Optional 4 | 5 | import pytest 6 | 7 | from temporal.converter import DEFAULT_DATA_CONVERTER_INSTANCE 8 | from temporal.workerfactory import WorkerFactory 9 | from temporal.workflow import WorkflowClient 10 | from . import cleanup_worker 11 | 12 | loop: Optional[AbstractEventLoop] = None 13 | 14 | 15 | @pytest.fixture 16 | def event_loop(): 17 | global loop 18 | if not loop: 19 | loop = asyncio.get_event_loop() 20 | yield loop 21 | 22 | 23 | @pytest.fixture 24 | async def worker(request): 25 | marker = request.node.get_closest_marker("worker_config") 26 | namespace = marker.args[0] 27 | task_queue = marker.args[1] 28 | activities = marker.kwargs.get("activities", []) 29 | workflows = marker.kwargs.get("workflows", []) 30 | data_converter = marker.kwargs.get("data_converter", DEFAULT_DATA_CONVERTER_INSTANCE) 31 | 32 | client: WorkflowClient = WorkflowClient.new_client("localhost", 7233, data_converter=data_converter) 33 | factory = WorkerFactory(client, namespace) 34 | worker_instance = factory.new_worker(task_queue) 35 | for a_instance, a_cls in activities: 36 | worker_instance.register_activities_implementation(a_instance, a_cls) 37 | for w in workflows: 38 | worker_instance.register_workflow_implementation_type(w) 39 | factory.start() 40 | 41 | yield worker_instance 42 | 43 | asyncio.create_task(cleanup_worker(client, worker_instance)) 44 | -------------------------------------------------------------------------------- /tests/test_workflow_random_uuid.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from uuid import UUID 3 | import pytest 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | 6 | TASK_QUEUE = "test_workflow_random_uuid_tq" 7 | NAMESPACE = "default" 8 | 9 | a = [] 10 | b = [] 11 | c = [] 12 | 13 | 14 | class GreetingWorkflow: 15 | 16 | @workflow_method(task_queue=TASK_QUEUE) 17 | async def get_greeting(self) -> None: 18 | raise NotImplementedError 19 | 20 | 21 | class GreetingWorkflowImpl(GreetingWorkflow): 22 | 23 | async def get_greeting(self): 24 | global a, b, c 25 | a.append(str(Workflow.random_uuid())) 26 | await Workflow.sleep(1) 27 | b.append(str(Workflow.random_uuid())) 28 | await Workflow.sleep(1) 29 | c.append(str(Workflow.random_uuid())) 30 | await Workflow.sleep(1) 31 | 32 | 33 | @pytest.mark.asyncio 34 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 35 | async def test(worker): 36 | global a, b, c 37 | a, b, c = [], [], [] 38 | client = WorkflowClient.new_client(namespace=NAMESPACE) 39 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 40 | await greeting_workflow.get_greeting() 41 | assert len(a) >= 4 and len(set(a)) == 1 42 | assert len(b) >= 3 and len(set(b)) == 1 43 | assert len(c) >= 2 and len(set(c)) == 1 44 | assert len(set(itertools.chain(a, b, c))) == 3 45 | for d in itertools.chain(a, b, c): 46 | UUID(d, version=3) 47 | -------------------------------------------------------------------------------- /tests/test_workflow_get_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | from .interceptor_testing_utils import reset_counter_filter_counter, LOGGING, get_counter_filter_counter 6 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 7 | 8 | TASK_QUEUE = "test_workflow_workflow_get_logger" 9 | NAMESPACE = "default" 10 | 11 | 12 | class GreetingWorkflow: 13 | 14 | @workflow_method(task_queue=TASK_QUEUE) 15 | async def get_greeting(self) -> None: 16 | raise NotImplementedError 17 | 18 | 19 | class GreetingWorkflowImpl(GreetingWorkflow): 20 | 21 | async def get_greeting(self): 22 | logger = Workflow.get_logger("test-logger") 23 | logger.info("********Test %d", 1) 24 | await Workflow.sleep(10) 25 | logger.info("********Test %d", 2) 26 | await Workflow.sleep(10) 27 | logger.info("********Test %d", 3) 28 | await Workflow.sleep(10) 29 | logger.info("********Test %d", 4) 30 | await Workflow.sleep(10) 31 | logger.info("********Test %d", 5) 32 | 33 | 34 | @pytest.mark.asyncio 35 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 36 | async def test(worker): 37 | reset_counter_filter_counter() 38 | logging.config.dictConfig(LOGGING) 39 | 40 | client = WorkflowClient.new_client(namespace=NAMESPACE) 41 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 42 | await greeting_workflow.get_greeting() 43 | assert get_counter_filter_counter() == 5 44 | -------------------------------------------------------------------------------- /tests/test_workflow_new_random.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from uuid import UUID 3 | import pytest 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | 6 | TASK_QUEUE = "test_workflow_new_random" 7 | NAMESPACE = "default" 8 | 9 | a = [] 10 | b = [] 11 | c = [] 12 | 13 | 14 | class GreetingWorkflow: 15 | 16 | @workflow_method(task_queue=TASK_QUEUE) 17 | async def get_greeting(self) -> None: 18 | raise NotImplementedError 19 | 20 | 21 | class GreetingWorkflowImpl(GreetingWorkflow): 22 | 23 | async def get_greeting(self): 24 | global a, b, c 25 | a.append(Workflow.new_random().randint(1,100)) 26 | await Workflow.sleep(1) 27 | b.append(Workflow.new_random().randint(1,100)) 28 | await Workflow.sleep(1) 29 | c.append(Workflow.new_random().randint(1,100)) 30 | await Workflow.sleep(1) 31 | 32 | 33 | @pytest.mark.asyncio 34 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 35 | async def test(worker): 36 | global a, b, c 37 | a, b, c = [], [], [] 38 | client = WorkflowClient.new_client(namespace=NAMESPACE) 39 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 40 | await greeting_workflow.get_greeting() 41 | assert len(a) >= 4 and len(set(a)) == 1 42 | assert len(b) >= 3 and len(set(b)) == 1 43 | assert len(c) >= 2 and len(set(c)) == 1 44 | assert len(set(itertools.chain(a, b, c))) == 3 45 | for d in itertools.chain(a, b, c): 46 | assert isinstance(d, int) 47 | -------------------------------------------------------------------------------- /tests/test_activity_return_list.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_activity_return_list_tq" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | 19 | async def compose_greeting(self): 20 | return ["a", "b", "c"] 21 | 22 | 23 | class GreetingWorkflow: 24 | @workflow_method(task_queue=TASK_QUEUE) 25 | async def get_greeting(self) -> None: 26 | raise NotImplementedError 27 | 28 | 29 | class GreetingWorkflowImpl(GreetingWorkflow): 30 | 31 | def __init__(self): 32 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 33 | 34 | async def get_greeting(self): 35 | return await self.greeting_activities.compose_greeting() 36 | 37 | 38 | @pytest.mark.asyncio 39 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 40 | workflows=[GreetingWorkflowImpl]) 41 | async def test(worker): 42 | client = WorkflowClient.new_client(namespace=NAMESPACE) 43 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 44 | ret_value = await greeting_workflow.get_greeting() 45 | 46 | assert ret_value == ["a", "b", "c"] 47 | -------------------------------------------------------------------------------- /tests/test_activity_return_value.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_activity_return_value" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | 19 | async def compose_greeting(self): 20 | return "from-activity" 21 | 22 | 23 | class GreetingWorkflow: 24 | @workflow_method(task_queue=TASK_QUEUE) 25 | async def get_greeting(self) -> None: 26 | raise NotImplementedError 27 | 28 | 29 | class GreetingWorkflowImpl(GreetingWorkflow): 30 | 31 | def __init__(self): 32 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 33 | 34 | async def get_greeting(self): 35 | return await self.greeting_activities.compose_greeting() 36 | 37 | 38 | @pytest.mark.asyncio 39 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 40 | workflows=[GreetingWorkflowImpl]) 41 | async def test(worker): 42 | client = WorkflowClient.new_client(namespace=NAMESPACE) 43 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 44 | ret_value = await greeting_workflow.get_greeting() 45 | 46 | assert ret_value == "from-activity" 47 | -------------------------------------------------------------------------------- /tests/test_cron.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from temporal.api.workflowservice.v1 import TerminateWorkflowExecutionRequest, TerminateWorkflowExecutionResponse 6 | from temporal.service_helpers import get_identity 7 | from temporal.workflow import workflow_method, WorkflowClient, cron_schedule, WorkflowExecutionContext 8 | 9 | TASK_QUEUE = "test_cron" 10 | NAMESPACE = "default" 11 | invoke_count = 0 12 | 13 | 14 | class GreetingWorkflow: 15 | @workflow_method(task_queue=TASK_QUEUE) 16 | @cron_schedule("*/1 * * * *") 17 | async def get_greeting(self) -> None: 18 | raise NotImplementedError 19 | 20 | 21 | class GreetingWorkflowImpl(GreetingWorkflow): 22 | 23 | async def get_greeting(self): 24 | global invoke_count 25 | invoke_count += 1 26 | 27 | 28 | @pytest.mark.asyncio 29 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 30 | async def test(worker): 31 | client = WorkflowClient.new_client(namespace=NAMESPACE) 32 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 33 | context: WorkflowExecutionContext = await client.start(greeting_workflow.get_greeting) 34 | await asyncio.sleep(60 * 3) 35 | assert invoke_count >= 2 36 | request = TerminateWorkflowExecutionRequest() 37 | request.namespace = NAMESPACE 38 | request.identity = get_identity() 39 | request.workflow_execution = context.workflow_execution 40 | request.workflow_execution.run_id = None 41 | response: TerminateWorkflowExecutionResponse = await client.service.terminate_workflow_execution(request=request) 42 | -------------------------------------------------------------------------------- /tests/test_signal_arguments.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient, signal_method, Workflow 4 | 5 | TASK_QUEUE = "test_start_signal_arguments_tq" 6 | NAMESPACE = "default" 7 | 8 | 9 | class GreetingWorkflow: 10 | 11 | @signal_method 12 | async def hello(self, a, b): 13 | raise NotImplementedError 14 | 15 | @workflow_method(task_queue=TASK_QUEUE) 16 | async def get_greeting(self) -> None: 17 | raise NotImplementedError 18 | 19 | 20 | class GreetingWorkflowImpl(GreetingWorkflow): 21 | 22 | def __init__(self): 23 | self.signal_arguments = [] 24 | 25 | async def hello(self, a, b): 26 | self.signal_arguments = [a, b] 27 | 28 | async def get_greeting(self): 29 | def fn(): 30 | return self.signal_arguments 31 | await Workflow.await_till(fn) 32 | return self.signal_arguments 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 37 | async def test(worker): 38 | client = WorkflowClient.new_client(namespace=NAMESPACE) 39 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 40 | context = await WorkflowClient.start(greeting_workflow.get_greeting) 41 | greeting_workflow = client.new_workflow_stub_from_workflow_id(GreetingWorkflow, 42 | workflow_id=context.workflow_execution.workflow_id) 43 | await greeting_workflow.hello("1", 2) 44 | ret_value = await client.wait_for_close(context) 45 | assert ret_value == ["1", 2] 46 | -------------------------------------------------------------------------------- /tests/test_activity_multi_argument.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_activity_multi_argument" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self, arg1, arg2) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | 19 | async def compose_greeting(self, arg1, arg2): 20 | return arg1 + arg2 21 | 22 | 23 | class GreetingWorkflow: 24 | @workflow_method(task_queue=TASK_QUEUE) 25 | async def get_greeting(self) -> None: 26 | raise NotImplementedError 27 | 28 | 29 | class GreetingWorkflowImpl(GreetingWorkflow): 30 | 31 | def __init__(self): 32 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 33 | 34 | async def get_greeting(self): 35 | return await self.greeting_activities.compose_greeting("a", "b") 36 | 37 | 38 | @pytest.mark.asyncio 39 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 40 | workflows=[GreetingWorkflowImpl]) 41 | async def test(worker): 42 | client = WorkflowClient.new_client(namespace=NAMESPACE) 43 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 44 | ret_value = await greeting_workflow.get_greeting() 45 | 46 | assert ret_value == "ab" 47 | -------------------------------------------------------------------------------- /tests/test_activity_single_argument.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_activity_single_argument" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self, arg1) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | 19 | async def compose_greeting(self, arg1): 20 | return arg1 21 | 22 | 23 | class GreetingWorkflow: 24 | @workflow_method(task_queue=TASK_QUEUE) 25 | async def get_greeting(self) -> None: 26 | raise NotImplementedError 27 | 28 | 29 | class GreetingWorkflowImpl(GreetingWorkflow): 30 | 31 | def __init__(self): 32 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 33 | 34 | async def get_greeting(self): 35 | return await self.greeting_activities.compose_greeting("first-argument") 36 | 37 | 38 | @pytest.mark.asyncio 39 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 40 | workflows=[GreetingWorkflowImpl]) 41 | async def test(worker): 42 | client = WorkflowClient.new_client(namespace=NAMESPACE) 43 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 44 | ret_value = await greeting_workflow.get_greeting() 45 | 46 | assert ret_value == "first-argument" 47 | -------------------------------------------------------------------------------- /tests/test_signal.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal.workflow import workflow_method, WorkflowClient, signal_method, Workflow 4 | 5 | TASK_QUEUE = "test_start_signal_tq" 6 | NAMESPACE = "default" 7 | signal_invoked: bool = False 8 | 9 | 10 | class GreetingWorkflow: 11 | 12 | @signal_method 13 | async def hello(self): 14 | raise NotImplementedError 15 | 16 | @workflow_method(task_queue=TASK_QUEUE) 17 | async def get_greeting(self) -> None: 18 | raise NotImplementedError 19 | 20 | 21 | class GreetingWorkflowImpl(GreetingWorkflow): 22 | 23 | def __init__(self): 24 | self.signal_invoked = False 25 | 26 | async def hello(self): 27 | global signal_invoked 28 | self.signal_invoked = signal_invoked = True 29 | 30 | async def get_greeting(self): 31 | def fn(): 32 | return self.signal_invoked 33 | await Workflow.await_till(fn) 34 | 35 | 36 | @pytest.mark.asyncio 37 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 38 | async def test(worker): 39 | global signal_invoked 40 | client = WorkflowClient.new_client(namespace=NAMESPACE) 41 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 42 | context = await WorkflowClient.start(greeting_workflow.get_greeting) 43 | greeting_workflow = client.new_workflow_stub_from_workflow_id(GreetingWorkflow, 44 | workflow_id=context.workflow_execution.workflow_id) 45 | await greeting_workflow.hello() 46 | await client.wait_for_close(context) 47 | assert signal_invoked 48 | -------------------------------------------------------------------------------- /tests/test_start_activity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_start_activity_tq" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | activity_method_executed: bool = False 19 | 20 | async def compose_greeting(self): 21 | GreetingActivitiesImpl.activity_method_executed = True 22 | 23 | 24 | class GreetingWorkflow: 25 | @workflow_method(task_queue=TASK_QUEUE) 26 | async def get_greeting(self) -> None: 27 | raise NotImplementedError 28 | 29 | 30 | class GreetingWorkflowImpl(GreetingWorkflow): 31 | 32 | def __init__(self): 33 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 34 | 35 | async def get_greeting(self): 36 | await self.greeting_activities.compose_greeting() 37 | 38 | 39 | @pytest.mark.asyncio 40 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 41 | workflows=[GreetingWorkflowImpl]) 42 | async def test(worker): 43 | client = WorkflowClient.new_client(namespace=NAMESPACE) 44 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 45 | await greeting_workflow.get_greeting() 46 | 47 | assert GreetingActivitiesImpl.activity_method_executed 48 | -------------------------------------------------------------------------------- /tests/test_start_async_activity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_start_async_activity_tq" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | activity_method_executed: bool = False 19 | 20 | async def compose_greeting(self): 21 | GreetingActivitiesImpl.activity_method_executed = True 22 | 23 | 24 | class GreetingWorkflow: 25 | @workflow_method(task_queue=TASK_QUEUE) 26 | async def get_greeting(self) -> None: 27 | raise NotImplementedError 28 | 29 | 30 | class GreetingWorkflowImpl(GreetingWorkflow): 31 | 32 | def __init__(self): 33 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 34 | 35 | async def get_greeting(self): 36 | await self.greeting_activities.compose_greeting() 37 | 38 | 39 | @pytest.mark.asyncio 40 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 41 | workflows=[GreetingWorkflowImpl]) 42 | async def test(worker): 43 | client = WorkflowClient.new_client(namespace=NAMESPACE) 44 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 45 | await greeting_workflow.get_greeting() 46 | 47 | assert GreetingActivitiesImpl.activity_method_executed 48 | -------------------------------------------------------------------------------- /temporal/api/errordetails/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/errordetails/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | 6 | import betterproto 7 | 8 | 9 | @dataclass 10 | class NotFoundFailure(betterproto.Message): 11 | current_cluster: str = betterproto.string_field(1) 12 | active_cluster: str = betterproto.string_field(2) 13 | 14 | 15 | @dataclass 16 | class WorkflowExecutionAlreadyStartedFailure(betterproto.Message): 17 | start_request_id: str = betterproto.string_field(1) 18 | run_id: str = betterproto.string_field(2) 19 | 20 | 21 | @dataclass 22 | class NamespaceNotActiveFailure(betterproto.Message): 23 | namespace: str = betterproto.string_field(1) 24 | current_cluster: str = betterproto.string_field(2) 25 | active_cluster: str = betterproto.string_field(3) 26 | 27 | 28 | @dataclass 29 | class ClientVersionNotSupportedFailure(betterproto.Message): 30 | client_version: str = betterproto.string_field(1) 31 | client_impl: str = betterproto.string_field(2) 32 | supported_versions: str = betterproto.string_field(3) 33 | 34 | 35 | @dataclass 36 | class FeatureVersionNotSupportedFailure(betterproto.Message): 37 | feature: str = betterproto.string_field(1) 38 | feature_version: str = betterproto.string_field(2) 39 | supported_versions: str = betterproto.string_field(3) 40 | 41 | 42 | @dataclass 43 | class NamespaceAlreadyExistsFailure(betterproto.Message): 44 | pass 45 | 46 | 47 | @dataclass 48 | class CancellationAlreadyRequestedFailure(betterproto.Message): 49 | pass 50 | 51 | 52 | @dataclass 53 | class QueryFailedFailure(betterproto.Message): 54 | pass 55 | -------------------------------------------------------------------------------- /temporal/retry.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import calendar 3 | import time 4 | 5 | INITIAL_DELAY_SECONDS = 3 6 | BACK_OFF_MULTIPLIER = 2 7 | MAX_DELAY_SECONDS = 5 * 60 8 | RESET_DELAY_AFTER_SECONDS = 10 * 60 9 | 10 | 11 | def retry(logger=None): 12 | def wrapper(fp): 13 | async def retry_loop(*args, **kwargs): 14 | last_failed_time = -1 15 | while True: 16 | try: 17 | await fp(*args, **kwargs) 18 | logger.debug("@retry decorated function %s exited, ending retry loop", fp.__name__) 19 | break 20 | except Exception as ex: 21 | now = calendar.timegm(time.gmtime()) 22 | if last_failed_time == -1 or (now - last_failed_time) > RESET_DELAY_AFTER_SECONDS: 23 | delay_seconds = INITIAL_DELAY_SECONDS 24 | else: 25 | delay_seconds = delay_seconds * BACK_OFF_MULTIPLIER 26 | if delay_seconds > MAX_DELAY_SECONDS: 27 | delay_seconds = MAX_DELAY_SECONDS 28 | last_failed_time = now 29 | logger.error("%s failed: %s, retrying in %d seconds", fp.__name__, ex, 30 | delay_seconds, exc_info=True) 31 | await asyncio.sleep(delay_seconds) 32 | 33 | return retry_loop 34 | 35 | return wrapper 36 | 37 | 38 | if __name__ == "__main__": 39 | import logging 40 | logging.basicConfig(level=logging.DEBUG) 41 | logger = logging.getLogger("retry-test") 42 | 43 | @retry(logger=logger) 44 | async def main(): 45 | raise Exception("blah") 46 | 47 | asyncio.run(main()) 48 | -------------------------------------------------------------------------------- /tests/test_workflow_get_version_single.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from temporal import DEFAULT_VERSION 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | 6 | TASK_QUEUE = "test_workflow_get_version_single_tq" 7 | NAMESPACE = "default" 8 | 9 | version_found_in_step_1_0 = None 10 | version_found_in_step_1_1 = None 11 | version_found_in_step_2_0 = None 12 | version_found_in_step_2_1 = None 13 | 14 | 15 | class GreetingWorkflow: 16 | 17 | @workflow_method(task_queue=TASK_QUEUE) 18 | async def get_greeting(self) -> None: 19 | raise NotImplementedError 20 | 21 | 22 | class GreetingWorkflowImpl(GreetingWorkflow): 23 | 24 | async def get_greeting(self): 25 | global version_found_in_step_1_0, version_found_in_step_1_1 26 | global version_found_in_step_2_0, version_found_in_step_2_1 27 | version_found_in_step_1_0 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 28 | version_found_in_step_1_1 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 29 | await Workflow.sleep(60) 30 | version_found_in_step_2_0 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 31 | version_found_in_step_2_1 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 32 | 33 | 34 | @pytest.mark.asyncio 35 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 36 | async def test(worker): 37 | client = WorkflowClient.new_client(namespace=NAMESPACE) 38 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 39 | await greeting_workflow.get_greeting() 40 | 41 | assert version_found_in_step_1_0 == 2 42 | assert version_found_in_step_1_1 == 2 43 | assert version_found_in_step_2_0 == 2 44 | assert version_found_in_step_2_1 == 2 45 | -------------------------------------------------------------------------------- /tests/test_workflow_long_history.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method 6 | 7 | TASK_QUEUE = "test_workflow_long_history" 8 | NAMESPACE = "default" 9 | activity_invocation_count = 0 10 | 11 | 12 | class GreetingActivities: 13 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 14 | async def compose_greeting(self, i: int) -> str: 15 | raise NotImplementedError 16 | 17 | 18 | class GreetingActivitiesImpl: 19 | 20 | async def compose_greeting(self, i: int): 21 | global activity_invocation_count 22 | activity_invocation_count += 1 23 | 24 | 25 | class GreetingWorkflow: 26 | @workflow_method(task_queue=TASK_QUEUE) 27 | async def get_greeting(self) -> None: 28 | raise NotImplementedError 29 | 30 | 31 | class GreetingWorkflowImpl(GreetingWorkflow): 32 | 33 | def __init__(self): 34 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 35 | 36 | async def get_greeting(self): 37 | for i in range(200): 38 | await self.greeting_activities.compose_greeting(i) 39 | return i 40 | 41 | 42 | @pytest.mark.asyncio 43 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 44 | workflows=[GreetingWorkflowImpl]) 45 | async def test(worker): 46 | client = WorkflowClient.new_client(namespace=NAMESPACE) 47 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 48 | ret_value = await greeting_workflow.get_greeting() 49 | 50 | assert ret_value == 199 51 | assert activity_invocation_count == 200 52 | -------------------------------------------------------------------------------- /tests/unittests/test_failure.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from temporal.api.failure.v1 import Failure 4 | from temporal.exception_handling import failure_to_str, str_to_failure, serialize_exception, deserialize_exception 5 | from temporal.exceptions import ActivityTaskTimeoutException 6 | 7 | 8 | def test_serialize_exception_without_traceback(): 9 | e: Exception = Exception() 10 | assert e.__traceback__ is None 11 | f: Failure = serialize_exception(e) 12 | assert isinstance(f, Failure) 13 | s = failure_to_str(f) 14 | assert isinstance(s, str) 15 | f2: Failure = str_to_failure(s) 16 | assert isinstance(f2, Failure) 17 | assert f.to_dict() == f2.to_dict() 18 | e2 = deserialize_exception(f2) 19 | assert isinstance(e2, Exception) 20 | assert e2.__traceback__ is None 21 | 22 | 23 | def test_deserialize_exception_with_traceback(): 24 | try: 25 | raise Exception("blah") 26 | except Exception as ex: 27 | e = ex 28 | assert e.__traceback__ is not None 29 | f: Failure = serialize_exception(e) 30 | assert isinstance(f, Failure) 31 | s = failure_to_str(f) 32 | assert isinstance(s, str) 33 | f2: Failure = str_to_failure(s) 34 | assert isinstance(f2, Failure) 35 | assert f.to_dict() == f2.to_dict() 36 | e2 = deserialize_exception(f2) 37 | assert isinstance(e2, Exception) 38 | assert e2.__traceback__ is not None 39 | assert traceback.format_tb(e.__traceback__) == traceback.format_tb(e2.__traceback__) 40 | 41 | 42 | # TODO: Add tests to ensure that other exception types are serializable as well 43 | def test_serialize_deserialize_activity_task_timeout_exception(): 44 | e1 = ActivityTaskTimeoutException(None, None, None) 45 | f = serialize_exception(e1) 46 | e2 = deserialize_exception(f) 47 | assert isinstance(e2, ActivityTaskTimeoutException) 48 | -------------------------------------------------------------------------------- /tests/test_workflow_untyped_activity_stub.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method, ActivityOptions 6 | 7 | TASK_QUEUE = "test_workflow_untyped_activity_stub_tq" 8 | NAMESPACE = "default" 9 | 10 | 11 | class GreetingActivities: 12 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 13 | async def compose_greeting(self, arg1) -> str: 14 | raise NotImplementedError 15 | 16 | 17 | class GreetingActivitiesImpl: 18 | 19 | async def compose_greeting(self, arg1): 20 | return "from-activity: " + arg1 21 | 22 | 23 | class GreetingWorkflow: 24 | @workflow_method(task_queue=TASK_QUEUE) 25 | async def get_greeting(self) -> None: 26 | raise NotImplementedError 27 | 28 | 29 | class GreetingWorkflowImpl(GreetingWorkflow): 30 | 31 | def __init__(self): 32 | self.stub = Workflow.new_untyped_activity_stub( 33 | activity_options=ActivityOptions(task_queue=TASK_QUEUE, 34 | schedule_to_close_timeout=timedelta(seconds=1000)) 35 | ) 36 | 37 | async def get_greeting(self): 38 | return await self.stub.execute("GreetingActivities::compose_greeting", "blah") 39 | 40 | 41 | @pytest.mark.asyncio 42 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 43 | workflows=[GreetingWorkflowImpl]) 44 | async def test(worker): 45 | client = WorkflowClient.new_client(namespace=NAMESPACE) 46 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 47 | ret_value = await greeting_workflow.get_greeting() 48 | 49 | assert ret_value == "from-activity: blah" 50 | -------------------------------------------------------------------------------- /tests/test_await_till.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from temporal.workflow import workflow_method, WorkflowClient, signal_method, Workflow 6 | 7 | TASK_QUEUE = "test_await_till_tq" 8 | NAMESPACE = "default" 9 | await_returned = False 10 | 11 | 12 | class GreetingWorkflow: 13 | 14 | @signal_method 15 | async def send_message(self, message: str): 16 | raise NotImplementedError 17 | 18 | @workflow_method(task_queue=TASK_QUEUE) 19 | async def get_greeting(self) -> None: 20 | raise NotImplementedError 21 | 22 | 23 | class GreetingWorkflowImpl(GreetingWorkflow): 24 | 25 | def __init__(self): 26 | self.message = None 27 | 28 | async def send_message(self, message: str): 29 | self.message = message 30 | 31 | async def get_greeting(self): 32 | global await_returned 33 | 34 | def fn(): 35 | return self.message == "done" 36 | 37 | await Workflow.await_till(fn) 38 | await_returned = True 39 | 40 | 41 | @pytest.mark.asyncio 42 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 43 | async def test(worker): 44 | global await_returned 45 | client = WorkflowClient.new_client(namespace=NAMESPACE) 46 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 47 | context = await WorkflowClient.start(greeting_workflow.get_greeting) 48 | greeting_workflow = client.new_workflow_stub_from_workflow_id(GreetingWorkflow, 49 | workflow_id=context.workflow_execution.workflow_id) 50 | await greeting_workflow.send_message("first") 51 | await asyncio.sleep(1) 52 | await greeting_workflow.send_message("second") 53 | await asyncio.sleep(1) 54 | await greeting_workflow.send_message("done") 55 | await client.wait_for_close(context) 56 | assert await_returned 57 | -------------------------------------------------------------------------------- /tests/test_activity_retry_maximum_attempts.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 5 | from temporal.activity_method import activity_method, RetryParameters 6 | 7 | TASK_QUEUE = "test_activity_retry_maximum_attempts" 8 | NAMESPACE = "default" 9 | invoke_count = 0 10 | 11 | 12 | class GreetingActivities: 13 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 14 | async def compose_greeting(self) -> str: 15 | raise NotImplementedError 16 | 17 | 18 | class GreetingActivitiesImpl: 19 | activity_method_executed: bool = False 20 | 21 | async def compose_greeting(self): 22 | global invoke_count 23 | invoke_count += 1 24 | if invoke_count < 3: 25 | raise Exception("blah blah") 26 | else: 27 | return 28 | 29 | 30 | class GreetingWorkflow: 31 | @workflow_method(task_queue=TASK_QUEUE) 32 | async def get_greeting(self) -> None: 33 | raise NotImplementedError 34 | 35 | 36 | class GreetingWorkflowImpl(GreetingWorkflow): 37 | 38 | def __init__(self): 39 | retry_parameters = RetryParameters(backoff_coefficient=2.0, maximum_attempts=3) 40 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities, 41 | retry_parameters=retry_parameters) 42 | 43 | async def get_greeting(self): 44 | await self.greeting_activities.compose_greeting() 45 | 46 | 47 | @pytest.mark.asyncio 48 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 49 | workflows=[GreetingWorkflowImpl]) 50 | async def test(worker): 51 | client = WorkflowClient.new_client(namespace=NAMESPACE) 52 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 53 | await greeting_workflow.get_greeting() 54 | assert invoke_count == 3 55 | -------------------------------------------------------------------------------- /tests/test_activity_async_all_of.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | import asyncio 4 | 5 | from temporal.async_activity import Async 6 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 7 | from temporal.activity_method import activity_method 8 | 9 | 10 | TASK_QUEUE = "test_activity_async_all_of_tq" 11 | NAMESPACE = "default" 12 | executed = False 13 | 14 | 15 | class GreetingActivities: 16 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 17 | async def compose_greeting(self, arg) -> str: 18 | raise NotImplementedError 19 | 20 | 21 | class GreetingActivitiesImpl: 22 | activity_method_executed_counter: int = 0 23 | 24 | async def compose_greeting(self, arg): 25 | GreetingActivitiesImpl.activity_method_executed_counter += 1 26 | return arg 27 | 28 | 29 | class GreetingWorkflow: 30 | @workflow_method(task_queue=TASK_QUEUE) 31 | async def get_greeting(self) -> None: 32 | raise NotImplementedError 33 | 34 | 35 | class GreetingWorkflowImpl(GreetingWorkflow): 36 | 37 | def __init__(self): 38 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 39 | 40 | async def get_greeting(self): 41 | futures = [ 42 | Async.function(self.greeting_activities.compose_greeting, 10), 43 | Async.function(self.greeting_activities.compose_greeting, 20), 44 | Async.function(self.greeting_activities.compose_greeting, 30) 45 | ] 46 | await Async.all_of(futures) 47 | 48 | 49 | @pytest.mark.asyncio 50 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 51 | workflows=[GreetingWorkflowImpl]) 52 | async def test(worker): 53 | client = WorkflowClient.new_client(namespace=NAMESPACE) 54 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 55 | await greeting_workflow.get_greeting() 56 | 57 | assert GreetingActivitiesImpl.activity_method_executed_counter == 3 58 | -------------------------------------------------------------------------------- /temporal/api/taskqueue/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/taskqueue/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from datetime import datetime, timedelta 6 | from typing import Optional 7 | 8 | import betterproto 9 | 10 | from temporal.api.enums import v1 as v1enums 11 | 12 | 13 | @dataclass 14 | class TaskQueue(betterproto.Message): 15 | name: str = betterproto.string_field(1) 16 | # Default: TASK_QUEUE_KIND_NORMAL. 17 | kind: v1enums.TaskQueueKind = betterproto.enum_field(2) 18 | 19 | 20 | @dataclass 21 | class TaskQueueMetadata(betterproto.Message): 22 | max_tasks_per_second: Optional[float] = betterproto.message_field( 23 | 1, wraps=betterproto.TYPE_DOUBLE 24 | ) 25 | 26 | 27 | @dataclass 28 | class TaskQueueStatus(betterproto.Message): 29 | backlog_count_hint: int = betterproto.int64_field(1) 30 | read_level: int = betterproto.int64_field(2) 31 | ack_level: int = betterproto.int64_field(3) 32 | rate_per_second: float = betterproto.double_field(4) 33 | task_id_block: "TaskIdBlock" = betterproto.message_field(5) 34 | 35 | 36 | @dataclass 37 | class TaskIdBlock(betterproto.Message): 38 | start_id: int = betterproto.int64_field(1) 39 | end_id: int = betterproto.int64_field(2) 40 | 41 | 42 | @dataclass 43 | class TaskQueuePartitionMetadata(betterproto.Message): 44 | key: str = betterproto.string_field(1) 45 | owner_host_name: str = betterproto.string_field(2) 46 | 47 | 48 | @dataclass 49 | class PollerInfo(betterproto.Message): 50 | # Unix Nano 51 | last_access_time: datetime = betterproto.message_field(1) 52 | identity: str = betterproto.string_field(2) 53 | rate_per_second: float = betterproto.double_field(3) 54 | 55 | 56 | @dataclass 57 | class StickyExecutionAttributes(betterproto.Message): 58 | worker_task_queue: "TaskQueue" = betterproto.message_field(1) 59 | # (-- api-linter: core::0140::prepositions=disabled aip.dev/not- 60 | # precedent: "to" is used to indicate interval. --) 61 | schedule_to_start_timeout: timedelta = betterproto.message_field(2) 62 | -------------------------------------------------------------------------------- /tests/test_async_any_of_timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import timedelta 3 | 4 | import pytest 5 | 6 | from temporal.activity import Activity 7 | from temporal.activity_method import activity_method 8 | from temporal.async_activity import Async 9 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 10 | 11 | TASK_QUEUE = "test_activity_async_any_of_timer_tq" 12 | NAMESPACE = "default" 13 | executed = False 14 | 15 | 16 | class GreetingActivities: 17 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 18 | async def compose_greeting(self, arg) -> str: 19 | raise NotImplementedError 20 | 21 | 22 | class GreetingActivitiesImpl: 23 | 24 | async def compose_greeting(self, sleep_seconds): 25 | Activity.do_not_complete_on_return() 26 | 27 | 28 | class GreetingWorkflow: 29 | @workflow_method(task_queue=TASK_QUEUE) 30 | async def get_greeting(self) -> None: 31 | raise NotImplementedError 32 | 33 | 34 | done = None 35 | timer = None 36 | 37 | 38 | class GreetingWorkflowImpl(GreetingWorkflow): 39 | 40 | def __init__(self): 41 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 42 | 43 | async def get_greeting(self): 44 | global done, timer 45 | timer = Workflow.new_timer(5) 46 | futures = [ 47 | Async.function(self.greeting_activities.compose_greeting, 50), 48 | timer 49 | ] 50 | done, pending = await Async.any_of(futures) 51 | 52 | 53 | @pytest.mark.asyncio 54 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 55 | workflows=[GreetingWorkflowImpl]) 56 | async def test(worker): 57 | client = WorkflowClient.new_client(namespace=NAMESPACE) 58 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 59 | start_time = time.time() 60 | await greeting_workflow.get_greeting() 61 | end_time = time.time() 62 | assert 5 < end_time - start_time < 50 63 | assert len(done) == 1 64 | assert done[0] is timer 65 | client.close() 66 | -------------------------------------------------------------------------------- /tests/test_activity_activity_attributes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.activity import Activity 5 | from temporal.api.common.v1 import WorkflowExecution 6 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 7 | from temporal.activity_method import activity_method 8 | 9 | TASK_QUEUE = "test_activity_activity_attributes" 10 | NAMESPACE = "default" 11 | namespace = None 12 | task_token = None 13 | workflow_execution: WorkflowExecution = None 14 | 15 | 16 | class GreetingActivities: 17 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 18 | async def compose_greeting(self) -> str: 19 | raise NotImplementedError 20 | 21 | 22 | class GreetingActivitiesImpl: 23 | activity_method_executed: bool = False 24 | 25 | async def compose_greeting(self): 26 | global namespace, task_token, workflow_execution 27 | namespace = Activity.get_namespace() 28 | task_token = Activity.get_task_token() 29 | workflow_execution = Activity.get_workflow_execution() 30 | 31 | 32 | class GreetingWorkflow: 33 | @workflow_method(task_queue=TASK_QUEUE) 34 | async def get_greeting(self) -> None: 35 | raise NotImplementedError 36 | 37 | 38 | class GreetingWorkflowImpl(GreetingWorkflow): 39 | 40 | def __init__(self): 41 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 42 | 43 | async def get_greeting(self): 44 | await self.greeting_activities.compose_greeting() 45 | 46 | 47 | @pytest.mark.asyncio 48 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 49 | workflows=[GreetingWorkflowImpl]) 50 | async def test(worker): 51 | client = WorkflowClient.new_client(namespace=NAMESPACE) 52 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 53 | await greeting_workflow.get_greeting() 54 | assert namespace == "default" 55 | assert task_token is not None 56 | assert workflow_execution is not None 57 | assert workflow_execution.workflow_id is not None 58 | assert workflow_execution.run_id is not None 59 | -------------------------------------------------------------------------------- /tests/test_start_workflow_start_parameters.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta, datetime 2 | 3 | import pytest 4 | 5 | from temporal.api.common.v1 import WorkflowExecution 6 | from temporal.api.enums.v1 import WorkflowIdReusePolicy 7 | from temporal.api.workflowservice.v1 import GetWorkflowExecutionHistoryRequest 8 | from temporal.workflow import workflow_method, WorkflowClient, WorkflowOptions 9 | 10 | TASK_QUEUE = "test_start_workflow_start_parameters_tq" 11 | NAMESPACE = "default" 12 | 13 | 14 | class GreetingWorkflow: 15 | @workflow_method(task_queue=TASK_QUEUE) 16 | async def get_greeting(self) -> None: 17 | raise NotImplementedError 18 | 19 | 20 | class GreetingWorkflowImpl(GreetingWorkflow): 21 | workflow_method_executed: bool = False 22 | 23 | async def get_greeting(self): 24 | pass 25 | 26 | 27 | @pytest.mark.asyncio 28 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 29 | async def test(worker): 30 | client = WorkflowClient.new_client(namespace=NAMESPACE) 31 | options = WorkflowOptions() 32 | options.workflow_id = 'blah' + str(datetime.now()) 33 | options.workflow_id_reuse_policy = WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE 34 | options.workflow_execution_timeout = timedelta(seconds=1000) 35 | options.workflow_run_timeout = timedelta(seconds=500) 36 | options.workflow_task_timeout = timedelta(seconds=30) 37 | options.memo = {"name": "bob"} 38 | options.search_attributes = {} 39 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow, workflow_options=options) 40 | await greeting_workflow.get_greeting() 41 | request = GetWorkflowExecutionHistoryRequest(namespace=NAMESPACE, 42 | execution=WorkflowExecution(workflow_id=options.workflow_id)) 43 | response = await client.service.get_workflow_execution_history(request=request) 44 | attr = response.history.events[0].workflow_execution_started_event_attributes 45 | assert attr.workflow_execution_timeout == timedelta(seconds=1000) 46 | assert attr.workflow_run_timeout == timedelta(seconds=500) 47 | assert attr.workflow_task_timeout == timedelta(seconds=30) 48 | assert attr.memo.fields["name"].data == b'"bob"' 49 | -------------------------------------------------------------------------------- /tests/test_start_workflow_workflow_method_parameters.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta, datetime 2 | 3 | import pytest 4 | 5 | from temporal.api.common.v1 import WorkflowExecution 6 | from temporal.api.enums.v1 import WorkflowIdReusePolicy 7 | from temporal.api.workflowservice.v1 import GetWorkflowExecutionHistoryRequest 8 | from temporal.workflow import workflow_method, WorkflowClient, WorkflowOptions 9 | 10 | TASK_QUEUE = "test_start_workflow_workflow_method_parameters" 11 | NAMESPACE = "default" 12 | 13 | workflow_id = 'blah' + str(datetime.now()) 14 | 15 | class GreetingWorkflow: 16 | @workflow_method(task_queue=TASK_QUEUE, 17 | workflow_id=workflow_id, 18 | workflow_id_reuse_policy=WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE, 19 | workflow_execution_timeout=timedelta(seconds=1000), 20 | workflow_run_timeout=timedelta(seconds=500), 21 | workflow_task_timeout=timedelta(seconds=30), 22 | memo={"name": "bob"}, 23 | search_attributes={}) 24 | async def get_greeting(self) -> None: 25 | raise NotImplementedError 26 | 27 | 28 | class GreetingWorkflowImpl(GreetingWorkflow): 29 | workflow_method_executed: bool = False 30 | 31 | async def get_greeting(self): 32 | pass 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 37 | async def test(worker): 38 | client = WorkflowClient.new_client(namespace=NAMESPACE) 39 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 40 | await greeting_workflow.get_greeting() 41 | request = GetWorkflowExecutionHistoryRequest(namespace=NAMESPACE, 42 | execution=WorkflowExecution(workflow_id=workflow_id)) 43 | response = await client.service.get_workflow_execution_history(request=request) 44 | attr = response.history.events[0].workflow_execution_started_event_attributes 45 | assert attr.workflow_execution_timeout == timedelta(seconds=1000) 46 | assert attr.workflow_run_timeout == timedelta(seconds=500) 47 | assert attr.workflow_task_timeout == timedelta(seconds=30) 48 | assert attr.memo.fields["name"].data == b'"bob"' 49 | -------------------------------------------------------------------------------- /tests/test_activity_exception.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.exceptions import ActivityFailureException 5 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 6 | from temporal.activity_method import activity_method, RetryParameters 7 | 8 | TASK_QUEUE = "test_activity_exception" 9 | NAMESPACE = "default" 10 | caught_exception = None 11 | 12 | 13 | class GreetingActivities: 14 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 15 | async def compose_greeting(self) -> str: 16 | raise NotImplementedError 17 | 18 | 19 | class GreetingException(Exception): 20 | pass 21 | 22 | 23 | class GreetingActivitiesImpl: 24 | 25 | async def compose_greeting(self): 26 | raise GreetingException("greeting error!") 27 | 28 | 29 | class GreetingWorkflow: 30 | @workflow_method(task_queue=TASK_QUEUE) 31 | async def get_greeting(self) -> None: 32 | raise NotImplementedError 33 | 34 | 35 | class GreetingWorkflowImpl(GreetingWorkflow): 36 | 37 | def __init__(self): 38 | retry_parameters = RetryParameters(maximum_attempts=1) 39 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities, 40 | retry_parameters=retry_parameters) 41 | 42 | async def get_greeting(self): 43 | try: 44 | await self.greeting_activities.compose_greeting() 45 | except Exception as ex: 46 | global caught_exception 47 | caught_exception = ex 48 | 49 | 50 | @pytest.mark.asyncio 51 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 52 | workflows=[GreetingWorkflowImpl]) 53 | async def test(worker): 54 | client = WorkflowClient.new_client(namespace=NAMESPACE) 55 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 56 | await greeting_workflow.get_greeting() 57 | assert caught_exception is not None 58 | assert isinstance(caught_exception, ActivityFailureException) 59 | cause = caught_exception.get_cause() 60 | assert isinstance(cause, GreetingException) 61 | assert cause.__traceback__ is not None 62 | assert cause.args == ("greeting error!",) 63 | -------------------------------------------------------------------------------- /tests/test_do_not_complete_on_return_complete.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | from datetime import timedelta 4 | from time import sleep 5 | 6 | import pytest 7 | 8 | from temporal.activity import Activity 9 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 10 | from temporal.activity_method import activity_method 11 | 12 | TASK_QUEUE = "test_async_do_not_complete_on_return_complete_tq" 13 | NAMESPACE = "default" 14 | 15 | 16 | class GreetingActivities: 17 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 18 | async def compose_greeting(self) -> str: 19 | raise NotImplementedError 20 | 21 | 22 | def greeting_activities_thread_func(task_token): 23 | async def fn(): 24 | client = WorkflowClient.new_client(namespace=NAMESPACE) 25 | activity_completion_client = client.new_activity_completion_client() 26 | sleep(10) 27 | return_value = "from-activity-completion-client" 28 | await activity_completion_client.complete(task_token, return_value) 29 | client.close() 30 | 31 | asyncio.run(fn()) 32 | 33 | 34 | class GreetingActivitiesImpl: 35 | 36 | async def compose_greeting(self): 37 | Activity.do_not_complete_on_return() 38 | thread = threading.Thread(target=greeting_activities_thread_func, args=(Activity.get_task_token(),)) 39 | thread.start() 40 | 41 | 42 | class GreetingWorkflow: 43 | @workflow_method(task_queue=TASK_QUEUE) 44 | async def get_greeting(self) -> None: 45 | raise NotImplementedError 46 | 47 | 48 | class GreetingWorkflowImpl(GreetingWorkflow): 49 | 50 | def __init__(self): 51 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 52 | 53 | async def get_greeting(self): 54 | return await self.greeting_activities.compose_greeting() 55 | 56 | 57 | @pytest.mark.asyncio 58 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 59 | workflows=[GreetingWorkflowImpl]) 60 | async def test(worker): 61 | client = WorkflowClient.new_client(namespace=NAMESPACE) 62 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 63 | ret_value = await greeting_workflow.get_greeting() 64 | 65 | assert ret_value == "from-activity-completion-client" 66 | -------------------------------------------------------------------------------- /tests/test_query.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from temporal.workflow import workflow_method, WorkflowClient, Workflow, query_method 6 | 7 | TASK_QUEUE = "test_query_tq" 8 | NAMESPACE = "default" 9 | workflow_started = False 10 | 11 | 12 | class GreetingWorkflow: 13 | 14 | @query_method 15 | async def get_status(self): 16 | raise NotImplementedError 17 | 18 | @query_method 19 | async def get_status_as_list(self): 20 | raise NotImplementedError 21 | 22 | @query_method 23 | async def get_status_as_list_with_args(self, a, b): 24 | raise NotImplementedError 25 | 26 | @workflow_method(task_queue=TASK_QUEUE) 27 | async def get_greeting(self) -> None: 28 | raise NotImplementedError 29 | 30 | 31 | class GreetingWorkflowImpl(GreetingWorkflow): 32 | 33 | def __init__(self): 34 | self.status = None 35 | 36 | async def get_status(self): 37 | return self.status 38 | 39 | async def get_status_as_list(self): 40 | return [self.status] 41 | 42 | async def get_status_as_list_with_args(self, a, b): 43 | return [self.status, a, b] 44 | 45 | async def get_greeting(self): 46 | global workflow_started 47 | self.status = "STARTED" 48 | workflow_started = True 49 | await Workflow.sleep(60) 50 | self.status = "DONE" 51 | 52 | 53 | @pytest.mark.asyncio 54 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl]) 55 | async def test(worker): 56 | global workflow_started 57 | client = WorkflowClient.new_client(namespace=NAMESPACE) 58 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 59 | context = await WorkflowClient.start(greeting_workflow.get_greeting) 60 | greeting_workflow = client.new_workflow_stub_from_workflow_id(GreetingWorkflow, 61 | workflow_id=context.workflow_execution.workflow_id) 62 | while not workflow_started: 63 | await asyncio.sleep(2) 64 | status = await greeting_workflow.get_status() 65 | assert status == "STARTED" 66 | status = await greeting_workflow.get_status_as_list() 67 | assert status == ["STARTED"] 68 | status = await greeting_workflow.get_status_as_list_with_args("1", "2") 69 | assert status == ["STARTED", "1", "2"] 70 | await client.wait_for_close(context) 71 | -------------------------------------------------------------------------------- /temporal/api/namespace/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/namespace/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from datetime import datetime, timedelta 6 | from typing import Dict 7 | 8 | import betterproto 9 | 10 | from temporal.api.enums import v1 as v1enums 11 | 12 | 13 | @dataclass 14 | class NamespaceInfo(betterproto.Message): 15 | name: str = betterproto.string_field(1) 16 | state: v1enums.NamespaceState = betterproto.enum_field(2) 17 | description: str = betterproto.string_field(3) 18 | owner_email: str = betterproto.string_field(4) 19 | # A key-value map for any customized purpose. 20 | data: Dict[str, str] = betterproto.map_field( 21 | 5, betterproto.TYPE_STRING, betterproto.TYPE_STRING 22 | ) 23 | id: str = betterproto.string_field(6) 24 | 25 | 26 | @dataclass 27 | class NamespaceConfig(betterproto.Message): 28 | workflow_execution_retention_ttl: timedelta = betterproto.message_field(1) 29 | bad_binaries: "BadBinaries" = betterproto.message_field(2) 30 | # If unspecified (ARCHIVAL_STATE_UNSPECIFIED) then default server 31 | # configuration is used. 32 | history_archival_state: v1enums.ArchivalState = betterproto.enum_field(3) 33 | history_archival_uri: str = betterproto.string_field(4) 34 | # If unspecified (ARCHIVAL_STATE_UNSPECIFIED) then default server 35 | # configuration is used. 36 | visibility_archival_state: v1enums.ArchivalState = betterproto.enum_field(5) 37 | visibility_archival_uri: str = betterproto.string_field(6) 38 | 39 | 40 | @dataclass 41 | class BadBinaries(betterproto.Message): 42 | binaries: Dict[str, "BadBinaryInfo"] = betterproto.map_field( 43 | 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE 44 | ) 45 | 46 | 47 | @dataclass 48 | class BadBinaryInfo(betterproto.Message): 49 | reason: str = betterproto.string_field(1) 50 | operator: str = betterproto.string_field(2) 51 | create_time: datetime = betterproto.message_field(3) 52 | 53 | 54 | @dataclass 55 | class UpdateNamespaceInfo(betterproto.Message): 56 | description: str = betterproto.string_field(1) 57 | owner_email: str = betterproto.string_field(2) 58 | # A key-value map for any customized purpose. 59 | data: Dict[str, str] = betterproto.map_field( 60 | 3, betterproto.TYPE_STRING, betterproto.TYPE_STRING 61 | ) 62 | -------------------------------------------------------------------------------- /tests/test_activity_heartbeat.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datetime import timedelta 3 | 4 | from temporal.activity import Activity 5 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 6 | from temporal.activity_method import activity_method, RetryParameters 7 | 8 | TASK_QUEUE = "test_activity_heartbeat" 9 | NAMESPACE = "default" 10 | HEARTBEAT_VALUE: str = "bb93fbab-574b-4239-9dfc-f5d03a21a84e" 11 | captured_heartbeat_value: str = None 12 | 13 | 14 | class GreetingActivities: 15 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 16 | async def compose_greeting(self) -> str: 17 | raise NotImplementedError 18 | 19 | 20 | class GreetingActivitiesImpl: 21 | 22 | def __init__(self): 23 | self.heartbeated = False 24 | 25 | async def compose_greeting(self): 26 | global captured_heartbeat_value 27 | if not self.heartbeated: 28 | await Activity.heartbeat(HEARTBEAT_VALUE) 29 | self.heartbeated = True 30 | raise Exception("Blah") 31 | else: 32 | captured_heartbeat_value = Activity.get_heartbeat_details() 33 | return captured_heartbeat_value 34 | 35 | 36 | class GreetingWorkflow: 37 | @workflow_method(task_queue=TASK_QUEUE) 38 | async def get_greeting(self) -> None: 39 | raise NotImplementedError 40 | 41 | 42 | class GreetingWorkflowImpl(GreetingWorkflow): 43 | 44 | def __init__(self): 45 | retry_parameters = RetryParameters(backoff_coefficient=2.0, maximum_attempts=3) 46 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities, 47 | retry_parameters=retry_parameters) 48 | 49 | async def get_greeting(self): 50 | return await self.greeting_activities.compose_greeting() 51 | 52 | 53 | @pytest.mark.asyncio 54 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 55 | workflows=[GreetingWorkflowImpl]) 56 | async def test(worker): 57 | client = WorkflowClient.new_client(namespace=NAMESPACE) 58 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 59 | ret_value = await greeting_workflow.get_greeting() 60 | 61 | assert ret_value == HEARTBEAT_VALUE 62 | assert captured_heartbeat_value == HEARTBEAT_VALUE 63 | -------------------------------------------------------------------------------- /temporal/async_activity.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from asyncio import Future 3 | from typing import List, Union 4 | 5 | from temporal.activity_method import ExecuteActivityParameters 6 | from temporal.decision_loop import ActivityFuture 7 | 8 | 9 | class Async: 10 | @staticmethod 11 | def function(method, *args) -> ActivityFuture: 12 | return Async.function_with_self(method, method.__self__, *args) 13 | 14 | @staticmethod 15 | def function_with_self(method, self, *args): 16 | assert self._decision_context 17 | assert method._execute_parameters 18 | parameters: ExecuteActivityParameters = copy.deepcopy(method._execute_parameters) 19 | return Async.call(self, parameters, args) 20 | 21 | @staticmethod 22 | def call(self, parameters, args: List[object]): 23 | if hasattr(self, "_activity_options") and self._activity_options: 24 | self._activity_options.fill_execute_activity_parameters(parameters) 25 | if self._retry_parameters: 26 | parameters.retry_parameters = self._retry_parameters 27 | from temporal.decision_loop import DecisionContext 28 | decision_context: DecisionContext = self._decision_context 29 | parameters.input = decision_context.decider.worker.client.data_converter.to_payloads(args) 30 | return decision_context.schedule_activity_task(parameters=parameters) 31 | 32 | @staticmethod 33 | async def any_of(futures: List[Union[ActivityFuture, Future]], timeout_seconds=0): 34 | done, pending = [], [] 35 | 36 | def condition(): 37 | done[:] = [] 38 | pending[:] = [] 39 | for f in futures: 40 | if f.done(): 41 | done.append(f) 42 | else: 43 | pending.append(f) 44 | if done: 45 | return True 46 | else: 47 | return False 48 | 49 | await Workflow.await_till(condition, timeout_seconds=timeout_seconds) 50 | return done, pending 51 | 52 | @staticmethod 53 | async def all_of(futures: List[Union[ActivityFuture, Future]], timeout_seconds=0): 54 | 55 | def condition(): 56 | for f in futures: 57 | if not f.done(): 58 | return False 59 | return True 60 | 61 | await Workflow.await_till(condition, timeout_seconds=timeout_seconds) 62 | 63 | 64 | from temporal.workflow import Workflow 65 | -------------------------------------------------------------------------------- /test-utils/java-test-client/gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /tests/test_data_converter.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass 3 | 4 | import pytest 5 | from datetime import timedelta 6 | 7 | from temporal.api.common.v1 import Payload 8 | from temporal.conversions import METADATA_ENCODING_KEY 9 | from temporal.converter import DataConverter 10 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 11 | from temporal.activity_method import activity_method 12 | 13 | TASK_QUEUE = "test_data_converter" 14 | NAMESPACE = "default" 15 | 16 | 17 | @dataclass 18 | class Greeting: 19 | name: str 20 | age: int 21 | 22 | 23 | class GreetingActivities: 24 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 25 | async def compose_greeting(self, name: str, age: int) -> Greeting: 26 | raise NotImplementedError 27 | 28 | 29 | class GreetingActivitiesImpl: 30 | 31 | async def compose_greeting(self, name: str, age: int) -> Greeting: 32 | return Greeting(name, age) 33 | 34 | 35 | class GreetingWorkflow: 36 | @workflow_method(task_queue=TASK_QUEUE) 37 | async def get_greeting(self) -> Greeting: 38 | raise NotImplementedError 39 | 40 | 41 | class GreetingWorkflowImpl(GreetingWorkflow): 42 | 43 | def __init__(self): 44 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 45 | 46 | async def get_greeting(self) -> Greeting: 47 | return await self.greeting_activities.compose_greeting("Bob", 20) 48 | 49 | 50 | class PickleDataConverter(DataConverter): 51 | 52 | def to_payload(self, arg: object) -> Payload: 53 | payload = Payload() 54 | payload.metadata = {METADATA_ENCODING_KEY: b"PYTHON_PICKLE"} 55 | payload.data = pickle.dumps(arg) 56 | return payload 57 | 58 | def from_payload(self, payload: Payload, type_hint: type = None) -> object: 59 | obj = pickle.loads(payload.data) 60 | return obj 61 | 62 | 63 | @pytest.mark.asyncio 64 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 65 | workflows=[GreetingWorkflowImpl], data_converter=PickleDataConverter()) 66 | async def test(worker): 67 | client = WorkflowClient.new_client(namespace=NAMESPACE, data_converter=PickleDataConverter()) 68 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 69 | ret_value = await greeting_workflow.get_greeting() 70 | 71 | assert ret_value == Greeting("Bob", 20) 72 | client.close() 73 | -------------------------------------------------------------------------------- /tests/test_activity_async_sync.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import pytest 4 | from datetime import timedelta 5 | import asyncio 6 | 7 | from temporal.activity import Activity 8 | from temporal.async_activity import Async 9 | from temporal.decision_loop import ActivityFuture 10 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 11 | from temporal.activity_method import activity_method 12 | 13 | 14 | TASK_QUEUE = "test_activity_async_sync_tq" 15 | NAMESPACE = "default" 16 | executed = False 17 | 18 | 19 | def greeting_activities_thread_func(task_token, sleep_seconds): 20 | async def fn(): 21 | client = WorkflowClient.new_client(namespace=NAMESPACE) 22 | activity_completion_client = client.new_activity_completion_client() 23 | print(f"Sleeping for {sleep_seconds} seconds") 24 | await asyncio.sleep(sleep_seconds) 25 | await activity_completion_client.complete(task_token, sleep_seconds) 26 | client.close() 27 | 28 | asyncio.run(fn()) 29 | 30 | 31 | class GreetingActivities: 32 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 33 | async def compose_greeting(self, arg) -> str: 34 | raise NotImplementedError 35 | 36 | 37 | class GreetingActivitiesImpl: 38 | 39 | async def compose_greeting(self, sleep_seconds): 40 | Activity.do_not_complete_on_return() 41 | thread = threading.Thread(target=greeting_activities_thread_func, args=(Activity.get_task_token(), sleep_seconds)) 42 | thread.start() 43 | 44 | 45 | class GreetingWorkflow: 46 | @workflow_method(task_queue=TASK_QUEUE) 47 | async def get_greeting(self) -> None: 48 | raise NotImplementedError 49 | 50 | 51 | class GreetingWorkflowImpl(GreetingWorkflow): 52 | 53 | def __init__(self): 54 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 55 | 56 | async def get_greeting(self): 57 | f: ActivityFuture = Async.function(self.greeting_activities.compose_greeting, 10) 58 | await f 59 | v = f.get_result() 60 | return v 61 | 62 | 63 | @pytest.mark.asyncio 64 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 65 | workflows=[GreetingWorkflowImpl]) 66 | async def test(worker): 67 | client = WorkflowClient.new_client(namespace=NAMESPACE) 68 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 69 | ret_value = await greeting_workflow.get_greeting() 70 | assert ret_value == 10 71 | 72 | -------------------------------------------------------------------------------- /tests/unittests/test_workflow.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | from temporal.api.enums.v1 import WorkflowIdReusePolicy 4 | from temporal.converter import DEFAULT_DATA_CONVERTER_INSTANCE 5 | from temporal.workflow import create_memo, create_search_attributes, WorkflowMethod, WorkflowClient, WorkflowOptions, \ 6 | create_start_workflow_request 7 | 8 | 9 | def test_create_memo(): 10 | memo = create_memo({ 11 | "name": "bob", 12 | "age": 20 13 | }, DEFAULT_DATA_CONVERTER_INSTANCE) 14 | assert "name" in memo.fields 15 | assert "age" in memo.fields 16 | assert memo.fields["name"].data == b'\"bob\"' 17 | assert memo.fields["age"].data == b'20' 18 | 19 | 20 | def test_create_search_attributes(): 21 | search_attributes = create_search_attributes({ 22 | "name": "bob", 23 | "age": 20 24 | }, DEFAULT_DATA_CONVERTER_INSTANCE) 25 | assert "name" in search_attributes.indexed_fields 26 | assert "age" in search_attributes.indexed_fields 27 | assert search_attributes.indexed_fields["name"].data == b'\"bob\"' 28 | assert search_attributes.indexed_fields["age"].data == b'20' 29 | 30 | 31 | def test_create_start_workflow_request_override_workflow_options(): 32 | client = WorkflowClient(None, "the-namespace", None, DEFAULT_DATA_CONVERTER_INSTANCE) 33 | wm = WorkflowMethod() 34 | options = WorkflowOptions(workflow_id="workflow-id", 35 | workflow_id_reuse_policy=WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE, 36 | workflow_run_timeout=timedelta(seconds=200), 37 | workflow_execution_timeout=timedelta(seconds=100), 38 | workflow_task_timeout=timedelta(seconds=60), 39 | task_queue="task-queue", 40 | cron_schedule="cron-schedule", 41 | memo={"name": "bob"}, search_attributes={"name": "alex"}) 42 | request = create_start_workflow_request(client, wm, [], options) 43 | assert request.workflow_id == "workflow-id" 44 | assert request.workflow_id_reuse_policy == WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE 45 | assert request.workflow_run_timeout == timedelta(seconds=200) 46 | assert request.workflow_execution_timeout == timedelta(seconds=100) 47 | assert request.workflow_task_timeout == timedelta(seconds=60) 48 | assert request.task_queue == "task-queue" 49 | assert request.cron_schedule == "cron-schedule" 50 | assert request.memo.fields["name"].data == b'"bob"' 51 | assert request.search_attributes.indexed_fields["name"].data == b'"alex"' 52 | -------------------------------------------------------------------------------- /temporal/converter.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Union, List, Iterable, Callable 3 | 4 | from temporal.api.common.v1 import Payload, Payloads 5 | from temporal.conversions import ENCODINGS, METADATA_ENCODING_KEY, DECODINGS 6 | 7 | 8 | def get_fn_args_type_hints(fn: Callable) -> List[type]: 9 | spec = inspect.getfullargspec(fn) 10 | args = [] 11 | for s in spec.args: 12 | args.append(spec.annotations.get(s)) 13 | if inspect.ismethod(fn): 14 | # Remove "self 15 | args.pop(0) 16 | return args 17 | 18 | 19 | def get_fn_ret_type_hints(fn: Callable) -> List[type]: 20 | spec = inspect.getfullargspec(fn) 21 | return [spec.annotations.get("return")] 22 | 23 | 24 | class DataConverter: 25 | def to_payload(self, arg: object) -> Payload: 26 | raise NotImplementedError() 27 | 28 | def from_payload(self, payload: Payload, type_hint: type = None) -> object: 29 | raise NotImplementedError() 30 | 31 | def to_payloads(self, args: Union[object, List[object]]) -> Payloads: 32 | payloads: Payloads = Payloads() 33 | payloads.payloads = [] 34 | if isinstance(args, (str, bytes)) or not isinstance(args, Iterable): 35 | args = [args] 36 | for arg in args: 37 | payloads.payloads.append(self.to_payload(arg)) 38 | return payloads 39 | 40 | def from_payloads(self, payloads: Payloads, type_hints: List[type] = []) -> List[object]: 41 | args: List[object] = [] 42 | if payloads is None: 43 | return [None] 44 | for i, payload in enumerate(payloads.payloads): 45 | try: 46 | type_hint = type_hints[i] 47 | except IndexError as ex: 48 | type_hint = None 49 | args.append(self.from_payload(payload, type_hint)) 50 | return args 51 | 52 | @staticmethod 53 | def get_default(): 54 | return DefaultDataConverter() 55 | 56 | 57 | class DefaultDataConverter(DataConverter): 58 | def to_payload(self, arg: object) -> Payload: 59 | for fn in ENCODINGS: 60 | payload = fn(arg) 61 | if payload is not None: 62 | return payload 63 | raise Exception(f"Object cannot be encoded: {arg}") 64 | 65 | def from_payload(self, payload: Payload, type_hint: type = None) -> object: 66 | encoding: bytes = payload.metadata[METADATA_ENCODING_KEY] 67 | decoding = DECODINGS.get(encoding) 68 | if not decoding: 69 | raise Exception(f"Unsupported encoding: {str(encoding, 'utf-8')}") 70 | return decoding(payload) 71 | 72 | 73 | DEFAULT_DATA_CONVERTER_INSTANCE = DefaultDataConverter() 74 | -------------------------------------------------------------------------------- /tests/test_activity_async_any_of.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import pytest 4 | from datetime import timedelta 5 | import asyncio 6 | 7 | from temporal.activity import Activity 8 | from temporal.async_activity import Async 9 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 10 | from temporal.activity_method import activity_method 11 | 12 | 13 | TASK_QUEUE = "test_activity_async_any_of_tq" 14 | NAMESPACE = "default" 15 | executed = False 16 | 17 | 18 | def greeting_activities_thread_func(task_token, sleep_seconds): 19 | async def fn(): 20 | client = WorkflowClient.new_client(namespace=NAMESPACE) 21 | activity_completion_client = client.new_activity_completion_client() 22 | print(f"Sleeping for {sleep_seconds} seconds") 23 | await asyncio.sleep(sleep_seconds) 24 | await activity_completion_client.complete(task_token, sleep_seconds) 25 | client.close() 26 | 27 | asyncio.run(fn()) 28 | 29 | 30 | class GreetingActivities: 31 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 32 | async def compose_greeting(self, arg) -> str: 33 | raise NotImplementedError 34 | 35 | 36 | class GreetingActivitiesImpl: 37 | 38 | async def compose_greeting(self, sleep_seconds): 39 | Activity.do_not_complete_on_return() 40 | thread = threading.Thread(target=greeting_activities_thread_func, args=(Activity.get_task_token(), sleep_seconds)) 41 | thread.start() 42 | 43 | 44 | class GreetingWorkflow: 45 | @workflow_method(task_queue=TASK_QUEUE) 46 | async def get_greeting(self) -> None: 47 | raise NotImplementedError 48 | 49 | 50 | class GreetingWorkflowImpl(GreetingWorkflow): 51 | 52 | def __init__(self): 53 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 54 | 55 | async def get_greeting(self): 56 | futures = [ 57 | Async.function(self.greeting_activities.compose_greeting, 10), 58 | Async.function(self.greeting_activities.compose_greeting, 20), 59 | Async.function(self.greeting_activities.compose_greeting, 30) 60 | ] 61 | done, pending = await Async.any_of(futures) 62 | return done[0].get_result() 63 | 64 | 65 | @pytest.mark.asyncio 66 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 67 | workflows=[GreetingWorkflowImpl]) 68 | async def test(worker): 69 | client = WorkflowClient.new_client(namespace=NAMESPACE) 70 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 71 | ret_value = await greeting_workflow.get_greeting() 72 | assert ret_value == 10 73 | 74 | -------------------------------------------------------------------------------- /temporal/api/common/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/common/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from datetime import timedelta 6 | from typing import Dict, List 7 | 8 | import betterproto 9 | 10 | from temporal.api.enums import v1 as v1enums 11 | 12 | 13 | @dataclass 14 | class DataBlob(betterproto.Message): 15 | encoding_type: v1enums.EncodingType = betterproto.enum_field(1) 16 | data: bytes = betterproto.bytes_field(2) 17 | 18 | 19 | @dataclass 20 | class Payloads(betterproto.Message): 21 | payloads: List["Payload"] = betterproto.message_field(1) 22 | 23 | 24 | @dataclass 25 | class Payload(betterproto.Message): 26 | metadata: Dict[str, bytes] = betterproto.map_field( 27 | 1, betterproto.TYPE_STRING, betterproto.TYPE_BYTES 28 | ) 29 | data: bytes = betterproto.bytes_field(2) 30 | 31 | 32 | @dataclass 33 | class SearchAttributes(betterproto.Message): 34 | indexed_fields: Dict[str, "Payload"] = betterproto.map_field( 35 | 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE 36 | ) 37 | 38 | 39 | @dataclass 40 | class Memo(betterproto.Message): 41 | fields: Dict[str, "Payload"] = betterproto.map_field( 42 | 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE 43 | ) 44 | 45 | 46 | @dataclass 47 | class Header(betterproto.Message): 48 | fields: Dict[str, "Payload"] = betterproto.map_field( 49 | 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE 50 | ) 51 | 52 | 53 | @dataclass 54 | class WorkflowExecution(betterproto.Message): 55 | workflow_id: str = betterproto.string_field(1) 56 | run_id: str = betterproto.string_field(2) 57 | 58 | 59 | @dataclass 60 | class WorkflowType(betterproto.Message): 61 | name: str = betterproto.string_field(1) 62 | 63 | 64 | @dataclass 65 | class ActivityType(betterproto.Message): 66 | name: str = betterproto.string_field(1) 67 | 68 | 69 | @dataclass 70 | class RetryPolicy(betterproto.Message): 71 | # Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is 72 | # used for all retries. 73 | initial_interval: timedelta = betterproto.message_field(1) 74 | # Coefficient used to calculate the next retry interval. The next retry 75 | # interval is previous interval multiplied by the coefficient. Must be 1 or 76 | # larger. 77 | backoff_coefficient: float = betterproto.double_field(2) 78 | # Maximum interval between retries. Exponential backoff leads to interval 79 | # increase. This value is the cap of the increase. Default is 100x of the 80 | # initial interval. 81 | maximum_interval: timedelta = betterproto.message_field(3) 82 | # Maximum number of attempts. When exceeded the retries stop even if not 83 | # expired yet. 1 disables retries. 0 means unlimited (up to the timeouts) 84 | maximum_attempts: int = betterproto.int32_field(4) 85 | # Non-Retryable errors types. Will stop retrying if error type matches this 86 | # list. 87 | non_retryable_error_types: List[str] = betterproto.string_field(5) 88 | -------------------------------------------------------------------------------- /tests/test_workflow_get_version_with_update.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from temporal import DEFAULT_VERSION 6 | from temporal.workerfactory import WorkerFactory 7 | from temporal.workflow import workflow_method, Workflow, WorkflowClient 8 | from tests import cleanup_worker 9 | 10 | TASK_QUEUE = "test_get_version_with_update_tq" 11 | NAMESPACE = "default" 12 | 13 | v1_hits = 0 14 | v2_hits = 0 15 | 16 | version_found_in_v2_step_1_0 = None 17 | version_found_in_v2_step_1_1 = None 18 | version_found_in_v2_step_2_0 = None 19 | version_found_in_v2_step_2_1 = None 20 | 21 | v2_done = False 22 | 23 | 24 | class TestWorkflowGetVersion: 25 | @workflow_method(task_queue=TASK_QUEUE) 26 | async def get_greetings(self) -> list: 27 | raise NotImplementedError 28 | 29 | 30 | class TestWorkflowGetVersionImplV1(TestWorkflowGetVersion): 31 | 32 | def __init__(self): 33 | pass 34 | 35 | async def get_greetings(self): 36 | global v1_hits 37 | v1_hits += 1 38 | await Workflow.sleep(60) 39 | 40 | 41 | class TestWorkflowGetVersionImplV2(TestWorkflowGetVersion): 42 | 43 | def __init__(self): 44 | pass 45 | 46 | async def get_greetings(self): 47 | global v2_hits 48 | global version_found_in_v2_step_1_0, version_found_in_v2_step_1_1 49 | global version_found_in_v2_step_2_0, version_found_in_v2_step_2_1 50 | global v2_done 51 | v2_hits += 1 52 | 53 | version_found_in_v2_step_1_0 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 54 | version_found_in_v2_step_1_1 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 55 | await Workflow.sleep(60) 56 | version_found_in_v2_step_2_0 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 57 | version_found_in_v2_step_2_1 = Workflow.get_version("first-item", DEFAULT_VERSION, 2) 58 | v2_done = True 59 | 60 | 61 | @pytest.mark.asyncio 62 | async def test_workflow_workflow_get_version(): 63 | global v1_hits, v2_hits 64 | client: WorkflowClient = WorkflowClient.new_client("localhost", 7233, namespace=NAMESPACE) 65 | factory = WorkerFactory(client, NAMESPACE) 66 | worker = factory.new_worker(TASK_QUEUE) 67 | worker.register_workflow_implementation_type(TestWorkflowGetVersionImplV1) 68 | factory.start() 69 | 70 | workflow: TestWorkflowGetVersion = client.new_workflow_stub(TestWorkflowGetVersion) 71 | 72 | await client.start(workflow.get_greetings) 73 | while v1_hits == 0: 74 | print(".", end="") 75 | await asyncio.sleep(2) 76 | 77 | worker.register_workflow_implementation_type(TestWorkflowGetVersionImplV2) 78 | 79 | while not v2_done: 80 | print(".", end="") 81 | await asyncio.sleep(2) 82 | 83 | assert v1_hits == 1 84 | assert v2_hits == 1 85 | assert version_found_in_v2_step_1_0 == DEFAULT_VERSION 86 | assert version_found_in_v2_step_1_1 == DEFAULT_VERSION 87 | assert version_found_in_v2_step_2_0 == DEFAULT_VERSION 88 | assert version_found_in_v2_step_2_1 == DEFAULT_VERSION 89 | 90 | # TODO: Assert that there are no markers recorded 91 | 92 | await cleanup_worker(client, worker) 93 | -------------------------------------------------------------------------------- /tests/test_do_not_complete_on_return_complete_exceptionally.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | from datetime import timedelta 4 | from time import sleep 5 | 6 | import pytest 7 | 8 | from temporal.activity import Activity 9 | from temporal.exceptions import ActivityFailureException 10 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 11 | from temporal.activity_method import activity_method, RetryParameters 12 | 13 | TASK_QUEUE = "test_async_do_not_complete_on_return_complete_exceptionally_tq" 14 | NAMESPACE = "default" 15 | caught_exception = None 16 | 17 | 18 | class GreetingException(Exception): 19 | 20 | def __init__(self, *args: object) -> None: 21 | super().__init__(*args) 22 | 23 | 24 | class GreetingActivities: 25 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 26 | async def compose_greeting(self) -> str: 27 | raise NotImplementedError 28 | 29 | 30 | def greeting_activities_thread_func(task_token): 31 | async def fn(): 32 | client = WorkflowClient.new_client(namespace=NAMESPACE) 33 | activity_completion_client = client.new_activity_completion_client() 34 | sleep(10) 35 | await activity_completion_client.complete_exceptionally(task_token, GreetingException("greeting error!")) 36 | client.close() 37 | 38 | asyncio.run(fn()) 39 | 40 | 41 | class GreetingActivitiesImpl: 42 | 43 | async def compose_greeting(self): 44 | Activity.do_not_complete_on_return() 45 | thread = threading.Thread(target=greeting_activities_thread_func, args=(Activity.get_task_token(),)) 46 | thread.start() 47 | 48 | 49 | class GreetingWorkflow: 50 | @workflow_method(task_queue=TASK_QUEUE) 51 | async def get_greeting(self) -> None: 52 | raise NotImplementedError 53 | 54 | 55 | class GreetingWorkflowImpl(GreetingWorkflow): 56 | 57 | def __init__(self): 58 | retry_parameters = RetryParameters(maximum_attempts=1) 59 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities, 60 | retry_parameters=retry_parameters) 61 | 62 | async def get_greeting(self): 63 | try: 64 | return await self.greeting_activities.compose_greeting() 65 | except Exception as ex: 66 | global caught_exception 67 | caught_exception = ex 68 | 69 | 70 | @pytest.mark.asyncio 71 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 72 | workflows=[GreetingWorkflowImpl]) 73 | async def test(worker): 74 | client = WorkflowClient.new_client(namespace=NAMESPACE) 75 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 76 | await greeting_workflow.get_greeting() 77 | assert caught_exception is not None 78 | assert isinstance(caught_exception, ActivityFailureException) 79 | cause = caught_exception.get_cause() 80 | assert isinstance(cause, GreetingException) 81 | assert cause.__traceback__ is not None 82 | assert cause.args == ("greeting error!",) 83 | 84 | -------------------------------------------------------------------------------- /temporal/conversions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import List, Optional, Union, Iterable 4 | 5 | from temporal.api.common.v1 import Payload, Payloads 6 | 7 | 8 | def camel_to_snake(name): 9 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 10 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 11 | 12 | 13 | def snake_to_camel(snake_str): 14 | components = snake_str.split('_') 15 | # We capitalize the first letter of each component except the first one 16 | # with the 'title' method and join them together. 17 | return components[0] + ''.join(x.title() for x in components[1:]) 18 | 19 | 20 | def snake_to_title(snake_str): 21 | components = snake_str.split('_') 22 | return ''.join(x.title() for x in components) 23 | 24 | 25 | METADATA_ENCODING_KEY = "encoding" 26 | 27 | METADATA_ENCODING_NULL_NAME = "binary/null" 28 | METADATA_ENCODING_NULL = METADATA_ENCODING_NULL_NAME.encode("utf-8") 29 | METADATA_ENCODING_RAW_NAME = "binary/plain" 30 | METADATA_ENCODING_RAW = METADATA_ENCODING_RAW_NAME.encode("utf-8") 31 | METADATA_ENCODING_JSON_NAME = "json/plain" 32 | METADATA_ENCODING_JSON = METADATA_ENCODING_JSON_NAME.encode("utf-8") 33 | 34 | # TODO: Implement encode/decode for these: 35 | METADATA_ENCODING_PROTOBUF_JSON_NAME = "json/protobuf" 36 | METADATA_ENCODING_PROTOBUF_JSON = METADATA_ENCODING_PROTOBUF_JSON_NAME.encode("utf-8") 37 | METADATA_ENCODING_PROTOBUF_NAME = "binary/protobuf" 38 | METADATA_ENCODING_PROTOBUF = METADATA_ENCODING_PROTOBUF_NAME.encode('utf-8') 39 | 40 | 41 | def encode_null(value: object) -> Optional[Payload]: 42 | if value is None: 43 | p: Payload = Payload() 44 | p.metadata = {METADATA_ENCODING_KEY: METADATA_ENCODING_NULL} 45 | p.data = bytes() 46 | return p 47 | else: 48 | return None 49 | 50 | 51 | # noinspection PyUnusedLocal 52 | def decode_null(payload: Payload) -> object: 53 | return None 54 | 55 | 56 | def encode_binary(value: object) -> Optional[Payload]: 57 | if isinstance(value, bytes): 58 | p: Payload = Payload() 59 | p.metadata = {METADATA_ENCODING_KEY: METADATA_ENCODING_RAW} 60 | p.data = value 61 | return p 62 | else: 63 | return None 64 | 65 | 66 | def decode_binary(payload: Payload) -> object: 67 | return payload.data 68 | 69 | 70 | def encode_json_string(value: object) -> Payload: 71 | # TODO: 72 | # mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); 73 | # mapper.registerModule(new JavaTimeModule()); 74 | # mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); 75 | p: Payload = Payload() 76 | p.metadata = {METADATA_ENCODING_KEY: METADATA_ENCODING_JSON} 77 | p.data = json.dumps(value).encode("utf-8") 78 | return p 79 | 80 | 81 | def decode_json_string(payload: Payload) -> object: 82 | # TODO: 83 | # mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); 84 | # mapper.registerModule(new JavaTimeModule()); 85 | # mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); 86 | b = str(payload.data, "utf-8") 87 | return json.loads(b) 88 | 89 | 90 | ENCODINGS = [ 91 | encode_null, 92 | encode_binary, 93 | encode_json_string 94 | ] 95 | 96 | 97 | DECODINGS = { 98 | METADATA_ENCODING_NULL: decode_null, 99 | METADATA_ENCODING_RAW: decode_binary, 100 | METADATA_ENCODING_JSON: decode_json_string 101 | } 102 | 103 | -------------------------------------------------------------------------------- /temporal/exception_handling.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import traceback 4 | from typing import Optional 5 | 6 | import tblib # type: ignore 7 | 8 | from temporal.api.failure.v1 import Failure, ApplicationFailureInfo 9 | 10 | THIS_SOURCE = "PythonSDK" 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ExternalException(Exception): 16 | def __init__(self, details): 17 | super().__init__(details) 18 | 19 | @property 20 | def details(self): 21 | return self.args[0] 22 | 23 | 24 | def exception_class_fqn(o): 25 | # Copied from: https://stackoverflow.com/a/2020083 26 | module = o.__class__.__module__ 27 | # if module is None or module == str.__class__.__module__: 28 | # return o.__class__.__name__ # Avoid reporting __builtin__ 29 | # else: 30 | return module + '.' + o.__class__.__name__ 31 | 32 | 33 | def import_class_from_string(path): 34 | # Taken from: https://stackoverflow.com/a/30042585 35 | from importlib import import_module 36 | module_path, _, class_name = path.rpartition('.') 37 | mod = import_module(module_path) 38 | klass = getattr(mod, class_name) 39 | return klass 40 | 41 | 42 | def serialize_exception(ex: BaseException) -> Failure: 43 | failure = Failure() 44 | failure.message = str(ex) 45 | failure.source = THIS_SOURCE 46 | # TODO: support chaining? 47 | # failure.cause = ??? 48 | exception_cls_name: str = exception_class_fqn(ex) 49 | failure.application_failure_info = ApplicationFailureInfo() 50 | failure.application_failure_info.type = exception_class_fqn(ex) 51 | tb = "".join(traceback.format_exception(type(ex), ex, ex.__traceback__)) 52 | failure.stack_trace = json.dumps({ 53 | "class": exception_cls_name, 54 | "args": ex.args, 55 | "traceback": tb, 56 | }) 57 | return failure 58 | 59 | 60 | """ 61 | TODO: Need unit testing for this 62 | """ 63 | def deserialize_exception(details: Failure) -> Exception: 64 | """ 65 | TODO: Support built-in types like Exception 66 | """ 67 | exception: Optional[Exception] = None 68 | source = details.source 69 | exception_cls_name: str = details.application_failure_info.type 70 | 71 | if source == THIS_SOURCE and exception_cls_name: 72 | details_dict = json.loads(details.stack_trace) 73 | try: 74 | klass = import_class_from_string(exception_cls_name) 75 | exception = klass(*details_dict["args"]) 76 | traceback_string: str = details_dict["traceback"] 77 | # when complete_exceptionally() is invoked, the Exception has no 78 | # traceback, so don't try to deserialize the traceback if there is none 79 | if len(traceback_string.split("\n")) > 2: 80 | t = tblib.Traceback.from_string(traceback_string) 81 | assert exception is not None 82 | exception.with_traceback(t.as_traceback()) 83 | except Exception as e: 84 | exception = None 85 | logger.error("Failed to deserialize exception (details=%s) cause=%r", details_dict, e) 86 | 87 | if not exception: 88 | # TODO: Better to deserialize details 89 | return ExternalException(details) 90 | else: 91 | return exception 92 | 93 | 94 | def failure_to_str(f: Failure) -> str: 95 | return f.to_json() 96 | 97 | 98 | def str_to_failure(s: str) -> Failure: 99 | f = Failure() 100 | return f.from_json(s) 101 | -------------------------------------------------------------------------------- /temporal/api/failure/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/failure/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | 6 | import betterproto 7 | 8 | from temporal.api.common import v1 as v1common 9 | from temporal.api.enums import v1 as v1enums 10 | 11 | 12 | @dataclass 13 | class ApplicationFailureInfo(betterproto.Message): 14 | type: str = betterproto.string_field(1) 15 | non_retryable: bool = betterproto.bool_field(2) 16 | details: v1common.Payloads = betterproto.message_field(3) 17 | 18 | 19 | @dataclass 20 | class TimeoutFailureInfo(betterproto.Message): 21 | timeout_type: v1enums.TimeoutType = betterproto.enum_field(1) 22 | last_heartbeat_details: v1common.Payloads = betterproto.message_field(2) 23 | 24 | 25 | @dataclass 26 | class CanceledFailureInfo(betterproto.Message): 27 | details: v1common.Payloads = betterproto.message_field(1) 28 | 29 | 30 | @dataclass 31 | class TerminatedFailureInfo(betterproto.Message): 32 | pass 33 | 34 | 35 | @dataclass 36 | class ServerFailureInfo(betterproto.Message): 37 | non_retryable: bool = betterproto.bool_field(1) 38 | 39 | 40 | @dataclass 41 | class ResetWorkflowFailureInfo(betterproto.Message): 42 | last_heartbeat_details: v1common.Payloads = betterproto.message_field(1) 43 | 44 | 45 | @dataclass 46 | class ActivityFailureInfo(betterproto.Message): 47 | scheduled_event_id: int = betterproto.int64_field(1) 48 | started_event_id: int = betterproto.int64_field(2) 49 | identity: str = betterproto.string_field(3) 50 | activity_type: v1common.ActivityType = betterproto.message_field(4) 51 | activity_id: str = betterproto.string_field(5) 52 | retry_state: v1enums.RetryState = betterproto.enum_field(6) 53 | 54 | 55 | @dataclass 56 | class ChildWorkflowExecutionFailureInfo(betterproto.Message): 57 | namespace: str = betterproto.string_field(1) 58 | workflow_execution: v1common.WorkflowExecution = betterproto.message_field(2) 59 | workflow_type: v1common.WorkflowType = betterproto.message_field(3) 60 | initiated_event_id: int = betterproto.int64_field(4) 61 | started_event_id: int = betterproto.int64_field(5) 62 | retry_state: v1enums.RetryState = betterproto.enum_field(6) 63 | 64 | 65 | @dataclass 66 | class Failure(betterproto.Message): 67 | message: str = betterproto.string_field(1) 68 | source: str = betterproto.string_field(2) 69 | stack_trace: str = betterproto.string_field(3) 70 | cause: "Failure" = betterproto.message_field(4) 71 | application_failure_info: "ApplicationFailureInfo" = betterproto.message_field( 72 | 5, group="failure_info" 73 | ) 74 | timeout_failure_info: "TimeoutFailureInfo" = betterproto.message_field( 75 | 6, group="failure_info" 76 | ) 77 | canceled_failure_info: "CanceledFailureInfo" = betterproto.message_field( 78 | 7, group="failure_info" 79 | ) 80 | terminated_failure_info: "TerminatedFailureInfo" = betterproto.message_field( 81 | 8, group="failure_info" 82 | ) 83 | server_failure_info: "ServerFailureInfo" = betterproto.message_field( 84 | 9, group="failure_info" 85 | ) 86 | reset_workflow_failure_info: "ResetWorkflowFailureInfo" = betterproto.message_field( 87 | 10, group="failure_info" 88 | ) 89 | activity_failure_info: "ActivityFailureInfo" = betterproto.message_field( 90 | 11, group="failure_info" 91 | ) 92 | child_workflow_execution_failure_info: "ChildWorkflowExecutionFailureInfo" = betterproto.message_field( 93 | 12, group="failure_info" 94 | ) 95 | -------------------------------------------------------------------------------- /temporal/errors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class BadRequestError(Exception): 9 | message: str 10 | 11 | def __str__(self): 12 | return self.message 13 | 14 | 15 | @dataclass 16 | class InternalServiceError(Exception): 17 | message: str 18 | 19 | 20 | @dataclass 21 | class NamespaceAlreadyExistsError(Exception): 22 | message: str 23 | 24 | 25 | @dataclass 26 | class WorkflowExecutionAlreadyStartedError(Exception): 27 | message: Optional[str] 28 | startRequestId: Optional[str] 29 | runId: Optional[str] 30 | 31 | @property 32 | def start_request_id(self): 33 | return self.startRequestId 34 | 35 | @property 36 | def run_id(self): 37 | return self.runId 38 | 39 | 40 | @dataclass 41 | class EntityNotExistsError(Exception): 42 | message: str 43 | 44 | def __str__(self): 45 | return self.message 46 | 47 | 48 | @dataclass 49 | class ServiceBusyError(Exception): 50 | message: str 51 | 52 | 53 | @dataclass 54 | class CancellationAlreadyRequestedError(Exception): 55 | message: str 56 | 57 | 58 | @dataclass 59 | class QueryFailedError(Exception): 60 | message: str 61 | 62 | 63 | @dataclass 64 | class NamespaceNotActiveError(Exception): 65 | message: str 66 | namespaceName: str 67 | currentCluster: str 68 | activeCluster: str 69 | 70 | @property 71 | def namespace_name(self): 72 | return self.namespaceName 73 | 74 | @property 75 | def current_cluster(self): 76 | return self.currentCluster 77 | 78 | @property 79 | def active_cluster(self): 80 | return self.activeCluster 81 | 82 | 83 | @dataclass 84 | class LimitExceededError(Exception): 85 | message: str 86 | 87 | 88 | @dataclass 89 | class AccessDeniedError(Exception): 90 | message: str 91 | 92 | 93 | @dataclass 94 | class RetryTaskError(Exception): 95 | message: str 96 | namespace_id: str 97 | workflow_id: str 98 | run_id: str 99 | next_event_id: int 100 | 101 | 102 | @dataclass 103 | class ClientVersionNotSupportedError(Exception): 104 | feature_version: str 105 | client_impl: str 106 | supported_versions: str 107 | 108 | 109 | CADENCE_ERROR_FIELDS = { 110 | "badRequestError": BadRequestError, 111 | "internalServiceError": InternalServiceError, 112 | "namespaceExistsError": NamespaceAlreadyExistsError, 113 | "sessionAlreadyExistError": WorkflowExecutionAlreadyStartedError, 114 | "entityNotExistError": EntityNotExistsError, 115 | "serviceBusyError": ServiceBusyError, 116 | "cancellationAlreadyRequestedError": CancellationAlreadyRequestedError, 117 | "queryFailedError": QueryFailedError, 118 | "namespaceNotActiveError": NamespaceNotActiveError, 119 | "limitExceededError": LimitExceededError, 120 | "workflowAlreadyStartedError": WorkflowExecutionAlreadyStartedError, 121 | "clientVersionNotSupportedError": ClientVersionNotSupportedError 122 | } 123 | 124 | IGNORE_FIELDS_IN_ERRORS = ("args", "type_spec", "from_primitive", "to_primitive", "with_traceback") 125 | 126 | 127 | def find_error(response): 128 | for key, cls in CADENCE_ERROR_FIELDS.items(): 129 | error = getattr(response, key, None) 130 | if error: 131 | kwargs = {} 132 | for field in dir(error): 133 | if field not in IGNORE_FIELDS_IN_ERRORS and not field.startswith("__"): 134 | kwargs[field] = getattr(error, field) 135 | return cls(**kwargs) 136 | return None 137 | -------------------------------------------------------------------------------- /tests/test_activity_method_activity_options.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | from datetime import timedelta 5 | 6 | from temporal.api.common.v1 import WorkflowExecution 7 | from temporal.api.enums.v1 import EventType 8 | from temporal.api.workflowservice.v1 import GetWorkflowExecutionHistoryRequest 9 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 10 | from temporal.activity_method import activity_method, RetryParameters 11 | 12 | TASK_QUEUE = "test_activity_method_activity_options" 13 | NAMESPACE = "default" 14 | workflow_id = "test_activity_method_activity_options-" + str(uuid.uuid4()) 15 | 16 | 17 | class GreetingActivities: 18 | @activity_method(task_queue=TASK_QUEUE, 19 | schedule_to_close_timeout=timedelta(seconds=1000), 20 | schedule_to_start_timeout=timedelta(seconds=500), 21 | start_to_close_timeout=timedelta(seconds=800), 22 | heartbeat_timeout=timedelta(seconds=600), 23 | retry_parameters=RetryParameters( 24 | initial_interval=timedelta(seconds=70), 25 | backoff_coefficient=5.0, 26 | maximum_interval=timedelta(seconds=700), 27 | maximum_attempts=8, 28 | non_retryable_error_types=["DummyError"] 29 | )) 30 | async def compose_greeting(self) -> str: 31 | raise NotImplementedError 32 | 33 | 34 | class GreetingActivitiesImpl: 35 | 36 | async def compose_greeting(self): 37 | pass 38 | 39 | 40 | class GreetingWorkflow: 41 | @workflow_method(task_queue=TASK_QUEUE, workflow_id=workflow_id) 42 | async def get_greeting(self) -> None: 43 | raise NotImplementedError 44 | 45 | 46 | class GreetingWorkflowImpl(GreetingWorkflow): 47 | 48 | def __init__(self): 49 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 50 | 51 | async def get_greeting(self): 52 | await self.greeting_activities.compose_greeting() 53 | 54 | 55 | @pytest.mark.asyncio 56 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 57 | workflows=[GreetingWorkflowImpl]) 58 | async def test(worker): 59 | client = WorkflowClient.new_client(namespace=NAMESPACE) 60 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 61 | await greeting_workflow.get_greeting() 62 | request = GetWorkflowExecutionHistoryRequest(namespace=NAMESPACE, 63 | execution=WorkflowExecution(workflow_id=workflow_id)) 64 | response = await client.service.get_workflow_execution_history(request=request) 65 | e = next(filter(lambda v: v.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED, response.history.events)) 66 | assert e.activity_task_scheduled_event_attributes.schedule_to_close_timeout == timedelta(seconds=1000) 67 | assert e.activity_task_scheduled_event_attributes.schedule_to_start_timeout == timedelta(seconds=500) 68 | assert e.activity_task_scheduled_event_attributes.start_to_close_timeout == timedelta(seconds=800) 69 | assert e.activity_task_scheduled_event_attributes.heartbeat_timeout == timedelta(seconds=600) 70 | assert e.activity_task_scheduled_event_attributes.retry_policy.initial_interval == timedelta(seconds=70) 71 | assert e.activity_task_scheduled_event_attributes.retry_policy.backoff_coefficient == 5.0 72 | assert e.activity_task_scheduled_event_attributes.retry_policy.maximum_interval == timedelta(seconds=700) 73 | assert e.activity_task_scheduled_event_attributes.retry_policy.maximum_attempts == 8 74 | assert e.activity_task_scheduled_event_attributes.retry_policy.non_retryable_error_types == ["DummyError"] 75 | -------------------------------------------------------------------------------- /temporal/api/workflow/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/workflow/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from datetime import datetime, timedelta 6 | from typing import List 7 | 8 | import betterproto 9 | 10 | from temporal.api.common import v1 as v1common 11 | from temporal.api.enums import v1 as v1enums 12 | from temporal.api.failure import v1 as v1failure 13 | from temporal.api.taskqueue import v1 as v1taskqueue 14 | 15 | 16 | @dataclass 17 | class WorkflowExecutionInfo(betterproto.Message): 18 | execution: v1common.WorkflowExecution = betterproto.message_field(1) 19 | type: v1common.WorkflowType = betterproto.message_field(2) 20 | start_time: datetime = betterproto.message_field(3) 21 | close_time: datetime = betterproto.message_field(4) 22 | status: v1enums.WorkflowExecutionStatus = betterproto.enum_field(5) 23 | history_length: int = betterproto.int64_field(6) 24 | parent_namespace_id: str = betterproto.string_field(7) 25 | parent_execution: v1common.WorkflowExecution = betterproto.message_field(8) 26 | execution_time: datetime = betterproto.message_field(9) 27 | memo: v1common.Memo = betterproto.message_field(10) 28 | search_attributes: v1common.SearchAttributes = betterproto.message_field(11) 29 | auto_reset_points: "ResetPoints" = betterproto.message_field(12) 30 | task_queue: str = betterproto.string_field(13) 31 | 32 | 33 | @dataclass 34 | class WorkflowExecutionConfig(betterproto.Message): 35 | task_queue: v1taskqueue.TaskQueue = betterproto.message_field(1) 36 | workflow_execution_timeout: timedelta = betterproto.message_field(2) 37 | workflow_run_timeout: timedelta = betterproto.message_field(3) 38 | default_workflow_task_timeout: timedelta = betterproto.message_field(4) 39 | 40 | 41 | @dataclass 42 | class PendingActivityInfo(betterproto.Message): 43 | activity_id: str = betterproto.string_field(1) 44 | activity_type: v1common.ActivityType = betterproto.message_field(2) 45 | state: v1enums.PendingActivityState = betterproto.enum_field(3) 46 | heartbeat_details: v1common.Payloads = betterproto.message_field(4) 47 | last_heartbeat_time: datetime = betterproto.message_field(5) 48 | last_started_time: datetime = betterproto.message_field(6) 49 | attempt: int = betterproto.int32_field(7) 50 | maximum_attempts: int = betterproto.int32_field(8) 51 | scheduled_time: datetime = betterproto.message_field(9) 52 | expiration_time: datetime = betterproto.message_field(10) 53 | last_failure: v1failure.Failure = betterproto.message_field(11) 54 | last_worker_identity: str = betterproto.string_field(12) 55 | 56 | 57 | @dataclass 58 | class PendingChildExecutionInfo(betterproto.Message): 59 | workflow_id: str = betterproto.string_field(1) 60 | run_id: str = betterproto.string_field(2) 61 | workflow_type_name: str = betterproto.string_field(3) 62 | initiated_id: int = betterproto.int64_field(4) 63 | # Default: PARENT_CLOSE_POLICY_TERMINATE. 64 | parent_close_policy: v1enums.ParentClosePolicy = betterproto.enum_field(5) 65 | 66 | 67 | @dataclass 68 | class ResetPoints(betterproto.Message): 69 | points: List["ResetPointInfo"] = betterproto.message_field(1) 70 | 71 | 72 | @dataclass 73 | class ResetPointInfo(betterproto.Message): 74 | binary_checksum: str = betterproto.string_field(1) 75 | run_id: str = betterproto.string_field(2) 76 | first_workflow_task_completed_id: int = betterproto.int64_field(3) 77 | create_time: datetime = betterproto.message_field(4) 78 | # (-- api-linter: core::0214::resource-expiry=disabled aip.dev/not- 79 | # precedent: TTL is not defined for ResetPointInfo. --) The time that the run 80 | # is deleted due to retention. 81 | expire_time: datetime = betterproto.message_field(5) 82 | # false if the reset point has pending childWFs/reqCancels/signalExternals. 83 | resettable: bool = betterproto.bool_field(6) 84 | -------------------------------------------------------------------------------- /temporal/exceptions.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from temporal.api.common.v1 import WorkflowExecution 4 | from temporal.api.enums.v1 import TimeoutType, WorkflowExecutionStatus 5 | from temporal.api.failure.v1 import Failure 6 | from temporal.exception_handling import deserialize_exception, str_to_failure 7 | 8 | 9 | class IllegalStateException(BaseException): 10 | pass 11 | 12 | 13 | class IllegalArgumentException(BaseException): 14 | pass 15 | 16 | 17 | class WorkflowTypeNotFound(Exception): 18 | pass 19 | 20 | 21 | class NonDeterministicWorkflowException(BaseException): 22 | pass 23 | 24 | 25 | class ActivityTaskFailedException(Exception): 26 | 27 | def __init__(self, reason: str, cause: Exception) -> None: 28 | super().__init__(reason) 29 | self.reason = reason 30 | self.cause = cause 31 | 32 | 33 | class ActivityTaskTimeoutException(Exception): 34 | 35 | def __init__(self, event_id: int = None, timeout_type: TimeoutType = None, details: bytes = None, *args: object) -> None: 36 | super().__init__(*args) 37 | self.details = details 38 | self.timeout_type = timeout_type 39 | self.event_id = event_id 40 | 41 | 42 | class SignalNotFound(Exception): 43 | pass 44 | 45 | 46 | class QueryNotFound(Exception): 47 | pass 48 | 49 | 50 | class QueryDidNotComplete(Exception): 51 | pass 52 | 53 | class CancellationException(Exception): 54 | 55 | def __init__(self, *args, **kwargs): 56 | super().__init__(*args, **kwargs) 57 | self.cause = None 58 | 59 | def init_cause(self, cause): 60 | self.cause = cause 61 | 62 | 63 | class ActivityCancelledException(Exception): 64 | pass 65 | 66 | 67 | class WorkflowOperationException(Exception): 68 | def __init__(self, event_id: int): 69 | self.event_id = event_id 70 | 71 | 72 | class ActivityException(WorkflowOperationException): 73 | def __init__(self, event_id: int, activity_type: str, activity_id: str): 74 | super().__init__(event_id=event_id) 75 | self.activity_type = activity_type 76 | self.activity_id = activity_id 77 | 78 | 79 | class ActivityFailureException(ActivityException): 80 | """ 81 | cause should be the result of failure_to_str() 82 | Note: Don't change cause to be of type Failure, it needs to be an "str" to make it easier to handle serialization 83 | of Exceptions. 84 | """ 85 | def __init__(self, event_id: int, activity_type: str, activity_id: str, cause: str): 86 | super().__init__(event_id, activity_type, activity_id) 87 | self.cause: str = cause 88 | self.attempt: int = None 89 | self.backoff: int = 0 90 | 91 | # def set_cause(self): 92 | # if self.cause: 93 | # cause_ex = deserialize_exception(self.cause) 94 | # self.__cause__ = cause_ex 95 | 96 | def get_cause(self): 97 | if self.cause: 98 | f: Failure = str_to_failure(self.cause) 99 | return deserialize_exception(f) 100 | else: 101 | return None 102 | 103 | 104 | @dataclass 105 | class WorkflowException(Exception): 106 | workflow_type: str = None 107 | execution: WorkflowExecution = None 108 | 109 | def __str__(self): 110 | return f'{type(self).__name__}: WorkflowType="{self.workflow_type}", ' \ 111 | f'WorkflowID="{self.execution.workflow_id}", RunID="{self.execution.run_id} ' 112 | 113 | 114 | @dataclass 115 | class WorkflowFailureException(WorkflowException): 116 | decision_task_completed_event_id: int = None 117 | 118 | 119 | @dataclass 120 | class QueryFailureException(Exception): 121 | query_type: str = None 122 | execution: WorkflowExecution = None 123 | 124 | def __str__(self): 125 | return f'{type(self).__name__}: QueryType="{self.query_type}", ' \ 126 | f'WorkflowID="{self.execution.workflow_id}", RunID="{self.execution.run_id} ' 127 | 128 | 129 | 130 | class QueryRejectedException(Exception): 131 | close_status: WorkflowExecutionStatus 132 | 133 | def __init__(self, close_status: WorkflowExecutionStatus): 134 | self.close_status = close_status 135 | -------------------------------------------------------------------------------- /tests/test_typed_data_converter.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from dataclasses import dataclass 3 | 4 | import pytest 5 | from datetime import timedelta 6 | 7 | from temporal.api.common.v1 import Payload 8 | from temporal.conversions import METADATA_ENCODING_KEY 9 | from temporal.converter import DataConverter 10 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 11 | from temporal.activity_method import activity_method 12 | 13 | TASK_QUEUE = "test_typed_data_converter" 14 | NAMESPACE = "default" 15 | 16 | 17 | @dataclass 18 | class ActivityArgType1: 19 | pass 20 | 21 | 22 | @dataclass 23 | class ActivityArgType2: 24 | pass 25 | 26 | 27 | @dataclass 28 | class ActivityRetType: 29 | pass 30 | 31 | 32 | @dataclass 33 | class WorkflowArgType1: 34 | pass 35 | 36 | 37 | @dataclass 38 | class WorkflowArgType2: 39 | pass 40 | 41 | 42 | @dataclass 43 | class WorkflowRetType: 44 | pass 45 | 46 | 47 | ACTIVITY_ARG_TYPE_1 = ActivityArgType1() 48 | ACTIVITY_ARG_TYPE_2 = ActivityArgType2() 49 | ACTIVITY_RET_TYPE = ActivityRetType() 50 | WORKFLOW_ARG_TYPE_1 = WorkflowArgType1() 51 | WORKFLOW_ARG_TYPE_2 = WorkflowArgType2() 52 | WORKFLOW_RET_TYPE = WorkflowRetType() 53 | 54 | 55 | class GreetingActivities: 56 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 57 | async def compose_greeting(self, arg1: ActivityArgType1, arg2: ActivityArgType2) -> ActivityRetType: 58 | raise NotImplementedError 59 | 60 | 61 | class GreetingActivitiesImpl: 62 | 63 | async def compose_greeting(self, arg1: ActivityArgType1, arg2: ActivityArgType2) -> ActivityRetType: 64 | return ACTIVITY_RET_TYPE 65 | 66 | 67 | class GreetingWorkflow: 68 | @workflow_method(task_queue=TASK_QUEUE) 69 | async def get_greeting(self, arg1: WorkflowArgType1, arg2: WorkflowArgType2) -> WorkflowRetType: 70 | raise NotImplementedError 71 | 72 | 73 | class GreetingWorkflowImpl(GreetingWorkflow): 74 | 75 | def __init__(self): 76 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 77 | 78 | async def get_greeting(self, arg1: WorkflowArgType1, arg2: WorkflowArgType2) -> WorkflowRetType: 79 | ret_value: ActivityRetType = await self.greeting_activities.compose_greeting(ACTIVITY_ARG_TYPE_1, 80 | ACTIVITY_ARG_TYPE_2) 81 | return WORKFLOW_RET_TYPE 82 | 83 | 84 | deserialized = [] 85 | 86 | 87 | class PickleDataConverter(DataConverter): 88 | 89 | def to_payload(self, arg: object) -> Payload: 90 | payload = Payload() 91 | payload.metadata = {METADATA_ENCODING_KEY: b"PYTHON_PICKLE"} 92 | payload.data = pickle.dumps(arg) 93 | return payload 94 | 95 | def from_payload(self, payload: Payload, type_hint: type = None) -> object: 96 | obj = pickle.loads(payload.data) 97 | deserialized.append((type_hint, obj)) 98 | return obj 99 | 100 | 101 | @pytest.mark.asyncio 102 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 103 | workflows=[GreetingWorkflowImpl], data_converter=PickleDataConverter()) 104 | async def test(worker): 105 | client = WorkflowClient.new_client(namespace=NAMESPACE, data_converter=PickleDataConverter()) 106 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 107 | await greeting_workflow.get_greeting(WORKFLOW_ARG_TYPE_1, WORKFLOW_ARG_TYPE_2) 108 | for t, obj in deserialized: 109 | assert isinstance(obj, t) 110 | assert deserialized == [ 111 | # Invoking workflow 112 | (WorkflowArgType1, WORKFLOW_ARG_TYPE_1), 113 | (WorkflowArgType2, WORKFLOW_ARG_TYPE_2), 114 | # Invoking activity 115 | (ActivityArgType1, ACTIVITY_ARG_TYPE_1), 116 | (ActivityArgType2, ACTIVITY_ARG_TYPE_2), 117 | # Replaying workflow with activity result 118 | (WorkflowArgType1, WORKFLOW_ARG_TYPE_1), 119 | (WorkflowArgType2, WORKFLOW_ARG_TYPE_2), 120 | (ActivityRetType, ACTIVITY_RET_TYPE), 121 | # Processing workflow return value 122 | (WorkflowRetType, WORKFLOW_RET_TYPE) 123 | ] 124 | -------------------------------------------------------------------------------- /tests/test_typed_data_converter_query_signal.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pickle 3 | from dataclasses import dataclass 4 | 5 | import pytest 6 | 7 | from temporal.api.common.v1 import Payload 8 | from temporal.conversions import METADATA_ENCODING_KEY 9 | from temporal.converter import DataConverter 10 | from temporal.workflow import workflow_method, WorkflowClient, Workflow, query_method, signal_method 11 | 12 | TASK_QUEUE = "test_typed_data_converter_query_signal" 13 | NAMESPACE = "default" 14 | workflow_started = False 15 | 16 | 17 | @dataclass 18 | class QueryArgType1: 19 | pass 20 | 21 | 22 | @dataclass 23 | class QueryArgType2: 24 | pass 25 | 26 | 27 | @dataclass 28 | class QueryRetType: 29 | pass 30 | 31 | 32 | @dataclass 33 | class SignalArgType1: 34 | pass 35 | 36 | 37 | @dataclass 38 | class SignalArgType2: 39 | pass 40 | 41 | 42 | QUERY_ARG_TYPE_1 = QueryArgType1() 43 | QUERY_ARG_TYPE_2 = QueryArgType2() 44 | QUERY_RET_TYPE = QueryRetType() 45 | SIGNAL_ARG_TYPE_1 = SignalArgType1() 46 | SIGNAL_ARG_TYPE_2 = SignalArgType2() 47 | 48 | 49 | class GreetingWorkflow: 50 | 51 | @query_method 52 | async def get_status(self, arg1: QueryArgType1, arg2: QueryArgType2) -> QueryRetType: 53 | raise NotImplementedError 54 | 55 | @signal_method 56 | async def push_status(self, arg1: SignalArgType1, arg2: SignalArgType2): 57 | raise NotImplementedError 58 | 59 | @workflow_method(task_queue=TASK_QUEUE) 60 | async def get_greeting(self) -> None: 61 | raise NotImplementedError 62 | 63 | 64 | class GreetingWorkflowImpl(GreetingWorkflow): 65 | 66 | @query_method 67 | async def get_status(self, arg1: QueryArgType1, arg2: QueryArgType2) -> QueryRetType: 68 | return QUERY_RET_TYPE 69 | 70 | @signal_method 71 | async def push_status(self, arg1: SignalArgType1, arg2: SignalArgType2): 72 | pass 73 | 74 | async def get_greeting(self): 75 | global workflow_started 76 | workflow_started = True 77 | await Workflow.sleep(10) 78 | 79 | 80 | deserialized = [] 81 | 82 | 83 | class PickleDataConverter(DataConverter): 84 | 85 | def to_payload(self, arg: object) -> Payload: 86 | payload = Payload() 87 | payload.metadata = {METADATA_ENCODING_KEY: b"PYTHON_PICKLE"} 88 | payload.data = pickle.dumps(arg) 89 | return payload 90 | 91 | def from_payload(self, payload: Payload, type_hint: type = None) -> object: 92 | obj = pickle.loads(payload.data) 93 | if type_hint: 94 | deserialized.append((type_hint, obj)) 95 | return obj 96 | 97 | 98 | @pytest.mark.asyncio 99 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[], workflows=[GreetingWorkflowImpl], 100 | data_converter=PickleDataConverter()) 101 | async def test(worker): 102 | global workflow_started 103 | client = WorkflowClient.new_client(namespace=NAMESPACE, data_converter=PickleDataConverter()) 104 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 105 | context = await WorkflowClient.start(greeting_workflow.get_greeting) 106 | while not workflow_started: 107 | await asyncio.sleep(2) 108 | greeting_workflow = client.new_workflow_stub_from_workflow_id(GreetingWorkflow, 109 | workflow_id=context.workflow_execution.workflow_id) 110 | await greeting_workflow.push_status(SIGNAL_ARG_TYPE_1, SIGNAL_ARG_TYPE_2) 111 | await greeting_workflow.get_status(QUERY_ARG_TYPE_1, QUERY_ARG_TYPE_2) 112 | await client.wait_for_close(context) 113 | assert deserialized == [ 114 | # Signal invocation 115 | (SignalArgType1, SIGNAL_ARG_TYPE_1), 116 | (SignalArgType2, SIGNAL_ARG_TYPE_2), 117 | # Replay signal to invoke query 118 | (SignalArgType1, SIGNAL_ARG_TYPE_1), 119 | (SignalArgType2, SIGNAL_ARG_TYPE_2), 120 | (QueryArgType1, QUERY_ARG_TYPE_1), 121 | (QueryArgType2, QUERY_ARG_TYPE_2), 122 | # Query return type being deserialized 123 | (QueryRetType, QUERY_RET_TYPE), 124 | # Workflow replayed after sleep is over 125 | (SignalArgType1, SIGNAL_ARG_TYPE_1), 126 | (SignalArgType2, SIGNAL_ARG_TYPE_2) 127 | ] 128 | -------------------------------------------------------------------------------- /tests/test_activity_method_activity_options_from_stub.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | from datetime import timedelta 5 | 6 | from temporal.api.common.v1 import WorkflowExecution 7 | from temporal.api.enums.v1 import EventType 8 | from temporal.api.workflowservice.v1 import GetWorkflowExecutionHistoryRequest 9 | from temporal.workflow import workflow_method, WorkflowClient, Workflow 10 | from temporal.activity_method import activity_method, RetryParameters, ActivityOptions 11 | 12 | TASK_QUEUE = "test_activity_method_activity_options_from_stub" 13 | NAMESPACE = "default" 14 | workflow_id = "test_activity_method_activity_options_from_stub-" + str(uuid.uuid4()) 15 | 16 | 17 | class GreetingActivities: 18 | @activity_method(task_queue=TASK_QUEUE) 19 | async def compose_greeting(self) -> str: 20 | raise NotImplementedError 21 | 22 | 23 | class GreetingActivitiesImpl: 24 | 25 | async def compose_greeting(self): 26 | pass 27 | 28 | 29 | class GreetingWorkflow: 30 | @workflow_method(task_queue=TASK_QUEUE, workflow_id=workflow_id) 31 | async def get_greeting(self) -> None: 32 | raise NotImplementedError 33 | 34 | 35 | class GreetingWorkflowImpl(GreetingWorkflow): 36 | 37 | def __init__(self): 38 | self.greeting_activities: GreetingActivities 39 | self.greeting_activities = Workflow.new_activity_stub(GreetingActivities, 40 | activity_options=ActivityOptions( 41 | schedule_to_close_timeout=timedelta(seconds=1000), 42 | schedule_to_start_timeout=timedelta(seconds=500), 43 | start_to_close_timeout=timedelta(seconds=800), 44 | heartbeat_timeout=timedelta(seconds=600)), 45 | retry_parameters=RetryParameters( 46 | initial_interval=timedelta(seconds=70), 47 | backoff_coefficient=5.0, 48 | maximum_interval=timedelta(seconds=700), 49 | maximum_attempts=8, 50 | non_retryable_error_types=["DummyError"])) 51 | async def get_greeting(self): 52 | await self.greeting_activities.compose_greeting() 53 | 54 | 55 | @pytest.mark.asyncio 56 | @pytest.mark.worker_config(NAMESPACE, TASK_QUEUE, activities=[(GreetingActivitiesImpl(), "GreetingActivities")], 57 | workflows=[GreetingWorkflowImpl]) 58 | async def test(worker): 59 | client = WorkflowClient.new_client(namespace=NAMESPACE) 60 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 61 | await greeting_workflow.get_greeting() 62 | request = GetWorkflowExecutionHistoryRequest(namespace=NAMESPACE, 63 | execution=WorkflowExecution(workflow_id=workflow_id)) 64 | response = await client.service.get_workflow_execution_history(request=request) 65 | e = next(filter(lambda v: v.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED, response.history.events)) 66 | assert e.activity_task_scheduled_event_attributes.schedule_to_close_timeout == timedelta(seconds=1000) 67 | assert e.activity_task_scheduled_event_attributes.schedule_to_start_timeout == timedelta(seconds=500) 68 | assert e.activity_task_scheduled_event_attributes.start_to_close_timeout == timedelta(seconds=800) 69 | assert e.activity_task_scheduled_event_attributes.heartbeat_timeout == timedelta(seconds=600) 70 | assert e.activity_task_scheduled_event_attributes.retry_policy.initial_interval == timedelta(seconds=70) 71 | assert e.activity_task_scheduled_event_attributes.retry_policy.backoff_coefficient == 5.0 72 | assert e.activity_task_scheduled_event_attributes.retry_policy.maximum_interval == timedelta(seconds=700) 73 | assert e.activity_task_scheduled_event_attributes.retry_policy.maximum_attempts == 8 74 | assert e.activity_task_scheduled_event_attributes.retry_policy.non_retryable_error_types == ["DummyError"] 75 | -------------------------------------------------------------------------------- /temporal/activity_loop.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import json 4 | import inspect 5 | from typing import List 6 | 7 | from grpclib import GRPCError 8 | 9 | from temporal.activity import ActivityContext, ActivityTask, complete_exceptionally, complete 10 | from temporal.api.taskqueue.v1 import TaskQueue, TaskQueueMetadata 11 | from temporal.converter import get_fn_args_type_hints 12 | from temporal.retry import retry 13 | from temporal.service_helpers import get_identity 14 | from temporal.worker import Worker, StopRequestedException 15 | from temporal.api.workflowservice.v1 import WorkflowServiceStub as WorkflowService, PollActivityTaskQueueRequest, \ 16 | PollActivityTaskQueueResponse 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @retry(logger=logger) 22 | async def activity_task_loop_func(worker: Worker): 23 | service: WorkflowService = worker.client.service 24 | logger.info(f"Activity task worker started: {get_identity()}") 25 | try: 26 | while True: 27 | if worker.is_stop_requested(): 28 | return 29 | try: 30 | polling_start = datetime.datetime.now() 31 | polling_request: PollActivityTaskQueueRequest = PollActivityTaskQueueRequest() 32 | polling_request.task_queue_metadata = TaskQueueMetadata() 33 | polling_request.task_queue_metadata.max_tasks_per_second = 200000 34 | polling_request.namespace = worker.namespace 35 | polling_request.identity = get_identity() 36 | polling_request.task_queue = TaskQueue() 37 | polling_request.task_queue.name = worker.task_queue 38 | task: PollActivityTaskQueueResponse 39 | task = await service.poll_activity_task_queue(request=polling_request) 40 | polling_end = datetime.datetime.now() 41 | logger.debug("PollActivityTaskQueue: %dms", (polling_end - polling_start).total_seconds() * 1000) 42 | except StopRequestedException: 43 | return 44 | except GRPCError as ex: 45 | logger.error("Error invoking poll_activity_task_queue: %s", ex, exc_info=True) 46 | continue 47 | task_token = task.task_token 48 | if not task_token: 49 | logger.debug("PollActivityTaskQueue has no task_token (expected): %s", task) 50 | continue 51 | 52 | logger.info(f"Request for activity: {task.activity_type.name}") 53 | fn = worker.activities.get(task.activity_type.name) 54 | if not fn: 55 | logger.error("Activity type not found: " + task.activity_type.name) 56 | continue 57 | 58 | args: List[object] = worker.client.data_converter.from_payloads(task.input, 59 | get_fn_args_type_hints(fn)) 60 | 61 | process_start = datetime.datetime.now() 62 | activity_context = ActivityContext() 63 | activity_context.client = worker.client 64 | activity_context.activity_task = ActivityTask.from_poll_for_activity_task_response(task) 65 | activity_context.namespace = worker.namespace 66 | try: 67 | ActivityContext.set(activity_context) 68 | if inspect.iscoroutinefunction(fn): 69 | return_value = await fn(*args) 70 | else: 71 | raise Exception(f"Activity method {fn.__module__}.{fn.__qualname__} should be a coroutine") 72 | if activity_context.do_not_complete: 73 | logger.info(f"Not completing activity {task.activity_type.name}({str(args)[1:-1]})") 74 | continue 75 | 76 | logger.info( 77 | f"Activity {task.activity_type.name}({str(args)[1:-1]}) returned {return_value}") 78 | 79 | try: 80 | await complete(worker.client, task_token, return_value) 81 | except GRPCError as ex: 82 | logger.error("Error invoking respond_activity_task_completed: %s", ex, exc_info=True) 83 | except Exception as ex: 84 | logger.error(f"Activity {task.activity_type.name} failed: {type(ex).__name__}({ex})", exc_info=True) 85 | try: 86 | await complete_exceptionally(worker.client, task_token, ex) 87 | except GRPCError as ex2: 88 | logger.error("Error invoking respond_activity_task_failed: %s", ex2, exc_info=True) 89 | finally: 90 | ActivityContext.set(None) 91 | process_end = datetime.datetime.now() 92 | logger.info("Process ActivityTask: %dms", (process_end - process_start).total_seconds() * 1000) 93 | finally: 94 | worker.notify_thread_stopped() 95 | logger.info("Activity loop ended") 96 | -------------------------------------------------------------------------------- /temporal/activity_method.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass, field 3 | from datetime import timedelta 4 | from typing import Callable, List 5 | 6 | from temporal.api.common.v1 import RetryPolicy, ActivityType, Payloads 7 | 8 | 9 | 10 | def get_activity_method_name(method: Callable): 11 | return "::".join(method.__qualname__.split(".")[-2:]) 12 | 13 | 14 | @dataclass 15 | class RetryParameters: 16 | initial_interval: timedelta = None 17 | backoff_coefficient: float = None 18 | maximum_interval: timedelta = None 19 | maximum_attempts: int = None 20 | non_retryable_error_types: List[str] = field(default_factory=list) 21 | 22 | def to_retry_policy(self) -> RetryPolicy: 23 | policy = RetryPolicy() 24 | policy.initial_interval = self.initial_interval 25 | policy.backoff_coefficient = self.backoff_coefficient 26 | policy.maximum_interval = self.maximum_interval 27 | policy.maximum_attempts = self.maximum_attempts 28 | policy.non_retryable_error_types = self.non_retryable_error_types 29 | return policy 30 | 31 | 32 | @dataclass 33 | class ExecuteActivityParameters: 34 | fn: Callable = None 35 | activity_id: str = "" 36 | activity_type: ActivityType = None 37 | heartbeat_timeout: timedelta = None 38 | input: Payloads = None 39 | schedule_to_close_timeout: timedelta = None 40 | schedule_to_start_timeout: timedelta = None 41 | start_to_close_timeout: timedelta = None 42 | task_queue: str = "" 43 | retry_parameters: RetryParameters = None 44 | 45 | 46 | def activity_method(func: Callable = None, name: str = "", schedule_to_close_timeout: timedelta = None, 47 | schedule_to_start_timeout: timedelta = None, start_to_close_timeout: timedelta = None, 48 | heartbeat_timeout: timedelta = None, task_queue: str = "", retry_parameters: RetryParameters = None): 49 | def wrapper(fn: Callable): 50 | # noinspection PyProtectedMember 51 | async def stub_activity_fn(self, *args): 52 | from .async_activity import Async 53 | from .decision_loop import ActivityFuture 54 | future: ActivityFuture = Async.function_with_self(stub_activity_fn, self, *args) 55 | return await future.wait_for_result() 56 | 57 | if not task_queue: 58 | raise Exception("task_queue parameter is mandatory") 59 | 60 | execute_parameters = ExecuteActivityParameters() 61 | execute_parameters.fn = fn 62 | execute_parameters.activity_type = ActivityType() 63 | execute_parameters.activity_type.name = name if name else get_activity_method_name(fn) 64 | execute_parameters.schedule_to_close_timeout = schedule_to_close_timeout 65 | execute_parameters.schedule_to_start_timeout = schedule_to_start_timeout 66 | execute_parameters.start_to_close_timeout = start_to_close_timeout 67 | execute_parameters.heartbeat_timeout = heartbeat_timeout 68 | execute_parameters.task_queue = task_queue 69 | execute_parameters.retry_parameters = retry_parameters 70 | # noinspection PyTypeHints 71 | stub_activity_fn._execute_parameters = execute_parameters # type: ignore 72 | fn.stub_activity_fn = stub_activity_fn 73 | return fn 74 | 75 | if func and inspect.isfunction(func): 76 | raise Exception("activity_method must be called with arguments") 77 | else: 78 | return wrapper 79 | 80 | 81 | @dataclass 82 | class ActivityOptions: 83 | schedule_to_close_timeout: timedelta = None 84 | schedule_to_start_timeout: timedelta = None 85 | start_to_close_timeout: timedelta = None 86 | heartbeat_timeout: timedelta = None 87 | task_queue: str = None 88 | 89 | def fill_execute_activity_parameters(self, execute_parameters: ExecuteActivityParameters): 90 | if self.schedule_to_close_timeout is not None: 91 | execute_parameters.schedule_to_close_timeout = self.schedule_to_close_timeout 92 | if self.schedule_to_start_timeout is not None: 93 | execute_parameters.schedule_to_start_timeout = self.schedule_to_start_timeout 94 | if self.start_to_close_timeout is not None: 95 | execute_parameters.start_to_close_timeout = self.start_to_close_timeout 96 | if self.heartbeat_timeout is not None: 97 | execute_parameters.heartbeat_timeout = self.heartbeat_timeout 98 | if self.task_queue is not None: 99 | execute_parameters.task_queue = self.task_queue 100 | 101 | 102 | @dataclass 103 | class UntypedActivityStub: 104 | _decision_context: object = None 105 | _retry_parameters: RetryParameters = None 106 | _activity_options: ActivityOptions = None 107 | 108 | async def execute(self, activity_name: str, *args): 109 | f = await self.execute_async(activity_name, *args) 110 | return await f.wait_for_result() 111 | 112 | async def execute_async(self, activity_name: str, *args): 113 | from .async_activity import Async 114 | execute_parameters = ExecuteActivityParameters() 115 | execute_parameters.activity_type = ActivityType() 116 | execute_parameters.activity_type.name = activity_name 117 | return Async.call(self, execute_parameters, args) 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ---- 2 | NOTE: I'm no longer working on this SDK. I will be refocusing my efforts on downstream tools. 3 | ---- 4 | 5 | # Unofficial Python SDK for the Temporal Workflow Engine 6 | 7 | ## Status 8 | 9 | This should be considered EXPERIMENTAL at the moment. At the moment, all I can say is that the [test cases](https://gist.github.com/firdaus/4ec442f2c626122ad0c8d379a7ffd8bc) currently pass. I have not tested this for any real world use cases yet. 10 | 11 | ## Installation 12 | 13 | ``` 14 | pip install temporal-python-sdk 15 | ``` 16 | ## Sample Code 17 | 18 | Sample code for using this library can be found in [Workflows in Python Using Temporal](https://onepointzero.app/workflows-in-python-using-temporal/). 19 | 20 | ## Hello World 21 | 22 | ```python 23 | import asyncio 24 | import logging 25 | from datetime import timedelta 26 | 27 | from temporal.activity_method import activity_method 28 | from temporal.workerfactory import WorkerFactory 29 | from temporal.workflow import workflow_method, Workflow, WorkflowClient 30 | 31 | logging.basicConfig(level=logging.INFO) 32 | 33 | TASK_QUEUE = "HelloActivity-python-tq" 34 | NAMESPACE = "default" 35 | 36 | # Activities Interface 37 | class GreetingActivities: 38 | @activity_method(task_queue=TASK_QUEUE, schedule_to_close_timeout=timedelta(seconds=1000)) 39 | async def compose_greeting(self, greeting: str, name: str) -> str: 40 | raise NotImplementedError 41 | 42 | 43 | # Activities Implementation 44 | class GreetingActivitiesImpl: 45 | async def compose_greeting(self, greeting: str, name: str): 46 | return greeting + " " + name + "!" 47 | 48 | 49 | # Workflow Interface 50 | class GreetingWorkflow: 51 | @workflow_method(task_queue=TASK_QUEUE) 52 | async def get_greeting(self, name: str) -> str: 53 | raise NotImplementedError 54 | 55 | 56 | # Workflow Implementation 57 | class GreetingWorkflowImpl(GreetingWorkflow): 58 | 59 | def __init__(self): 60 | self.greeting_activities: GreetingActivities = Workflow.new_activity_stub(GreetingActivities) 61 | pass 62 | 63 | async def get_greeting(self, name): 64 | return await self.greeting_activities.compose_greeting("Hello", name) 65 | 66 | 67 | async def client_main(): 68 | client = WorkflowClient.new_client(namespace=NAMESPACE) 69 | 70 | factory = WorkerFactory(client, NAMESPACE) 71 | worker = factory.new_worker(TASK_QUEUE) 72 | worker.register_activities_implementation(GreetingActivitiesImpl(), "GreetingActivities") 73 | worker.register_workflow_implementation_type(GreetingWorkflowImpl) 74 | factory.start() 75 | 76 | greeting_workflow: GreetingWorkflow = client.new_workflow_stub(GreetingWorkflow) 77 | result = await greeting_workflow.get_greeting("Python") 78 | print(result) 79 | print("Stopping workers.....") 80 | await worker.stop() 81 | print("Workers stopped......") 82 | 83 | if __name__ == '__main__': 84 | asyncio.run(client_main()) 85 | ``` 86 | 87 | ## Roadmap 88 | 89 | 1.0 90 | - [x] Workflow argument passing and return values 91 | - [x] Activity invocation 92 | - [x] Activity heartbeat and Activity.getHeartbeatDetails() 93 | - [x] doNotCompleteOnReturn 94 | - [x] ActivityCompletionClient 95 | - [x] complete 96 | - [x] complete_exceptionally 97 | - [x] Activity get_namespace(), get_task_token() get_workflow_execution() 98 | - [x] Activity Retry 99 | - [x] Activity Failure Exceptions 100 | - [x] workflow_execution_timeout / workflow_run_timeout / workflow_task_timeout 101 | - [x] Workflow exceptions 102 | - [x] Cron workflows 103 | - [x] Workflow static methods: 104 | - [x] await_till() 105 | - [x] sleep() 106 | - [x] current_time_millis() 107 | - [x] now() 108 | - [x] random_uuid() 109 | - [x] new_random() 110 | - [x] get_workflow_id() 111 | - [x] get_run_id() 112 | - [x] get_version() 113 | - [x] get_logger() 114 | - [x] Activity invocation parameters 115 | - [x] Query method 116 | - [x] Signal methods 117 | - [x] Workflow start parameters - workflow_id etc... 118 | - [x] Workflow client - starting workflows synchronously 119 | - [x] Workflow client - starting workflows asynchronously (WorkflowClient.start) 120 | - [x] Get workflow result after async execution (client.wait_for_close) 121 | - [x] Workflow client - invoking signals 122 | - [x] Workflow client - invoking queries 123 | 124 | 1.1 125 | - [x] ActivityStub and Workflow.newUntypedActivityStub 126 | - [x] Remove threading, use coroutines for everything all concurrency 127 | - [x] Classes as arguments and return values to/from activity and workflow methods (DataConverter) 128 | - [x] Type hints for DataConverter 129 | - [x] Parallel activity execution (STATUS: there's a working but not finalized API). 130 | 131 | 1.2 132 | - [x] Timers 133 | - [x] Custom workflow ids through start() and new_workflow_stub() 134 | 135 | 136 | Other: 137 | - [ ] WorkflowStub and WorkflowClient.newUntypedWorkflowStub 138 | - [ ] ContinueAsNew 139 | - [ ] Sticky workflows 140 | - [ ] Child Workflows 141 | - [ ] Support for keyword arguments 142 | - [ ] Compatibility with Java client 143 | - [ ] Compatibility with Golang client 144 | - [ ] Upgrade python-betterproto 145 | - [ ] sideEffect/mutableSideEffect 146 | - [ ] Local activity 147 | - [ ] Cancellation Scopes 148 | - [ ] Explicit activity ids for activity invocations 149 | -------------------------------------------------------------------------------- /test-utils/java-test-client/gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS="" 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn () { 37 | echo "$*" 38 | } 39 | 40 | die () { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | # For Darwin, add options to specify how the application appears in the dock 108 | if $darwin; then 109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 110 | fi 111 | 112 | # For Cygwin, switch paths to Windows format before running java 113 | if $cygwin ; then 114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 116 | JAVACMD=`cygpath --unix "$JAVACMD"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Escape application args 158 | save () { 159 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 160 | echo " " 161 | } 162 | APP_ARGS=$(save "$@") 163 | 164 | # Collect all arguments for the java command, following the shell quoting and substitution rules 165 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 166 | 167 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 168 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 169 | cd "$(dirname "$0")" 170 | fi 171 | 172 | exec "$JAVACMD" "$@" 173 | -------------------------------------------------------------------------------- /temporal/marker.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass, field 4 | from dataclasses_json import dataclass_json, LetterCase 5 | 6 | from typing import Dict 7 | 8 | from temporal.api.common.v1 import Header, Payload, Payloads 9 | from temporal.api.enums.v1 import EventType 10 | from temporal.api.history.v1 import MarkerRecordedEventAttributes, HistoryEvent 11 | 12 | from .decision_loop import DecisionContext 13 | 14 | MUTABLE_MARKER_HEADER_KEY = "MutableMarkerHeader" 15 | 16 | 17 | class MarkerInterface: 18 | @staticmethod 19 | def from_event_attributes(attributes: MarkerRecordedEventAttributes) -> MarkerInterface: 20 | if attributes.header and attributes.header.fields and MUTABLE_MARKER_HEADER_KEY in attributes.header.fields: 21 | buffer: bytes = attributes.header.fields.get(MUTABLE_MARKER_HEADER_KEY).data 22 | header = MarkerHeader.from_json(str(buffer, "utf-8")) # type: ignore 23 | return MarkerData(header=header, data=attributes.details) 24 | else: 25 | # return PlainMarkerData.from_json(str(attributes.details, "utf-8")) 26 | raise Exception("PlainMarkerData is not supported") 27 | 28 | def get_id(self) -> str: 29 | raise NotImplementedError() 30 | 31 | def get_access_count(self) -> int: 32 | raise NotImplementedError() 33 | 34 | def get_data(self) -> Dict[str, Payloads]: 35 | raise NotImplementedError() 36 | 37 | 38 | @dataclass_json(letter_case=LetterCase.CAMEL) 39 | @dataclass 40 | class MarkerHeader: 41 | id: str = None 42 | event_id: int = None 43 | access_count: int = 0 44 | 45 | 46 | @dataclass_json(letter_case=LetterCase.CAMEL) 47 | @dataclass 48 | class MarkerData(MarkerInterface): 49 | header: MarkerHeader = None 50 | data: Dict[str, Payloads] = None 51 | 52 | @staticmethod 53 | def create(id: str, event_id: int, data: Dict[str, Payloads], access_count: int) -> MarkerData: 54 | header = MarkerHeader(id=id, event_id=event_id, access_count=access_count) 55 | return MarkerData(header=header, data=data) 56 | 57 | def get_header(self) -> Header: 58 | header_bytes = self.header.to_json().encode("utf-8") # type: ignore 59 | header = Header() 60 | header.fields[MUTABLE_MARKER_HEADER_KEY] = Payload(data=header_bytes) 61 | return header 62 | 63 | def get_access_count(self) -> int: 64 | return self.header.access_count 65 | 66 | def get_data(self) -> Dict[str, Payloads]: 67 | return self.data 68 | 69 | def get_id(self) -> str: 70 | return self.header.id 71 | 72 | 73 | @dataclass_json(letter_case=LetterCase.CAMEL) 74 | @dataclass 75 | class MarkerResult: 76 | data: Dict[str, Payloads] = None 77 | access_count: int = 0 78 | replayed: bool = False 79 | 80 | 81 | @dataclass 82 | class MarkerHandler: 83 | decision_context: DecisionContext 84 | marker_name: str 85 | mutable_marker_results: Dict[str, MarkerResult] = field(default_factory=dict) 86 | 87 | def record_mutable_marker(self, id: str, event_id: int, data: Dict[str, Payloads], access_count: int): 88 | marker = MarkerData.create(id=id, event_id=event_id, data=data, access_count=access_count) 89 | if id in self.mutable_marker_results: 90 | self.mutable_marker_results[id].replayed = True 91 | else: 92 | self.mutable_marker_results[id] = MarkerResult(data=data) 93 | self.decision_context.record_marker(self.marker_name, marker.get_header(), data) 94 | 95 | # Sets data without creating a decision - used when DEFAULT_VERSION is the implicit current version 96 | def set_data(self, id, data: Dict[str, Payloads]): 97 | self.mutable_marker_results[id] = MarkerResult(data=data) 98 | 99 | def mark_replayed(self, id): 100 | self.mutable_marker_results[id].replayed = True 101 | 102 | def handle(self, id: str, func) -> Dict[str, Payloads]: 103 | event_id = self.decision_context.decider.next_decision_event_id 104 | result: MarkerResult = self.mutable_marker_results.get(id) 105 | if result or self.decision_context.is_replaying(): 106 | if result: 107 | if self.decision_context.is_replaying() and not result.replayed: 108 | # Need to insert marker to ensure that event_id is incremented 109 | self.record_mutable_marker(id, event_id, result.data, 0) 110 | return result.data 111 | else: 112 | return None 113 | else: 114 | to_store = func() 115 | if to_store: 116 | data = to_store 117 | self.record_mutable_marker(id, event_id, data, 0) 118 | return to_store 119 | else: 120 | # TODO: Should this ever happen? - at least for version it will never happen 121 | return None 122 | 123 | # This method is currently not being used - after adopting the version logic from the 124 | # Golang client 125 | def get_marker_data_from_history(self, event_id: int, marker_id: str, expected_access_count: int) -> \ 126 | Dict[str, Payloads]: 127 | event: HistoryEvent = self.decision_context.decider.get_optional_decision_event(event_id) 128 | if not event or event.event_type != EventType.EVENT_TYPE_MARKER_RECORDED: 129 | return None 130 | 131 | attributes: MarkerRecordedEventAttributes = event.marker_recorded_event_attributes 132 | name = attributes.marker_name 133 | if self.marker_name != name: 134 | return None 135 | 136 | marker_data = MarkerInterface.from_event_attributes(attributes) 137 | if marker_id != marker_data.get_id() or marker_data.get_access_count() > expected_access_count: 138 | return None 139 | 140 | return marker_data.get_data() 141 | 142 | 143 | # @dataclass_json(letter_case=LetterCase.CAMEL) 144 | # @dataclass 145 | # class PlainMarkerData(MarkerInterface): 146 | # id: str = None 147 | # event_id: int = None 148 | # data: bytes = None 149 | # access_count: int = 0 150 | # 151 | # def get_access_count(self): 152 | # return self.access_count 153 | # 154 | # def get_data(self): 155 | # return self.data 156 | # 157 | # def get_id(self) -> str: 158 | # return self.id 159 | # 160 | -------------------------------------------------------------------------------- /temporal/activity.py: -------------------------------------------------------------------------------- 1 | import contextvars 2 | from contextvars import ContextVar 3 | from dataclasses import dataclass 4 | from datetime import datetime, timedelta 5 | 6 | from typing import List, Optional 7 | 8 | from temporal.api.common.v1 import WorkflowExecution, ActivityType, Payloads 9 | from temporal.api.workflowservice.v1 import PollActivityTaskQueueResponse, RecordActivityTaskHeartbeatRequest, \ 10 | RespondActivityTaskFailedRequest, RespondActivityTaskCompletedRequest 11 | from temporal.exception_handling import serialize_exception 12 | from temporal.exceptions import ActivityCancelledException 13 | from temporal.service_helpers import get_identity 14 | 15 | current_activity_context: ContextVar['ActivityContext'] = contextvars.ContextVar("current_activity_context") 16 | 17 | 18 | @dataclass 19 | class ActivityTask: 20 | @classmethod 21 | def from_poll_for_activity_task_response(cls, task: PollActivityTaskQueueResponse) -> 'ActivityTask': 22 | activity_task: 'ActivityTask' = cls() 23 | activity_task.task_token = task.task_token 24 | activity_task.workflow_execution = task.workflow_execution 25 | activity_task.activity_id = task.activity_id 26 | activity_task.activity_type = task.activity_type 27 | activity_task.scheduled_time = task.scheduled_time 28 | activity_task.schedule_to_close_timeout = task.schedule_to_close_timeout 29 | activity_task.start_to_close_timeout = task.start_to_close_timeout 30 | activity_task.heartbeat_timeout = task.heartbeat_timeout 31 | activity_task.attempt = task.attempt 32 | activity_task.heartbeat_details = task.heartbeat_details 33 | activity_task.workflow_namespace = task.workflow_namespace 34 | return activity_task 35 | 36 | task_token: bytes = None 37 | workflow_execution: WorkflowExecution = None 38 | activity_id: str = None 39 | activity_type: ActivityType = None 40 | scheduled_time: datetime = None 41 | schedule_to_close_timeout: timedelta = None 42 | start_to_close_timeout: timedelta = None 43 | heartbeat_timeout: timedelta = None 44 | attempt: int = None 45 | heartbeat_details: Payloads = None 46 | workflow_namespace: str = None 47 | 48 | 49 | async def heartbeat(client: 'WorkflowClient', task_token: bytes, details: object): 50 | request: RecordActivityTaskHeartbeatRequest = RecordActivityTaskHeartbeatRequest() 51 | request.details = client.data_converter.to_payloads([details]) 52 | request.identity = get_identity() 53 | request.task_token = task_token 54 | response = await client.service.record_activity_task_heartbeat(request=request) 55 | # ----- 56 | # if error: 57 | # raise error 58 | # ----- 59 | if response.cancel_requested: 60 | raise ActivityCancelledException() 61 | 62 | 63 | class ActivityContext: 64 | client: 'WorkflowClient' = None 65 | activity_task: ActivityTask = None 66 | namespace: str = None 67 | do_not_complete: bool = False 68 | 69 | @staticmethod 70 | def get() -> 'ActivityContext': 71 | return current_activity_context.get() 72 | 73 | @staticmethod 74 | def set(context: Optional['ActivityContext']): 75 | current_activity_context.set(context) 76 | 77 | # Plan for heartbeat: 78 | # - We will possibly implement a non-async version of this 79 | # - We might standardize on a background thread (using an in-memory queue) for heartbeats 80 | # in which case it doesn't matter whether it's an async method or not. 81 | async def heartbeat(self, details: object): 82 | await heartbeat(self.client, self.activity_task.task_token, details) 83 | 84 | def get_heartbeat_details(self) -> object: 85 | details: Payloads = self.activity_task.heartbeat_details 86 | if not self.activity_task.heartbeat_details: 87 | return None 88 | payloads: List[object] = self.client.data_converter.from_payloads(details) 89 | return payloads[0] 90 | 91 | def do_not_complete_on_return(self): 92 | self.do_not_complete = True 93 | 94 | 95 | class Activity: 96 | 97 | @staticmethod 98 | def get_task_token() -> bytes: 99 | return ActivityContext.get().activity_task.task_token 100 | 101 | @staticmethod 102 | def get_workflow_execution() -> WorkflowExecution: 103 | return ActivityContext.get().activity_task.workflow_execution 104 | 105 | @staticmethod 106 | def get_namespace() -> str: 107 | return ActivityContext.get().namespace 108 | 109 | @staticmethod 110 | def get_heartbeat_details() -> object: 111 | return ActivityContext.get().get_heartbeat_details() 112 | 113 | @staticmethod 114 | async def heartbeat(details: object): 115 | await ActivityContext.get().heartbeat(details) 116 | 117 | @staticmethod 118 | def get_activity_task() -> ActivityTask: 119 | return ActivityContext.get().activity_task 120 | 121 | @staticmethod 122 | def do_not_complete_on_return(): 123 | return ActivityContext.get().do_not_complete_on_return() 124 | 125 | 126 | @dataclass 127 | class ActivityCompletionClient: 128 | client: 'WorkflowClient' 129 | 130 | def heartbeat(self, task_token: bytes, details: object): 131 | heartbeat(self.client, task_token, details) 132 | 133 | async def complete(self, task_token: bytes, return_value: object): 134 | await complete(self.client, task_token, return_value) 135 | 136 | async def complete_exceptionally(self, task_token: bytes, ex: Exception): 137 | await complete_exceptionally(self.client, task_token, ex) 138 | 139 | 140 | async def complete_exceptionally(client: 'WorkflowClient', task_token, ex: Exception): 141 | respond: RespondActivityTaskFailedRequest = RespondActivityTaskFailedRequest() 142 | respond.task_token = task_token 143 | respond.identity = get_identity() 144 | respond.failure = serialize_exception(ex) 145 | await client.service.respond_activity_task_failed(request=respond) 146 | 147 | 148 | async def complete(client: 'WorkflowClient', task_token, return_value: object): 149 | respond = RespondActivityTaskCompletedRequest() 150 | respond.task_token = task_token 151 | respond.result = client.data_converter.to_payloads([return_value]) 152 | respond.identity = get_identity() 153 | await client.service.respond_activity_task_completed(request=respond) 154 | 155 | 156 | from temporal.workflow import WorkflowClient 157 | -------------------------------------------------------------------------------- /temporal/clock_decision_context.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | from typing import Callable, Dict, Any, Union 4 | from datetime import datetime, tzinfo, timedelta 5 | 6 | import pytz 7 | 8 | from .api.command.v1 import StartTimerCommandAttributes 9 | from .api.common.v1 import Payloads, Payload 10 | from .api.history.v1 import TimerFiredEventAttributes, HistoryEvent, TimerCanceledEventAttributes 11 | from .decision_loop import ReplayDecider, DecisionContext 12 | from .exceptions import CancellationException 13 | from .marker import MarkerHandler, MarkerInterface, MarkerResult 14 | from .util import OpenRequestInfo 15 | from . import DEFAULT_VERSION 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | SIDE_EFFECT_MARKER_NAME = "SideEffect" 20 | MUTABLE_SIDE_EFFECT_MARKER_NAME = "MutableSideEffect" 21 | VERSION_MARKER_NAME = "Version" 22 | LOCAL_ACTIVITY_MARKER_NAME = "LocalActivity" 23 | 24 | 25 | @dataclass 26 | class ClockDecisionContext: 27 | decider: ReplayDecider 28 | decision_context: DecisionContext 29 | scheduled_timers: Dict[int, OpenRequestInfo] = field(default_factory=dict) 30 | replay_current_time_milliseconds: datetime = datetime.fromtimestamp(0, tz=pytz.UTC) 31 | replaying: bool = True 32 | version_handler: MarkerHandler = None 33 | 34 | def __post_init__(self): 35 | self.version_handler = MarkerHandler(self.decision_context, VERSION_MARKER_NAME) 36 | 37 | def set_replay_current_time_milliseconds(self, s: datetime): 38 | self.replay_current_time_milliseconds = s 39 | 40 | def current_time_millis(self) -> datetime: 41 | return self.replay_current_time_milliseconds 42 | 43 | def create_timer(self, delay_seconds: int, callback: Callable): 44 | if delay_seconds < 0: 45 | raise Exception("Negative delay seconds: " + str(delay_seconds)) 46 | if delay_seconds == 0: 47 | callback(None) 48 | return None 49 | firing_time = (self.current_time_millis().timestamp() * 1000) + delay_seconds * 1000 50 | context = OpenRequestInfo(user_context=firing_time) 51 | timer = StartTimerCommandAttributes() 52 | timer.start_to_fire_timeout = timedelta(seconds=delay_seconds) 53 | timer.timer_id = str(self.decider.get_and_increment_next_id()) 54 | start_event_id: int = self.decider.start_timer(timer) 55 | context.completion_handle = lambda ctx, e: callback(e) # type: ignore 56 | self.scheduled_timers[start_event_id] = context 57 | return TimerCancellationHandler(start_event_id=start_event_id, clock_decision_context=self) 58 | 59 | def is_replaying(self): 60 | return self.replaying 61 | 62 | def set_replaying(self, replaying): 63 | self.replaying = replaying 64 | 65 | def timer_cancelled(self, start_event_id: int, reason: Exception): 66 | scheduled: OpenRequestInfo = self.scheduled_timers.pop(start_event_id, None) 67 | if not scheduled: 68 | return 69 | callback = scheduled.completion_handle 70 | exception = CancellationException("Cancelled by request") 71 | exception.init_cause(reason) 72 | callback(None, exception) 73 | 74 | def handle_timer_fired(self, attributes: TimerFiredEventAttributes): 75 | started_event_id: int = attributes.started_event_id 76 | if self.decider.handle_timer_closed(attributes): 77 | scheduled = self.scheduled_timers.pop(started_event_id, None) 78 | if scheduled: 79 | callback = scheduled.completion_handle 80 | callback(None, None) 81 | 82 | def handle_timer_canceled(self, event: HistoryEvent): 83 | attributes: TimerCanceledEventAttributes = event.timer_canceled_event_attributes 84 | started_event_id: int = attributes.started_event_id 85 | if self.decider.handle_timer_canceled(event): 86 | self.timer_cancelled(started_event_id, None) 87 | 88 | def get_version(self, change_id: str, min_supported: int, max_supported: int) -> int: 89 | def func(): 90 | d: Dict[str, Payloads] = {"VERSION": self.decider.worker.client.data_converter.to_payloads([max_supported])} 91 | return d 92 | 93 | result: Dict[str, Payloads] = self.version_handler.handle(change_id, func) 94 | if result is None: 95 | result = {"VERSION": self.decider.worker.client.data_converter.to_payloads([DEFAULT_VERSION])} 96 | self.version_handler.set_data(change_id, result) 97 | self.version_handler.mark_replayed(change_id) # so that we don't ever emit a MarkerRecorded for this 98 | 99 | version: int = self.decider.worker.client.data_converter.from_payloads(result["VERSION"])[0] # type: ignore 100 | self.validate_version(change_id, version, min_supported, max_supported) 101 | return version 102 | 103 | def validate_version(self, change_id: str, version: int, min_supported: int, max_supported: int): 104 | if version < min_supported or version > max_supported: 105 | raise Exception(f"Version {version} of changeID {change_id} is not supported. " 106 | f"Supported version is between {min_supported} and {max_supported}.") 107 | 108 | def handle_marker_recorded(self, event: HistoryEvent): 109 | """ 110 | Will be executed more than once for the same event. 111 | """ 112 | attributes = event.marker_recorded_event_attributes 113 | name: str = attributes.marker_name 114 | if SIDE_EFFECT_MARKER_NAME == name: 115 | # TODO 116 | # sideEffectResults.put(event.getEventId(), attributes.getDetails()); 117 | pass 118 | elif LOCAL_ACTIVITY_MARKER_NAME == name: 119 | # TODO 120 | # handleLocalActivityMarker(attributes); 121 | pass 122 | elif VERSION_MARKER_NAME == name: 123 | marker_data = MarkerInterface.from_event_attributes(attributes) 124 | change_id: str = marker_data.get_id() 125 | data: Dict[str, Payloads] = marker_data.get_data() 126 | self.version_handler.mutable_marker_results[change_id] = MarkerResult(data=data) 127 | elif MUTABLE_SIDE_EFFECT_MARKER_NAME != name: 128 | # TODO 129 | # if (log.isWarnEnabled()) { 130 | # log.warn("Unexpected marker: " + event); 131 | # } 132 | pass 133 | 134 | 135 | @dataclass 136 | class TimerCancellationHandler: 137 | start_event_id: int 138 | clock_decision_context: ClockDecisionContext 139 | 140 | def accept(self, reason: Exception): 141 | self.clock_decision_context.decider.cancel_timer(self.start_event_id, 142 | lambda: self.clock_decision_context.timer_cancelled( 143 | self.start_event_id, reason)) 144 | -------------------------------------------------------------------------------- /temporal/worker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass, field 3 | from typing import Callable, Dict, Tuple 4 | import inspect 5 | import logging 6 | 7 | from temporal.constants import DEFAULT_SOCKET_TIMEOUT_SECONDS 8 | from temporal.conversions import camel_to_snake, snake_to_camel, snake_to_title 9 | 10 | from .workflow import WorkflowMethod, SignalMethod, QueryMethod, WorkflowClient 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @dataclass 16 | class WorkerOptions: 17 | pass 18 | 19 | 20 | def _find_interface_class(impl_cls) -> type: 21 | hierarchy = list(inspect.getmro(impl_cls)) 22 | hierarchy.reverse() 23 | hierarchy.pop(0) # remove object 24 | for cls in hierarchy: 25 | for method_name, fn in inspect.getmembers(cls, predicate=inspect.isfunction): 26 | # first class with a "_workflow_method" is considered the interface 27 | if hasattr(fn, "_workflow_method"): 28 | return cls 29 | return impl_cls 30 | 31 | 32 | def _find_metadata_field(cls, metadata_field, method_name): 33 | for c in inspect.getmro(cls): 34 | if not hasattr(c, method_name): 35 | continue 36 | m = getattr(c, method_name) 37 | if not hasattr(m, metadata_field): 38 | continue 39 | return getattr(m, metadata_field) 40 | return None 41 | 42 | 43 | def _get_wm(cls: type, method_name: str) -> WorkflowMethod: 44 | metadata_field = "_workflow_method" 45 | return _find_metadata_field(cls, metadata_field, method_name) 46 | 47 | 48 | def _get_sm(cls: type, method_name: str) -> SignalMethod: 49 | metadata_field = "_signal_method" 50 | return _find_metadata_field(cls, metadata_field, method_name) 51 | 52 | 53 | def _get_qm(cls: type, method_name: str) -> QueryMethod: 54 | metadata_field = "_query_method" 55 | return _find_metadata_field(cls, metadata_field, method_name) 56 | 57 | 58 | @dataclass 59 | class Worker: 60 | client: WorkflowClient 61 | namespace: str = None 62 | task_queue: str = None 63 | options: WorkerOptions = None 64 | activities: Dict[str, Callable] = field(default_factory=dict) 65 | workflow_methods: Dict[str, Tuple[type, Callable]] = field(default_factory=dict) 66 | threads_started: int = 0 67 | threads_stopped: int = 0 68 | stop_requested: bool = False 69 | timeout: int = DEFAULT_SOCKET_TIMEOUT_SECONDS 70 | num_activity_tasks = 1 71 | num_worker_tasks = 1 72 | 73 | def register_activities_implementation(self, activities_instance: object, activities_cls_name: str = None): 74 | cls_name = activities_cls_name if activities_cls_name else type(activities_instance).__name__ 75 | for method_name, fn in inspect.getmembers(activities_instance, predicate=inspect.ismethod): 76 | if method_name.startswith("_"): 77 | continue 78 | self.activities[f'{cls_name}::{camel_to_snake(method_name)}'] = fn 79 | self.activities[f'{cls_name}::{snake_to_camel(method_name)}'] = fn 80 | self.activities[f'{cls_name}::{snake_to_title(method_name)}'] = fn 81 | 82 | def register_workflow_implementation_type(self, impl_cls: type, workflow_cls_name: str = None): 83 | cls_name = workflow_cls_name if workflow_cls_name else _find_interface_class(impl_cls).__name__ 84 | if not hasattr(impl_cls, "_signal_methods"): 85 | impl_cls._signal_methods = {} # type: ignore 86 | if not hasattr(impl_cls, "_query_methods"): 87 | impl_cls._query_methods = {} # type: ignore 88 | for method_name, fn in inspect.getmembers(impl_cls, predicate=inspect.isfunction): 89 | wm: WorkflowMethod = _get_wm(impl_cls, method_name) 90 | if wm: 91 | impl_fn = getattr(impl_cls, method_name) 92 | self.workflow_methods[wm._name] = (impl_cls, impl_fn) 93 | if "::" in wm._name: 94 | _, method_name = wm._name.split("::") 95 | self.workflow_methods[f'{cls_name}::{camel_to_snake(method_name)}'] = (impl_cls, impl_fn) 96 | self.workflow_methods[f'{cls_name}::{snake_to_camel(method_name)}'] = (impl_cls, impl_fn) 97 | continue 98 | sm: SignalMethod = _get_sm(impl_cls, method_name) 99 | if sm: 100 | impl_fn = getattr(impl_cls, method_name) 101 | impl_cls._signal_methods[sm.name] = impl_fn # type: ignore 102 | if "::" in sm.name: 103 | _, method_name = sm.name.split("::") 104 | impl_cls._signal_methods[f'{cls_name}::{camel_to_snake(method_name)}'] = impl_fn # type: ignore 105 | impl_cls._signal_methods[f'{cls_name}::{snake_to_camel(method_name)}'] = impl_fn # type: ignore 106 | continue 107 | qm: QueryMethod = _get_qm(impl_cls, method_name) 108 | if qm: 109 | impl_fn = getattr(impl_cls, method_name) 110 | impl_cls._query_methods[qm.name] = impl_fn # type: ignore 111 | if "::" in qm.name: 112 | _, method_name = qm.name.split("::") 113 | impl_cls._query_methods[f'{cls_name}::{camel_to_snake(method_name)}'] = impl_fn # type: ignore 114 | impl_cls._query_methods[f'{cls_name}::{snake_to_camel(method_name)}'] = impl_fn # type: ignore 115 | 116 | def start(self): 117 | from .activity_loop import activity_task_loop_func 118 | from .decision_loop import decision_task_loop_func 119 | self.threads_stopped = 0 120 | self.threads_started = 0 121 | self.stop_requested = False 122 | if self.activities: 123 | for i in range(0, self.num_activity_tasks): 124 | asyncio.create_task(activity_task_loop_func(self)) 125 | self.threads_started += 1 126 | if self.workflow_methods: 127 | for i in range(0, self.num_worker_tasks): 128 | decision_task_loop_func(self) 129 | self.threads_started += 1 130 | 131 | async def stop(self, background=False): 132 | self.stop_requested = True 133 | if background: 134 | return 135 | else: 136 | while self.threads_stopped != self.threads_started: 137 | await asyncio.sleep(5) 138 | 139 | def is_stopped(self): 140 | return self.threads_stopped == self.threads_started 141 | 142 | def is_stop_requested(self): 143 | return self.stop_requested 144 | 145 | def notify_thread_stopped(self): 146 | self.threads_stopped += 1 147 | 148 | def get_workflow_method(self, workflow_type_name: str) -> Tuple[type, Callable]: 149 | return self.workflow_methods[workflow_type_name] 150 | 151 | def set_timeout(self, timeout): 152 | self.timeout = timeout 153 | 154 | def get_timeout(self): 155 | return self.timeout 156 | 157 | def raise_if_stop_requested(self): 158 | if self.is_stop_requested(): 159 | raise StopRequestedException() 160 | 161 | 162 | class StopRequestedException(Exception): 163 | pass 164 | -------------------------------------------------------------------------------- /temporal/api/command/v1.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # sources: temporal/api/command/v1/message.proto 3 | # plugin: python-betterproto 4 | from dataclasses import dataclass 5 | from datetime import timedelta 6 | from typing import Dict 7 | 8 | import betterproto 9 | 10 | from temporal.api.common import v1 as v1common 11 | from temporal.api.enums import v1 as v1enums 12 | from temporal.api.failure import v1 as v1failure 13 | from temporal.api.taskqueue import v1 as v1taskqueue 14 | 15 | 16 | @dataclass 17 | class ScheduleActivityTaskCommandAttributes(betterproto.Message): 18 | activity_id: str = betterproto.string_field(1) 19 | activity_type: v1common.ActivityType = betterproto.message_field(2) 20 | namespace: str = betterproto.string_field(3) 21 | task_queue: v1taskqueue.TaskQueue = betterproto.message_field(4) 22 | header: v1common.Header = betterproto.message_field(5) 23 | input: v1common.Payloads = betterproto.message_field(6) 24 | # (-- api-linter: core::0140::prepositions=disabled aip.dev/not- 25 | # precedent: "to" is used to indicate interval. --) Indicates how long the 26 | # caller is willing to wait for an activity completion. Limits for how long 27 | # retries are happening. Either this or start_to_close_timeout_seconds must 28 | # be specified. When not specified defaults to the workflow execution 29 | # timeout. 30 | schedule_to_close_timeout: timedelta = betterproto.message_field(7) 31 | # (-- api-linter: core::0140::prepositions=disabled aip.dev/not- 32 | # precedent: "to" is used to indicate interval. --) Limits time an activity 33 | # task can stay in a task queue before a worker picks it up. This timeout is 34 | # always non retryable as all a retry would achieve is to put it back into 35 | # the same queue. Defaults to schedule_to_close_timeout_seconds or workflow 36 | # execution timeout if not specified. 37 | schedule_to_start_timeout: timedelta = betterproto.message_field(8) 38 | # (-- api-linter: core::0140::prepositions=disabled aip.dev/not- 39 | # precedent: "to" is used to indicate interval. --) Maximum time an activity 40 | # is allowed to execute after a pick up by a worker. This timeout is always 41 | # retryable. Either this or schedule_to_close_timeout_seconds must be 42 | # specified. 43 | start_to_close_timeout: timedelta = betterproto.message_field(9) 44 | # Maximum time between successful worker heartbeats. 45 | heartbeat_timeout: timedelta = betterproto.message_field(10) 46 | # Activities are provided by a default retry policy controlled through the 47 | # service dynamic configuration. Retries are happening up to 48 | # schedule_to_close_timeout. To disable retries set 49 | # retry_policy.maximum_attempts to 1. 50 | retry_policy: v1common.RetryPolicy = betterproto.message_field(11) 51 | 52 | 53 | @dataclass 54 | class RequestCancelActivityTaskCommandAttributes(betterproto.Message): 55 | scheduled_event_id: int = betterproto.int64_field(1) 56 | 57 | 58 | @dataclass 59 | class StartTimerCommandAttributes(betterproto.Message): 60 | timer_id: str = betterproto.string_field(1) 61 | # (-- api-linter: core::0140::prepositions=disabled aip.dev/not- 62 | # precedent: "to" is used to indicate interval. --) 63 | start_to_fire_timeout: timedelta = betterproto.message_field(2) 64 | 65 | 66 | @dataclass 67 | class CompleteWorkflowExecutionCommandAttributes(betterproto.Message): 68 | result: v1common.Payloads = betterproto.message_field(1) 69 | 70 | 71 | @dataclass 72 | class FailWorkflowExecutionCommandAttributes(betterproto.Message): 73 | failure: v1failure.Failure = betterproto.message_field(1) 74 | 75 | 76 | @dataclass 77 | class CancelTimerCommandAttributes(betterproto.Message): 78 | timer_id: str = betterproto.string_field(1) 79 | 80 | 81 | @dataclass 82 | class CancelWorkflowExecutionCommandAttributes(betterproto.Message): 83 | details: v1common.Payloads = betterproto.message_field(1) 84 | 85 | 86 | @dataclass 87 | class RequestCancelExternalWorkflowExecutionCommandAttributes(betterproto.Message): 88 | namespace: str = betterproto.string_field(1) 89 | workflow_id: str = betterproto.string_field(2) 90 | run_id: str = betterproto.string_field(3) 91 | control: str = betterproto.string_field(4) 92 | child_workflow_only: bool = betterproto.bool_field(5) 93 | 94 | 95 | @dataclass 96 | class SignalExternalWorkflowExecutionCommandAttributes(betterproto.Message): 97 | namespace: str = betterproto.string_field(1) 98 | execution: v1common.WorkflowExecution = betterproto.message_field(2) 99 | signal_name: str = betterproto.string_field(3) 100 | input: v1common.Payloads = betterproto.message_field(4) 101 | control: str = betterproto.string_field(5) 102 | child_workflow_only: bool = betterproto.bool_field(6) 103 | 104 | 105 | @dataclass 106 | class UpsertWorkflowSearchAttributesCommandAttributes(betterproto.Message): 107 | search_attributes: v1common.SearchAttributes = betterproto.message_field(1) 108 | 109 | 110 | @dataclass 111 | class RecordMarkerCommandAttributes(betterproto.Message): 112 | marker_name: str = betterproto.string_field(1) 113 | details: Dict[str, v1common.Payloads] = betterproto.map_field( 114 | 2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE 115 | ) 116 | header: v1common.Header = betterproto.message_field(3) 117 | failure: v1failure.Failure = betterproto.message_field(4) 118 | 119 | 120 | @dataclass 121 | class ContinueAsNewWorkflowExecutionCommandAttributes(betterproto.Message): 122 | workflow_type: v1common.WorkflowType = betterproto.message_field(1) 123 | task_queue: v1taskqueue.TaskQueue = betterproto.message_field(2) 124 | input: v1common.Payloads = betterproto.message_field(3) 125 | # workflow_execution_timeout is omitted as it shouldn'be overridden from 126 | # within a workflow. Timeout of a single workflow run. 127 | workflow_run_timeout: timedelta = betterproto.message_field(4) 128 | # Timeout of a single workflow task. 129 | workflow_task_timeout: timedelta = betterproto.message_field(5) 130 | backoff_start_interval: timedelta = betterproto.message_field(6) 131 | retry_policy: v1common.RetryPolicy = betterproto.message_field(7) 132 | initiator: v1enums.ContinueAsNewInitiator = betterproto.enum_field(8) 133 | failure: v1failure.Failure = betterproto.message_field(9) 134 | last_completion_result: v1common.Payloads = betterproto.message_field(10) 135 | cron_schedule: str = betterproto.string_field(11) 136 | header: v1common.Header = betterproto.message_field(12) 137 | memo: v1common.Memo = betterproto.message_field(13) 138 | search_attributes: v1common.SearchAttributes = betterproto.message_field(14) 139 | 140 | 141 | @dataclass 142 | class StartChildWorkflowExecutionCommandAttributes(betterproto.Message): 143 | namespace: str = betterproto.string_field(1) 144 | workflow_id: str = betterproto.string_field(2) 145 | workflow_type: v1common.WorkflowType = betterproto.message_field(3) 146 | task_queue: v1taskqueue.TaskQueue = betterproto.message_field(4) 147 | input: v1common.Payloads = betterproto.message_field(5) 148 | # Total workflow execution timeout including retries and continue as new. 149 | workflow_execution_timeout: timedelta = betterproto.message_field(6) 150 | # Timeout of a single workflow run. 151 | workflow_run_timeout: timedelta = betterproto.message_field(7) 152 | # Timeout of a single workflow task. 153 | workflow_task_timeout: timedelta = betterproto.message_field(8) 154 | # Default: PARENT_CLOSE_POLICY_TERMINATE. 155 | parent_close_policy: v1enums.ParentClosePolicy = betterproto.enum_field(9) 156 | control: str = betterproto.string_field(10) 157 | # Default: WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. 158 | workflow_id_reuse_policy: v1enums.WorkflowIdReusePolicy = betterproto.enum_field(11) 159 | retry_policy: v1common.RetryPolicy = betterproto.message_field(12) 160 | cron_schedule: str = betterproto.string_field(13) 161 | header: v1common.Header = betterproto.message_field(14) 162 | memo: v1common.Memo = betterproto.message_field(15) 163 | search_attributes: v1common.SearchAttributes = betterproto.message_field(16) 164 | 165 | 166 | @dataclass 167 | class Command(betterproto.Message): 168 | command_type: v1enums.CommandType = betterproto.enum_field(1) 169 | schedule_activity_task_command_attributes: "ScheduleActivityTaskCommandAttributes" = betterproto.message_field( 170 | 2, group="attributes" 171 | ) 172 | start_timer_command_attributes: "StartTimerCommandAttributes" = betterproto.message_field( 173 | 3, group="attributes" 174 | ) 175 | complete_workflow_execution_command_attributes: "CompleteWorkflowExecutionCommandAttributes" = betterproto.message_field( 176 | 4, group="attributes" 177 | ) 178 | fail_workflow_execution_command_attributes: "FailWorkflowExecutionCommandAttributes" = betterproto.message_field( 179 | 5, group="attributes" 180 | ) 181 | request_cancel_activity_task_command_attributes: "RequestCancelActivityTaskCommandAttributes" = betterproto.message_field( 182 | 6, group="attributes" 183 | ) 184 | cancel_timer_command_attributes: "CancelTimerCommandAttributes" = betterproto.message_field( 185 | 7, group="attributes" 186 | ) 187 | cancel_workflow_execution_command_attributes: "CancelWorkflowExecutionCommandAttributes" = betterproto.message_field( 188 | 8, group="attributes" 189 | ) 190 | request_cancel_external_workflow_execution_command_attributes: "RequestCancelExternalWorkflowExecutionCommandAttributes" = betterproto.message_field( 191 | 9, group="attributes" 192 | ) 193 | record_marker_command_attributes: "RecordMarkerCommandAttributes" = betterproto.message_field( 194 | 10, group="attributes" 195 | ) 196 | continue_as_new_workflow_execution_command_attributes: "ContinueAsNewWorkflowExecutionCommandAttributes" = betterproto.message_field( 197 | 11, group="attributes" 198 | ) 199 | start_child_workflow_execution_command_attributes: "StartChildWorkflowExecutionCommandAttributes" = betterproto.message_field( 200 | 12, group="attributes" 201 | ) 202 | signal_external_workflow_execution_command_attributes: "SignalExternalWorkflowExecutionCommandAttributes" = betterproto.message_field( 203 | 13, group="attributes" 204 | ) 205 | upsert_workflow_search_attributes_command_attributes: "UpsertWorkflowSearchAttributesCommandAttributes" = betterproto.message_field( 206 | 14, group="attributes" 207 | ) 208 | --------------------------------------------------------------------------------