├── .env.examples
├── .gitignore
├── LICENSE
├── README.md
├── examples
├── client.py
├── data
│ ├── demo.json
│ ├── single_code_block.json
│ ├── single_custom_tool.json
│ ├── single_intention.json
│ ├── single_knowledge_base.json
│ ├── single_llm.json
│ ├── single_logic_branches.json
│ ├── single_sp_app.json
│ ├── single_tool_blip.json
│ ├── single_tool_google.json
│ ├── single_tool_web_reader.json
│ └── single_tts_gpt_sovits.json
└── server.py
├── logo.png
├── poetry.lock
├── pyproject.toml
└── src
├── __init__.py
└── argo_workflow_runner
├── __init__.py
├── cli.py
├── client
└── __init__.py
├── configs.py
├── core
├── __init__.py
├── exec_node.py
├── llm_memory.py
├── schema.py
├── server.py
└── workflow_manager.py
├── env_settings.py
├── modules
├── __init__.py
├── agent.py
├── code_block.py
├── custom_tool.py
├── intention.py
├── knowledge_base.py
├── llm.py
├── logic_branches.py
├── sp_app.py
├── tool_blip.py
├── tool_evluate.py
├── tool_google.py
├── tool_web_reader.py
└── tts.py
└── utils
├── __init__.py
├── llm.py
├── sse_client.py
├── tts.py
├── web_search.py
└── ws_messager.py
/.env.examples:
--------------------------------------------------------------------------------
1 | LLM_BASE_URL = 'https://api.openai.com/v1'
2 | LLM_KEY = ''
3 | LLM_MODEL="gpt-3.5-turbo"
4 | TTS_GPT_SOVITS_URL = 'http://192.168.31.182:5000/tts'
5 |
6 | TAVILY_API_KEY = ''
7 | SERP_API_KEY = ''
8 |
9 | DOWNLOAD_URL_FMT = 'http://192.168.31.182:9101/file/{file_name}'
10 | BLIP_URL = 'http://127.0.0.1:7801/blip/api/generate'
11 | KB_URL = 'http://127.0.0.1:13001/knowledge/search'
12 |
13 | REDIS_URL = 'redis://localhost'
14 |
15 | RESTFUL_SERVER_HOST="127.0.0.1"
16 | RESTFUL_SERVER_PORT=8003
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # Virtual Environment
24 | .env
25 | .venv
26 | env/
27 | venv/
28 | ENV/
29 |
30 | # IDE
31 | .idea/
32 | .vscode/
33 | *.swp
34 | *.swo
35 |
36 | # Test
37 | .coverage
38 | htmlcov/
39 | .pytest_cache/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI3lab/argo-workflow-runner/0e177ea21748302ccff23ba61460ce920f17659e/LICENSE
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ---
10 |
11 | workflow runner engine for argo
12 |
13 | ## Design goals
14 |
15 | - Distributed agentic workflow support
16 | - Highly extensible - integrate your own tools
17 | - Web3 Support
18 | - Frameworks have many built-in modules
19 |
20 |
21 | We are in the process of heavy development
22 |
23 | ## Quick Start
24 |
25 | ### Prerequisites
26 | - Python 3.10+
27 | - Redis
28 |
29 | 1. Install Poetry
30 |
31 | We use [poetry](https://python-poetry.org/docs/#installing-with-the-official-installer) to build this project.
32 |
33 |
34 | ```bash
35 | curl -sSL https://install.python-poetry.org | python3 - --version 1.8.4
36 | ...
37 | poetry install
38 | ```
39 |
40 | 2. Install redis
41 |
42 | Install them according to different operating systems or use Docker for installation.
43 |
44 | 3. Configure
45 |
46 | Copy .env.example to .env and modify it to meet your enviroment
47 |
48 | 4. Start workflow server and use client to test
49 |
50 | ```bash
51 | :~/train/argo-workflow-runner/examples$ python server.py
52 | INFO: Started server process [202004]
53 | INFO: Waiting for application startup.
54 | INFO: Application startup complete.
55 | 2025-01-16 10:47:03 - server.py[line:715] - INFO: server listening on 127.0.0.1:8004
56 | 2025-01-16 10:47:03 - server.py[line:184] - INFO: websocket server started
57 | INFO: Uvicorn running on http://127.0.0.1:8003 (Press CTRL+C to quit)
58 |
59 |
60 | ~/train/argo-workflow-runner/examples$ python client.py -s data/single_llm.json
61 | ws://127.0.0.1:8004/?workflow_id=e112556a-492a-4aa2-9dc2-c00bc75d7f38
62 | 2025-01-16 10:50:19 - client.py[line:56] - INFO: ws connected: ws://127.0.0.1:8004/?workflow_id=e112556a-492a-4aa2-9dc2-c00bc75d7f38
63 | 2025-01-16 10:50:19 - client.py[line:95] - INFO: recv msg: {"type":"enter","node_id":"123","node_type":"llm"}
64 | あなたは私に、中国語のテキストを日本語に翻訳してほしいとおっしゃっています。元の意味を保ち、不要な言葉を避けるようにします。
65 | 2025-01-16 10:50:20 - client.py[line:95] - INFO: recv msg: {"type":"result","node_id":"123","node_type":"llm","data":{"result":"あなたは私に、中国語のテキストを日本語に翻訳してほしいとおっしゃっています。元の意味を保ち、不要な言葉を避けるようにします。"}}
66 | 2025-01-16 10:50:20 - client.py[line:97] - INFO: audio_len: 0
67 |
68 | ```
69 |
--------------------------------------------------------------------------------
/examples/client.py:
--------------------------------------------------------------------------------
1 | import json
2 | import asyncio
3 | import aiohttp
4 | import websockets
5 | import subprocess
6 | import shlex
7 | import argparse
8 | from typing import Dict
9 |
10 | from argo_workflow_runner.configs import logger
11 | from argo_workflow_runner.env_settings import settings
12 | from argo_workflow_runner.utils.ws_messager import WsMessageType
13 |
14 | class AudioPlayer():
15 | def __init__(self) -> None:
16 | self._player = None
17 |
18 | def recv_audio_bytes(self, audio_bytes):
19 | if self._player is None:
20 | ffplay_process = "ffplay -autoexit -nodisp -hide_banner -loglevel error -i pipe:0"
21 | self._player = subprocess.Popen(
22 | shlex.split(ffplay_process),
23 | stdin=subprocess.PIPE,
24 | bufsize=10*1024*1024,
25 | )
26 | logger.info('Audio player opened.')
27 |
28 | self._player.stdin.write(audio_bytes)
29 |
30 | def close(self):
31 | if self._player is None:
32 | return
33 | self._player.stdin.close()
34 | self._player.wait()
35 | self._player = None
36 | logger.info('Audio player closed.')
37 |
38 | async def send_executing_data(data: Dict, is_test=False):
39 | url = f'http://{settings.RESTFUL_SERVER_HOST}:{settings.RESTFUL_SERVER_PORT}/{"test_single_node" if is_test else "send_workflow"}'
40 | async with aiohttp.ClientSession() as session:
41 | async with session.post(url, json=data) as resp:
42 | resp.raise_for_status()
43 |
44 | resp_json = await resp.json()
45 | return resp_json['workflow_id']
46 |
47 | async def main(data: Dict, is_test=False):
48 | workflow_id = await send_executing_data(data, is_test)
49 |
50 |
51 | url = f'ws://{settings.WS_SERVER_HOST}:{settings.WS_SERVER_PORT}/?workflow_id={workflow_id}'
52 | print(url)
53 | async with websockets.connect(
54 | url, max_size=settings.WS_MAX_SIZE,
55 | ) as websocket:
56 | logger.info(f'ws connected: {url}')
57 | audio_len = 0
58 | text_appending = False
59 | _fp = None
60 | async for message in websocket:
61 | if type(message) == bytes:
62 | msg_type = int(message[0])
63 | message = message[1:]
64 | if msg_type == WsMessageType.AUDIO.value:
65 | if len(message) > 0:
66 | if _fp is None:
67 | _fp = open('output.wav', 'wb')
68 | _fp.write(message)
69 | audio_len += len(message)
70 | # audio_player.recv_audio_bytes(message)
71 | else:
72 | logger.info('recv audio end.')
73 | if _fp is not None:
74 | chunk_size = audio_len - 44
75 | _fp.seek(40, 0)
76 | _fp.write(chunk_size.to_bytes(4, byteorder='little'))
77 |
78 | _fp.close()
79 | _fp = None
80 | continue
81 |
82 | msg = json.loads(message)
83 | if msg['type'] == 'text':
84 | data = msg['data']
85 | if not data['is_end']:
86 | print(data['text'], end='')
87 | text_appending = True
88 | else:
89 | if text_appending:
90 | print(data['text'])
91 | else:
92 | logger.info(f'recv msg: {message}')
93 | text_appending = False
94 | else:
95 | logger.info(f'recv msg: {message}')
96 |
97 | logger.info(f'audio_len: {audio_len}')
98 |
99 | # audio_player.close()
100 |
101 | if __name__ == '__main__':
102 | parser = argparse.ArgumentParser(description="terminal for testing workflow runner.")
103 |
104 | parser.add_argument(
105 | '-w',
106 | '--workflow',
107 | help=("workflow file path")
108 | )
109 | parser.add_argument(
110 | '-s',
111 | '--single',
112 | help=("single node file path")
113 | )
114 |
115 | args = parser.parse_args()
116 |
117 | if args.workflow:
118 | with open(args.workflow, 'rb') as fp:
119 | workflow_info = json.load(fp)
120 | asyncio.run(main(workflow_info))
121 | elif args.single:
122 | with open(args.single, 'rb') as fp:
123 | single_node_info = json.load(fp)
124 | asyncio.run(main(single_node_info, True))
125 |
--------------------------------------------------------------------------------
/examples/data/demo.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "Nice to meet you."
4 | },
5 | "nodes": [
6 | {
7 | "id": "node_1",
8 | "type": "logic_branches",
9 | "config": {
10 | "branches": [
11 | {
12 | "name": "pictures in question",
13 | "conditions": [
14 | {
15 | "cond_param": "__file",
16 | "compare_type": "not_empty"
17 | }
18 | ]
19 | }
20 | ]
21 | }
22 | },
23 | {
24 | "id": "node_2",
25 | "type": "intention",
26 | "config": {
27 | "inputs": [
28 | "__text"
29 | ],
30 | "model": "openai/o1-mini",
31 | "memory_cnt": 0,
32 | "branches": [
33 | {
34 | "title": "MEME",
35 | "instruction": "When the user wants to generate images for MEME smart contract, select branch 'MEME'."
36 | },
37 | {
38 | "title": "web_reader",
39 | "instruction": "When the user give a website URL and wants to summarize the article, select branch 'web_reader'."
40 | },
41 | {
42 | "title": "google",
43 | "instruction": "When the user wants to query some realtime information, select branch 'google'."
44 | }
45 | ]
46 | }
47 | },
48 | {
49 | "id": "node_3",
50 | "type": "tool",
51 | "config": {
52 | "inputs": [
53 | "__file",
54 | "__file_transfer_type"
55 | ],
56 | "name": "blip"
57 | }
58 | },
59 | {
60 | "id": "node_4",
61 | "type": "sp_app",
62 | "config": {
63 | "inputs": [
64 | "__text"
65 | ],
66 | "name": "meme",
67 | "memory_cnt": 0
68 | }
69 | },
70 | {
71 | "id": "node_51",
72 | "type": "tool",
73 | "config": {
74 | "inputs": [
75 | "__text"
76 | ],
77 | "name": "web_reader"
78 | }
79 | },
80 | {
81 | "id": "node_52",
82 | "type": "llm",
83 | "config": {
84 | "inputs": [
85 | "node_51"
86 | ],
87 | "prompt": "As an information organization expert, please organize the input information in a clear and concise manner for easy reading",
88 | "prompt_params": [],
89 | "temperature": 0.2,
90 | "model": "openai/o1-mini",
91 | "memory_cnt": 0
92 | }
93 | },
94 | {
95 | "id": "node_61",
96 | "type": "tool",
97 | "config": {
98 | "inputs": [
99 | "__text"
100 | ],
101 | "name": "google"
102 | }
103 | },
104 | {
105 | "id": "node_62",
106 | "type": "llm",
107 | "config": {
108 | "inputs": [
109 | "node_61"
110 | ],
111 | "prompt": "You are an expert in information organization. Please organize the input information in a clear and concise manner to facilitate easy reading.",
112 | "prompt_params": [],
113 | "temperature": 0.2,
114 | "model": "openai/o1-mini",
115 | "memory_cnt": 0
116 | }
117 | },
118 | {
119 | "id": "node_7",
120 | "type": "llm",
121 | "config": {
122 | "inputs": [
123 | "__text"
124 | ],
125 | "prompt": "You are an AI assistant. Please engage in friendly conversation with the user.",
126 | "prompt_params": [],
127 | "temperature": 0.2,
128 | "model": "openai/o1-mini",
129 | "memory_cnt": 0
130 | }
131 | }
132 | ],
133 | "edges": [
134 | {
135 | "to_node": "node_1"
136 | },
137 | {
138 | "from_node": "node_1",
139 | "from_branch": 0,
140 | "to_node": "node_3"
141 | },
142 | {
143 | "from_node": "node_1",
144 | "from_branch": -1,
145 | "to_node": "node_2"
146 | },
147 | {
148 | "from_node": "node_2",
149 | "from_branch": 0,
150 | "to_node": "node_4"
151 | },
152 | {
153 | "from_node": "node_2",
154 | "from_branch": 1,
155 | "to_node": "node_51"
156 | },
157 | {
158 | "from_node": "node_51",
159 | "to_node": "node_52"
160 | },
161 | {
162 | "from_node": "node_2",
163 | "from_branch": 2,
164 | "to_node": "node_61"
165 | },
166 | {
167 | "from_node": "node_61",
168 | "to_node": "node_62"
169 | },
170 | {
171 | "from_node": "node_2",
172 | "from_branch": -1,
173 | "to_node": "node_7"
174 | },
175 | {
176 | "from_node": "node_3"
177 | },
178 | {
179 | "from_node": "node_4"
180 | },
181 | {
182 | "from_node": "node_52"
183 | },
184 | {
185 | "from_node": "node_62"
186 | },
187 | {
188 | "from_node": "node_7"
189 | }
190 | ]
191 | }
--------------------------------------------------------------------------------
/examples/data/single_code_block.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__arg1": "abc",
4 | "__arg2": "123"
5 | },
6 | "node": {
7 | "type": "code_block",
8 | "config": {
9 | "args": {
10 | "arg1": "__arg1",
11 | "arg2": "__arg2"
12 | },
13 | "code": "def main(arg1, arg2):\n \n import json\n res = json.dumps([arg1 + arg2])\n return {'result':res}\n"
14 | }
15 | }
16 | }
--------------------------------------------------------------------------------
/examples/data/single_custom_tool.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "node_1": {
4 | "symbol": "LTCUSDT"
5 | }
6 | },
7 | "node": {
8 | "type": "custom_tool",
9 | "config": {
10 | "inputs": [
11 | "node_1"
12 | ],
13 | "url": "https://api.binance.com/api/v3/ticker/price",
14 | "method": "GET",
15 | "headers": {},
16 | "name": "binance-spot-price"
17 | }
18 | }
19 | }
--------------------------------------------------------------------------------
/examples/data/single_intention.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "generate images for MEME smart contract"
4 | },
5 | "node": {
6 | "type": "intention",
7 | "config": {
8 | "inputs": [
9 | "__text"
10 | ],
11 | "model": "gpt-3.5-turbo",
12 | "memory_cnt": 0,
13 | "branches": [
14 | {
15 | "title": "MEME",
16 | "instruction": "When the user wants to generate images for MEME smart contract, select branch 'MEME'."
17 | },
18 | {
19 | "title": "web_reader",
20 | "instruction": "When the user give a website URL and wants to summarize the article, select branch 'web_reader'."
21 | },
22 | {
23 | "title": "google",
24 | "instruction": "When the user wants to query some realtime information, select branch 'google'."
25 | }
26 | ]
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/examples/data/single_knowledge_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__user_id": "0x76ee9597e931e09443caf20374dc0fe3d29e2020",
4 | "__text": "eth status conversion"
5 | },
6 | "node": {
7 | "type": "knowledge_base",
8 | "config": {
9 | "inputs": [
10 | "__text"
11 | ],
12 | "kb_name": "eth",
13 | "search_type": "vector",
14 | "similarity": 0.75,
15 | "cnt": 3
16 | }
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/examples/data/single_llm.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__session_id": "abc",
4 | "node_51": "Please translate it to japanaese。"
5 | },
6 | "node": {
7 | "id": "123",
8 | "type": "llm",
9 | "config": {
10 | "inputs": [
11 | "node_51"
12 | ],
13 | "prompt": "You are a chatbot",
14 | "prompt_params": [],
15 | "temperature": 0.2,
16 | "model": "gpt-3.5-turbo",
17 | "memory_cnt": 2
18 | }
19 | }
20 | }
--------------------------------------------------------------------------------
/examples/data/single_logic_branches.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "Nice to meet you."
4 | },
5 | "node": {
6 | "type": "logic_branches",
7 | "config": {
8 | "branches": [
9 | {
10 | "name": "include pictures",
11 | "conditions": [
12 | {
13 | "cond_param": "__file",
14 | "compare_type": "not_empty"
15 | }
16 | ]
17 | }
18 | ]
19 | }
20 | }
21 | }
--------------------------------------------------------------------------------
/examples/data/single_sp_app.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "Nice to meet you."
4 | },
5 | "node": {
6 | "type": "sp_app",
7 | "config": {
8 | "inputs": [
9 | "__text"
10 | ],
11 | "name": "meme",
12 | "memory_cnt": 0
13 | }
14 | }
15 | }
--------------------------------------------------------------------------------
/examples/data/single_tool_blip.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__file": "https://img.xoocity.com/learns/learn_min1.jpg",
4 | "__file_transfer_type": "url"
5 | },
6 | "node": {
7 | "type": "tool",
8 | "config": {
9 | "inputs": [
10 | "__file",
11 | "__file_transfer_type"
12 | ],
13 | "name": "blip"
14 | }
15 | }
16 | }
--------------------------------------------------------------------------------
/examples/data/single_tool_google.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "Current BTC Price."
4 | },
5 | "node": {
6 | "type": "tool",
7 | "config": {
8 | "inputs": [
9 | "__text"
10 | ],
11 | "name": "google"
12 | }
13 | }
14 | }
--------------------------------------------------------------------------------
/examples/data/single_tool_web_reader.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "https://www.163.com/news/article/JG2DM99E0001899O.html"
4 | },
5 | "node": {
6 | "type": "tool",
7 | "config": {
8 | "inputs": [
9 | "__text"
10 | ],
11 | "name": "web_reader"
12 | }
13 | }
14 | }
--------------------------------------------------------------------------------
/examples/data/single_tts_gpt_sovits.json:
--------------------------------------------------------------------------------
1 | {
2 | "start": {
3 | "__text": "Nice to meet you."
4 | },
5 | "node": {
6 | "type": "tts",
7 | "config": {
8 | "inputs": [
9 | "__text"
10 | ],
11 | "name": "gpt-sovits",
12 | "voice": "Emma"
13 | }
14 | }
15 | }
--------------------------------------------------------------------------------
/examples/server.py:
--------------------------------------------------------------------------------
1 | # examples/client.py
2 |
3 | import asyncio
4 | from argo_workflow_runner.core.server import WorkflowServer
5 |
6 |
7 | async def run_async_server():
8 | server = WorkflowServer()
9 | try:
10 | await server.start()
11 | except KeyboardInterrupt:
12 | await server.stop()
13 |
14 |
15 | if __name__ == "__main__":
16 |
17 | asyncio.run(run_async_server())
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI3lab/argo-workflow-runner/0e177ea21748302ccff23ba61460ce920f17659e/logo.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "argo-workflow-runner"
3 | version = "0.1.0"
4 | description = "A Python runner for Argo Workflows"
5 | authors = ["Sam Gao Hunter Han "]
6 | readme = "README.md"
7 |
8 | homepage = "https://ai3labs.net"
9 | repository = "https://github.com/yourusername/argo-workflow-runner"
10 | documentation = "https://argo-workflow-runner.readthedocs.io"
11 | packages = [
12 | { include = "argo_workflow_runner", from = "src" }
13 | ]
14 |
15 | classifiers = [
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | ]
20 |
21 | [tool.poetry.dependencies]
22 | python = ">=3.10.10,<4.0"
23 | langchain-community="^0.3.3"
24 | langchain-openai="^0.2.14"
25 | langgraph="^0.2.39"
26 | fastapi = "^0.115.6"
27 | redis = "^5.2.1"
28 | python-dotenv = "^1.0.1"
29 | uvicorn = "^0.34.0"
30 | python-decouple = "^3.8"
31 | aiofiles = "^24.1.0"
32 | websockets = "13.1"
33 | aiohttp = "^3.11.11"
34 | python-multipart = "^0.0.20"
35 | tavily-python = "^0.5.0"
36 | python-ulid = "^3.0.0"
37 | passlib = "^1.7.4"
38 | httpx = "^0.28.1"
39 |
40 | [tool.poetry.dev-dependencies]
41 | pytest = "^7.0.0"
42 | black = "^22.0.0"
43 | flake8 = "^4.0.0"
44 | mypy = "^0.950"
45 | pytest-cov = "^3.0.0"
46 |
47 | [tool.poetry.scripts]
48 | argo-runner = "argo_workflow_runner.cli:main"
49 |
50 | [build-system]
51 | requires = ["poetry-core"]
52 | build-backend = "poetry.core.masonry.api"
53 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI3lab/argo-workflow-runner/0e177ea21748302ccff23ba61460ce920f17659e/src/argo_workflow_runner/__init__.py
--------------------------------------------------------------------------------
/src/argo_workflow_runner/cli.py:
--------------------------------------------------------------------------------
1 |
2 | import click
3 | from argo_workflow_runner.core.server import WorkflowServer, ServerConfig
4 |
5 | @click.group()
6 | def cli():
7 | pass
8 |
9 | @cli.command()
10 | @click.option('--host', default='0.0.0.0', help='Server host')
11 | @click.option('--restful-port', default=8000, help='REST API port')
12 | @click.option('--websocket-port', default=8001, help='WebSocket port')
13 | @click.option('--upload-dir', default='./uploads', help='Upload directory')
14 | def serve(host, restful_port, websocket_port, upload_dir):
15 | """Start the workflow server"""
16 | config = ServerConfig(
17 | host=host,
18 | restful_port=restful_port,
19 | websocket_port=websocket_port,
20 | upload_dir=upload_dir
21 | )
22 | server = WorkflowServer(config)
23 | server.run()
24 |
25 | if __name__ == '__main__':
26 | cli()
--------------------------------------------------------------------------------
/src/argo_workflow_runner/client/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI3lab/argo-workflow-runner/0e177ea21748302ccff23ba61460ce920f17659e/src/argo_workflow_runner/client/__init__.py
--------------------------------------------------------------------------------
/src/argo_workflow_runner/configs.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from logging.handlers import RotatingFileHandler
3 | import os
4 | import sys
5 | from argo_workflow_runner.env_settings import settings
6 |
7 | USING_FILE_HANDLER = False
8 |
9 | def init_logger():
10 | try:
11 | import langchain
12 | langchain.verbose = False
13 | except:
14 | pass
15 |
16 | # log format
17 | LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
18 | logger = logging.getLogger('0xbot')
19 | logger.setLevel(logging.INFO)
20 |
21 | handlers = None
22 | if USING_FILE_HANDLER:
23 | # log dir
24 | LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
25 | if not os.path.exists(LOG_PATH):
26 | os.mkdir(LOG_PATH)
27 |
28 | proc_name = os.path.basename(sys.argv[0])
29 | proc_name = proc_name[0:-3]
30 | log_file = LOG_PATH + os.sep + f'{proc_name}.log'
31 | file_handler = RotatingFileHandler(
32 | filename=log_file,
33 | mode='a',
34 | maxBytes=(100*1024*1024),
35 | backupCount=10,
36 | encoding="utf-8",
37 | )
38 | handlers = [file_handler]
39 |
40 | logging.basicConfig(
41 | level=logging.INFO,
42 | handlers=handlers,
43 | format=LOG_FORMAT,
44 | datefmt="%Y-%m-%d %H:%M:%S"
45 | )
46 | logging.getLogger("requests").setLevel(logging.WARNING)
47 | logging.getLogger("urllib3").setLevel(logging.WARNING)
48 | logging.getLogger("websockets").setLevel(logging.INFO)
49 |
50 | return logger
51 |
52 | logger = init_logger()
53 |
54 |
55 | UPLOAD_DIR = os.path.join(settings.WORKING_DIR, "upload")
56 | if not os.path.exists(UPLOAD_DIR):
57 | os.mkdir(UPLOAD_DIR)
58 |
59 | COMPOSITE_MODULES = {'tool', 'custom_tool'}
60 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI3lab/argo-workflow-runner/0e177ea21748302ccff23ba61460ce920f17659e/src/argo_workflow_runner/core/__init__.py
--------------------------------------------------------------------------------
/src/argo_workflow_runner/core/exec_node.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, List, Any
2 | import websockets
3 | from langgraph.graph import END
4 |
5 | from argo_workflow_runner.configs import logger, COMPOSITE_MODULES
6 | from argo_workflow_runner.core.schema import ExecResponse
7 | from argo_workflow_runner.utils.llm import llm_chat_stream
8 |
9 | class ExecNode():
10 | def __init__(self, info: Dict[str, Any], websocket: websockets.WebSocketServerProtocol):
11 | self.id = info['id']
12 | node_key = info['type']
13 | if node_key in COMPOSITE_MODULES:
14 | name = info['config']['name']
15 | node_key = f'{node_key}/{name}'
16 | self.type = node_key
17 | self.config = info['config']
18 |
19 | self.websocket = websocket
20 |
21 | self.select_branch_id = -1
22 |
23 | async def try_to_execute(self, state: Dict):
24 | try:
25 | await self.execute(state)
26 | except Exception as e:
27 | await self.send_response(ExecResponse(
28 | type='error',
29 | node_id=self.id,
30 | node_type=self.type,
31 | data={
32 | 'msg': str(e),
33 | },
34 | ))
35 | async for text_chunk, done in error_report(str(e)):
36 | await self.send_response(ExecResponse(
37 | type='text',
38 | node_id=self.id,
39 | node_type=self.type,
40 | data={
41 | 'text': text_chunk,
42 | 'is_end': done,
43 | },
44 | ))
45 | raise e
46 |
47 | async def execute(self, state: Dict):
48 | logger.info(f'execute: {self.type}#{self.id}, state: {state.keys()}')
49 | await self.send_response(ExecResponse(
50 | type='enter',
51 | node_id=self.id,
52 | node_type=self.type,
53 | ))
54 |
55 | def select_branch(self):
56 | logger.error(f'Not implement select_branch method for node: {self.id}, {self.type}')
57 |
58 |
59 | async def send_response(self, rsp: ExecResponse):
60 | msg = rsp.model_dump_json(exclude_none=True)
61 | await self.websocket.send(msg)
62 |
63 | class BranchRouter():
64 | def __init__(self, from_node: ExecNode, edges: Dict[str, Any]):
65 | self.from_node = from_node
66 | self.edges = edges
67 |
68 | def run(self, state: Dict):
69 | for edge in self.edges:
70 | if edge['from_branch'] == self.from_node.select_branch_id:
71 | if 'to_node' in edge:
72 | return edge['to_node']
73 | else:
74 | return END
75 |
76 | raise Exception(f'No branch selected for node: {self.from_node.id}, {self.from_node.type}')
77 |
78 | ERROR_REPORT_PROMPT = """
79 | <|begin_of_text|><|start_header_id|>system<|end_header_id|>
80 | You are an assistant named Argo that manages agents running for the user.
81 | But sometimes errors may occur when the agents running.
82 | Here is the error:
83 | '{error}'
84 | You need to politely remind the user this error infomation.
85 | Please generate three sentence maximum to tell the user.
86 | <|eot_id|><|start_header_id|>user<|end_header_id|>
87 | <|eot_id|><|start_header_id|>assistant<|end_header_id|>
88 | """
89 |
90 | async def error_report(error: str):
91 | async for result in llm_chat_stream(ERROR_REPORT_PROMPT, {'error': error}):
92 | yield (result, False)
93 | yield ('', True)
94 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/core/llm_memory.py:
--------------------------------------------------------------------------------
1 | import redis.asyncio as redis
2 | import json
3 | from typing import List, Tuple
4 |
5 | from argo_workflow_runner.env_settings import settings
6 |
7 | class LLMMemoryManager():
8 | def __init__(self):
9 | self.pool = redis.ConnectionPool.from_url(settings.REDIS_URL)
10 |
11 | def get_client(self) -> redis.Redis:
12 | return redis.Redis(connection_pool=self.pool)
13 |
14 | async def close_client(self, client: redis.Redis):
15 | await client.aclose()
16 |
17 | async def close_pool(self):
18 | await self.pool.aclose()
19 |
20 | async def set_memory(self, session_id: str, node_id: str, human_words: str, ai_words: str, cnt: int):
21 | client = self.get_client()
22 |
23 | list_key = f'{session_id}/{node_id}'
24 | val = [human_words, ai_words]
25 | await client.rpush(list_key, json.dumps(val))
26 |
27 | await client.ltrim(list_key, 0, cnt)
28 |
29 | await self.close_client(client)
30 |
31 | async def get_memory(self, session_id: str, node_id: str) -> List[Tuple[str, str]]:
32 | client = self.get_client()
33 |
34 | list_key = f'{session_id}/{node_id}'
35 | values = await client.lrange(list_key, 0, -1)
36 |
37 | mem_list = []
38 | for val in values:
39 | arr = json.loads(val)
40 | mem_list.append(('human', arr[0]))
41 | mem_list.append(('ai', arr[1]))
42 |
43 | await self.close_client(client)
44 |
45 | return mem_list
46 |
47 | llm_memory_mgr = LLMMemoryManager()
48 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/core/schema.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 | from typing import Literal, Dict, Optional, List, Any
3 |
4 | class RspUpload(BaseModel):
5 | name: str = Field(description='Uploaded file name')
6 |
7 | class CommonModule(BaseModel):
8 | inputs: List[str]
9 |
10 | class StartNodeConfig(BaseModel):
11 | text: Optional[str] = None
12 | file: Optional[str] = None
13 | file_transfer_type: Optional[Literal['upload', 'url']] = 'url'
14 |
15 | class LogicBranchCondition(BaseModel):
16 | cond_param: str
17 | compare_type: Literal['include', 'not_include', 'equal', 'not_equal', 'empty', 'not_empty', 'start_with', 'end_with']
18 | cond_val: Optional[str] = None
19 | logic_relation: Optional[Literal['and', 'or']] = 'and'
20 |
21 | class LogicBranch(BaseModel):
22 | name: str
23 | conditions: List[LogicBranchCondition]
24 |
25 | class LogicBranchesConfig(BaseModel):
26 | branches: List[LogicBranch]
27 |
28 | class IntentionBranch(BaseModel):
29 | title: str
30 | instruction: str
31 |
32 | class IntentionConfig(CommonModule):
33 | model: str
34 | memory_cnt: Optional[int] = 0
35 | branches: List[IntentionBranch]
36 |
37 | class ToolConfig(CommonModule):
38 | name: str
39 |
40 |
41 |
42 | class CustomToolConfig(ToolConfig):
43 | url: str
44 | method: str
45 | headers: Dict[str, str]
46 |
47 | class TTSConfig(ToolConfig):
48 | voice: str
49 |
50 | class SpAppConfig(CommonModule):
51 | name: str
52 | memory_cnt: Optional[int] = 0
53 |
54 | class LLMConfig(CommonModule):
55 | prompt: str
56 | prompt_params: List[str]
57 | temperature: float = 0.0
58 | model: str
59 | memory_cnt: Optional[int] = 0
60 |
61 | class AgentConfig(CommonModule):
62 | name: str
63 | agent_id: str = None
64 |
65 | class CodeBlockConfig(BaseModel):
66 | args: Dict[str, str]
67 | code: str
68 |
69 | class KnowledgeBaseConfig(CommonModule):
70 | kb_name: str
71 | search_type: Literal['vector', 'enhance'] = 'vector'
72 | similarity: float = 0.75
73 | cnt: int = 3
74 |
75 | class EdgeConfig(BaseModel):
76 | from_node_id: Optional[str] = None
77 | from_branch_id: Optional[int] = None
78 | to_node_id: Optional[str] = None
79 |
80 | class ExecResponse(BaseModel):
81 | type: Literal['enter', 'result', 'error', 'text', 'json', 'images', 'billing', 'app']
82 | node_id: Optional[str] = None
83 | node_type: Optional[str] = None
84 | data: Optional[Dict] = None
85 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/core/server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional
3 | from dataclasses import dataclass
4 | import asyncio
5 | from urllib.parse import urlparse
6 | import uvicorn
7 | from fastapi import FastAPI, UploadFile
8 | from fastapi.responses import FileResponse
9 | import websockets
10 | import signal
11 | import os
12 | import uuid
13 | import aiofiles
14 | from pydantic import BaseModel, Field
15 | import uvicorn
16 | from typing import Dict
17 | from argo_workflow_runner.core.schema import RspUpload
18 | from argo_workflow_runner.env_settings import settings
19 | from argo_workflow_runner.core.llm_memory import llm_memory_mgr
20 | from argo_workflow_runner.configs import logger
21 | from argo_workflow_runner.core.workflow_manager import workflow_manager
22 |
23 |
24 | class RestfulServer(uvicorn.Server):
25 | def install_signal_handlers(self) -> None:
26 | pass
27 |
28 | class ServerManager():
29 | def __init__(self):
30 | self._restful_servers = []
31 | self._ws_servers = []
32 | self._is_running = True
33 | signal.signal(signal.SIGINT, lambda _, __: self.terminate_all())
34 |
35 | def reg_ws_server(self, server):
36 | self._ws_servers.append(server)
37 |
38 | def create_restful_server(self, config: uvicorn.Config):
39 | server = RestfulServer(config)
40 | self._restful_servers.append(server)
41 | return server
42 |
43 | def is_running(self):
44 | return self._is_running
45 |
46 | def terminate_all(self):
47 | for svr in self._ws_servers:
48 | svr.close()
49 | for svr in self._restful_servers:
50 | svr.should_exit = True
51 |
52 | self._is_running = False
53 |
54 | logger.info('Require to terminate all servers.')
55 |
56 |
57 | async def on_connected(websocket: websockets.WebSocketServerProtocol):
58 | try:
59 | logger.info('connected.')
60 | path = websocket.path
61 | parse_result = urlparse(path)
62 | query = parse_result.query
63 | prefix = 'workflow_id='
64 | if not query.startswith(prefix):
65 | logger.error('No workflow_id in query')
66 | await websocket.close(1000, "No workflow_id provided")
67 | return
68 |
69 | workflow_id = query[len(prefix):]
70 | logger.info(f'Received workflow_id: {workflow_id}')
71 |
72 | workflow = await workflow_manager.get_workflow(workflow_id)
73 | if not workflow:
74 | logger.error(f'No workflow found for id: {workflow_id}')
75 | await websocket.close(1000, "Workflow not found")
76 | return
77 |
78 | logger.info(f'run_workflow: {workflow_id}')
79 | await workflow_manager.run_workflow(workflow, websocket)
80 | await workflow_manager.rmv_workflow(workflow_id)
81 |
82 | except Exception as e:
83 | logger.error(f"Error in websocket connection: {str(e)}", exc_info=True)
84 | try:
85 | await websocket.close(1011, f"Internal server error: {str(e)}")
86 | except:
87 | pass #
88 |
89 |
90 |
91 |
92 |
93 | app = FastAPI()
94 |
95 |
96 | @app.post('/send_workflow')
97 | async def send_workflow(workflow: Dict):
98 | wf_id = await workflow_manager.add_workflow(workflow)
99 | return {
100 | 'workflow_id': wf_id,
101 | }
102 |
103 |
104 | @app.post('/test_single_node')
105 | async def test_single_node(node: Dict):
106 | wf_id = await workflow_manager.add_single_node_workflow(node)
107 | return {
108 | 'workflow_id': wf_id,
109 | }
110 |
111 |
112 | @app.get("/file/{file_name}")
113 | async def api_get_file(file_name: str):
114 | UPLOAD_DIR = os.path.join(settings.WORKING_DIR, "upload")
115 |
116 | file_path = os.path.join(UPLOAD_DIR, file_name)
117 | if os.path.exists(file_path):
118 | return FileResponse(path=file_path, filename=file_name)
119 | else:
120 | return {
121 | 'error': f'No file named {file_name}',
122 | }
123 |
124 |
125 |
126 | @app.post("/upload")
127 | async def api_upload(file: UploadFile) -> RspUpload:
128 | _, file_extension = os.path.splitext(file.filename)
129 | file_name = str(uuid.uuid4()) + file_extension
130 | UPLOAD_DIR = os.path.join(settings.WORKING_DIR, "upload")
131 |
132 | file_path = os.path.join(UPLOAD_DIR, file_name)
133 |
134 | try:
135 | contents = file.file.read()
136 | async with aiofiles.open(file_path, 'wb') as fp:
137 | await fp.write(contents)
138 | except Exception:
139 | return {"error": "There was an error uploading the file"}
140 | finally:
141 | file.file.close()
142 |
143 | return RspUpload(
144 | name=file_name,
145 | )
146 |
147 |
148 |
149 |
150 | class WorkflowServer:
151 | def __init__(self):
152 | self.server_manager = ServerManager()
153 |
154 | async def start(self):
155 | """Start both REST and WebSocket servers"""
156 | tasks = [
157 | self._start_restful_server(),
158 | self._start_ws_server()
159 | ]
160 | await asyncio.gather(*tasks)
161 |
162 | def run(self):
163 | """Synchronous method to start the server"""
164 | asyncio.run(self.start())
165 |
166 | async def stop(self):
167 | """Stop all servers"""
168 | self.server_manager.terminate_all()
169 |
170 | async def _start_restful_server(self):
171 | # Move restful server logic here
172 | config = uvicorn.Config("argo_workflow_runner.core.server:app", host=settings.RESTFUL_SERVER_HOST, port=settings.RESTFUL_SERVER_PORT, log_level="info")
173 | server = self.server_manager.create_restful_server(config)
174 | await server.serve()
175 | logger.info('restful server exit')
176 | async def _start_ws_server(self):
177 | server = await websockets.serve(
178 | on_connected,
179 | settings.WS_SERVER_HOST, settings.WS_SERVER_PORT,
180 | max_size=settings.WS_MAX_SIZE,
181 | start_serving=True,
182 | )
183 | self.server_manager.reg_ws_server(server)
184 | logger.info('websocket server started')
185 |
186 | await server.wait_closed()
187 |
188 | logger.info('websocket server exit')
189 |
190 | await llm_memory_mgr.close_pool()
191 | logger.info('llm_memory_mgr closed')
--------------------------------------------------------------------------------
/src/argo_workflow_runner/core/workflow_manager.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import uuid
3 | from typing import Dict, Optional, List, Any, Type
4 | import websockets
5 | from langgraph.graph import END, StateGraph
6 |
7 | from argo_workflow_runner.configs import logger, COMPOSITE_MODULES
8 | from argo_workflow_runner.core.exec_node import ExecNode, BranchRouter
9 | from argo_workflow_runner.modules import *
10 |
11 |
12 | class WorkflowManager():
13 | def __init__(self):
14 | self._workflow_map: Dict[str, Dict] = {}
15 | self._lock = asyncio.Lock()
16 | self._exec_node_map: Dict[str, Type[ExecNode]] = {
17 | 'logic_branches': LogicBranchesNode,
18 | 'intention': IntentionNode,
19 | 'llm': LLMNode,
20 | 'code_block': CodeBlockNode,
21 | 'custom_tool': CustomToolNode,
22 | 'sp_app':SpAppNode,
23 | 'agent': AgentNode,
24 | 'knowledge_base': KnowledgeBaseNode,
25 | 'tool/blip': ToolBlipNode,
26 | 'tool/google': ToolGoogleNode,
27 | 'tool/web_reader': ToolWebReaderNode,
28 | 'tts': TTSNode,
29 |
30 | }
31 |
32 | def add_module(self, module_type: str, module_class: Type[ExecNode]) -> None:
33 | """
34 | Add a new module type to the workflow manager
35 |
36 | Args:
37 | module_type: The type identifier for the module
38 | module_class: The ExecNode class implementation
39 |
40 | Raises:
41 | ValueError: If module_type already exists
42 | TypeError: If module_class is not a subclass of ExecNode
43 | """
44 | if not issubclass(module_class, ExecNode):
45 | raise TypeError(f"Module class must be a subclass of ExecNode")
46 |
47 | if module_type in self._exec_node_map:
48 | raise ValueError(f"Module type '{module_type}' already exists")
49 |
50 | self._exec_node_map[module_type] = module_class
51 |
52 | def get_module(self, module_type: str) -> Optional[Type[ExecNode]]:
53 | """
54 | Get module class by type
55 |
56 | Args:
57 | module_type: The type identifier for the module
58 |
59 | Returns:
60 | The ExecNode class if found, None otherwise
61 | """
62 | return self._exec_node_map.get(module_type)
63 |
64 | def list_modules(self) -> List[str]:
65 | """
66 | List all registered module types
67 |
68 | Returns:
69 | List of module type strings
70 | """
71 | return list(self._exec_node_map.keys())
72 |
73 | async def add_workflow(self, workflow: Dict) -> str:
74 | """
75 | Add a new workflow
76 |
77 | Args:
78 | workflow: Workflow configuration dictionary
79 |
80 | Returns:
81 | Workflow ID string
82 | """
83 | async with self._lock:
84 | while True:
85 | wf_id = str(uuid.uuid4())
86 | if wf_id in self._workflow_map:
87 | continue
88 | self._workflow_map[wf_id] = workflow
89 | return wf_id
90 |
91 | async def get_workflow(self, wf_id: str) -> Optional[Dict]:
92 | """
93 | Get workflow by ID
94 |
95 | Args:
96 | wf_id: Workflow ID
97 |
98 | Returns:
99 | Workflow configuration if found, None otherwise
100 | """
101 | async with self._lock:
102 | if wf_id not in self._workflow_map:
103 | return None
104 |
105 | return self._workflow_map[wf_id]
106 |
107 | async def rmv_workflow(self, wf_id: str):
108 | """
109 | Remove workflow by ID
110 |
111 | Args:
112 | wf_id: Workflow ID to remove
113 | """
114 | async with self._lock:
115 | if wf_id not in self._workflow_map:
116 | return
117 |
118 | del self._workflow_map[wf_id]
119 |
120 | async def add_single_node_workflow(self, node_info: Dict) -> str:
121 | """
122 | Create a workflow with a single node
123 |
124 | Args:
125 | node_info: Node configuration dictionary
126 |
127 | Returns:
128 | Workflow ID string
129 | """
130 | if 'id' in node_info['node']:
131 | node_id = node_info['node']['id']
132 | else:
133 | node_id = 'node_1'
134 | node_info['node']['id'] = node_id
135 | workflow = {
136 | 'start': node_info['start'],
137 | 'nodes': [
138 | node_info['node']
139 | ],
140 | 'edges': [
141 | {
142 | 'to_node': node_id,
143 | },
144 | {
145 | 'from_node': node_id,
146 | },
147 | ],
148 | }
149 | return await self.add_workflow(workflow)
150 |
151 | async def run_workflow(self, workflow: Dict, websocket: websockets.WebSocketServerProtocol):
152 | """
153 | Execute a workflow
154 |
155 | Args:
156 | workflow: Workflow configuration dictionary
157 | websocket: WebSocket connection for communication
158 |
159 | Raises:
160 | Exception: If node type is not found
161 | """
162 | graph_builder = StateGraph(dict)
163 |
164 | node_map: Dict[str, ExecNode] = {}
165 | for node_info in workflow['nodes']:
166 | node_key = node_info['type']
167 | if (node_key in COMPOSITE_MODULES) and (node_key != 'custom_tool'):
168 | name = node_info['config']['name']
169 | node_key = f'{node_key}/{name}'
170 |
171 | exec_node_cls = self._exec_node_map.get(node_key)
172 | if exec_node_cls is None:
173 | raise Exception(f'No such exec node: {node_key}')
174 |
175 | node: ExecNode = exec_node_cls(node_info, websocket)
176 | graph_builder.add_node(
177 | node.id,
178 | node.try_to_execute,
179 | )
180 | node_map[node.id] = node
181 |
182 | condition_edges = {}
183 | for edge in workflow['edges']:
184 | if 'from_node' not in edge:
185 | graph_builder.set_entry_point(edge['to_node'])
186 | elif 'to_node' not in edge:
187 | graph_builder.add_edge(edge['from_node'], END)
188 | else:
189 | if 'from_branch' in edge:
190 | if edge['from_node'] not in condition_edges:
191 | condition_edges[edge['from_node']] = []
192 | condition_edges[edge['from_node']].append(edge)
193 | else:
194 | graph_builder.add_edge(edge['from_node'], edge['to_node'])
195 |
196 | for from_node_id, edges in condition_edges.items():
197 | from_node = node_map[from_node_id]
198 | router = BranchRouter(from_node, edges)
199 | graph_builder.add_conditional_edges(
200 | from_node_id,
201 | router.run,
202 | )
203 |
204 | start_input: Dict = workflow.get('start', {})
205 | input = {}
206 | for k, v in start_input.items():
207 | input[k] = v
208 |
209 | graph = graph_builder.compile()
210 | logger.info(input)
211 | try:
212 | await graph.ainvoke(input)
213 | except Exception as e:
214 | logger.error('run_workflow error', exc_info=e)
215 |
216 |
217 | workflow_manager = WorkflowManager()
--------------------------------------------------------------------------------
/src/argo_workflow_runner/env_settings.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, List
2 | from pydantic import AnyHttpUrl, field_validator, BaseModel
3 | from dotenv import load_dotenv
4 | from decouple import config, Csv
5 |
6 | load_dotenv()
7 |
8 |
9 | class EnvSettings(BaseModel):
10 |
11 | LLM_MODEL: str = config("LLM_MODEL", default="gpt-3.5-turbo", cast=str)
12 | LLM_BASE_URL: str = config("LLM_BASE_URL", default="", cast=str)
13 | LLM_KEY: str = config("LLM_KEY", default="", cast=str)
14 | WORKING_DIR : str = config("WORKING_DIR", default="", cast=str)
15 |
16 | TTS_GPT_SOVITS_URL: str = config("TTS_GPT_SOVITS_URL", default="", cast=str)
17 |
18 | TAVILY_API_KEY : str = config("TAVILY_API_KEY", default="", cast=str)
19 | SERP_API_KEY : str = config("SERP_API_KEY", default="", cast=str)
20 |
21 | DOWNLOAD_URL_FMT : str = config("DOWNLOAD_URL_FMT", default="", cast=str)
22 | BLIP_URL : str = config("BLIP_URL", default="", cast=str)
23 | KB_URL : str = config("KB_URL", default="", cast=str)
24 |
25 | REDIS_URL : str = config("REDIS_URL", default="redis://localhost", cast=str)
26 |
27 | RESTFUL_SERVER_HOST: str = config("RESTFUL_SERVER_HOST", default="127.0.0.1", cast=str)
28 | RESTFUL_SERVER_PORT: int = config("RESTFUL_SERVER_PORT", default=8003, cast=int)
29 | WS_SERVER_HOST: str = config("WS_SERVER_HOST", default="127.0.0.1", cast=str)
30 | WS_SERVER_PORT: int = config("WS_SERVER_PORT", default=8004, cast=int)
31 | WS_MAX_SIZE: int = config("WS_SERVER_PORT", default=1*1024*1024, cast=int)
32 | WORKING_DIR: str = config("WORKING_DIR", default=".", cast=str)
33 |
34 | settings = EnvSettings()
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from argo_workflow_runner.modules.logic_branches import LogicBranchesNode
2 | from argo_workflow_runner.modules.intention import IntentionNode
3 | from argo_workflow_runner.modules.llm import LLMNode
4 | from argo_workflow_runner.modules.code_block import CodeBlockNode
5 | from argo_workflow_runner.modules.custom_tool import CustomToolNode
6 | from argo_workflow_runner.modules.tool_blip import ToolBlipNode
7 | from argo_workflow_runner.modules.agent import AgentNode
8 | from argo_workflow_runner.modules.knowledge_base import KnowledgeBaseNode
9 | from argo_workflow_runner.modules.tool_google import ToolGoogleNode
10 | from argo_workflow_runner.modules.tool_web_reader import ToolWebReaderNode
11 | from argo_workflow_runner.modules.tts import TTSNode
12 | from argo_workflow_runner.modules.sp_app import SpAppNode
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/agent.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from argo_workflow_runner.core.exec_node import ExecNode
4 | from argo_workflow_runner.core.schema import (
5 | AgentConfig,
6 | ExecResponse,
7 | )
8 | from argo_workflow_runner.configs import logger
9 | from argo_workflow_runner.configs import logger
10 |
11 | class AgentNode(ExecNode):
12 | def __init__(self, info, websocket):
13 | super().__init__(info, websocket)
14 | self.config_model = AgentConfig.model_validate(self.config)
15 |
16 | async def execute(self, state: Dict):
17 | await super().execute(state)
18 | param_obj = state.get(self.config_model.inputs[0], None)
19 | if param_obj is None:
20 | raise Exception(f'No available input: {self.config_model.inputs[0]}')
21 | params = {}
22 | for key in self.config_model.inputs:
23 | val = state.get(key, None)
24 | if val is not None:
25 | params[key] = val
26 |
27 | async with aiohttp.ClientSession(
28 | headers=self.config_model.headers,
29 | ) as session:
30 |
31 | kwargs = {
32 | 'text': param_obj,
33 | }
34 |
35 | async with session.request(method="POST", url=self.config_model.url, **kwargs) as resp:
36 | resp.raise_for_status()
37 | resp_text = await resp.text()
38 | logger.info(resp_text)
39 |
40 | state[self.id] = resp_text
41 |
42 | await self.send_response(ExecResponse(
43 | type='result',
44 | node_id=self.id,
45 | node_type=self.type,
46 | data={
47 | 'result': resp_text,
48 | },
49 | ))
50 |
51 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/code_block.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import resource
3 | from concurrent.futures import ProcessPoolExecutor
4 | from typing import Any, Dict, TypeAlias, TypeVar, cast, NamedTuple, Set
5 | from dataclasses import dataclass
6 | import asyncio
7 | import traceback
8 | import resource
9 |
10 | import asyncio
11 | import logging
12 | import os
13 |
14 | from concurrent.futures import ProcessPoolExecutor
15 | from dataclasses import dataclass
16 | from typing import Any, Dict, TypeAlias, TypeVar, cast, NamedTuple, Set
17 | from aiohttp import web
18 | import traceback
19 |
20 | import resource
21 |
22 |
23 | from argo_workflow_runner.core.exec_node import ExecNode
24 | from argo_workflow_runner.core.schema import (
25 | CodeBlockConfig,
26 | ExecResponse,
27 | )
28 |
29 | JsonDict: TypeAlias = Dict[str, Any]
30 | T = TypeVar('T')
31 |
32 |
33 | @dataclass
34 | class ExecutionResult:
35 | status: str
36 | result: Any | None = None
37 | error: str | None = None
38 | traceback: str | None = None
39 |
40 | def to_dict(self) -> JsonDict:
41 | return {k: v for k, v in self.__dict__.items() if v is not None}
42 |
43 | SAFE_BUILTINS: Set[str] = {
44 | 'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
45 | 'chr', 'complex', 'dict', 'divmod', 'enumerate', 'filter', 'float',
46 | 'format', 'frozenset', 'hash', 'hex', 'int', 'isinstance', 'issubclass',
47 | 'iter', 'len', 'list', 'map', 'max', 'min', 'next', 'oct', 'ord',
48 | 'pow', 'print', 'range', 'repr', 'reversed', 'round', 'set', 'slice',
49 | 'sorted', 'str', 'sum', 'tuple', 'type', 'zip'
50 | }
51 |
52 | FORBIDDEN_MODULES: Set[str] = {
53 | 'os', 'sys', 'subprocess', 'socket', 'requests', 'urllib',
54 | 'pathlib', 'pickle', 'shutil', 'importlib', 'builtins'
55 | }
56 | def create_safe_globals() -> Dict[str, Any]:
57 | import builtins
58 | safe_globals = {
59 | '__builtins__': {
60 | name: getattr(builtins, name)
61 | for name in SAFE_BUILTINS
62 | if hasattr(builtins, name)
63 | },
64 | 'print': lambda *args, **kwargs: None
65 | }
66 | return safe_globals
67 |
68 | def validate_code(code: str) -> None:
69 |
70 | for module in FORBIDDEN_MODULES:
71 | if f"import {module}" in code or f"from {module}" in code:
72 | raise SecurityError(f"Importing module '{module}' is not allowed")
73 |
74 | if 'eval(' in code or 'exec(' in code:
75 | raise SecurityError("Using eval() or exec() is not allowed")
76 |
77 | if 'open(' in code or 'file(' in code:
78 | raise SecurityError("File operations are not allowed")
79 |
80 | RESOURCE_LIMITS = {
81 | 'CPU_TIME': 1,
82 | 'MEMORY': 100 * 1024 * 1024, # 30MB
83 | 'FILE_SIZE': 1024 * 1024, # 1MB
84 | 'PROCESSES': 1,
85 | 'OPEN_FILES': 10
86 | }
87 |
88 | def set_resource_limits() -> None:
89 | """resource limit"""
90 | try:
91 | resource.setrlimit(resource.RLIMIT_CPU,
92 | (RESOURCE_LIMITS['CPU_TIME'], RESOURCE_LIMITS['CPU_TIME']))
93 |
94 | resource.setrlimit(resource.RLIMIT_AS,
95 | (RESOURCE_LIMITS['MEMORY'], RESOURCE_LIMITS['MEMORY']))
96 |
97 | resource.setrlimit(resource.RLIMIT_FSIZE,
98 | (RESOURCE_LIMITS['FILE_SIZE'], RESOURCE_LIMITS['FILE_SIZE']))
99 |
100 | resource.setrlimit(resource.RLIMIT_NPROC,
101 | (RESOURCE_LIMITS['PROCESSES'], RESOURCE_LIMITS['PROCESSES']))
102 |
103 | resource.setrlimit(resource.RLIMIT_NOFILE,
104 | (RESOURCE_LIMITS['OPEN_FILES'], RESOURCE_LIMITS['OPEN_FILES']))
105 |
106 | except Exception as e:
107 | logger.error(f"Failed to set resource limits: {e}")
108 | raise
109 | class SecurityError(Exception):
110 | pass
111 |
112 | class CodeBlockNode(ExecNode):
113 | @staticmethod
114 | async def execute_code_in_process(code: str, params: JsonDict) -> ExecutionResult:
115 | loop = asyncio.get_event_loop()
116 | with ProcessPoolExecutor(max_workers=1) as executor:
117 | try:
118 | future = loop.run_in_executor(
119 | executor,
120 | CodeBlockNode._execute_code_safely,
121 | code,
122 | params
123 | )
124 | result = await asyncio.wait_for(future, timeout=RESOURCE_LIMITS['CPU_TIME'])
125 | return result
126 | except asyncio.TimeoutError:
127 | return ExecutionResult(
128 | status='error',
129 | error='Execution timeout'
130 | )
131 | except Exception as e:
132 | return ExecutionResult(
133 | status='error',
134 | error=str(e),
135 | traceback=traceback.format_exc()
136 | )
137 |
138 | @staticmethod
139 | def _execute_code_safely(code: str, params: JsonDict) -> ExecutionResult:
140 | try:
141 | set_resource_limits()
142 |
143 | validate_code(code)
144 |
145 | globals_dict = create_safe_globals()
146 | globals_dict.update(params)
147 |
148 | compiled_code = compile(code, '', 'exec')
149 |
150 | exec(compiled_code, globals_dict)
151 |
152 | if 'main' not in globals_dict:
153 | return ExecutionResult(
154 | status='error',
155 | error="No main function defined"
156 | )
157 |
158 | result = globals_dict['main'](**params)
159 | return ExecutionResult(
160 | status='success',
161 | result=result
162 | )
163 |
164 | except SecurityError as e:
165 | return ExecutionResult(
166 | status='error',
167 | error=f"Security violation: {str(e)}"
168 | )
169 | except Exception as e:
170 | return ExecutionResult(
171 | status='error',
172 | error=str(e),
173 | traceback=traceback.format_exc()
174 | )
175 |
176 |
177 | def __init__(self, info, websocket):
178 | super().__init__(info, websocket)
179 | self.config_model = CodeBlockConfig.model_validate(self.config)
180 |
181 | async def execute(self, state: Dict):
182 | await super().execute(state)
183 |
184 | params = {}
185 | for arg_key, key in self.config_model.args.items():
186 | val = state.get(key, None)
187 | if val is None:
188 | raise Exception(f'No available input: {key} in {self.id}#{self.type}')
189 | params[arg_key] = val
190 |
191 |
192 | res_obj = await self.execute_code_in_process(self.config_model.code, params)
193 | logging.info(f"CodeBlockNode execute result: {res_obj}")
194 | if res_obj.status == 'success':
195 |
196 | state[self.id] = res_obj
197 |
198 | await self.send_response(ExecResponse(
199 | type='result',
200 | node_id=self.id,
201 | node_type=self.type,
202 | data={
203 | 'result': res_obj.result
204 | },
205 | ))
206 | else:
207 | await self.send_response(ExecResponse(
208 | type='error',
209 | node_id=self.id,
210 | node_type=self.type,
211 | data={
212 | 'msg': str(res_obj),
213 | },
214 | ))
215 |
216 | #TODO:Put the code_block in docker and standalone server
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/custom_tool.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import aiohttp
3 | from typing import Dict
4 |
5 | from argo_workflow_runner.core.exec_node import ExecNode
6 | from argo_workflow_runner.core.schema import (
7 | CustomToolConfig,
8 | ExecResponse,
9 | )
10 | from argo_workflow_runner.env_settings import settings
11 | from argo_workflow_runner.configs import logger
12 |
13 | class CustomToolNode(ExecNode):
14 | def __init__(self, info, websocket):
15 | super().__init__(info, websocket)
16 | self.config_model = CustomToolConfig.model_validate(self.config)
17 |
18 | async def execute(self, state: Dict):
19 | await super().execute(state)
20 |
21 | text = state.get(self.config_model.inputs[0], None)
22 | if text is None:
23 | raise Exception(f'human words should not be empty for {self.id}, {self.type}')
24 |
25 |
26 | # param_obj = state.get(self.config_model.inputs[0], None)
27 | # if param_obj is None:
28 | # raise Exception(f'No available input: {self.config_model.inputs[0]}')
29 | # fields = self.config_model.fields
30 | #
31 | # param_obj = {}
32 | # for field in fields:
33 | # param_obj[field] = await param_obj.get(field, None)
34 | #
35 | # if isinstance(param_obj, str):
36 | # param_obj = {"json": param_obj} #
37 |
38 | async with aiohttp.ClientSession(
39 | headers=self.config_model.headers,
40 | ) as session:
41 | # if self.config_model.method == 'GET':
42 | # kwargs = {
43 | # 'params': param_obj
44 | # }
45 | # else:
46 | # kwargs = {
47 | # 'json': param_obj
48 | # }
49 | #
50 | #
51 | # logger.info(f'param_obj: {param_obj}')
52 | # logger.info(f'Executing {self.config_model.method} {self.config_model.inputs[0]}')
53 | # logger.info(f'Params: {kwargs}')
54 | data = {
55 | "text": text,
56 | }
57 |
58 |
59 | async with session.request(method=self.config_model.method, url=self.config_model.url, json=data) as resp:
60 | resp.raise_for_status()
61 | resp_text = await resp.text()
62 |
63 | logger.info(resp_text)
64 |
65 | state[self.id] = resp_text
66 | logger.info(f'CustomToolNode result resp_text: {resp_text}')
67 |
68 | await self.send_response(ExecResponse(
69 | type='result',
70 | node_id=self.id,
71 | node_type=self.type,
72 | data={
73 | 'result': resp_text,
74 | },
75 | ))
76 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/intention.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
3 | from langchain_core.output_parsers import StrOutputParser
4 |
5 | from argo_workflow_runner.configs import logger
6 | from argo_workflow_runner.core.exec_node import ExecNode
7 | from argo_workflow_runner.core.schema import (
8 | IntentionConfig,
9 | IntentionBranch,
10 | ExecResponse,
11 | )
12 | from argo_workflow_runner.utils.llm import get_chat_model
13 | from argo_workflow_runner.configs import logger
14 |
15 | SYSTEM_TEMPLATE = """
16 | You are an AI assistant that determines the user's intention based on the user's words.
17 | The user may have the following intentions:
18 | {content}
19 | If none of the above intent options are relevant to the user's utterance, select branch 'none'.
20 | Please provide the intent branch option that is closest to the user's utterance, without preamble or explaination.
21 | """
22 | HUMAN_TEMPLATE = '{text}'
23 |
24 | class IntentionNode(ExecNode):
25 | def __init__(self, info, websocket):
26 | super().__init__(info, websocket)
27 | self.config_model = IntentionConfig.model_validate(self.config)
28 |
29 | async def execute(self, state: Dict):
30 | await super().execute(state)
31 |
32 | chat_model = await get_chat_model(model=self.config_model.model, temperature=0)
33 |
34 | system_message_prompt = SystemMessagePromptTemplate.from_template(SYSTEM_TEMPLATE)
35 | human_message_prompt = HumanMessagePromptTemplate.from_template(HUMAN_TEMPLATE)
36 | chat_prompt = ChatPromptTemplate.from_messages([
37 | system_message_prompt,
38 | human_message_prompt,
39 | ])
40 |
41 | chain = chat_prompt | chat_model | StrOutputParser()
42 |
43 | content = '\n'.join(map(lambda x: x.instruction, self.config_model.branches))
44 | input = {
45 | 'content': content,
46 | 'text': state.get(self.config_model.inputs[0], ''),
47 | }
48 | result = await chain.ainvoke(input)
49 | logger.info(f'selected branch: {result} in {self.id} ')
50 |
51 | self.select_branch_id = 0
52 | select_branch_name = 'default branch'
53 |
54 | if result != 'none':
55 | branch_cnt = len(self.config_model.branches)
56 | for idx in range(branch_cnt):
57 | branch = self.config_model.branches[idx]
58 | if branch.title == result:
59 | self.select_branch_id = idx
60 | select_branch_name = branch.title
61 | break
62 |
63 | result_data = {
64 | 'select_branch_id': self.select_branch_id,
65 | 'select_branch_name': select_branch_name,
66 | }
67 | logger.info(f'selected branch: {result_data} in {self.id}, {self.type}')
68 | state[self.id] = result_data
69 |
70 | await self.send_response(ExecResponse(
71 | type='result',
72 | node_id=self.id,
73 | node_type=self.type,
74 | data=result_data,
75 | ))
76 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/knowledge_base.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 | from typing import Dict
3 |
4 | from argo_workflow_runner.core.exec_node import ExecNode
5 | from argo_workflow_runner.core.schema import (
6 | KnowledgeBaseConfig,
7 | ExecResponse,
8 | )
9 | from argo_workflow_runner.env_settings import settings
10 | from argo_workflow_runner.configs import logger
11 |
12 | class KnowledgeBaseNode(ExecNode):
13 | def __init__(self, info, websocket):
14 | super().__init__(info, websocket)
15 | self.config_model = KnowledgeBaseConfig.model_validate(self.config)
16 |
17 | async def execute(self, state: Dict):
18 | await super().execute(state)
19 |
20 | user_id = state.get('__user_id', 'workflow_runner')
21 | query_info = state.get(self.config_model.inputs[0], None)
22 | if query_info is None:
23 | raise Exception(f'No available input: {self.config_model.inputs[0]}')
24 |
25 | async with aiohttp.ClientSession() as session:
26 | payload = {
27 | "knowledge_base_id": self.config_model.knowledge_base_id,
28 | "user_id": self.config_model.knowledge_user_id,
29 | "q": query_info,
30 | "similarity": self.config_model.similarity,
31 | "top_k": self.config_model.cnt,
32 | "search_mode": self.config_model.search_type,
33 | }
34 | url = settings.KB_URL
35 | async with session.post(url, json=payload) as resp:
36 | resp.raise_for_status()
37 | resp_json = await resp.json()
38 | result = resp_json['data']
39 |
40 | state[self.id] = result
41 |
42 | await self.send_response(ExecResponse(
43 | type='result',
44 | node_id=self.id,
45 | node_type=self.type,
46 | data={
47 | 'result': result,
48 | },
49 | ))
50 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/llm.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
3 | from langchain_core.output_parsers import StrOutputParser
4 | from argo_workflow_runner.configs import logger
5 |
6 | from argo_workflow_runner.core.exec_node import ExecNode
7 | from argo_workflow_runner.core.schema import (
8 | LLMConfig,
9 | ExecResponse,
10 | )
11 | from argo_workflow_runner.utils.llm import get_chat_model
12 | from argo_workflow_runner.core.llm_memory import llm_memory_mgr
13 |
14 | HUMAN_TEMPLATE = '{text}'
15 |
16 | class LLMNode(ExecNode):
17 | def __init__(self, info, websocket):
18 | super().__init__(info, websocket)
19 | self.config_model = LLMConfig.model_validate(self.config)
20 |
21 | async def execute(self, state: Dict):
22 | await super().execute(state)
23 |
24 | session_id = state.get('__session_id', '')
25 |
26 | chat_model = await get_chat_model(model=self.config_model.model, temperature=self.config_model.temperature)
27 |
28 | system_message_prompt = SystemMessagePromptTemplate.from_template(self.config_model.prompt)
29 | human_message_prompt = HumanMessagePromptTemplate.from_template(HUMAN_TEMPLATE)
30 |
31 | mem_list = []
32 | if self.config_model.memory_cnt > 0:
33 | mem_list = await llm_memory_mgr.get_memory(session_id, self.id)
34 |
35 | chat_prompt = ChatPromptTemplate.from_messages([
36 | system_message_prompt,
37 | *mem_list,
38 | human_message_prompt,
39 | ])
40 |
41 | chain = chat_prompt | chat_model | StrOutputParser()
42 |
43 | human_words = state.get(self.config_model.inputs[0], None)
44 | if human_words is None:
45 | raise Exception(f'human words should not be empty for {self.id}, {self.type}')
46 |
47 | input = {
48 | 'text': human_words,
49 | }
50 | for key in self.config_model.prompt_params:
51 | val = state.get(key, None)
52 | if val is not None:
53 | input[key] = val
54 |
55 | result = ''
56 | async for text_chunk in chain.astream(input):
57 | await self.send_response(ExecResponse(
58 | type='text',
59 | node_id=self.id,
60 | node_type=self.type,
61 | data={
62 | 'text': text_chunk,
63 | 'is_end': False,
64 | },
65 | ))
66 | result += text_chunk
67 | await self.send_response(ExecResponse(
68 | type='text',
69 | node_id=self.id,
70 | node_type=self.type,
71 | data={
72 | 'text': '',
73 | 'is_end': True,
74 | },
75 | ))
76 |
77 | state[self.id] = result
78 | logger.info(f'LLM result : {result}')
79 |
80 | await self.send_response(ExecResponse(
81 | type='result',
82 | node_id=self.id,
83 | node_type=self.type,
84 | data={
85 | 'result': result,
86 | },
87 | ))
88 |
89 | if self.config_model.memory_cnt > 0:
90 | await llm_memory_mgr.set_memory(session_id, self.id, human_words, result, self.config_model.memory_cnt)
91 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/logic_branches.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from argo_workflow_runner.core.exec_node import ExecNode
4 | from argo_workflow_runner.core.schema import (
5 | LogicBranchesConfig,
6 | LogicBranchCondition,
7 | ExecResponse,
8 | )
9 | from argo_workflow_runner.configs import logger
10 |
11 | class LogicBranchesNode(ExecNode):
12 | def __init__(self, info, websocket):
13 | super().__init__(info, websocket)
14 | self.config_model = LogicBranchesConfig.model_validate(self.config)
15 |
16 | async def execute(self, state: Dict):
17 | await super().execute(state)
18 |
19 | self.select_branch_id = 0
20 | select_branch_name = 'default'
21 | logger.info(f"Select branch state: {state}")
22 |
23 | branch_cnt = len(self.config_model.branches)
24 | for idx in range(branch_cnt):
25 | branch = self.config_model.branches[idx]
26 | matched = True
27 | for cond in branch.conditions:
28 | if cond.logic_relation == 'and':
29 | matched = (matched and self.logic_calculate(state, cond))
30 | else:
31 | matched = (matched or self.logic_calculate(state, cond))
32 |
33 | if matched:
34 | self.select_branch_id = idx
35 | select_branch_name = branch.name
36 | break
37 |
38 | result_data = {
39 | 'select_branch_id': self.select_branch_id,
40 | 'select_branch_name': select_branch_name,
41 | }
42 | logger.info(f"Select branch: {result_data}")
43 | state[self.id] = result_data
44 |
45 | await self.send_response(ExecResponse(
46 | type='result',
47 | node_id=self.id,
48 | node_type=self.type,
49 | data=result_data,
50 | ))
51 |
52 |
53 | def logic_calculate(self, state: Dict, cond: LogicBranchCondition) -> bool:
54 | state_val = str(state.get(cond.cond_param, ''))
55 | if cond.cond_val is None:
56 | cond_val = ''
57 | else:
58 | cond_val = str(cond.cond_val)
59 |
60 | compare_type = cond.compare_type
61 | if compare_type == 'include':
62 | return state_val.find(cond_val) != -1
63 |
64 | elif compare_type == 'not_include':
65 | return state_val.find(cond_val) == -1
66 |
67 | elif compare_type == 'equal':
68 | return state_val == cond_val
69 |
70 | elif compare_type == 'not_equal':
71 | return state_val != cond_val
72 |
73 | elif compare_type == 'empty':
74 | return state_val == ''
75 |
76 | elif compare_type == 'not_empty':
77 | return state_val != ''
78 |
79 | elif compare_type == 'start_with':
80 | return state_val.startswith(cond_val)
81 |
82 | elif compare_type == 'end_with':
83 | return state_val.endswith(cond_val)
84 |
85 | raise Exception(f'Invalid compare_type({compare_type}) for node: {self.id}')
86 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/sp_app.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from argo_workflow_runner.core.exec_node import ExecNode
4 | from argo_workflow_runner.core.schema import (
5 | SpAppConfig,
6 | ExecResponse,
7 | )
8 | from argo_workflow_runner.configs import logger
9 |
10 | class SpAppNode(ExecNode):
11 | def __init__(self, info, websocket):
12 | super().__init__(info, websocket)
13 | self.config_model = SpAppConfig.model_validate(self.config)
14 |
15 | async def execute(self, state: Dict):
16 | await super().execute(state)
17 |
18 | params = {}
19 | for key in self.config_model.inputs:
20 | val = state.get(key, None)
21 | if val is not None:
22 | params[key] = val
23 |
24 | await self.send_response(ExecResponse(
25 | type='app',
26 | node_id=self.id,
27 | node_type=self.type,
28 | data={
29 | 'name': self.config_model.name,
30 | 'params': params,
31 | },
32 | ))
33 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/tool_blip.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import aiohttp
3 | from typing import Dict
4 |
5 | from argo_workflow_runner.core.exec_node import ExecNode
6 | from argo_workflow_runner.core.schema import (
7 | ToolConfig,
8 | ExecResponse,
9 | )
10 | from argo_workflow_runner.env_settings import settings
11 | from argo_workflow_runner.configs import logger
12 |
13 |
14 |
15 | lock = asyncio.Lock()
16 |
17 | class ToolBlipNode(ExecNode):
18 | def __init__(self, info, websocket):
19 | super().__init__(info, websocket)
20 | self.config_model = ToolConfig.model_validate(self.config)
21 |
22 | async def execute(self, state: Dict):
23 | await super().execute(state)
24 |
25 | image_name = state.get(self.config_model.inputs[0], None)
26 | if image_name is None:
27 | raise Exception(f'No available input: {self.config_model.inputs[0]}')
28 |
29 | transfer_type = state.get('__file_transfer_type', None)
30 | if transfer_type is None:
31 | raise Exception('No transfer_type specified.')
32 | if transfer_type == 'upload':
33 | image_url = settings.DOWNLOAD_URL_FMT.format(file_name=image_name)
34 | else:
35 | image_url = image_name
36 |
37 | async with lock:
38 | async with aiohttp.ClientSession() as session:
39 | payload = {
40 | 'input': image_url,
41 | }
42 | url = settings.BLIP_URL
43 | async with session.post(url, json=payload) as resp:
44 | resp.raise_for_status()
45 | resp_json = await resp.json()
46 | result = resp_json['output']
47 |
48 | state[self.id] = result
49 |
50 | await self.send_response(ExecResponse(
51 | type='result',
52 | node_id=self.id,
53 | node_type=self.type,
54 | data={
55 | 'result': result,
56 | },
57 | ))
58 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/tool_evluate.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import aiohttp
3 | from typing import Dict
4 |
5 | from argo_workflow_runner.core.exec_node import ExecNode
6 | from argo_workflow_runner.core.schema import (
7 | ToolConfig,
8 | ExecResponse,
9 | )
10 | from argo_workflow_runner.env_settings import settings
11 | from argo_workflow_runner.configs import logger
12 |
13 |
14 |
15 | lock = asyncio.Lock()
16 |
17 | class ToolBlipNode(ExecNode):
18 | def __init__(self, info, websocket):
19 | super().__init__(info, websocket)
20 | self.config_model = ToolConfig.model_validate(self.config)
21 |
22 | async def execute(self, state: Dict):
23 | await super().execute(state)
24 |
25 | image_name = state.get(self.config_model.inputs[0], None)
26 | if image_name is None:
27 | raise Exception(f'No available input: {self.config_model.inputs[0]}')
28 |
29 | transfer_type = state.get('__file_transfer_type', None)
30 | if transfer_type is None:
31 | raise Exception('No transfer_type specified.')
32 | if transfer_type == 'upload':
33 | image_url = settings.DOWNLOAD_URL_FMT.format(file_name=image_name)
34 | else:
35 | image_url = image_name
36 |
37 | async with lock:
38 | async with aiohttp.ClientSession() as session:
39 | payload = {
40 | 'input': image_url,
41 | }
42 | url = settings.BLIP_URL
43 | async with session.post(url, json=payload) as resp:
44 | resp.raise_for_status()
45 | resp_json = await resp.json()
46 | result = resp_json['output']
47 |
48 | state[self.id] = result
49 |
50 | await self.send_response(ExecResponse(
51 | type='result',
52 | node_id=self.id,
53 | node_type=self.type,
54 | data={
55 | 'result': result,
56 | },
57 | ))
58 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/tool_google.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from argo_workflow_runner.core.exec_node import ExecNode
4 | from argo_workflow_runner.core.schema import (
5 | ToolConfig,
6 | ExecResponse,
7 | )
8 | from argo_workflow_runner.utils.web_search import serp_api_search
9 | from argo_workflow_runner.configs import logger
10 |
11 | class ToolGoogleNode(ExecNode):
12 | def __init__(self, info, websocket):
13 | super().__init__(info, websocket)
14 | self.config_model = ToolConfig.model_validate(self.config)
15 |
16 | async def execute(self, state: Dict):
17 | await super().execute(state)
18 |
19 | query_info = state.get(self.config_model.inputs[0], None)
20 | if query_info is None:
21 | raise Exception(f'No available input: {self.config_model.inputs[0]}')
22 |
23 | result = await serp_api_search(query_info)
24 |
25 | state[self.id] = result
26 |
27 | await self.send_response(ExecResponse(
28 | type='result',
29 | node_id=self.id,
30 | node_type=self.type,
31 | data={
32 | 'result': result,
33 | },
34 | ))
35 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/tool_web_reader.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 | from typing import Dict
3 |
4 | from argo_workflow_runner.core.exec_node import ExecNode
5 | from argo_workflow_runner.core.schema import (
6 | ToolConfig,
7 | ExecResponse,
8 | )
9 | from argo_workflow_runner.configs import logger
10 |
11 | class ToolWebReaderNode(ExecNode):
12 | def __init__(self, info, websocket):
13 | super().__init__(info, websocket)
14 | self.config_model = ToolConfig.model_validate(self.config)
15 |
16 | async def execute(self, state: Dict):
17 | await super().execute(state)
18 |
19 | url = state.get(self.config_model.inputs[0], None)
20 | if url is None:
21 | raise Exception(f'No available input: {self.config_model.inputs[0]}')
22 |
23 | result = ''
24 | async with aiohttp.ClientSession() as session:
25 | async with session.get(url) as resp:
26 | resp.raise_for_status()
27 |
28 | result = await resp.text()
29 |
30 | state[self.id] = result
31 |
32 | await self.send_response(ExecResponse(
33 | type='result',
34 | node_id=self.id,
35 | node_type=self.type,
36 | data={
37 | 'result': result,
38 | },
39 | ))
40 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/modules/tts.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import aiohttp
3 | from typing import Dict
4 |
5 | from argo_workflow_runner.core.exec_node import ExecNode
6 | from argo_workflow_runner.core.schema import (
7 | TTSConfig,
8 | ExecResponse,
9 | )
10 | import argo_workflow_runner.utils.tts as tts
11 | import argo_workflow_runner.utils.ws_messager as ws_messager
12 | from argo_workflow_runner.configs import logger
13 |
14 | class TTSNode(ExecNode):
15 | def __init__(self, info, websocket):
16 | super().__init__(info, websocket)
17 | self.config_model = TTSConfig.model_validate(self.config)
18 |
19 | async def execute(self, state: Dict):
20 | await super().execute(state)
21 |
22 | text = state.get(self.config_model.inputs[0], None)
23 | if text is None:
24 | raise Exception(f'No available input: {self.config_model.inputs[0]}, in {self.id}#{self.type}#{self.config_model.name}')
25 |
26 | async for audio in tts.fetch_audio_stream(text, self.config_model.voice):
27 | await ws_messager.ws_send_audio(self.websocket, audio)
28 |
29 | await ws_messager.ws_send_audio_end(self.websocket)
30 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI3lab/argo-workflow-runner/0e177ea21748302ccff23ba61460ce920f17659e/src/argo_workflow_runner/utils/__init__.py
--------------------------------------------------------------------------------
/src/argo_workflow_runner/utils/llm.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import ChatOpenAI
2 | from langchain.prompts import PromptTemplate
3 | from langchain.prompts.chat import ChatPromptTemplate
4 | from langchain.chains.summarize import load_summarize_chain
5 | from langchain_core.documents import Document
6 | from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
7 | from typing import Union, List, Tuple, Dict, Any, Literal
8 | from argo_workflow_runner.env_settings import settings
9 |
10 |
11 | async def get_chat_model(
12 | model,
13 | temperature=0,
14 | ) -> ChatOpenAI:
15 | return ChatOpenAI(
16 | temperature=temperature,
17 | base_url=settings.LLM_BASE_URL,
18 | api_key=settings.LLM_KEY,
19 | model=model,
20 | )
21 |
22 | async def llm_chat(
23 | prompt: Union[str, ChatPromptTemplate],
24 | input: Dict[str, Any],
25 | model='meta-llama/llama-3.2-11b-vision-instruct:free',
26 | temperature=0,
27 | ):
28 | chat_model = await get_chat_model(model=model, temperature=temperature)
29 |
30 | if type(prompt) == str:
31 | prompt_tpl = PromptTemplate.from_template(prompt)
32 | else:
33 | prompt_tpl = prompt
34 |
35 | chain = prompt_tpl | chat_model | StrOutputParser()
36 |
37 | result = await chain.ainvoke(input)
38 |
39 | return result
40 |
41 | async def llm_chat_stream(
42 | prompt: Union[str, ChatPromptTemplate],
43 | input: Dict[str, Any],
44 | model='meta-llama/llama-3.2-11b-vision-instruct:free',
45 | temperature=0,
46 | ):
47 | chat_model = await get_chat_model(model=model, temperature=temperature)
48 |
49 | if type(prompt) == str:
50 | prompt_tpl = PromptTemplate.from_template(prompt)
51 | else:
52 | prompt_tpl = prompt
53 |
54 | chain = prompt_tpl | chat_model | StrOutputParser()
55 | async for result in chain.astream(input):
56 | yield result
57 |
58 | async def summarize(
59 | docs: List[Document],
60 | model='meta-llama/llama-3.2-11b-vision-instruct:free',
61 | temperature=0,
62 | ):
63 | chat_model = await get_chat_model(model=model, temperature=temperature)
64 |
65 | chain = load_summarize_chain(chat_model, chain_type="stuff")
66 | result = await chain.ainvoke(docs, return_only_outputs=True)
67 |
68 | return result["output_text"]
69 |
70 | async def llm_analyze(
71 | prompt: Union[str, ChatPromptTemplate],
72 | input: Dict[str, Any],
73 | output_type: Literal['json', 'text']='json',
74 | model='meta-llama/llama-3.2-11b-vision-instruct:free',
75 | temperature=0,
76 | ):
77 | chat_model = await get_chat_model(model=model, temperature=temperature)
78 |
79 | if type(prompt) == str:
80 | prompt_tpl = PromptTemplate.from_template(prompt)
81 | else:
82 | prompt_tpl = prompt
83 |
84 | if output_type == 'json':
85 | chain = prompt_tpl | chat_model | JsonOutputParser()
86 | else:
87 | chain = prompt_tpl | chat_model | StrOutputParser()
88 |
89 | return await chain.ainvoke(input)
90 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/utils/sse_client.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Main module."""
3 | import asyncio
4 | import logging
5 | from datetime import timedelta
6 | from typing import Optional, Dict, Any
7 |
8 | import attr
9 | from aiohttp import hdrs, ClientSession, ClientConnectionError
10 | from multidict import MultiDict
11 | from yarl import URL
12 |
13 | READY_STATE_CONNECTING = 0
14 | READY_STATE_OPEN = 1
15 | READY_STATE_CLOSED = 2
16 |
17 | DEFAULT_RECONNECTION_TIME = timedelta(seconds=5)
18 | DEFAULT_MAX_CONNECT_RETRY = 5
19 | DEFAULT_MAX_READ_RETRY = 10
20 |
21 | CONTENT_TYPE_EVENT_STREAM = 'text/event-stream'
22 | LAST_EVENT_ID_HEADER = 'Last-Event-Id'
23 |
24 | _LOGGER = logging.getLogger(__name__)
25 |
26 |
27 | @attr.s(slots=True, frozen=True)
28 | class MessageEvent:
29 | """Represent DOM MessageEvent Interface
30 |
31 | .. seealso:: https://www.w3.org/TR/eventsource/#dispatchMessage section 4
32 | .. seealso:: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent
33 | """
34 | type = attr.ib(type=str)
35 | message = attr.ib(type=str)
36 | data = attr.ib(type=str)
37 | origin = attr.ib(type=str)
38 | last_event_id = attr.ib(type=str)
39 |
40 |
41 | class EventSource:
42 | """Represent EventSource Interface as an async context manager.
43 |
44 | .. code-block:: python
45 |
46 | from aiohttp_sse_client import client as sse_client
47 |
48 | async with sse_client.EventSource(
49 | 'https://stream.wikimedia.org/v2/stream/recentchange'
50 | ) as event_source:
51 | try:
52 | async for event in event_source:
53 | print(event)
54 | except ConnectionError:
55 | pass
56 |
57 | .. seealso:: https://www.w3.org/TR/eventsource/#eventsource
58 | """
59 | def __init__(self, url: str,
60 | option: Optional[Dict[str, Any]] = None,
61 | reconnection_time: timedelta = DEFAULT_RECONNECTION_TIME,
62 | max_connect_retry: int = DEFAULT_MAX_CONNECT_RETRY,
63 | session: Optional[ClientSession] = None,
64 | on_open=None,
65 | on_message=None,
66 | on_error=None,
67 | not_reconnect_when_ended=True,
68 | **kwargs):
69 | """Construct EventSource instance.
70 |
71 | :param url: specifies the URL to which to connect
72 | :param option: specifies the settings, if any,
73 | in the form of an Dict[str, Any]. Accept the "method" key for
74 | specifying the HTTP method with which connection
75 | should be established
76 | :param reconnection_time: wait time before try to reconnect in case
77 | connection broken
78 | :param session: specifies a aiohttp.ClientSession, if not, create
79 | a default ClientSession
80 | :param on_open: event handler for open event
81 | :param on_message: event handler for message event
82 | :param on_error: event handler for error event
83 | :param kwargs: keyword arguments will pass to underlying
84 | aiohttp request() method.
85 | """
86 | self._url = URL(url)
87 | self._ready_state = READY_STATE_CONNECTING
88 |
89 | if session is not None:
90 | self._session = session
91 | self._need_close_session = False
92 | else:
93 | self._session = ClientSession()
94 | self._need_close_session = True
95 |
96 | self._on_open = on_open
97 | self._on_message = on_message
98 | self._on_error = on_error
99 |
100 | self._reconnection_time = reconnection_time
101 | self._orginal_reconnection_time = reconnection_time
102 | self._max_connect_retry = max_connect_retry
103 | self._last_event_id = ''
104 | self._kwargs = kwargs
105 | if 'headers' not in self._kwargs:
106 | self._kwargs['headers'] = MultiDict()
107 |
108 | self._event_id = ''
109 | self._event_type = ''
110 | self._event_data = ''
111 |
112 | self._origin = None
113 | self._response = None
114 |
115 | self._method = 'GET' if option is None else option.get('method', 'GET')
116 |
117 | self._not_reconnect_when_ended = not_reconnect_when_ended
118 |
119 | def __enter__(self):
120 | """Use async with instead."""
121 | raise TypeError("Use async with instead")
122 |
123 | def __exit__(self, *exc):
124 | """Should exist in pair with __enter__ but never executed."""
125 | pass # pragma: no cover
126 |
127 | async def __aenter__(self) -> 'EventSource':
128 | """Connect and listen Server-Sent Event."""
129 | await self.connect(self._max_connect_retry)
130 | return self
131 |
132 | async def __aexit__(self, *exc):
133 | """Close connection and session if need."""
134 | await self.close()
135 | if self._need_close_session:
136 | await self._session.close()
137 | pass
138 |
139 | @property
140 | def url(self) -> URL:
141 | """Return URL to which to connect."""
142 | return self._url
143 |
144 | @property
145 | def ready_state(self) -> int:
146 | """Return ready state."""
147 | return self._ready_state
148 |
149 | def __aiter__(self):
150 | """Return"""
151 | return self
152 |
153 | async def __anext__(self) -> MessageEvent:
154 | """Process events"""
155 | if not self._response:
156 | raise ValueError
157 |
158 | # async for ... in StreamReader only split line by \n
159 | while self._response.status != 204:
160 | async for line_in_bytes in self._response.content:
161 | line = line_in_bytes.decode('utf8') # type: str
162 | line = line.rstrip('\n').rstrip('\r')
163 |
164 | if line == '':
165 | # empty line
166 | event = self._dispatch_event()
167 | if event is not None:
168 | return event
169 | continue
170 |
171 | if line[0] == ':':
172 | # comment line, ignore
173 | continue
174 |
175 | if ':' in line:
176 | # contains ':'
177 | fields = line.split(':', 1)
178 | field_name = fields[0]
179 | field_value = fields[1].lstrip(' ')
180 | self._process_field(field_name, field_value)
181 | else:
182 | self._process_field(line, '')
183 |
184 | if self._not_reconnect_when_ended:
185 | raise StopAsyncIteration
186 | self._ready_state = READY_STATE_CONNECTING
187 | if self._on_error:
188 | self._on_error()
189 | self._reconnection_time *= 2
190 | _LOGGER.debug('wait %s seconds for retry',
191 | self._reconnection_time.total_seconds())
192 | await asyncio.sleep(
193 | self._reconnection_time.total_seconds())
194 | await self.connect()
195 | raise StopAsyncIteration
196 |
197 | async def connect(self, retry=0):
198 | """Connect to resource."""
199 | _LOGGER.debug('connect')
200 | headers = self._kwargs['headers']
201 |
202 | # For HTTP connections, the Accept header may be included;
203 | # if included, it must contain only formats of event framing that are
204 | # supported by the user agent (one of which must be text/event-stream,
205 | # as described below).
206 | headers[hdrs.ACCEPT] = CONTENT_TYPE_EVENT_STREAM
207 |
208 | # If the event source's last event ID string is not the empty string,
209 | # then a Last-Event-Id HTTP header must be included with the request,
210 | # whose value is the value of the event source's last event ID string,
211 | # encoded as UTF-8.
212 | if self._last_event_id != '':
213 | headers[LAST_EVENT_ID_HEADER] = self._last_event_id
214 |
215 | # User agents should use the Cache-Control: no-cache header in
216 | # requests to bypass any caches for requests of event sources.
217 | headers[hdrs.CACHE_CONTROL] = 'no-cache'
218 |
219 | try:
220 | response = await self._session.request(
221 | self._method,
222 | self._url,
223 | **self._kwargs
224 | )
225 | except ClientConnectionError:
226 | if retry <= 0 or self._ready_state == READY_STATE_CLOSED:
227 | await self._fail_connect()
228 | raise
229 | else:
230 | self._ready_state = READY_STATE_CONNECTING
231 | if self._on_error:
232 | self._on_error()
233 | self._reconnection_time *= 2
234 | _LOGGER.debug('wait %s seconds for retry',
235 | self._reconnection_time.total_seconds())
236 | await asyncio.sleep(
237 | self._reconnection_time.total_seconds())
238 | await self.connect(retry - 1)
239 | return
240 |
241 | if response.status >= 400 or response.status == 305:
242 | error_message = 'fetch {} failed: {}'.format(
243 | self._url, response.status)
244 | _LOGGER.error(error_message)
245 |
246 | await self._fail_connect()
247 |
248 | if response.status in [305, 401, 407]:
249 | raise ConnectionRefusedError(error_message)
250 | raise ConnectionError(error_message)
251 |
252 | if response.status != 200:
253 | error_message = 'fetch {} failed with wrong response status: {}'. \
254 | format(self._url, response.status)
255 | _LOGGER.error(error_message)
256 | await self._fail_connect()
257 | raise ConnectionAbortedError(error_message)
258 |
259 | if response.content_type != CONTENT_TYPE_EVENT_STREAM:
260 | error_message = \
261 | 'fetch {} failed with wrong Content-Type: {}'.format(
262 | self._url, response.headers.get(hdrs.CONTENT_TYPE))
263 | _LOGGER.error(error_message)
264 |
265 | await self._fail_connect()
266 | raise ConnectionAbortedError(error_message)
267 |
268 | # only status == 200 and content_type == 'text/event-stream'
269 | await self._connected()
270 |
271 | self._response = response
272 | self._origin = str(response.real_url.origin())
273 |
274 | async def close(self):
275 | """Close connection."""
276 | _LOGGER.debug('close')
277 | self._ready_state = READY_STATE_CLOSED
278 | if self._response is not None:
279 | self._response.close()
280 | self._response = None
281 |
282 | async def _connected(self):
283 | """Announce the connection is made."""
284 | if self._ready_state != READY_STATE_CLOSED:
285 | self._ready_state = READY_STATE_OPEN
286 | if self._on_open:
287 | self._on_open()
288 | self._reconnection_time = self._orginal_reconnection_time
289 |
290 | async def _fail_connect(self):
291 | """Announce the connection is failed."""
292 | if self._ready_state != READY_STATE_CLOSED:
293 | self._ready_state = READY_STATE_CLOSED
294 | if self._on_error:
295 | self._on_error()
296 | pass
297 |
298 | def _dispatch_event(self):
299 | """Dispatch event."""
300 | self._last_event_id = self._event_id
301 |
302 | if self._event_data == '':
303 | self._event_type = ''
304 | return
305 |
306 | self._event_data = self._event_data.rstrip('\n')
307 |
308 | message = MessageEvent(
309 | type=self._event_type if self._event_type != '' else None,
310 | message=self._event_type,
311 | data=self._event_data,
312 | origin=self._origin,
313 | last_event_id=self._last_event_id
314 | )
315 | _LOGGER.debug(message)
316 | if self._on_message:
317 | self._on_message(message)
318 |
319 | self._event_type = ''
320 | self._event_data = ''
321 | return message
322 |
323 | def _process_field(self, field_name, field_value):
324 | """Process field."""
325 | if field_name == 'event':
326 | self._event_type = field_value
327 |
328 | elif field_name == 'data':
329 | self._event_data += field_value
330 | self._event_data += '\n'
331 |
332 | elif field_name == 'id' and field_value not in ('\u0000', '\x00\x00'):
333 | self._event_id = field_value
334 |
335 | elif field_name == 'retry':
336 | try:
337 | retry_in_ms = int(field_value)
338 | self._reconnection_time = timedelta(milliseconds=retry_in_ms)
339 | except ValueError:
340 | _LOGGER.warning('Received invalid retry value %s, ignore it',
341 | field_value)
342 | pass
343 |
344 | pass
--------------------------------------------------------------------------------
/src/argo_workflow_runner/utils/tts.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import aiohttp
3 | from argo_workflow_runner.configs import logger
4 | from argo_workflow_runner.env_settings import settings
5 |
6 |
7 | async def gpt_sovits_audio_stream_gen(text, voice):
8 | urlencoded_text = requests.utils.quote(text)
9 | got_first_chunk = False
10 | async with aiohttp.ClientSession() as session:
11 | payload = {
12 | "cha_name": voice,
13 | "character_emotion": "default",
14 | "text": urlencoded_text,
15 | "text_language": "auto",
16 | "batch_size": 10,
17 | "speed": 1,
18 | "top_k": 6,
19 | "top_p": 0.8,
20 | "temperature": 0.8,
21 | "stream": "True",
22 | "cut_method": "auto_cut_25",
23 | "seed": -1,
24 | "save_temp": "False"
25 | }
26 | url = settings.TTS_GPT_SOVITS_URL
27 | async with session.post(url, json=payload) as resp:
28 | if resp.status == 200:
29 | async for data in resp.content.iter_chunked(1024):
30 | if not got_first_chunk:
31 | logger.info(f'gpt_sovits_audio_stream_gen got first chunk')
32 | got_first_chunk = True
33 | yield data
34 | else:
35 | content = await resp.text()
36 | logger.error(f'connect to {url} fail: {content}')
37 |
38 | async def fetch_audio_stream(text, voice='Emma'):
39 | async_audio_stream = gpt_sovits_audio_stream_gen
40 |
41 | async for audio_data in async_audio_stream(text, voice):
42 | yield audio_data
43 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/utils/web_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | from langchain_community.tools.tavily_search import TavilySearchResults
4 | import serpapi
5 | import time
6 | from argo_workflow_runner.env_settings import settings
7 | from argo_workflow_runner.configs import logger
8 |
9 |
10 |
11 | os.environ['TAVILY_API_KEY'] = settings.TAVILY_API_KEY
12 |
13 | async def tavily_search(question: str):
14 | web_search_tool = TavilySearchResults(k=3)
15 | try:
16 | start_tm = time.time()
17 | docs = await web_search_tool.ainvoke({"query": question})
18 | logger.info(f'tavily_search cost {time.time() - start_tm:.04} seconds')
19 | return [d["content"] for d in docs]
20 | except:
21 | return []
22 |
23 |
24 | async def serp_api_search(question: str):
25 | params = {
26 | "engine": "google",
27 | "q": question,
28 | "location": "Seattle-Tacoma, WA, Washington, United States",
29 | "hl": "en",
30 | "gl": "us",
31 | "google_domain": "google.com",
32 | "num": "10",
33 | "safe": "active",
34 | }
35 |
36 | client = serpapi.Client(api_key=settings.SERP_API_KEY)
37 | try:
38 | start_tm = time.time()
39 | results = await asyncio.to_thread(client.search, params)
40 | logger.info(f'serp_api_search cost {time.time() - start_tm:.04} seconds')
41 |
42 | return list(map(lambda x: x['snippet'], results['organic_results']))
43 | except:
44 | return []
45 |
--------------------------------------------------------------------------------
/src/argo_workflow_runner/utils/ws_messager.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, unique
2 | import websockets
3 |
4 | @unique
5 | class WsMessageType(Enum):
6 | AUDIO = 0x1
7 |
8 | async def ws_send_audio(websocket: websockets.WebSocketServerProtocol, msg: bytes):
9 | await ws_send_msg(websocket, WsMessageType.AUDIO, msg)
10 |
11 | async def ws_send_audio_end(websocket: websockets.WebSocketServerProtocol):
12 | await ws_send_msg_end(websocket, WsMessageType.AUDIO)
13 |
14 | async def ws_send_msg(websocket: websockets.WebSocketServerProtocol, msg_type: WsMessageType, msg: bytes):
15 | if len(msg) == 0:
16 | return
17 | wrapped_msg = msg_type.value.to_bytes(1, byteorder='little') + msg
18 | await websocket.send(wrapped_msg)
19 |
20 | async def ws_send_msg_end(websocket: websockets.WebSocketServerProtocol, msg_type: WsMessageType):
21 | wrapped_msg = msg_type.value.to_bytes(1, byteorder='little')
22 | await websocket.send(wrapped_msg)
23 |
--------------------------------------------------------------------------------