├── .env.example ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── integration_tests ├── env_variable_names.py ├── helpers.py ├── samples │ ├── demo │ │ └── demo.py │ ├── doc │ │ └── doc.py │ ├── short_mode │ │ └── test.py │ └── socket_mode │ │ ├── websocket_client_example_another.py │ │ ├── websocket_client_example_ase.py │ │ └── websocket_client_example_intactive.py └── sparkai │ └── memory │ └── test_memory.py ├── log.jpg ├── pyproject.toml ├── sparkai ├── __init__.py ├── core │ ├── __init__.py │ ├── _api │ │ ├── __init__.py │ │ ├── beta_decorator.py │ │ ├── deprecation.py │ │ ├── internal.py │ │ └── path.py │ ├── _base_api.py │ ├── caches.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── manager.py │ │ ├── stdout.py │ │ └── streaming_stdout.py │ ├── globals │ │ └── __init__.py │ ├── language_models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── chat_models.py │ │ └── llms.py │ ├── load │ │ ├── __init__.py │ │ ├── dump.py │ │ ├── load.py │ │ ├── mapping.py │ │ └── serializable.py │ ├── messages │ │ ├── __init__.py │ │ ├── ai.py │ │ ├── base.py │ │ ├── chat.py │ │ ├── function.py │ │ ├── human.py │ │ ├── image_chat.py │ │ ├── system.py │ │ └── tool.py │ ├── outputs │ │ ├── __init__.py │ │ ├── chat_generation.py │ │ ├── chat_result.py │ │ ├── generation.py │ │ ├── llm_result.py │ │ └── run_info.py │ ├── prompt_values.py │ ├── prompts │ │ ├── __init__.py │ │ ├── base.py │ │ ├── chat.py │ │ ├── few_shot.py │ │ ├── few_shot_with_templates.py │ │ ├── image.py │ │ ├── loading.py │ │ ├── pipeline.py │ │ ├── prompt.py │ │ └── string.py │ ├── pydantic_v1 │ │ ├── __init__.py │ │ ├── dataclasses.py │ │ └── main.py │ ├── runnables │ │ ├── __init__.py │ │ ├── base.py │ │ ├── branch.py │ │ ├── config.py │ │ ├── configurable.py │ │ ├── fallbacks.py │ │ ├── graph.py │ │ ├── graph_draw.py │ │ ├── history.py │ │ ├── passthrough.py │ │ ├── retry.py │ │ ├── router.py │ │ ├── schema.py │ │ └── utils.py │ ├── tools.py │ ├── tracers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── context.py │ │ ├── evaluation.py │ │ ├── langchain.py │ │ ├── langchain_v1.py │ │ ├── log_stream.py │ │ ├── memory_stream.py │ │ ├── root_listeners.py │ │ ├── run_collector.py │ │ ├── schemas.py │ │ └── stdout.py │ └── utils │ │ ├── __init__.py │ │ ├── _merge.py │ │ ├── aiter.py │ │ ├── env.py │ │ ├── formatting.py │ │ ├── function_calling.py │ │ ├── html.py │ │ ├── image.py │ │ ├── input.py │ │ ├── interactive_env.py │ │ ├── iter.py │ │ ├── json_schema.py │ │ ├── loading.py │ │ ├── pydantic.py │ │ ├── strings.py │ │ └── utils.py ├── deprecation.py ├── depreciated │ ├── __init__.py │ ├── client │ │ ├── __init__.py │ │ ├── llm.py │ │ └── sample_langchain_spark.py │ └── service │ │ ├── __init__.py │ │ ├── api_server.py │ │ └── spark_ws.py ├── embedding │ ├── __init__.py │ ├── spark_embedding.py │ └── sparkai_base.py ├── errors │ └── __init__.py ├── exceptions.py ├── frameworks │ ├── __init__.py │ ├── autogen │ │ └── __init__.py │ └── llama_index │ │ └── __init__.py ├── http_retry │ ├── __init__.py │ ├── async_handler.py │ ├── builtin_async_handlers.py │ ├── builtin_handlers.py │ ├── builtin_interval_calculators.py │ ├── handler.py │ ├── interval_calculator.py │ ├── jitter.py │ ├── request.py │ ├── response.py │ └── state.py ├── llm │ ├── __init__.py │ └── llm.py ├── log │ ├── __init__.py │ └── logger.py ├── memory │ ├── __init__.py │ ├── buffer.py │ ├── buffer_window.py │ ├── chat_memory.py │ ├── chat_message_histories │ │ ├── __init__.py │ │ ├── dynamodb.py │ │ ├── file.py │ │ ├── in_memory.py │ │ ├── postgres.py │ │ └── redis.py │ ├── combined.py │ ├── readonly.py │ ├── simple.py │ ├── token_buffer.py │ └── utils.py ├── messages.py ├── models │ ├── __init__.py │ ├── basic_objects.py │ └── chat │ │ └── __init__.py ├── prompts │ └── classification │ │ └── __init__.py ├── proxy_env_variable_loader.py ├── schema.py ├── socket_mode │ ├── __init__.py │ ├── client.py │ ├── interval_runner.py │ ├── listeners.py │ ├── request.py │ ├── response.py │ └── websocket_client │ │ └── __init__.py ├── spark_proxy │ ├── generate_message.py │ ├── generate_stream.py │ ├── main.py │ ├── openai_types.py │ ├── server.py │ ├── spark_api.py │ └── spark_auth.py ├── v2 │ ├── __init__.py │ ├── client │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ └── consts.py │ │ ├── http │ │ │ └── __init__.py │ │ └── ws │ │ │ └── __init__.py │ ├── core │ │ └── __init__.py │ └── llm │ │ └── __init__.py ├── version.py └── xf_util.py ├── tests ├── embedding_test │ ├── embedding_test.py │ └── test_llama.py ├── examples │ ├── docs │ │ ├── agent_artchitect.png │ │ ├── agents.png │ │ ├── autogen_grouchat_with_graph.md │ │ ├── llama-index.png │ │ ├── llama_index.md │ │ ├── proxy_open_ai.md │ │ ├── test_openai_proxy.py │ │ └── token_usage.png │ ├── llama_index_embedding.png │ ├── llama_index_text.py │ ├── llm_test.py │ └── spark_llama_index.png ├── openai_test │ ├── mock_multi_function.py │ └── multi_function.py └── sparkai_test │ ├── prompts │ ├── wrapper_write_code.txt │ └── wrapper_write_code_spark.txt │ └── wrapper_write.py └── weichat.jpg /.env.example: -------------------------------------------------------------------------------- 1 | # spark 授权信息 2 | SPARKAI_APP_ID= 3 | SPARKAI_API_KEY= 4 | SPARKAI_API_SECRET= 5 | SPARKAI_DOMAIN= 6 | SPARKAI_URL= 7 | 8 | 9 | # spark embedding 授权信息 10 | SPARK_Embedding_APP_ID= 11 | SPARK_Embedding_API_KEY= 12 | SPARK_Embedding_API_SECRET= 13 | SPARKAI_Embedding_DOMAIN= -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | # Pattern matched against refs/tags 5 | tags: 6 | - 'v[0-9]+.[0-9]+.[0-9]+' 7 | jobs: 8 | ci: 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | python-version: [ "3.9", "3.10", "3.11"] 13 | 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Install Python 18 | uses: actions/setup-python@v4 19 | # see details (matrix, python-version, python-version-file, etc.) 20 | # https://github.com/actions/setup-python 21 | - name: Install poetry 22 | uses: abatilo/actions-poetry@v2 23 | 24 | - name: Setup a local virtual environment (if no poetry.toml file) 25 | run: | 26 | poetry config virtualenvs.create true --local 27 | poetry config virtualenvs.in-project true --local 28 | pip install pytest 29 | 30 | - uses: actions/cache@v3 31 | name: Define a cache for the virtual environment based on the dependencies lock file 32 | with: 33 | path: ./.venv 34 | key: venv-${{ hashFiles('poetry.lock') }} 35 | 36 | - name: Install the project dependencies 37 | run: poetry install 38 | 39 | - name: Run the automated tests (for example) 40 | run: poetry run pytest -v -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | **__pycache__** 3 | .venv 4 | poetry.lock 5 | .env 6 | .env_* 7 | 8 | 9 | /meetkai/ 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 iFLYTEK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | poetry build 3 | ls -lt dist | grep "tar.gz" | head -n 1 |awk '{print "./dist/"$$9}' |xargs pip install 4 | 5 | 6 | publish: 7 | poetry publish 8 | 9 | publish-custom: 10 | poetry publish -r my-custom-repo -------------------------------------------------------------------------------- /integration_tests/env_variable_names.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/integration_tests/env_variable_names.py -------------------------------------------------------------------------------- /integration_tests/helpers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import inspect 3 | import sys 4 | from asyncio.events import AbstractEventLoop 5 | 6 | 7 | def async_test(coro): 8 | loop: AbstractEventLoop = asyncio.new_event_loop() 9 | asyncio.set_event_loop(loop) 10 | 11 | def wrapper(*args, **kwargs): 12 | current_loop: AbstractEventLoop = asyncio.get_event_loop() 13 | return current_loop.run_until_complete(coro(*args, **kwargs)) 14 | 15 | return wrapper 16 | 17 | 18 | def is_not_specified() -> bool: 19 | # get the caller's filepath 20 | frame = inspect.stack()[1] 21 | module = inspect.getmodule(frame[0]) 22 | filepath: str = module.__file__ 23 | 24 | # python setup.py integration_tests --test-target=web/test_issue_560.py 25 | test_target: str = sys.argv[1] # e.g., web/test_issue_560.py 26 | return not test_target or not filepath.endswith(test_target) 27 | -------------------------------------------------------------------------------- /integration_tests/samples/demo/demo.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | 5 | logging.basicConfig(level=logging.WARNING) 6 | 7 | import os 8 | from threading import Event 9 | from sparkai.socket_mode.websocket_client import SparkAISocketModeClient 10 | from sparkai.memory import ChatMessageHistory 11 | 12 | if __name__ == "__main__": 13 | client = SparkAISocketModeClient( 14 | app_id=os.environ.get("APP_ID"), 15 | api_key=os.environ.get("API_KEY"), 16 | api_secret=os.environ.get("API_SECRET"), 17 | chat_interactive=False, 18 | trace_enabled=False, 19 | conversation_memory=ChatMessageHistory() 20 | ) 21 | 22 | client.connect() 23 | result = client.chat_with_histories( 24 | [ 25 | {'role': 'user', 'content': '请帮我完成目标:\n\n帮我生成一个 2到2000的随机数\n\n'}, {'role': 'assistant', 26 | 'content': '{\n\n"thoughts": {\n\n"text": "Generate a random number between 2 and 2000.",\n\n"reasoning": "To complete this task, I will need to access the internet for information gathering.",\n\n"plan": "I will use the random_number command with the min and max arguments set to 2 and 2000, respectively.",\n\n"criticism": "",\n\n"speak": "The random number generated is: 1587."\n\n},\n\n"command": {\n\n"name": "random_number",\n\n"args": {\n\n"min": "2",\n\n"max": "2000"\n\n}\n\n}\n\n}'}, 27 | {'role': 'user', 'content': '\n请帮我完成目标:\n\n帮我把这个随机数 发给 ybyang7@iflytek.com 并告诉他这个随机数很重要\n\n'}]) 28 | 29 | if result: 30 | print(result.content) 31 | 32 | Event().wait() 33 | -------------------------------------------------------------------------------- /integration_tests/samples/doc/doc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: doc 7 | @time: 2023/05/18 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | """ 13 | 14 | -------------------------------------------------------------------------------- /integration_tests/samples/socket_mode/websocket_client_example_another.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | 5 | logging.basicConfig(level=logging.WARNING) 6 | 7 | import os 8 | from threading import Event 9 | from sparkai.socket_mode.websocket_client import SparkAISocketModeClient 10 | from sparkai.memory import ChatMessageHistory 11 | 12 | print_question = False 13 | 14 | response_format = { 15 | "thoughts": { 16 | "text": "thought", 17 | "speak": "thoughts summary to say to user", 18 | "plan": "- short bulleted - list that conveys - long-term plan", 19 | "reasoning": "reasoning" 20 | } 21 | } 22 | rf = json.dumps(response_format, indent=4) 23 | 24 | question = "" 25 | query_prompt = f''' 26 | 帮我润色下如下问题: 27 | 28 | {question} 29 | 30 | ''' 31 | 32 | from sparkai.prompts.classification import PROMPTS 33 | 34 | query_prompt1 = f''' 35 | 总结下述问题并按照如下json格式输出: 36 | {rf} 37 | 38 | 请注意回答的结果必须满足下述约束: 39 | 1. 结果响应只能包含json内容 40 | 2. 结果响应不能有markdown内容 41 | 3. 结果中json格式务必正确且能够被python json.loads 解析 42 | 43 | 现在请回答: {question} 44 | 45 | ''' 46 | if __name__ == "__main__": 47 | client = SparkAISocketModeClient( 48 | app_id=os.environ.get("APP_ID"), 49 | api_key=os.environ.get("API_KEY"), 50 | api_secret=os.environ.get("API_SECRET"), 51 | chat_interactive=False, 52 | trace_enabled=False, 53 | conversation_memory=ChatMessageHistory() 54 | ) 55 | 56 | q = PROMPTS + "帮我发送一份邮件给 ybyang7@iflytek.com, 内容由你帮我生成一段写进去,主要表达欢迎他加入公司的意思就可以" 57 | # q = PROMPTS + "2023年5月8日,合肥天气怎么样" 58 | if print_question: 59 | print("Question: ", q) 60 | client.connect() 61 | result = client.chat_with_histories( 62 | [ 63 | {'role': 'user', 'content': '请帮我完成目标:\n\n帮我生成一个 2到2000的随机数\n\n'}, {'role': 'assistant', 64 | 'content': '{\n\n"thoughts": {\n\n"text": "Generate a random number between 2 and 2000.",\n\n"reasoning": "To complete this task, I will need to access the internet for information gathering.",\n\n"plan": "I will use the random_number command with the min and max arguments set to 2 and 2000, respectively.",\n\n"criticism": "",\n\n"speak": "The random number generated is: 1587."\n\n},\n\n"command": {\n\n"name": "random_number",\n\n"args": {\n\n"min": "2",\n\n"max": "2000"\n\n}\n\n}\n\n}'}, 65 | {'role': 'user', 'content': '\n请帮我完成目标:\n\n帮我把这个随机数 发给 ybyang7@iflytek.com 并告诉他这个随机数很重要\n\n'}]) 66 | 67 | if result: 68 | print(result.content) 69 | # result = client.chat_in("你是谁") 70 | 71 | Event().wait() 72 | -------------------------------------------------------------------------------- /integration_tests/samples/socket_mode/websocket_client_example_ase.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.ERROR) 4 | 5 | import os 6 | from threading import Event 7 | from sparkai.socket_mode.request import SocketModeRequest 8 | from sparkai.socket_mode.websocket_client import SparkAISocketModeClient 9 | 10 | client = SparkAISocketModeClient( 11 | app_id=os.environ.get("APP_ID"), 12 | api_key=os.environ.get("API_KEY"), 13 | api_secret=os.environ.get("API_SECRET"), 14 | chat_interactive=True, 15 | trace_enabled=False, 16 | ) 17 | 18 | if __name__ == "__main__": 19 | def process(client: SparkAISocketModeClient, req: SocketModeRequest): 20 | pass 21 | 22 | 23 | def on_message(ws, message): 24 | pass 25 | 26 | 27 | def on_open(ws): 28 | pass 29 | 30 | 31 | def on_close(ws): 32 | pass 33 | 34 | 35 | def on_error(ws, error): 36 | pass 37 | 38 | 39 | client.socket_mode_request_listeners.append(process) 40 | client.on_message_listeners.append(on_message) 41 | client.on_open_listeners.append(on_open) 42 | client.on_close_listeners.append(on_close) 43 | client.on_error_listeners.append(on_error) 44 | 45 | client.connect() 46 | 47 | Event().wait() 48 | -------------------------------------------------------------------------------- /integration_tests/samples/socket_mode/websocket_client_example_intactive.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | 5 | logging.basicConfig(level=logging.WARNING) 6 | 7 | import os 8 | from threading import Event 9 | from sparkai.socket_mode.websocket_client import SparkAISocketModeClient 10 | from sparkai.memory import ChatMessageHistory 11 | 12 | print_question = False 13 | 14 | response_format = { 15 | "thoughts": { 16 | "text": "thought", 17 | "speak": "thoughts summary to say to user", 18 | "plan": "- short bulleted - list that conveys - long-term plan", 19 | "reasoning": "reasoning" 20 | } 21 | } 22 | rf = json.dumps(response_format, indent=4) 23 | 24 | question = "" 25 | query_prompt = f''' 26 | 帮我润色下如下问题: 27 | 28 | {question} 29 | 30 | ''' 31 | 32 | from sparkai.prompts.classification import PROMPTS 33 | 34 | query_prompt1 = f''' 35 | 总结下述问题并按照如下json格式输出: 36 | {rf} 37 | 38 | 请注意回答的结果必须满足下述约束: 39 | 1. 结果响应只能包含json内容 40 | 2. 结果响应不能有markdown内容 41 | 3. 结果中json格式务必正确且能够被python json.loads 解析 42 | 43 | 现在请回答: {question} 44 | 45 | ''' 46 | if __name__ == "__main__": 47 | client = SparkAISocketModeClient( 48 | app_id=os.environ.get("APP_ID"), 49 | api_key=os.environ.get("API_KEY"), 50 | api_secret=os.environ.get("API_SECRET"), 51 | chat_interactive=True, 52 | trace_enabled=False, 53 | conversation_memory=ChatMessageHistory() 54 | ) 55 | 56 | client.connect() 57 | 58 | 59 | Event().wait() 60 | -------------------------------------------------------------------------------- /integration_tests/sparkai/memory/test_memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: test_memory 7 | @time: 2023/04/29 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | 34 | 35 | from sparkai.memory import ChatMessageHistory,ConversationBufferMemory 36 | import json 37 | 38 | if __name__ == '__main__': 39 | 40 | history = ChatMessageHistory() 41 | 42 | history.add_user_message("hi!") 43 | 44 | history.add_ai_message("whats up?") 45 | 46 | print(str(history.messages)) 47 | 48 | memory = ConversationBufferMemory(return_messages=True) 49 | memory.chat_memory.add_user_message("hi!") 50 | memory.chat_memory.add_ai_message("whats up?") 51 | 52 | print(memory.load_memory_variables({})) -------------------------------------------------------------------------------- /log.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/log.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "spark-ai-python" 3 | version = "0.4.5" 4 | description = "a sdk for iflytek's spark LLM." 5 | authors = ["whybeyoung ", "mingduan "] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [ 9 | { include = "sparkai" } 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.8.1,<3.13" 14 | aiohttp = ">3.3" 15 | requests = "*" 16 | nest-asyncio = "^1.6.0" 17 | websocket-client = "^1.7.0" 18 | websockets = "*" 19 | nest_asyncio = "*" 20 | python-dotenv = "*" 21 | 22 | packaging = "*" 23 | tenacity = "*" 24 | jsonpatch = "*" 25 | pydantic = "*" 26 | pyyaml = "*" 27 | httpx = "*" 28 | llama-index = "^0.10.24" 29 | llama-index-vector-stores-chroma = "^0.1.6" 30 | llama-index-core = { version = "^0.10.24.post1", optional = true } 31 | pyautogen = { version = ">=0.2.20", optional = true } 32 | uvicorn = { version = ">=0.26.0", optional = true } 33 | fastapi = { extras = ["all"], version = "^0.110.0", optional = true} 34 | 35 | [tool.poetry.extras] 36 | llama_index = [ 37 | "llama-index-core", 38 | "llama-index", 39 | "llama-index-vector-stores-chroma", 40 | 41 | 42 | ] 43 | autogen = [ 44 | "pyautogen" 45 | ] 46 | proxy = [ 47 | "fastapi", 48 | "uvicorn" 49 | ] 50 | 51 | [tool.pytest.ini_options] 52 | 53 | [build-system] 54 | requires = ["poetry-core"] 55 | build-backend = "poetry.core.masonry.api" 56 | 57 | [[tool.poetry.source]] 58 | name = "my-custom-repo" # This name will be used in the configuration to retreive the proper credentials 59 | url = "https://repo.model.xfyun.cn/api/packages/administrator/pypi" # URL used to download your packages from 60 | priority = "primary" 61 | 62 | 63 | [[tool.poetry.source]] 64 | name = "mirrors" 65 | url = "https://pypi.tuna.tsinghua.edu.cn/simple/" 66 | priority = "default" 67 | 68 | -------------------------------------------------------------------------------- /sparkai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/sparkai/__init__.py -------------------------------------------------------------------------------- /sparkai/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/02/23 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/core/_api/__init__.py: -------------------------------------------------------------------------------- 1 | """Helper functions for managing the LangChain API. 2 | 3 | This module is only relevant for LangChain developers, not for users. 4 | 5 | .. warning:: 6 | 7 | This module and its submodules are for internal use only. Do not use them 8 | in your own code. We may change the API at any time with no warning. 9 | 10 | """ 11 | from .beta_decorator import ( 12 | SparkAIBetaWarning, 13 | beta, 14 | suppress_sparkai_beta_warning, 15 | surface_sparkai_beta_warnings, 16 | ) 17 | from .deprecation import ( 18 | SparkAIDeprecationWarning, 19 | deprecated, 20 | suppress_sparkai_deprecation_warning, 21 | surface_sparkai_deprecation_warnings, 22 | warn_deprecated, 23 | ) 24 | from .path import as_import_path, get_relative_path 25 | 26 | __all__ = [ 27 | "as_import_path", 28 | "beta", 29 | "deprecated", 30 | "get_relative_path", 31 | "SparkAIBetaWarning", 32 | "SparkAIDeprecationWarning", 33 | "suppress_sparkai_beta_warning", 34 | "surface_sparkai_beta_warnings", 35 | "suppress_sparkai_deprecation_warning", 36 | "surface_sparkai_deprecation_warnings", 37 | "warn_deprecated", 38 | ] 39 | -------------------------------------------------------------------------------- /sparkai/core/_api/internal.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | def is_caller_internal(depth: int = 2) -> bool: 5 | """Return whether the caller at `depth` of this function is internal.""" 6 | try: 7 | frame = inspect.currentframe() 8 | except AttributeError: 9 | return False 10 | if frame is None: 11 | return False 12 | try: 13 | for _ in range(depth): 14 | frame = frame.f_back 15 | if frame is None: 16 | return False 17 | caller_module = inspect.getmodule(frame) 18 | if caller_module is None: 19 | return False 20 | caller_module_name = caller_module.__name__ 21 | return caller_module_name.startswith("depreciated") 22 | finally: 23 | del frame 24 | -------------------------------------------------------------------------------- /sparkai/core/_api/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | HERE = Path(__file__).parent 6 | 7 | # Get directory of depreciated package 8 | PACKAGE_DIR = HERE.parent 9 | SEPARATOR = os.sep 10 | 11 | 12 | def get_relative_path( 13 | file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR 14 | ) -> str: 15 | """Get the path of the file as a relative path to the package directory.""" 16 | if isinstance(file, str): 17 | file = Path(file) 18 | return str(file.relative_to(relative_to)) 19 | 20 | 21 | def as_import_path( 22 | file: Union[Path, str], 23 | *, 24 | suffix: Optional[str] = None, 25 | relative_to: Path = PACKAGE_DIR, 26 | ) -> str: 27 | """Path of the file as a LangChain import exclude depreciated top namespace.""" 28 | if isinstance(file, str): 29 | file = Path(file) 30 | path = get_relative_path(file, relative_to=relative_to) 31 | if file.is_file(): 32 | path = path[: -len(file.suffix)] 33 | import_path = path.replace(SEPARATOR, ".") 34 | if suffix: 35 | import_path += "." + suffix 36 | return import_path 37 | -------------------------------------------------------------------------------- /sparkai/core/_base_api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: _base_api 7 | @time: 2024/02/23 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | from __future__ import annotations 34 | from typing import TYPE_CHECKING 35 | 36 | if TYPE_CHECKING: 37 | from .._client import ZhipuAI 38 | 39 | 40 | class BaseAPI: 41 | _client: ZhipuAI 42 | 43 | def __init__(self, client: ZhipuAI) -> None: 44 | self._client = client 45 | self._delete = client.delete 46 | self._get = client.get 47 | self._post = client.post 48 | self._put = client.put 49 | self._patch = client.patch -------------------------------------------------------------------------------- /sparkai/core/caches.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Optional, Sequence 5 | 6 | from sparkai.core.outputs import Generation 7 | 8 | RETURN_VAL_TYPE = Sequence[Generation] 9 | 10 | 11 | class BaseCache(ABC): 12 | """Base interface for cache.""" 13 | 14 | @abstractmethod 15 | def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: 16 | """Look up based on prompt and llm_string.""" 17 | 18 | @abstractmethod 19 | def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: 20 | """Update cache based on prompt and llm_string.""" 21 | 22 | @abstractmethod 23 | def clear(self, **kwargs: Any) -> None: 24 | """Clear cache that can take additional keyword arguments.""" 25 | -------------------------------------------------------------------------------- /sparkai/core/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from sparkai.core.callbacks.base import ( 2 | AsyncCallbackHandler, 3 | BaseCallbackHandler, 4 | BaseCallbackManager, 5 | CallbackManagerMixin, 6 | Callbacks, 7 | ChainManagerMixin, 8 | LLMManagerMixin, 9 | RetrieverManagerMixin, 10 | RunManagerMixin, 11 | ToolManagerMixin, 12 | ) 13 | from sparkai.core.callbacks.manager import ( 14 | AsyncCallbackManager, 15 | AsyncCallbackManagerForLLMRun, 16 | AsyncCallbackManagerForToolRun, 17 | AsyncParentRunManager, 18 | AsyncRunManager, 19 | BaseRunManager, 20 | CallbackManager, 21 | CallbackManagerForLLMRun, 22 | CallbackManagerForToolRun, 23 | ParentRunManager, 24 | RunManager, 25 | ) 26 | from sparkai.core.callbacks.stdout import StdOutCallbackHandler 27 | from sparkai.core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 28 | 29 | __all__ = [ 30 | "RetrieverManagerMixin", 31 | "LLMManagerMixin", 32 | "ChainManagerMixin", 33 | "ToolManagerMixin", 34 | "Callbacks", 35 | "CallbackManagerMixin", 36 | "RunManagerMixin", 37 | "BaseCallbackHandler", 38 | "AsyncCallbackHandler", 39 | "BaseCallbackManager", 40 | "BaseRunManager", 41 | "RunManager", 42 | "ParentRunManager", 43 | "AsyncRunManager", 44 | "AsyncParentRunManager", 45 | "CallbackManagerForLLMRun", 46 | "AsyncCallbackManagerForLLMRun", 47 | "CallbackManagerForToolRun", 48 | "AsyncCallbackManagerForToolRun", 49 | "CallbackManager", 50 | "AsyncCallbackManager", 51 | "StdOutCallbackHandler", 52 | "StreamingStdOutCallbackHandler", 53 | ] 54 | -------------------------------------------------------------------------------- /sparkai/core/callbacks/stdout.py: -------------------------------------------------------------------------------- 1 | """Callback Handler that prints to std out.""" 2 | from __future__ import annotations 3 | 4 | from typing import TYPE_CHECKING, Any, Dict, Optional 5 | 6 | from sparkai.core.callbacks.base import BaseCallbackHandler 7 | from sparkai.core.utils import print_text 8 | 9 | 10 | 11 | class StdOutCallbackHandler(BaseCallbackHandler): 12 | """Callback Handler that prints to std out.""" 13 | 14 | def __init__(self, color: Optional[str] = None) -> None: 15 | """Initialize callback handler.""" 16 | self.color = color 17 | 18 | def on_chain_start( 19 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 20 | ) -> None: 21 | """Print out that we are entering a chain.""" 22 | class_name = serialized.get("name", serialized.get("id", [""])[-1]) 23 | print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") 24 | 25 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: 26 | """Print out that we finished a chain.""" 27 | print("\n\033[1m> Finished chain.\033[0m") 28 | 29 | def on_tool_end( 30 | self, 31 | output: str, 32 | color: Optional[str] = None, 33 | observation_prefix: Optional[str] = None, 34 | llm_prefix: Optional[str] = None, 35 | **kwargs: Any, 36 | ) -> None: 37 | """If not the final action, print out observation.""" 38 | if observation_prefix is not None: 39 | print_text(f"\n{observation_prefix}") 40 | print_text(output, color=color or self.color) 41 | if llm_prefix is not None: 42 | print_text(f"\n{llm_prefix}") 43 | 44 | def on_text( 45 | self, 46 | text: str, 47 | color: Optional[str] = None, 48 | end: str = "", 49 | **kwargs: Any, 50 | ) -> None: 51 | """Run when agent ends.""" 52 | print_text(text, color=color or self.color, end=end) 53 | 54 | -------------------------------------------------------------------------------- /sparkai/core/callbacks/streaming_stdout.py: -------------------------------------------------------------------------------- 1 | """Callback Handler streams to stdout on new llm token.""" 2 | from __future__ import annotations 3 | 4 | import sys 5 | from typing import TYPE_CHECKING, Any, Dict, List 6 | 7 | from sparkai.core.callbacks.base import BaseCallbackHandler 8 | 9 | if TYPE_CHECKING: 10 | from sparkai.core.messages import BaseMessage 11 | from sparkai.core.outputs import LLMResult 12 | 13 | 14 | class StreamingStdOutCallbackHandler(BaseCallbackHandler): 15 | """Callback handler for streaming. Only works with LLMs that support streaming.""" 16 | 17 | def on_llm_start( 18 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 19 | ) -> None: 20 | """Run when LLM starts running.""" 21 | 22 | def on_chat_model_start( 23 | self, 24 | serialized: Dict[str, Any], 25 | messages: List[List[BaseMessage]], 26 | **kwargs: Any, 27 | ) -> None: 28 | """Run when LLM starts running.""" 29 | 30 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 31 | """Run on new LLM token. Only available when streaming is enabled.""" 32 | sys.stdout.write(token) 33 | sys.stdout.flush() 34 | 35 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 36 | """Run when LLM ends running.""" 37 | 38 | def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: 39 | """Run when LLM errors.""" 40 | 41 | def on_chain_start( 42 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 43 | ) -> None: 44 | """Run when chain starts running.""" 45 | 46 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: 47 | """Run when chain ends running.""" 48 | 49 | def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: 50 | """Run when chain errors.""" 51 | 52 | def on_tool_start( 53 | self, serialized: Dict[str, Any], input_str: str, **kwargs: Any 54 | ) -> None: 55 | """Run when tool starts running.""" 56 | 57 | def on_tool_end(self, output: str, **kwargs: Any) -> None: 58 | """Run when tool ends running.""" 59 | 60 | def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: 61 | """Run when tool errors.""" 62 | 63 | def on_text(self, text: str, **kwargs: Any) -> None: 64 | """Run on arbitrary text.""" 65 | 66 | -------------------------------------------------------------------------------- /sparkai/core/globals/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | """Global values and configuration that apply to all of SparkAI.""" 3 | import warnings 4 | from typing import TYPE_CHECKING, Optional 5 | 6 | if TYPE_CHECKING: 7 | from sparkai.core.caches import BaseCache 8 | 9 | 10 | _verbose: bool = False 11 | _debug: bool = False 12 | _llm_cache: Optional["BaseCache"] = None 13 | 14 | 15 | def set_verbose(value: bool) -> None: 16 | 17 | global _verbose 18 | _verbose = value 19 | 20 | 21 | def get_verbose() -> bool: 22 | old_verbose = False 23 | 24 | global _verbose 25 | return _verbose or old_verbose 26 | 27 | 28 | def set_debug(value: bool) -> None: 29 | 30 | global _debug 31 | _debug = value 32 | 33 | 34 | def get_debug() -> bool: 35 | old_debug = False 36 | 37 | global _debug 38 | return _debug or old_debug 39 | 40 | 41 | def set_llm_cache(value: Optional["BaseCache"]) -> None: 42 | 43 | global _llm_cache 44 | _llm_cache = value 45 | 46 | 47 | def get_llm_cache() -> "BaseCache": 48 | """Get the value of the `llm_cache` global setting.""" 49 | old_llm_cache = None 50 | global _llm_cache 51 | return _llm_cache or old_llm_cache 52 | -------------------------------------------------------------------------------- /sparkai/core/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | from sparkai.core.language_models.base import ( 2 | BaseLanguageModel, 3 | LanguageModelInput, 4 | LanguageModelLike, 5 | LanguageModelOutput, 6 | get_tokenizer, 7 | ) 8 | from sparkai.core.language_models.chat_models import BaseChatModel, SimpleChatModel 9 | from sparkai.core.language_models.llms import LLM, BaseLLM 10 | 11 | __all__ = [ 12 | "BaseLanguageModel", 13 | "BaseChatModel", 14 | "SimpleChatModel", 15 | "BaseLLM", 16 | "LLM", 17 | "LanguageModelInput", 18 | "get_tokenizer", 19 | "LanguageModelOutput", 20 | "LanguageModelLike", 21 | ] 22 | -------------------------------------------------------------------------------- /sparkai/core/load/__init__.py: -------------------------------------------------------------------------------- 1 | """Serialization and deserialization.""" 2 | from sparkai.core.load.dump import dumpd, dumps 3 | from sparkai.core.load.load import load, loads 4 | from sparkai.core.load.serializable import Serializable 5 | 6 | __all__ = ["dumpd", "dumps", "load", "loads", "Serializable"] 7 | -------------------------------------------------------------------------------- /sparkai/core/load/dump.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict 3 | 4 | from sparkai.core.load.serializable import Serializable, to_json_not_implemented 5 | 6 | 7 | def default(obj: Any) -> Any: 8 | """Return a default value for a Serializable object or 9 | a SerializedNotImplemented object.""" 10 | if isinstance(obj, Serializable): 11 | return obj.to_json() 12 | else: 13 | return to_json_not_implemented(obj) 14 | 15 | 16 | def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: 17 | """Return a json string representation of an object.""" 18 | if "default" in kwargs: 19 | raise ValueError("`default` should not be passed to dumps") 20 | try: 21 | if pretty: 22 | indent = kwargs.pop("indent", 2) 23 | return json.dumps(obj, default=default, indent=indent, **kwargs) 24 | else: 25 | return json.dumps(obj, default=default, **kwargs) 26 | except TypeError: 27 | if pretty: 28 | return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs) 29 | else: 30 | return json.dumps(to_json_not_implemented(obj), **kwargs) 31 | 32 | 33 | def dumpd(obj: Any) -> Dict[str, Any]: 34 | """Return a json dict representation of an object.""" 35 | return json.loads(dumps(obj)) 36 | -------------------------------------------------------------------------------- /sparkai/core/messages/ai.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Literal 2 | 3 | from sparkai.core.messages.base import ( 4 | BaseMessage, 5 | BaseMessageChunk, 6 | merge_content, 7 | ) 8 | 9 | 10 | class AIMessage(BaseMessage): 11 | """A Message from an AI.""" 12 | 13 | example: bool = False 14 | """Whether this Message is being passed in to the model as part of an example 15 | conversation. 16 | """ 17 | 18 | type: Literal["ai"] = "ai" 19 | 20 | @classmethod 21 | def get_lc_namespace(cls) -> List[str]: 22 | """Get the namespace of the sparkai object.""" 23 | return ["sparkai", "messages"] 24 | 25 | 26 | AIMessage.update_forward_refs() 27 | 28 | 29 | class AIMessageChunk(AIMessage, BaseMessageChunk): 30 | """A Message chunk from an AI.""" 31 | 32 | # Ignoring mypy re-assignment here since we're overriding the value 33 | # to make sure that the chunk variant can be discriminated from the 34 | # non-chunk variant. 35 | type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 36 | 37 | @classmethod 38 | def get_lc_namespace(cls) -> List[str]: 39 | """Get the namespace of the sparkai object.""" 40 | return ["sparkai", "messages"] 41 | 42 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 43 | if isinstance(other, AIMessageChunk): 44 | if self.example != other.example: 45 | raise ValueError( 46 | "Cannot concatenate AIMessageChunks with different example values." 47 | ) 48 | 49 | return self.__class__( 50 | example=self.example, 51 | content=merge_content(self.content, other.content), 52 | additional_kwargs=self._merge_kwargs_dict( 53 | self.additional_kwargs, other.additional_kwargs 54 | ), 55 | ) 56 | 57 | return super().__add__(other) 58 | -------------------------------------------------------------------------------- /sparkai/core/messages/chat.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Literal 2 | 3 | from sparkai.core.messages.base import ( 4 | BaseMessage, 5 | BaseMessageChunk, 6 | merge_content, 7 | ) 8 | 9 | 10 | class ChatMessage(BaseMessage): 11 | """A Message that can be assigned an arbitrary speaker (i.e. role).""" 12 | 13 | role: str 14 | """The speaker / role of the Message.""" 15 | 16 | type: Literal["chat"] = "chat" 17 | 18 | @classmethod 19 | def get_lc_namespace(cls) -> List[str]: 20 | """Get the namespace of the sparkai object.""" 21 | return ["sparkai", "messages"] 22 | 23 | 24 | ChatMessage.update_forward_refs() 25 | 26 | 27 | class ChatMessageChunk(ChatMessage, BaseMessageChunk): 28 | """A Chat Message chunk.""" 29 | 30 | # Ignoring mypy re-assignment here since we're overriding the value 31 | # to make sure that the chunk variant can be discriminated from the 32 | # non-chunk variant. 33 | type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore 34 | 35 | @classmethod 36 | def get_lc_namespace(cls) -> List[str]: 37 | """Get the namespace of the sparkai object.""" 38 | return ["sparkai", "messages"] 39 | 40 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 41 | if isinstance(other, ChatMessageChunk): 42 | if self.role != other.role: 43 | raise ValueError( 44 | "Cannot concatenate ChatMessageChunks with different roles." 45 | ) 46 | 47 | return self.__class__( 48 | role=self.role, 49 | content=merge_content(self.content, other.content), 50 | additional_kwargs=self._merge_kwargs_dict( 51 | self.additional_kwargs, other.additional_kwargs 52 | ), 53 | ) 54 | elif isinstance(other, BaseMessageChunk): 55 | return self.__class__( 56 | role=self.role, 57 | content=merge_content(self.content, other.content), 58 | additional_kwargs=self._merge_kwargs_dict( 59 | self.additional_kwargs, other.additional_kwargs 60 | ), 61 | ) 62 | else: 63 | return super().__add__(other) 64 | -------------------------------------------------------------------------------- /sparkai/core/messages/function.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Literal 2 | 3 | from sparkai.core.messages import AIMessageChunk 4 | from sparkai.core.messages.base import ( 5 | BaseMessage, 6 | BaseMessageChunk, 7 | merge_content, 8 | ) 9 | 10 | 11 | class FunctionMessage(BaseMessage): 12 | """A Message for passing the result of executing a function back to a model.""" 13 | 14 | name: str 15 | """The name of the function that was executed.""" 16 | 17 | type: Literal["function"] = "function" 18 | 19 | @classmethod 20 | def get_lc_namespace(cls) -> List[str]: 21 | """Get the namespace of the sparkai object.""" 22 | return ["sparkai", "messages"] 23 | 24 | 25 | FunctionMessage.update_forward_refs() 26 | 27 | 28 | class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): 29 | """A Function Message chunk.""" 30 | 31 | # Ignoring mypy re-assignment here since we're overriding the value 32 | # to make sure that the chunk variant can be discriminated from the 33 | # non-chunk variant. 34 | type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment] 35 | 36 | @classmethod 37 | def get_lc_namespace(cls) -> List[str]: 38 | """Get the namespace of the sparkai object.""" 39 | return ["sparkai", "messages"] 40 | 41 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 42 | if isinstance(other, FunctionMessageChunk): 43 | if self.name != other.name: 44 | raise ValueError( 45 | "Cannot concatenate FunctionMessageChunks with different names." 46 | ) 47 | 48 | return self.__class__( 49 | name=self.name, 50 | content=merge_content(self.content, other.content), 51 | additional_kwargs=self._merge_kwargs_dict( 52 | self.additional_kwargs, other.additional_kwargs 53 | ), 54 | ) 55 | 56 | return super().__add__(other) 57 | 58 | class FunctionCallMessage(BaseMessage): 59 | """A FunctionCall Message from an AI.""" 60 | example: bool = False 61 | """Whether this Message is being passed in to the model as part of an example 62 | conversation. 63 | """ 64 | function_call: dict = {} 65 | type: Literal["ai"] = "assistant" 66 | 67 | @classmethod 68 | def get_lc_namespace(cls) -> List[str]: 69 | """Get the namespace of the langchain object.""" 70 | return ["sparkai", "messages"] 71 | 72 | 73 | FunctionCallMessage.update_forward_refs() 74 | 75 | 76 | 77 | class FunctionCallMessageChunk(FunctionCallMessage, BaseMessageChunk): 78 | """A FunctionCall Message chunk.""" 79 | name:str = "function_call" 80 | # Ignoring mypy re-assignment here since we're overriding the value 81 | # to make sure that the chunk variant can be discriminated from the 82 | # non-chunk variant. 83 | type: Literal["FunctionCallMessageChunk"] = "FunctionCallMessageChunk" # type: ignore[assignment] 84 | 85 | @classmethod 86 | def get_lc_namespace(cls) -> List[str]: 87 | """Get the namespace of the sparkai object.""" 88 | return ["sparkai", "messages"] 89 | 90 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 91 | if isinstance(other, FunctionCallMessageChunk): 92 | if self.name != other.name: 93 | raise ValueError( 94 | "Cannot concatenate FunctionCallMessageChunks with different names." 95 | ) 96 | 97 | return self.__class__( 98 | name=self.name, 99 | content=merge_content(self.content, other.content), 100 | function_call=self._merge_kwargs_dict(self.function_call,other.function_call), # function call no need chunk now 101 | additional_kwargs=self._merge_kwargs_dict( 102 | self.additional_kwargs, other.additional_kwargs 103 | ), 104 | ) 105 | return super().__add__(other) -------------------------------------------------------------------------------- /sparkai/core/messages/human.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal 2 | 3 | from sparkai.core.messages.base import BaseMessage, BaseMessageChunk 4 | 5 | 6 | class HumanMessage(BaseMessage): 7 | """A Message from a human.""" 8 | 9 | example: bool = False 10 | """Whether this Message is being passed in to the model as part of an example 11 | conversation. 12 | """ 13 | 14 | type: Literal["human"] = "human" 15 | 16 | @classmethod 17 | def get_lc_namespace(cls) -> List[str]: 18 | """Get the namespace of the sparkai object.""" 19 | return ["sparkai", "messages"] 20 | 21 | 22 | HumanMessage.update_forward_refs() 23 | 24 | 25 | class HumanMessageChunk(HumanMessage, BaseMessageChunk): 26 | """A Human Message chunk.""" 27 | 28 | # Ignoring mypy re-assignment here since we're overriding the value 29 | # to make sure that the chunk variant can be discriminated from the 30 | # non-chunk variant. 31 | type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501 32 | 33 | @classmethod 34 | def get_lc_namespace(cls) -> List[str]: 35 | """Get the namespace of the sparkai object.""" 36 | return ["sparkai", "messages"] 37 | -------------------------------------------------------------------------------- /sparkai/core/messages/image_chat.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Literal 2 | 3 | from sparkai.core.messages.base import ( 4 | BaseMessage, 5 | BaseMessageChunk, 6 | merge_content, 7 | ) 8 | 9 | 10 | class ImageChatMessage(BaseMessage): 11 | """A Message that can be assigned an arbitrary speaker (i.e. role).""" 12 | 13 | role: str 14 | 15 | """The speaker / role of the Message.""" 16 | 17 | type: Literal["image_chat"] = "image_chat" 18 | 19 | content_type: Literal["image","text"] = "image" 20 | 21 | @classmethod 22 | def get_lc_namespace(cls) -> List[str]: 23 | """Get the namespace of the sparkai object.""" 24 | return ["sparkai", "messages"] 25 | 26 | 27 | ImageChatMessage.update_forward_refs() 28 | 29 | 30 | class ImageChatMessageChunk(ImageChatMessage, BaseMessageChunk): 31 | """A Chat Message chunk.""" 32 | 33 | # Ignoring mypy re-assignment here since we're overriding the value 34 | # to make sure that the chunk variant can be discriminated from the 35 | # non-chunk variant. 36 | type: Literal["ImageChatMessageChunk"] = "ImageChatMessageChunk" # type: ignore 37 | 38 | @classmethod 39 | def get_lc_namespace(cls) -> List[str]: 40 | """Get the namespace of the sparkai object.""" 41 | return ["sparkai", "messages"] 42 | 43 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 44 | if isinstance(other, ImageChatMessageChunk): 45 | if self.role != other.role: 46 | raise ValueError( 47 | "Cannot concatenate ChatMessageChunks with different roles." 48 | ) 49 | 50 | return self.__class__( 51 | role=self.role, 52 | content=merge_content(self.content, other.content), 53 | additional_kwargs=self._merge_kwargs_dict( 54 | self.additional_kwargs, other.additional_kwargs 55 | ), 56 | ) 57 | elif isinstance(other, BaseMessageChunk): 58 | return self.__class__( 59 | role=self.role, 60 | content=merge_content(self.content, other.content), 61 | additional_kwargs=self._merge_kwargs_dict( 62 | self.additional_kwargs, other.additional_kwargs 63 | ), 64 | ) 65 | else: 66 | return super().__add__(other) 67 | -------------------------------------------------------------------------------- /sparkai/core/messages/system.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal 2 | 3 | from sparkai.core.messages.base import BaseMessage, BaseMessageChunk 4 | 5 | 6 | class SystemMessage(BaseMessage): 7 | """A Message for priming AI behavior, usually passed in as the first of a sequence 8 | of input messages. 9 | """ 10 | 11 | type: Literal["system"] = "system" 12 | 13 | @classmethod 14 | def get_lc_namespace(cls) -> List[str]: 15 | """Get the namespace of the sparkai object.""" 16 | return ["sparkai", "messages"] 17 | 18 | 19 | SystemMessage.update_forward_refs() 20 | 21 | 22 | class SystemMessageChunk(SystemMessage, BaseMessageChunk): 23 | """A System Message chunk.""" 24 | 25 | # Ignoring mypy re-assignment here since we're overriding the value 26 | # to make sure that the chunk variant can be discriminated from the 27 | # non-chunk variant. 28 | type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501 29 | 30 | @classmethod 31 | def get_lc_namespace(cls) -> List[str]: 32 | """Get the namespace of the sparkai object.""" 33 | return ["sparkai", "messages"] 34 | -------------------------------------------------------------------------------- /sparkai/core/messages/tool.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Literal 2 | 3 | from sparkai.core.messages.base import ( 4 | BaseMessage, 5 | BaseMessageChunk, 6 | merge_content, 7 | ) 8 | 9 | 10 | class ToolMessage(BaseMessage): 11 | """A Message for passing the result of executing a tool back to a model.""" 12 | 13 | tool_call_id: str 14 | """Tool call that this message is responding to.""" 15 | 16 | type: Literal["tool"] = "tool" 17 | 18 | @classmethod 19 | def get_lc_namespace(cls) -> List[str]: 20 | """Get the namespace of the sparkai object.""" 21 | return ["sparkai", "messages"] 22 | 23 | 24 | ToolMessage.update_forward_refs() 25 | 26 | 27 | class ToolMessageChunk(ToolMessage, BaseMessageChunk): 28 | """A Tool Message chunk.""" 29 | 30 | # Ignoring mypy re-assignment here since we're overriding the value 31 | # to make sure that the chunk variant can be discriminated from the 32 | # non-chunk variant. 33 | type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment] 34 | 35 | @classmethod 36 | def get_lc_namespace(cls) -> List[str]: 37 | """Get the namespace of the sparkai object.""" 38 | return ["sparkai", "messages"] 39 | 40 | def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore 41 | if isinstance(other, ToolMessageChunk): 42 | if self.tool_call_id != other.tool_call_id: 43 | raise ValueError( 44 | "Cannot concatenate ToolMessageChunks with different names." 45 | ) 46 | 47 | return self.__class__( 48 | tool_call_id=self.tool_call_id, 49 | content=merge_content(self.content, other.content), 50 | additional_kwargs=self._merge_kwargs_dict( 51 | self.additional_kwargs, other.additional_kwargs 52 | ), 53 | ) 54 | 55 | return super().__add__(other) 56 | -------------------------------------------------------------------------------- /sparkai/core/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | from sparkai.core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk 2 | from sparkai.core.outputs.chat_result import ChatResult 3 | from sparkai.core.outputs.generation import Generation, GenerationChunk 4 | from sparkai.core.outputs.llm_result import LLMResult 5 | from sparkai.core.outputs.run_info import RunInfo 6 | 7 | __all__ = [ 8 | "ChatGeneration", 9 | "ChatGenerationChunk", 10 | "ChatResult", 11 | "Generation", 12 | "GenerationChunk", 13 | "LLMResult", 14 | "RunInfo", 15 | ] 16 | -------------------------------------------------------------------------------- /sparkai/core/outputs/chat_generation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, List, Literal 4 | 5 | from sparkai.core.messages import BaseMessage, BaseMessageChunk 6 | from sparkai.core.outputs.generation import Generation 7 | from sparkai.core.pydantic_v1 import root_validator 8 | from sparkai.core.utils._merge import merge_dicts 9 | 10 | 11 | class ChatGeneration(Generation): 12 | """A single chat generation output.""" 13 | 14 | text: str = "" 15 | """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" 16 | message: BaseMessage 17 | """The message output by the chat model.""" 18 | # Override type to be ChatGeneration, ignore mypy error as this is intentional 19 | type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment] 20 | """Type is used exclusively for serialization purposes.""" 21 | 22 | @root_validator 23 | def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: 24 | """Set the text attribute to be the contents of the message.""" 25 | try: 26 | values["text"] = values["message"].content 27 | except (KeyError, AttributeError) as e: 28 | raise ValueError("Error while initializing ChatGeneration") from e 29 | return values 30 | 31 | @classmethod 32 | def get_lc_namespace(cls) -> List[str]: 33 | """Get the namespace of the depreciated object.""" 34 | return ["depreciated", "schema", "output"] 35 | 36 | 37 | class ChatGenerationChunk(ChatGeneration): 38 | """A ChatGeneration chunk, which can be concatenated with other 39 | ChatGeneration chunks. 40 | 41 | Attributes: 42 | message: The message chunk output by the chat model. 43 | """ 44 | 45 | message: BaseMessageChunk 46 | # Override type to be ChatGeneration, ignore mypy error as this is intentional 47 | type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501 48 | """Type is used exclusively for serialization purposes.""" 49 | 50 | @classmethod 51 | def get_lc_namespace(cls) -> List[str]: 52 | """Get the namespace of the depreciated object.""" 53 | return ["depreciated", "schema", "output"] 54 | 55 | def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: 56 | if isinstance(other, ChatGenerationChunk): 57 | generation_info = merge_dicts( 58 | self.generation_info or {}, 59 | other.generation_info or {}, 60 | ) 61 | return ChatGenerationChunk( 62 | message=self.message + other.message, 63 | generation_info=generation_info or None, 64 | ) 65 | else: 66 | raise TypeError( 67 | f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" 68 | ) 69 | -------------------------------------------------------------------------------- /sparkai/core/outputs/chat_result.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from sparkai.core.outputs.chat_generation import ChatGeneration 4 | from sparkai.core.pydantic_v1 import BaseModel 5 | 6 | 7 | class ChatResult(BaseModel): 8 | """Class that contains all results for a single chat model call.""" 9 | 10 | generations: List[ChatGeneration] 11 | """List of the chat generations. This is a List because an input can have multiple 12 | candidate generations. 13 | """ 14 | llm_output: Optional[dict] = None 15 | """For arbitrary LLM provider specific output.""" 16 | -------------------------------------------------------------------------------- /sparkai/core/outputs/generation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, List, Literal, Optional 4 | 5 | from sparkai.core.load import Serializable 6 | from sparkai.core.utils._merge import merge_dicts 7 | 8 | 9 | class Generation(Serializable): 10 | """A single text generation output.""" 11 | 12 | text: str 13 | """Generated text output.""" 14 | 15 | generation_info: Optional[Dict[str, Any]] = None 16 | """Raw response from the provider. May include things like the 17 | reason for finishing or token log probabilities. 18 | """ 19 | type: Literal["Generation"] = "Generation" 20 | """Type is used exclusively for serialization purposes.""" 21 | # TODO: add log probs as separate attribute 22 | 23 | @classmethod 24 | def is_lc_serializable(cls) -> bool: 25 | """Return whether this class is serializable.""" 26 | return True 27 | 28 | @classmethod 29 | def get_lc_namespace(cls) -> List[str]: 30 | """Get the namespace of the depreciated object.""" 31 | return ["depreciated", "schema", "output"] 32 | 33 | 34 | class GenerationChunk(Generation): 35 | """A Generation chunk, which can be concatenated with other Generation chunks.""" 36 | 37 | @classmethod 38 | def get_lc_namespace(cls) -> List[str]: 39 | """Get the namespace of the depreciated object.""" 40 | return ["depreciated", "schema", "output"] 41 | 42 | def __add__(self, other: GenerationChunk) -> GenerationChunk: 43 | if isinstance(other, GenerationChunk): 44 | generation_info = merge_dicts( 45 | self.generation_info or {}, 46 | other.generation_info or {}, 47 | ) 48 | return GenerationChunk( 49 | text=self.text + other.text, 50 | generation_info=generation_info or None, 51 | ) 52 | else: 53 | raise TypeError( 54 | f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" 55 | ) 56 | -------------------------------------------------------------------------------- /sparkai/core/outputs/llm_result.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import deepcopy 4 | from typing import List, Optional 5 | 6 | from sparkai.core.outputs.generation import Generation 7 | from sparkai.core.outputs.run_info import RunInfo 8 | from sparkai.core.pydantic_v1 import BaseModel 9 | 10 | 11 | class LLMResult(BaseModel): 12 | """Class that contains all results for a batched LLM call.""" 13 | 14 | generations: List[List[Generation]] 15 | """List of generated outputs. This is a List[List[]] because 16 | each input could have multiple candidate generations.""" 17 | llm_output: Optional[dict] = None 18 | """Arbitrary LLM provider-specific output.""" 19 | run: Optional[List[RunInfo]] = None 20 | """List of metadata info for model call for each input.""" 21 | 22 | def flatten(self) -> List[LLMResult]: 23 | """Flatten generations into a single list. 24 | 25 | Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult 26 | contains only a single Generation. If token usage information is available, 27 | it is kept only for the LLMResult corresponding to the top-choice 28 | Generation, to avoid over-counting of token usage downstream. 29 | 30 | Returns: 31 | List of LLMResults where each returned LLMResult contains a single 32 | Generation. 33 | """ 34 | llm_results = [] 35 | for i, gen_list in enumerate(self.generations): 36 | # Avoid double counting tokens in OpenAICallback 37 | if i == 0: 38 | llm_results.append( 39 | LLMResult( 40 | generations=[gen_list], 41 | llm_output=self.llm_output, 42 | ) 43 | ) 44 | else: 45 | if self.llm_output is not None: 46 | llm_output = deepcopy(self.llm_output) 47 | llm_output["token_usage"] = dict() 48 | else: 49 | llm_output = None 50 | llm_results.append( 51 | LLMResult( 52 | generations=[gen_list], 53 | llm_output=llm_output, 54 | ) 55 | ) 56 | return llm_results 57 | 58 | def __eq__(self, other: object) -> bool: 59 | """Check for LLMResult equality by ignoring any metadata related to runs.""" 60 | if not isinstance(other, LLMResult): 61 | return NotImplemented 62 | return ( 63 | self.generations == other.generations 64 | and self.llm_output == other.llm_output 65 | ) 66 | -------------------------------------------------------------------------------- /sparkai/core/outputs/run_info.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from uuid import UUID 4 | 5 | from sparkai.core.pydantic_v1 import BaseModel 6 | 7 | 8 | class RunInfo(BaseModel): 9 | """Class that contains metadata for a single execution of a Chain or model.""" 10 | 11 | run_id: UUID 12 | """A unique identifier for the model or chain run.""" 13 | -------------------------------------------------------------------------------- /sparkai/core/prompt_values.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List, Literal, Sequence 5 | 6 | from typing_extensions import TypedDict 7 | 8 | from sparkai.core.load.serializable import Serializable 9 | from sparkai.core.messages import ( 10 | AnyMessage, 11 | BaseMessage, 12 | HumanMessage, 13 | get_buffer_string, 14 | ) 15 | 16 | 17 | class PromptValue(Serializable, ABC): 18 | """Base abstract class for inputs to any language model. 19 | 20 | PromptValues can be converted to both LLM (pure text-generation) inputs and 21 | ChatModel inputs. 22 | """ 23 | 24 | @classmethod 25 | def is_lc_serializable(cls) -> bool: 26 | """Return whether this class is serializable.""" 27 | return True 28 | 29 | @classmethod 30 | def get_lc_namespace(cls) -> List[str]: 31 | """Get the namespace of the depreciated object.""" 32 | return ["depreciated", "schema", "prompt"] 33 | 34 | @abstractmethod 35 | def to_string(self) -> str: 36 | """Return prompt value as string.""" 37 | 38 | @abstractmethod 39 | def to_messages(self) -> List[BaseMessage]: 40 | """Return prompt as a list of Messages.""" 41 | 42 | 43 | class StringPromptValue(PromptValue): 44 | """String prompt value.""" 45 | 46 | text: str 47 | """Prompt text.""" 48 | type: Literal["StringPromptValue"] = "StringPromptValue" 49 | 50 | @classmethod 51 | def get_lc_namespace(cls) -> List[str]: 52 | """Get the namespace of the depreciated object.""" 53 | return ["depreciated", "prompts", "base"] 54 | 55 | def to_string(self) -> str: 56 | """Return prompt as string.""" 57 | return self.text 58 | 59 | def to_messages(self) -> List[BaseMessage]: 60 | """Return prompt as messages.""" 61 | return [HumanMessage(content=self.text)] 62 | 63 | 64 | class ChatPromptValue(PromptValue): 65 | """Chat prompt value. 66 | 67 | A type of a prompt value that is built from messages. 68 | """ 69 | 70 | messages: Sequence[BaseMessage] 71 | """List of messages.""" 72 | 73 | def to_string(self) -> str: 74 | """Return prompt as string.""" 75 | return get_buffer_string(self.messages) 76 | 77 | def to_messages(self) -> List[BaseMessage]: 78 | """Return prompt as a list of messages.""" 79 | return list(self.messages) 80 | 81 | @classmethod 82 | def get_lc_namespace(cls) -> List[str]: 83 | """Get the namespace of the depreciated object.""" 84 | return ["depreciated", "prompts", "chat"] 85 | 86 | 87 | class ImageURL(TypedDict, total=False): 88 | detail: Literal["auto", "low", "high"] 89 | """Specifies the detail level of the image.""" 90 | 91 | url: str 92 | """Either a URL of the image or the base64 encoded image data.""" 93 | 94 | 95 | class ImagePromptValue(PromptValue): 96 | """Image prompt value.""" 97 | 98 | image_url: ImageURL 99 | """Prompt image.""" 100 | type: Literal["ImagePromptValue"] = "ImagePromptValue" 101 | 102 | def to_string(self) -> str: 103 | """Return prompt as string.""" 104 | return self.image_url["url"] 105 | 106 | def to_messages(self) -> List[BaseMessage]: 107 | """Return prompt as messages.""" 108 | return [HumanMessage(content=[self.image_url])] 109 | 110 | 111 | class ChatPromptValueConcrete(ChatPromptValue): 112 | """Chat prompt value which explicitly lists out the message types it accepts. 113 | For use in external schemas.""" 114 | 115 | messages: Sequence[AnyMessage] 116 | 117 | type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" 118 | 119 | @classmethod 120 | def get_lc_namespace(cls) -> List[str]: 121 | """Get the namespace of the depreciated object.""" 122 | return ["depreciated", "prompts", "chat"] 123 | -------------------------------------------------------------------------------- /sparkai/core/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | """**Prompt** is the input to the model. 2 | 3 | Prompt is often constructed 4 | from multiple components. Prompt classes and functions make constructing 5 | and working with prompts easy. 6 | 7 | **Class hierarchy:** 8 | 9 | .. code-block:: 10 | 11 | BasePromptTemplate --> PipelinePromptTemplate 12 | StringPromptTemplate --> PromptTemplate 13 | FewShotPromptTemplate 14 | FewShotPromptWithTemplates 15 | BaseChatPromptTemplate --> AutoGPTPrompt 16 | ChatPromptTemplate --> AgentScratchPadChatPromptTemplate 17 | 18 | 19 | 20 | BaseMessagePromptTemplate --> MessagesPlaceholder 21 | BaseStringMessagePromptTemplate --> ChatMessagePromptTemplate 22 | HumanMessagePromptTemplate 23 | AIMessagePromptTemplate 24 | SystemMessagePromptTemplate 25 | 26 | """ # noqa: E501 27 | from langchain_core.prompts.base import BasePromptTemplate, format_document 28 | from langchain_core.prompts.chat import ( 29 | AIMessagePromptTemplate, 30 | BaseChatPromptTemplate, 31 | ChatMessagePromptTemplate, 32 | ChatPromptTemplate, 33 | HumanMessagePromptTemplate, 34 | MessagesPlaceholder, 35 | SystemMessagePromptTemplate, 36 | ) 37 | from langchain_core.prompts.few_shot import ( 38 | FewShotChatMessagePromptTemplate, 39 | FewShotPromptTemplate, 40 | ) 41 | from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemplates 42 | from langchain_core.prompts.loading import load_prompt 43 | from langchain_core.prompts.pipeline import PipelinePromptTemplate 44 | from langchain_core.prompts.prompt import PromptTemplate 45 | from langchain_core.prompts.string import ( 46 | StringPromptTemplate, 47 | check_valid_template, 48 | get_template_variables, 49 | jinja2_formatter, 50 | validate_jinja2, 51 | ) 52 | 53 | __all__ = [ 54 | "AIMessagePromptTemplate", 55 | "BaseChatPromptTemplate", 56 | "BasePromptTemplate", 57 | "ChatMessagePromptTemplate", 58 | "ChatPromptTemplate", 59 | "FewShotPromptTemplate", 60 | "FewShotPromptWithTemplates", 61 | "FewShotChatMessagePromptTemplate", 62 | "HumanMessagePromptTemplate", 63 | "MessagesPlaceholder", 64 | "PipelinePromptTemplate", 65 | "PromptTemplate", 66 | "StringPromptTemplate", 67 | "SystemMessagePromptTemplate", 68 | "load_prompt", 69 | "format_document", 70 | "check_valid_template", 71 | "get_template_variables", 72 | "jinja2_formatter", 73 | "validate_jinja2", 74 | ] 75 | -------------------------------------------------------------------------------- /sparkai/core/prompts/image.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue 4 | from langchain_core.prompts.base import BasePromptTemplate 5 | from langchain_core.pydantic_v1 import Field 6 | from langchain_core.utils import image as image_utils 7 | 8 | 9 | class ImagePromptTemplate(BasePromptTemplate[ImageURL]): 10 | """An image prompt template for a multimodal model.""" 11 | 12 | template: dict = Field(default_factory=dict) 13 | """Template for the prompt.""" 14 | 15 | def __init__(self, **kwargs: Any) -> None: 16 | if "input_variables" not in kwargs: 17 | kwargs["input_variables"] = [] 18 | 19 | overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail")) 20 | if overlap: 21 | raise ValueError( 22 | "input_variables for the image template cannot contain" 23 | " any of 'url', 'path', or 'detail'." 24 | f" Found: {overlap}" 25 | ) 26 | super().__init__(**kwargs) 27 | 28 | @property 29 | def _prompt_type(self) -> str: 30 | """Return the prompt type key.""" 31 | return "image-prompt" 32 | 33 | def format_prompt(self, **kwargs: Any) -> PromptValue: 34 | """Create Chat Messages.""" 35 | return ImagePromptValue(image_url=self.format(**kwargs)) 36 | 37 | def format( 38 | self, 39 | **kwargs: Any, 40 | ) -> ImageURL: 41 | """Format the prompt with the inputs. 42 | 43 | Args: 44 | kwargs: Any arguments to be passed to the prompt template. 45 | 46 | Returns: 47 | A formatted string. 48 | 49 | Example: 50 | 51 | .. code-block:: python 52 | 53 | prompt.format(variable1="foo") 54 | """ 55 | formatted = {} 56 | for k, v in self.template.items(): 57 | if isinstance(v, str): 58 | formatted[k] = v.format(**kwargs) 59 | else: 60 | formatted[k] = v 61 | url = kwargs.get("url") or formatted.get("url") 62 | path = kwargs.get("path") or formatted.get("path") 63 | detail = kwargs.get("detail") or formatted.get("detail") 64 | if not url and not path: 65 | raise ValueError("Must provide either url or path.") 66 | if not url: 67 | if not isinstance(path, str): 68 | raise ValueError("path must be a string.") 69 | url = image_utils.image_to_data_url(path) 70 | if not isinstance(url, str): 71 | raise ValueError("url must be a string.") 72 | output: ImageURL = {"url": url} 73 | if detail: 74 | # Don't check literal values here: let the API check them 75 | output["detail"] = detail # type: ignore[typeddict-item] 76 | return output 77 | 78 | def pretty_repr(self, html: bool = False) -> str: 79 | raise NotImplementedError() 80 | -------------------------------------------------------------------------------- /sparkai/core/prompts/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | from langchain_core.prompt_values import PromptValue 4 | from langchain_core.prompts.base import BasePromptTemplate 5 | from langchain_core.prompts.chat import BaseChatPromptTemplate 6 | from langchain_core.pydantic_v1 import root_validator 7 | 8 | 9 | def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: 10 | return {k: inputs[k] for k in input_variables} 11 | 12 | 13 | class PipelinePromptTemplate(BasePromptTemplate): 14 | """A prompt template for composing multiple prompt templates together. 15 | 16 | This can be useful when you want to reuse parts of prompts. 17 | A PipelinePrompt consists of two main parts: 18 | - final_prompt: This is the final prompt that is returned 19 | - pipeline_prompts: This is a list of tuples, consisting 20 | of a string (`name`) and a Prompt Template. 21 | Each PromptTemplate will be formatted and then passed 22 | to future prompt templates as a variable with 23 | the same name as `name` 24 | """ 25 | 26 | final_prompt: BasePromptTemplate 27 | """The final prompt that is returned.""" 28 | pipeline_prompts: List[Tuple[str, BasePromptTemplate]] 29 | """A list of tuples, consisting of a string (`name`) and a Prompt Template.""" 30 | 31 | @classmethod 32 | def get_lc_namespace(cls) -> List[str]: 33 | """Get the namespace of the depreciated object.""" 34 | return ["depreciated", "prompts", "pipeline"] 35 | 36 | @root_validator(pre=True) 37 | def get_input_variables(cls, values: Dict) -> Dict: 38 | """Get input variables.""" 39 | created_variables = set() 40 | all_variables = set() 41 | for k, prompt in values["pipeline_prompts"]: 42 | created_variables.add(k) 43 | all_variables.update(prompt.input_variables) 44 | values["input_variables"] = list(all_variables.difference(created_variables)) 45 | return values 46 | 47 | def format_prompt(self, **kwargs: Any) -> PromptValue: 48 | for k, prompt in self.pipeline_prompts: 49 | _inputs = _get_inputs(kwargs, prompt.input_variables) 50 | if isinstance(prompt, BaseChatPromptTemplate): 51 | kwargs[k] = prompt.format_messages(**_inputs) 52 | else: 53 | kwargs[k] = prompt.format(**_inputs) 54 | _inputs = _get_inputs(kwargs, self.final_prompt.input_variables) 55 | return self.final_prompt.format_prompt(**_inputs) 56 | 57 | def format(self, **kwargs: Any) -> str: 58 | return self.format_prompt(**kwargs).to_string() 59 | 60 | @property 61 | def _prompt_type(self) -> str: 62 | raise ValueError 63 | -------------------------------------------------------------------------------- /sparkai/core/pydantic_v1/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | 3 | ## Create namespaces for pydantic v1 and v2. 4 | # This code must stay at the top of the file before other modules may 5 | # attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules. 6 | # 7 | # This hack is done for the following reasons: 8 | # * Langchain will attempt to remain compatible with both pydantic v1 and v2 since 9 | # both dependencies and dependents may be stuck on either version of v1 or v2. 10 | # * Creating namespaces for pydantic v1 and v2 should allow us to write code that 11 | # unambiguously uses either v1 or v2 API. 12 | # * This change is easier to roll out and roll back. 13 | 14 | try: 15 | from pydantic.v1 import * # noqa: F403 # type: ignore 16 | except ImportError: 17 | from pydantic import * # noqa: F403 # type: ignore 18 | 19 | 20 | try: 21 | _PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0]) 22 | except metadata.PackageNotFoundError: 23 | _PYDANTIC_MAJOR_VERSION = 0 24 | -------------------------------------------------------------------------------- /sparkai/core/pydantic_v1/dataclasses.py: -------------------------------------------------------------------------------- 1 | try: 2 | from pydantic.v1.dataclasses import * # noqa: F403 3 | except ImportError: 4 | from pydantic.dataclasses import * # noqa: F403 5 | -------------------------------------------------------------------------------- /sparkai/core/pydantic_v1/main.py: -------------------------------------------------------------------------------- 1 | try: 2 | from pydantic.v1.main import * # noqa: F403 3 | except ImportError: 4 | from pydantic.main import * # noqa: F403 5 | -------------------------------------------------------------------------------- /sparkai/core/runnables/__init__.py: -------------------------------------------------------------------------------- 1 | """LangChain **Runnable** and the **LangChain Expression Language (LCEL)**. 2 | 3 | The LangChain Expression Language (LCEL) offers a declarative method to build 4 | production-grade programs that harness the power of LLMs. 5 | 6 | Programs created using LCEL and LangChain Runnables inherently support 7 | synchronous, asynchronous, batch, and streaming operations. 8 | 9 | Support for **async** allows servers hosting LCEL based programs to scale better 10 | for higher concurrent loads. 11 | 12 | **Streaming** of intermediate outputs as they're being generated allows for 13 | creating more responsive UX. 14 | 15 | This module contains schema and implementation of LangChain Runnables primitives. 16 | """ 17 | from sparkai.core.runnables.base import ( 18 | Runnable, 19 | RunnableBinding, 20 | RunnableGenerator, 21 | RunnableLambda, 22 | RunnableMap, 23 | RunnableParallel, 24 | RunnableSequence, 25 | RunnableSerializable, 26 | chain, 27 | ) 28 | from sparkai.core.runnables.branch import RunnableBranch 29 | from sparkai.core.runnables.config import ( 30 | RunnableConfig, 31 | ensure_config, 32 | get_config_list, 33 | patch_config, 34 | run_in_executor, 35 | ) 36 | from sparkai.core.runnables.fallbacks import RunnableWithFallbacks 37 | from sparkai.core.runnables.passthrough import ( 38 | RunnableAssign, 39 | RunnablePassthrough, 40 | RunnablePick, 41 | ) 42 | from sparkai.core.runnables.router import RouterInput, RouterRunnable 43 | from sparkai.core.runnables.utils import ( 44 | AddableDict, 45 | ConfigurableField, 46 | ConfigurableFieldMultiOption, 47 | ConfigurableFieldSingleOption, 48 | ConfigurableFieldSpec, 49 | aadd, 50 | add, 51 | ) 52 | 53 | __all__ = [ 54 | "chain", 55 | "AddableDict", 56 | "ConfigurableField", 57 | "ConfigurableFieldSingleOption", 58 | "ConfigurableFieldMultiOption", 59 | "ConfigurableFieldSpec", 60 | "ensure_config", 61 | "run_in_executor", 62 | "patch_config", 63 | "RouterInput", 64 | "RouterRunnable", 65 | "Runnable", 66 | "RunnableSerializable", 67 | "RunnableBinding", 68 | "RunnableBranch", 69 | "RunnableConfig", 70 | "RunnableGenerator", 71 | "RunnableLambda", 72 | "RunnableMap", 73 | "RunnableParallel", 74 | "RunnablePassthrough", 75 | "RunnableAssign", 76 | "RunnablePick", 77 | "RunnableSequence", 78 | "RunnableWithFallbacks", 79 | "get_config_list", 80 | "aadd", 81 | "add", 82 | ] 83 | -------------------------------------------------------------------------------- /sparkai/core/tracers/__init__.py: -------------------------------------------------------------------------------- 1 | """**Tracers** are classes for tracing runs. 2 | 3 | **Class hierarchy:** 4 | 5 | .. code-block:: 6 | 7 | BaseCallbackHandler --> BaseTracer --> Tracer # Examples: LangChainTracer, RootListenersTracer 8 | --> # Examples: LogStreamCallbackHandler 9 | """ # noqa: E501 10 | 11 | __all__ = [ 12 | "BaseTracer", 13 | "EvaluatorCallbackHandler", 14 | "LangChainTracer", 15 | "ConsoleCallbackHandler", 16 | "RunLog", 17 | "RunLogPatch", 18 | "LogStreamCallbackHandler", 19 | ] 20 | 21 | from sparkai.core.tracers.base import BaseTracer 22 | from sparkai.core.tracers.evaluation import EvaluatorCallbackHandler 23 | from sparkai.core.tracers.langchain import LangChainTracer 24 | from sparkai.core.tracers.log_stream import ( 25 | LogStreamCallbackHandler, 26 | RunLog, 27 | RunLogPatch, 28 | ) 29 | from sparkai.core.tracers.stdout import ConsoleCallbackHandler 30 | -------------------------------------------------------------------------------- /sparkai/core/tracers/root_listeners.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Union 2 | from uuid import UUID 3 | 4 | from sparkai.core.runnables.config import ( 5 | RunnableConfig, 6 | call_func_with_variable_args, 7 | ) 8 | from sparkai.core.tracers.base import BaseTracer 9 | from sparkai.core.tracers.schemas import Run 10 | 11 | Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] 12 | 13 | 14 | class RootListenersTracer(BaseTracer): 15 | """Tracer that calls listeners on run start, end, and error.""" 16 | 17 | def __init__( 18 | self, 19 | *, 20 | config: RunnableConfig, 21 | on_start: Optional[Listener], 22 | on_end: Optional[Listener], 23 | on_error: Optional[Listener], 24 | ) -> None: 25 | super().__init__() 26 | 27 | self.config = config 28 | self._arg_on_start = on_start 29 | self._arg_on_end = on_end 30 | self._arg_on_error = on_error 31 | self.root_id: Optional[UUID] = None 32 | 33 | def _persist_run(self, run: Run) -> None: 34 | # This is a legacy method only called once for an entire run tree 35 | # therefore not useful here 36 | pass 37 | 38 | def _on_run_create(self, run: Run) -> None: 39 | if self.root_id is not None: 40 | return 41 | 42 | self.root_id = run.id 43 | 44 | if self._arg_on_start is not None: 45 | call_func_with_variable_args(self._arg_on_start, run, self.config) 46 | 47 | def _on_run_update(self, run: Run) -> None: 48 | if run.id != self.root_id: 49 | return 50 | 51 | if run.error is None: 52 | if self._arg_on_end is not None: 53 | call_func_with_variable_args(self._arg_on_end, run, self.config) 54 | else: 55 | if self._arg_on_error is not None: 56 | call_func_with_variable_args(self._arg_on_error, run, self.config) 57 | -------------------------------------------------------------------------------- /sparkai/core/tracers/run_collector.py: -------------------------------------------------------------------------------- 1 | """A tracer that collects all nested runs in a list.""" 2 | 3 | from typing import Any, List, Optional, Union 4 | from uuid import UUID 5 | 6 | from sparkai.core.tracers.base import BaseTracer 7 | from sparkai.core.tracers.schemas import Run 8 | 9 | 10 | class RunCollectorCallbackHandler(BaseTracer): 11 | """ 12 | Tracer that collects all nested runs in a list. 13 | 14 | This tracer is useful for inspection and evaluation purposes. 15 | 16 | Parameters 17 | ---------- 18 | example_id : Optional[Union[UUID, str]], default=None 19 | The ID of the example being traced. It can be either a UUID or a string. 20 | """ 21 | 22 | name: str = "run-collector_callback_handler" 23 | 24 | def __init__( 25 | self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any 26 | ) -> None: 27 | """ 28 | Initialize the RunCollectorCallbackHandler. 29 | 30 | Parameters 31 | ---------- 32 | example_id : Optional[Union[UUID, str]], default=None 33 | The ID of the example being traced. It can be either a UUID or a string. 34 | """ 35 | super().__init__(**kwargs) 36 | self.example_id = ( 37 | UUID(example_id) if isinstance(example_id, str) else example_id 38 | ) 39 | self.traced_runs: List[Run] = [] 40 | 41 | def _persist_run(self, run: Run) -> None: 42 | """ 43 | Persist a run by adding it to the traced_runs list. 44 | 45 | Parameters 46 | ---------- 47 | run : Run 48 | The run to be persisted. 49 | """ 50 | run_ = run.copy() 51 | run_.reference_example_id = self.example_id 52 | self.traced_runs.append(run_) 53 | -------------------------------------------------------------------------------- /sparkai/core/tracers/schemas.py: -------------------------------------------------------------------------------- 1 | """Schemas for tracers.""" 2 | from __future__ import annotations 3 | 4 | import datetime 5 | import warnings 6 | from typing import Any, Dict, List, Optional, Type 7 | from uuid import UUID 8 | 9 | 10 | from sparkai.core._api import deprecated 11 | from sparkai.core.outputs import LLMResult 12 | from sparkai.core.pydantic_v1 import BaseModel, Field, root_validator 13 | 14 | 15 | 16 | 17 | @deprecated("0.1.0", removal="0.2.0") 18 | class TracerSessionV1Base(BaseModel): 19 | """Base class for TracerSessionV1.""" 20 | 21 | start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) 22 | name: Optional[str] = None 23 | extra: Optional[Dict[str, Any]] = None 24 | 25 | 26 | @deprecated("0.1.0", removal="0.2.0") 27 | class TracerSessionV1Create(TracerSessionV1Base): 28 | """Create class for TracerSessionV1.""" 29 | 30 | 31 | @deprecated("0.1.0", removal="0.2.0") 32 | class TracerSessionV1(TracerSessionV1Base): 33 | """TracerSessionV1 schema.""" 34 | 35 | id: int 36 | 37 | 38 | @deprecated("0.1.0", removal="0.2.0") 39 | class TracerSessionBase(TracerSessionV1Base): 40 | """Base class for TracerSession.""" 41 | 42 | tenant_id: UUID 43 | 44 | 45 | @deprecated("0.1.0", removal="0.2.0") 46 | class TracerSession(TracerSessionBase): 47 | """TracerSessionV1 schema for the V2 API.""" 48 | 49 | id: UUID 50 | 51 | 52 | @deprecated("0.1.0", alternative="Run", removal="0.2.0") 53 | class BaseRun(BaseModel): 54 | """Base class for Run.""" 55 | 56 | uuid: str 57 | parent_uuid: Optional[str] = None 58 | start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) 59 | end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) 60 | extra: Optional[Dict[str, Any]] = None 61 | execution_order: int 62 | child_execution_order: int 63 | serialized: Dict[str, Any] 64 | session_id: int 65 | error: Optional[str] = None 66 | 67 | 68 | @deprecated("0.1.0", alternative="Run", removal="0.2.0") 69 | class LLMRun(BaseRun): 70 | """Class for LLMRun.""" 71 | 72 | prompts: List[str] 73 | response: Optional[LLMResult] = None 74 | 75 | 76 | @deprecated("0.1.0", alternative="Run", removal="0.2.0") 77 | class ChainRun(BaseRun): 78 | """Class for ChainRun.""" 79 | 80 | inputs: Dict[str, Any] 81 | outputs: Optional[Dict[str, Any]] = None 82 | child_llm_runs: List[LLMRun] = Field(default_factory=list) 83 | child_chain_runs: List[ChainRun] = Field(default_factory=list) 84 | child_tool_runs: List[ToolRun] = Field(default_factory=list) 85 | 86 | 87 | @deprecated("0.1.0", alternative="Run", removal="0.2.0") 88 | class ToolRun(BaseRun): 89 | """Class for ToolRun.""" 90 | 91 | tool_input: str 92 | output: Optional[str] = None 93 | action: str 94 | child_llm_runs: List[LLMRun] = Field(default_factory=list) 95 | child_chain_runs: List[ChainRun] = Field(default_factory=list) 96 | child_tool_runs: List[ToolRun] = Field(default_factory=list) 97 | 98 | 99 | # Begin V2 API Schemas 100 | 101 | 102 | 103 | 104 | __all__ = [ 105 | "BaseRun", 106 | "ChainRun", 107 | "LLMRun", 108 | "ToolRun", 109 | "TracerSession", 110 | "TracerSessionBase", 111 | "TracerSessionV1", 112 | "TracerSessionV1Base", 113 | "TracerSessionV1Create", 114 | ] 115 | -------------------------------------------------------------------------------- /sparkai/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | **Utility functions** for LangChain. 3 | 4 | These functions do not depend on any other LangChain module. 5 | """ 6 | 7 | from sparkai.core.utils import image 8 | from sparkai.core.utils.env import get_from_dict_or_env, get_from_env 9 | from sparkai.core.utils.formatting import StrictFormatter, formatter 10 | from sparkai.core.utils.input import ( 11 | get_bolded_text, 12 | get_color_mapping, 13 | get_colored_text, 14 | print_text, 15 | ) 16 | from sparkai.core.utils.loading import try_load_from_hub 17 | from sparkai.core.utils.strings import comma_list, stringify_dict, stringify_value 18 | from sparkai.core.utils.utils import ( 19 | build_extra_kwargs, 20 | check_package_version, 21 | convert_to_secret_str, 22 | get_pydantic_field_names, 23 | guard_import, 24 | mock_now, 25 | raise_for_status_with_text, 26 | xor_args, 27 | ) 28 | 29 | __all__ = [ 30 | "StrictFormatter", 31 | "check_package_version", 32 | "convert_to_secret_str", 33 | "formatter", 34 | "get_bolded_text", 35 | "get_color_mapping", 36 | "get_colored_text", 37 | "get_pydantic_field_names", 38 | "guard_import", 39 | "mock_now", 40 | "print_text", 41 | "raise_for_status_with_text", 42 | "xor_args", 43 | "try_load_from_hub", 44 | "build_extra_kwargs", 45 | "image", 46 | "get_from_env", 47 | "get_from_dict_or_env", 48 | "stringify_dict", 49 | "comma_list", 50 | "stringify_value", 51 | ] 52 | -------------------------------------------------------------------------------- /sparkai/core/utils/_merge.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict 4 | 5 | 6 | def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: 7 | """Merge two dicts, handling specific scenarios where a key exists in both 8 | dictionaries but has a value of None in 'left'. In such cases, the method uses the 9 | value from 'right' for that key in the merged dictionary. 10 | 11 | Example: 12 | If left = {"function_call": {"arguments": None}} and 13 | right = {"function_call": {"arguments": "{\n"}} 14 | then, after merging, for the key "function_call", 15 | the value from 'right' is used, 16 | resulting in merged = {"function_call": {"arguments": "{\n"}}. 17 | """ 18 | merged = left.copy() 19 | for k, v in right.items(): 20 | if k not in merged: 21 | merged[k] = v 22 | elif merged[k] is None and v: 23 | merged[k] = v 24 | elif v is None: 25 | continue 26 | elif merged[k] == v: 27 | continue 28 | elif type(merged[k]) != type(v): 29 | raise TypeError( 30 | f'additional_kwargs["{k}"] already exists in this message,' 31 | " but with a different type." 32 | ) 33 | elif isinstance(merged[k], str): 34 | merged[k] += v 35 | elif isinstance(merged[k], dict): 36 | merged[k] = merge_dicts(merged[k], v) 37 | elif isinstance(merged[k], list): 38 | merged[k] = merged[k] + v 39 | else: 40 | raise TypeError( 41 | f"Additional kwargs key {k} already exists in left dict and value has " 42 | f"unsupported type {type(merged[k])}." 43 | ) 44 | return merged 45 | -------------------------------------------------------------------------------- /sparkai/core/utils/env.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from typing import Any, Dict, Optional 5 | 6 | 7 | def env_var_is_set(env_var: str) -> bool: 8 | """Check if an environment variable is set. 9 | 10 | Args: 11 | env_var (str): The name of the environment variable. 12 | 13 | Returns: 14 | bool: True if the environment variable is set, False otherwise. 15 | """ 16 | return env_var in os.environ and os.environ[env_var] not in ( 17 | "", 18 | "0", 19 | "false", 20 | "False", 21 | ) 22 | 23 | 24 | def get_from_dict_or_env( 25 | data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None 26 | ) -> str: 27 | """Get a value from a dictionary or an environment variable.""" 28 | if key in data and data[key]: 29 | return data[key] 30 | else: 31 | return get_from_env(key, env_key, default=default) 32 | 33 | 34 | def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: 35 | """Get a value from a dictionary or an environment variable.""" 36 | if env_key in os.environ and os.environ[env_key]: 37 | return os.environ[env_key] 38 | elif default is not None: 39 | return default 40 | else: 41 | raise ValueError( 42 | f"Did not find {key}, please add an environment variable" 43 | f" `{env_key}` which contains it, or pass" 44 | f" `{key}` as a named parameter." 45 | ) 46 | -------------------------------------------------------------------------------- /sparkai/core/utils/formatting.py: -------------------------------------------------------------------------------- 1 | """Utilities for formatting strings.""" 2 | from string import Formatter 3 | from typing import Any, List, Mapping, Sequence 4 | 5 | 6 | class StrictFormatter(Formatter): 7 | """A subclass of formatter that checks for extra keys.""" 8 | 9 | def vformat( 10 | self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] 11 | ) -> str: 12 | """Check that no arguments are provided.""" 13 | if len(args) > 0: 14 | raise ValueError( 15 | "No arguments should be provided, " 16 | "everything should be passed as keyword arguments." 17 | ) 18 | return super().vformat(format_string, args, kwargs) 19 | 20 | def validate_input_variables( 21 | self, format_string: str, input_variables: List[str] 22 | ) -> None: 23 | dummy_inputs = {input_variable: "foo" for input_variable in input_variables} 24 | super().format(format_string, **dummy_inputs) 25 | 26 | 27 | formatter = StrictFormatter() 28 | -------------------------------------------------------------------------------- /sparkai/core/utils/html.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Optional, Sequence, Union 3 | from urllib.parse import urljoin, urlparse 4 | 5 | PREFIXES_TO_IGNORE = ("javascript:", "mailto:", "#") 6 | SUFFIXES_TO_IGNORE = ( 7 | ".css", 8 | ".js", 9 | ".ico", 10 | ".png", 11 | ".jpg", 12 | ".jpeg", 13 | ".gif", 14 | ".svg", 15 | ".csv", 16 | ".bz2", 17 | ".zip", 18 | ".epub", 19 | ) 20 | SUFFIXES_TO_IGNORE_REGEX = ( 21 | "(?!" + "|".join([re.escape(s) + r"[\#'\"]" for s in SUFFIXES_TO_IGNORE]) + ")" 22 | ) 23 | PREFIXES_TO_IGNORE_REGEX = ( 24 | "(?!" + "|".join([re.escape(s) for s in PREFIXES_TO_IGNORE]) + ")" 25 | ) 26 | DEFAULT_LINK_REGEX = ( 27 | rf"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)[\#'\"]" 28 | ) 29 | 30 | 31 | def find_all_links( 32 | raw_html: str, *, pattern: Union[str, re.Pattern, None] = None 33 | ) -> List[str]: 34 | """Extract all links from a raw html string. 35 | 36 | Args: 37 | raw_html: original html. 38 | pattern: Regex to use for extracting links from raw html. 39 | 40 | Returns: 41 | List[str]: all links 42 | """ 43 | pattern = pattern or DEFAULT_LINK_REGEX 44 | return list(set(re.findall(pattern, raw_html))) 45 | 46 | 47 | def extract_sub_links( 48 | raw_html: str, 49 | url: str, 50 | *, 51 | base_url: Optional[str] = None, 52 | pattern: Union[str, re.Pattern, None] = None, 53 | prevent_outside: bool = True, 54 | exclude_prefixes: Sequence[str] = (), 55 | ) -> List[str]: 56 | """Extract all links from a raw html string and convert into absolute paths. 57 | 58 | Args: 59 | raw_html: original html. 60 | url: the url of the html. 61 | base_url: the base url to check for outside links against. 62 | pattern: Regex to use for extracting links from raw html. 63 | prevent_outside: If True, ignore external links which are not children 64 | of the base url. 65 | exclude_prefixes: Exclude any URLs that start with one of these prefixes. 66 | 67 | Returns: 68 | List[str]: sub links 69 | """ 70 | base_url_to_use = base_url if base_url is not None else url 71 | parsed_base_url = urlparse(base_url_to_use) 72 | all_links = find_all_links(raw_html, pattern=pattern) 73 | absolute_paths = set() 74 | for link in all_links: 75 | parsed_link = urlparse(link) 76 | # Some may be absolute links like https://to/path 77 | if parsed_link.scheme == "http" or parsed_link.scheme == "https": 78 | absolute_path = link 79 | # Some may have omitted the protocol like //to/path 80 | elif link.startswith("//"): 81 | absolute_path = f"{urlparse(url).scheme}:{link}" 82 | else: 83 | absolute_path = urljoin(url, parsed_link.path) 84 | absolute_paths.add(absolute_path) 85 | 86 | results = [] 87 | for path in absolute_paths: 88 | if any(path.startswith(exclude_prefix) for exclude_prefix in exclude_prefixes): 89 | continue 90 | 91 | if prevent_outside: 92 | parsed_path = urlparse(path) 93 | 94 | if parsed_base_url.netloc != parsed_path.netloc: 95 | continue 96 | 97 | # Will take care of verifying rest of path after netloc 98 | # if it's more specific 99 | if not path.startswith(base_url_to_use): 100 | continue 101 | 102 | results.append(path) 103 | return results 104 | -------------------------------------------------------------------------------- /sparkai/core/utils/image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import mimetypes 3 | 4 | 5 | def encode_image(image_path: str) -> str: 6 | """Get base64 string from image URI.""" 7 | with open(image_path, "rb") as image_file: 8 | return base64.b64encode(image_file.read()).decode("utf-8") 9 | 10 | 11 | def image_to_data_url(image_path: str) -> str: 12 | encoding = encode_image(image_path) 13 | mime_type = mimetypes.guess_type(image_path)[0] 14 | return f"data:{mime_type};base64,{encoding}" 15 | -------------------------------------------------------------------------------- /sparkai/core/utils/input.py: -------------------------------------------------------------------------------- 1 | """Handle chained inputs.""" 2 | from typing import Dict, List, Optional, TextIO 3 | 4 | _TEXT_COLOR_MAPPING = { 5 | "blue": "36;1", 6 | "yellow": "33;1", 7 | "pink": "38;5;200", 8 | "green": "32;1", 9 | "red": "31;1", 10 | } 11 | 12 | 13 | def get_color_mapping( 14 | items: List[str], excluded_colors: Optional[List] = None 15 | ) -> Dict[str, str]: 16 | """Get mapping for items to a support color.""" 17 | colors = list(_TEXT_COLOR_MAPPING.keys()) 18 | if excluded_colors is not None: 19 | colors = [c for c in colors if c not in excluded_colors] 20 | color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} 21 | return color_mapping 22 | 23 | 24 | def get_colored_text(text: str, color: str) -> str: 25 | """Get colored text.""" 26 | color_str = _TEXT_COLOR_MAPPING[color] 27 | return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" 28 | 29 | 30 | def get_bolded_text(text: str) -> str: 31 | """Get bolded text.""" 32 | return f"\033[1m{text}\033[0m" 33 | 34 | 35 | def print_text( 36 | text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None 37 | ) -> None: 38 | """Print text with highlighting and no end characters.""" 39 | text_to_print = get_colored_text(text, color) if color else text 40 | print(text_to_print, end=end, file=file) 41 | if file: 42 | file.flush() # ensure all printed content are written to file 43 | -------------------------------------------------------------------------------- /sparkai/core/utils/interactive_env.py: -------------------------------------------------------------------------------- 1 | def is_interactive_env() -> bool: 2 | """Determine if running within IPython or Jupyter.""" 3 | import sys 4 | 5 | return hasattr(sys, "ps2") 6 | -------------------------------------------------------------------------------- /sparkai/core/utils/json_schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import deepcopy 4 | from typing import Any, List, Optional, Sequence 5 | 6 | 7 | def _retrieve_ref(path: str, schema: dict) -> dict: 8 | components = path.split("/") 9 | if components[0] != "#": 10 | raise ValueError( 11 | "ref paths are expected to be URI fragments, meaning they should start " 12 | "with #." 13 | ) 14 | out = schema 15 | for component in components[1:]: 16 | if component.isdigit(): 17 | out = out[int(component)] 18 | else: 19 | out = out[component] 20 | return deepcopy(out) 21 | 22 | 23 | def _dereference_refs_helper( 24 | obj: Any, full_schema: dict, skip_keys: Sequence[str] 25 | ) -> Any: 26 | if isinstance(obj, dict): 27 | obj_out = {} 28 | for k, v in obj.items(): 29 | if k in skip_keys: 30 | obj_out[k] = v 31 | elif k == "$ref": 32 | ref = _retrieve_ref(v, full_schema) 33 | return _dereference_refs_helper(ref, full_schema, skip_keys) 34 | elif isinstance(v, (list, dict)): 35 | obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys) 36 | else: 37 | obj_out[k] = v 38 | return obj_out 39 | elif isinstance(obj, list): 40 | return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj] 41 | else: 42 | return obj 43 | 44 | 45 | def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]: 46 | keys = [] 47 | if isinstance(obj, dict): 48 | for k, v in obj.items(): 49 | if k == "$ref": 50 | ref = _retrieve_ref(v, full_schema) 51 | keys.append(v.split("/")[1]) 52 | keys += _infer_skip_keys(ref, full_schema) 53 | elif isinstance(v, (list, dict)): 54 | keys += _infer_skip_keys(v, full_schema) 55 | elif isinstance(obj, list): 56 | for el in obj: 57 | keys += _infer_skip_keys(el, full_schema) 58 | return keys 59 | 60 | 61 | def dereference_refs( 62 | schema_obj: dict, 63 | *, 64 | full_schema: Optional[dict] = None, 65 | skip_keys: Optional[Sequence[str]] = None, 66 | ) -> dict: 67 | """Try to substitute $refs in JSON Schema.""" 68 | 69 | full_schema = full_schema or schema_obj 70 | skip_keys = ( 71 | skip_keys 72 | if skip_keys is not None 73 | else _infer_skip_keys(schema_obj, full_schema) 74 | ) 75 | return _dereference_refs_helper(schema_obj, full_schema, skip_keys) 76 | -------------------------------------------------------------------------------- /sparkai/core/utils/loading.py: -------------------------------------------------------------------------------- 1 | """Utilities for loading configurations from langchain_core-hub.""" 2 | 3 | import os 4 | import re 5 | import tempfile 6 | from pathlib import Path, PurePosixPath 7 | from typing import Any, Callable, Optional, Set, TypeVar, Union 8 | from urllib.parse import urljoin 9 | 10 | import requests 11 | 12 | DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") 13 | URL_BASE = os.environ.get( 14 | "LANGCHAIN_HUB_URL_BASE", 15 | "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", 16 | ) 17 | HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | def try_load_from_hub( 23 | path: Union[str, Path], 24 | loader: Callable[[str], T], 25 | valid_prefix: str, 26 | valid_suffixes: Set[str], 27 | **kwargs: Any, 28 | ) -> Optional[T]: 29 | """Load configuration from hub. Returns None if path is not a hub path.""" 30 | if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): 31 | return None 32 | ref, remote_path_str = match.groups() 33 | ref = ref[1:] if ref else DEFAULT_REF 34 | remote_path = Path(remote_path_str) 35 | if remote_path.parts[0] != valid_prefix: 36 | return None 37 | if remote_path.suffix[1:] not in valid_suffixes: 38 | raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") 39 | 40 | # Using Path with URLs is not recommended, because on Windows 41 | # the backslash is used as the path separator, which can cause issues 42 | # when working with URLs that use forward slashes as the path separator. 43 | # Instead, use PurePosixPath to ensure that forward slashes are used as the 44 | # path separator, regardless of the operating system. 45 | full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) 46 | 47 | r = requests.get(full_url, timeout=5) 48 | if r.status_code != 200: 49 | raise ValueError(f"Could not find file at {full_url}") 50 | with tempfile.TemporaryDirectory() as tmpdirname: 51 | file = Path(tmpdirname) / remote_path.name 52 | with open(file, "wb") as f: 53 | f.write(r.content) 54 | return loader(str(file), **kwargs) 55 | -------------------------------------------------------------------------------- /sparkai/core/utils/pydantic.py: -------------------------------------------------------------------------------- 1 | """Utilities for tests.""" 2 | 3 | 4 | def get_pydantic_major_version() -> int: 5 | """Get the major version of Pydantic.""" 6 | try: 7 | import pydantic 8 | 9 | return int(pydantic.__version__.split(".")[0]) 10 | except ImportError: 11 | return 0 12 | 13 | 14 | PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() 15 | -------------------------------------------------------------------------------- /sparkai/core/utils/strings.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | 4 | def stringify_value(val: Any) -> str: 5 | """Stringify a value. 6 | 7 | Args: 8 | val: The value to stringify. 9 | 10 | Returns: 11 | str: The stringified value. 12 | """ 13 | if isinstance(val, str): 14 | return val 15 | elif isinstance(val, dict): 16 | return "\n" + stringify_dict(val) 17 | elif isinstance(val, list): 18 | return "\n".join(stringify_value(v) for v in val) 19 | else: 20 | return str(val) 21 | 22 | 23 | def stringify_dict(data: dict) -> str: 24 | """Stringify a dictionary. 25 | 26 | Args: 27 | data: The dictionary to stringify. 28 | 29 | Returns: 30 | str: The stringified dictionary. 31 | """ 32 | text = "" 33 | for key, value in data.items(): 34 | text += key + ": " + stringify_value(value) + "\n" 35 | return text 36 | 37 | 38 | def comma_list(items: List[Any]) -> str: 39 | """Convert a list to a comma-separated string.""" 40 | return ", ".join(str(item) for item in items) 41 | -------------------------------------------------------------------------------- /sparkai/deprecation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | 5 | def show_message(old: str, new: str) -> None: 6 | skip_deprecation = os.environ.get("SPARKAICLIENT_SKIP_DEPRECATION") # for unit tests etc. 7 | if skip_deprecation: 8 | return 9 | 10 | message = ( 11 | f"{old} package is deprecated. Please use {new} package instead. " 12 | "For more info, go to xfyun.cn" 13 | ) 14 | warnings.warn(message) 15 | -------------------------------------------------------------------------------- /sparkai/depreciated/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/depreciated/client/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/depreciated/client/llm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: llm 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | import logging 34 | import os 35 | from typing import Any, List, Mapping, Optional 36 | 37 | import requests 38 | 39 | from sparkai.core.callbacks.manager import CallbackManagerForLLMRun 40 | from sparkai.llms.base import LLM 41 | from .llms.utils import enforce_stop_tokens 42 | 43 | from sparkai.api_resources.chat_completion import * 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | class SparkLLM(LLM): 49 | """Define the custom LLM wrapper for Xunfei SparkLLM to get support of LangChain 50 | """ 51 | 52 | 53 | endpoint_url: str = "http://127.0.0.1:8000/qa?" 54 | """Endpoint URL to use.此URL指向部署的调用星火大模型的FastAPI接口地址""" 55 | model_kwargs: Optional[dict] = None 56 | """Key word arguments to pass to the model.""" 57 | # max_token: int = 4000 58 | """Max token allowed to pass to the model.在真实应用中考虑启用""" 59 | # temperature: float = 0.75 60 | """LLM model temperature from 0 to 10.在真实应用中考虑启用""" 61 | # history: List[List] = [] 62 | """History of the conversation.在真实应用中可以考虑是否启用""" 63 | # top_p: float = 0.85 64 | """Top P for nucleus sampling from 0 to 1.在真实应用中考虑启用""" 65 | # with_history: bool = False 66 | """Whether to use history or not.在真实应用中考虑启用""" 67 | 68 | @property 69 | def _llm_type(self) -> str: 70 | return "SparkLLM" 71 | 72 | @property 73 | def _identifying_params(self) -> Mapping[str, Any]: 74 | """Get the identifying parameters.""" 75 | _model_kwargs = self.model_kwargs or {} 76 | return { 77 | **{"endpoint_url": self.endpoint_url}, 78 | **{"model_kwargs": _model_kwargs}, 79 | } 80 | 81 | def _call( 82 | self, 83 | prompt: str, 84 | stop: Optional[List[str]] = None, 85 | run_manager: Optional[CallbackManagerForLLMRun] = None, 86 | **kwargs: Any, 87 | ) -> str: 88 | 89 | #payload = {"query": prompt} 90 | # call api 91 | 92 | api_key = os.environ.get("SPARK_API_KEY") 93 | api_secret = os.environ.get("SPARK_API_SECRET") 94 | api_base = os.environ.get("SPARK_API_BASE") 95 | app_id = os.environ.get("SPARK_APP_ID") 96 | c = SparkOnceWebsocket(api_key=api_key, api_secret=api_secret, app_id=app_id, api_base=api_base) 97 | 98 | messages = [{'role': 'user', 99 | 'content': prompt}] 100 | print(messages[0]['content']) 101 | 102 | code, response = c.send_messages(messages) 103 | 104 | logger.debug(f"SparkLLM response: {response}") 105 | 106 | if code != 0: 107 | raise ValueError(f"Failed with response: {response}") 108 | 109 | text = response 110 | return text -------------------------------------------------------------------------------- /sparkai/depreciated/client/sample_langchain_spark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: appp 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | 34 | 35 | from sparkai.depreciated.client.llm import SparkLLM 36 | 37 | from langchain.prompts import ( 38 | ChatPromptTemplate, 39 | MessagesPlaceholder, 40 | SystemMessagePromptTemplate, 41 | HumanMessagePromptTemplate, 42 | ) 43 | from langchain.chains import LLMChain 44 | from langchain.memory import ConversationBufferMemory 45 | 46 | llm = SparkLLM() 47 | prompt = ChatPromptTemplate( 48 | messages=[ 49 | SystemMessagePromptTemplate.from_template( 50 | "You are a nice chatbot having a conversation with a human." 51 | ), 52 | # The `variable_name` here is what must align with memory 53 | MessagesPlaceholder(variable_name="chat_history"), 54 | HumanMessagePromptTemplate.from_template("{question}") 55 | ] 56 | ) 57 | # Notice that we `return_messages=True` to fit into the MessagesPlaceholder 58 | # Notice that `"chat_history"` aligns with the MessagesPlaceholder name. 59 | memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) 60 | conversation = LLMChain( 61 | llm=llm, 62 | prompt=prompt, 63 | verbose=True, 64 | memory=memory 65 | ) 66 | 67 | conversation.run(question="Answer briefly. What are the first 3 colors of a rainbow?") 68 | conversation.run(question="And the next 4?") 69 | conversation.run(question="Thanks. Let's start a new conversation. What is the recommendation for breakfast?") -------------------------------------------------------------------------------- /sparkai/depreciated/service/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/depreciated/service/api_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: api_server 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | from sparkai.depreciated.service import spark_ws 34 | from fastapi import FastAPI 35 | 36 | app = FastAPI() 37 | 38 | # 以下密钥信息从控制台获取 39 | appid = "xxxxx" # 填写控制台中获取的 APPID 信息 40 | api_secret = "xxxxxxxx" # 填写控制台中获取的 APISecret 信息 41 | api_key = "xxxxxxxx" # 填写控制台中获取的 APIKey 信息 42 | 43 | # 用于配置大模型版本,默认“general/generalv2” 44 | domain = "general" # v1.5版本 45 | # domain = "generalv2" # v2.0版本 46 | # 云端环境的服务地址 47 | Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址 48 | 49 | 50 | # Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址 51 | 52 | # length = 0 53 | 54 | def getText(role, content): 55 | text = [] 56 | jsoncon = {} 57 | jsoncon["role"] = role 58 | jsoncon["content"] = content 59 | text.append(jsoncon) 60 | return text 61 | 62 | 63 | def getlength(text): 64 | length = 0 65 | for content in text: 66 | temp = content["content"] 67 | leng = len(temp) 68 | length += leng 69 | return length 70 | 71 | 72 | def checklen(text): 73 | while (getlength(text) > 8000): 74 | del text[0] 75 | return text 76 | 77 | 78 | @app.get("/qa") 79 | def call_llm(query: str): 80 | question = checklen(getText("user", query)) 81 | spark_ws.answer = "" 82 | spark_ws.main(appid, api_key, api_secret, Spark_url, domain, question) 83 | # text=getText("assistant",SparkApi.answer) 84 | return spark_ws.answer -------------------------------------------------------------------------------- /sparkai/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/26 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/errors/__init__.py: -------------------------------------------------------------------------------- 1 | """Errors that can be raised by this SDK""" 2 | 3 | 4 | class SparkAIClientError(Exception): 5 | """Base class for Client errors""" 6 | 7 | 8 | class BotUserAccessError(SparkAIClientError): 9 | """Error raised when an 'xoxb-*' token is 10 | being used for a SparkAI API method that only accepts 'xoxp-*' tokens. 11 | """ 12 | 13 | 14 | class SparkAIRequestError(SparkAIClientError): 15 | """Error raised when there's a problem with the request that's being submitted.""" 16 | 17 | 18 | class SparkAIApiError(SparkAIClientError): 19 | """Error raised when SparkAI does not send the expected response. 20 | 21 | Attributes: 22 | response (SparkAIResponse): The SparkAIResponse object containing all of the data sent back from the API. 23 | 24 | Note: 25 | The message (str) passed into the exception is used when 26 | a user converts the exception to a str. 27 | i.e. str(SparkAIApiError("This text will be sent as a string.")) 28 | """ 29 | 30 | def __init__(self, message, response): 31 | msg = f"{message}\nThe server responded with: {response}" 32 | self.response = response 33 | super(SparkAIApiError, self).__init__(msg) 34 | 35 | 36 | class SparkAITokenRotationError(SparkAIClientError): 37 | """Error raised when the oauth.v2.access call for token rotation fails""" 38 | 39 | api_error: SparkAIApiError 40 | 41 | def __init__(self, api_error: SparkAIApiError): 42 | self.api_error = api_error 43 | 44 | 45 | class SparkAIClientNotConnectedError(SparkAIClientError): 46 | """Error raised when attempting to send messages over the websocket when the 47 | connection is closed.""" 48 | 49 | 50 | class SparkAIObjectFormationError(SparkAIClientError): 51 | """Error raised when a constructed object is not valid/malformed""" 52 | 53 | 54 | class SparkAIClientConfigurationError(SparkAIClientError): 55 | """Error raised because of invalid configuration on the client side: 56 | * when attempting to send messages over the websocket when the connection is closed. 57 | * when external system (e.g., Amazon S3) configuration / credentials are not correct 58 | """ 59 | 60 | 61 | class SparkAIConnectionError(ConnectionError): 62 | def __init__(self, error_code, message): 63 | self.error_code = error_code 64 | self.message = message 65 | super().__init__(message) -------------------------------------------------------------------------------- /sparkai/exceptions.py: -------------------------------------------------------------------------------- 1 | """Custom **exceptions** for LangChain. """ 2 | from typing import Any, Optional 3 | 4 | 5 | class SparkAIException(Exception): 6 | """General LangChain exception.""" 7 | 8 | 9 | class TracerException(SparkAIException): 10 | """Base class for exceptions in tracers module.""" 11 | 12 | 13 | class OutputParserException(ValueError, SparkAIException): 14 | """Exception that output parsers should raise to signify a parsing error. 15 | 16 | This exists to differentiate parsing errors from other code or execution errors 17 | that also may arise inside the output parser. OutputParserExceptions will be 18 | available to catch and handle in ways to fix the parsing error, while other 19 | errors will be raised. 20 | 21 | Args: 22 | error: The error that's being re-raised or an error message. 23 | observation: String explanation of error which can be passed to a 24 | model to try and remediate the issue. 25 | llm_output: String model output which is error-ing. 26 | send_to_llm: Whether to send the observation and llm_output back to an Agent 27 | after an OutputParserException has been raised. This gives the underlying 28 | model driving the agent the context that the previous output was improperly 29 | structured, in the hopes that it will update the output to the correct 30 | format. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | error: Any, 36 | observation: Optional[str] = None, 37 | llm_output: Optional[str] = None, 38 | send_to_llm: bool = False, 39 | ): 40 | super(OutputParserException, self).__init__(error) 41 | if send_to_llm: 42 | if observation is None or llm_output is None: 43 | raise ValueError( 44 | "Arguments 'observation' & 'llm_output'" 45 | " are required if 'send_to_llm' is True" 46 | ) 47 | self.observation = observation 48 | self.llm_output = llm_output 49 | self.send_to_llm = send_to_llm 50 | -------------------------------------------------------------------------------- /sparkai/frameworks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/03/27 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/http_retry/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from .handler import RetryHandler 4 | from .builtin_handlers import ( 5 | ConnectionErrorRetryHandler, 6 | RateLimitErrorRetryHandler, 7 | ) 8 | from .interval_calculator import RetryIntervalCalculator 9 | from .builtin_interval_calculators import ( 10 | FixedValueRetryIntervalCalculator, 11 | BackoffRetryIntervalCalculator, 12 | ) 13 | from .jitter import Jitter 14 | from .request import HttpRequest 15 | from .response import HttpResponse 16 | from .state import RetryState 17 | 18 | connect_error_retry_handler = ConnectionErrorRetryHandler() 19 | rate_limit_error_retry_handler = RateLimitErrorRetryHandler() 20 | 21 | 22 | def default_retry_handlers() -> List[RetryHandler]: 23 | return [connect_error_retry_handler] 24 | 25 | 26 | def all_builtin_retry_handlers() -> List[RetryHandler]: 27 | return [ 28 | connect_error_retry_handler, 29 | rate_limit_error_retry_handler, 30 | ] 31 | 32 | 33 | __all__ = [ 34 | "RetryHandler", 35 | "ConnectionErrorRetryHandler", 36 | "RateLimitErrorRetryHandler", 37 | "RetryIntervalCalculator", 38 | "FixedValueRetryIntervalCalculator", 39 | "BackoffRetryIntervalCalculator", 40 | "Jitter", 41 | "HttpRequest", 42 | "HttpResponse", 43 | "RetryState", 44 | "connect_error_retry_handler", 45 | "rate_limit_error_retry_handler", 46 | "default_retry_handlers", 47 | "all_builtin_retry_handlers", 48 | ] 49 | -------------------------------------------------------------------------------- /sparkai/http_retry/async_handler.py: -------------------------------------------------------------------------------- 1 | """asyncio compatible RetryHandler interface. 2 | You can pass an array of handlers to customize retry logics in supported API clients. 3 | """ 4 | 5 | import asyncio 6 | from typing import Optional 7 | 8 | from sparkai.http_retry.state import RetryState 9 | from sparkai.http_retry.request import HttpRequest 10 | from sparkai.http_retry.response import HttpResponse 11 | from sparkai.http_retry.interval_calculator import RetryIntervalCalculator 12 | from sparkai.http_retry.builtin_interval_calculators import ( 13 | BackoffRetryIntervalCalculator, 14 | ) 15 | 16 | default_interval_calculator = BackoffRetryIntervalCalculator() 17 | 18 | 19 | class AsyncRetryHandler: 20 | """asyncio compatible RetryHandler interface. 21 | You can pass an array of handlers to customize retry logics in supported API clients. 22 | """ 23 | 24 | max_retry_count: int 25 | interval_calculator: RetryIntervalCalculator 26 | 27 | def __init__( 28 | self, 29 | max_retry_count: int = 1, 30 | interval_calculator: RetryIntervalCalculator = default_interval_calculator, 31 | ): 32 | """RetryHandler interface. 33 | 34 | Args: 35 | max_retry_count: The maximum times to do retries 36 | interval_calculator: Pass an interval calculator for customizing the logic 37 | """ 38 | self.max_retry_count = max_retry_count 39 | self.interval_calculator = interval_calculator 40 | 41 | async def can_retry_async( 42 | self, 43 | *, 44 | state: RetryState, 45 | request: HttpRequest, 46 | response: Optional[HttpResponse] = None, 47 | error: Optional[Exception] = None, 48 | ) -> bool: 49 | if state.current_attempt >= self.max_retry_count: 50 | return False 51 | return await self._can_retry_async( 52 | state=state, 53 | request=request, 54 | response=response, 55 | error=error, 56 | ) 57 | 58 | async def _can_retry_async( 59 | self, 60 | *, 61 | state: RetryState, 62 | request: HttpRequest, 63 | response: Optional[HttpResponse] = None, 64 | error: Optional[Exception] = None, 65 | ) -> bool: 66 | raise NotImplementedError() 67 | 68 | async def prepare_for_next_attempt_async( 69 | self, 70 | *, 71 | state: RetryState, 72 | request: HttpRequest, 73 | response: Optional[HttpResponse] = None, 74 | error: Optional[Exception] = None, 75 | ) -> None: 76 | state.next_attempt_requested = True 77 | duration = self.interval_calculator.calculate_sleep_duration(state.current_attempt) 78 | await asyncio.sleep(duration) 79 | state.increment_current_attempt() 80 | 81 | 82 | __all__ = [ 83 | "RetryState", 84 | "HttpRequest", 85 | "HttpResponse", 86 | "RetryIntervalCalculator", 87 | "BackoffRetryIntervalCalculator", 88 | "default_interval_calculator", 89 | ] 90 | -------------------------------------------------------------------------------- /sparkai/http_retry/builtin_async_handlers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | from typing import Optional, List, Type 4 | 5 | from aiohttp import ServerDisconnectedError, ServerConnectionError, ClientOSError 6 | 7 | from sparkai.http_retry.async_handler import AsyncRetryHandler 8 | from sparkai.http_retry.interval_calculator import RetryIntervalCalculator 9 | from sparkai.http_retry.state import RetryState 10 | from sparkai.http_retry.request import HttpRequest 11 | from sparkai.http_retry.response import HttpResponse 12 | from sparkai.http_retry.handler import default_interval_calculator 13 | 14 | 15 | class AsyncConnectionErrorRetryHandler(AsyncRetryHandler): 16 | """RetryHandler that does retries for connectivity issues.""" 17 | 18 | def __init__( 19 | self, 20 | max_retry_count: int = 1, 21 | interval_calculator: RetryIntervalCalculator = default_interval_calculator, 22 | error_types: List[Type[Exception]] = [ 23 | ServerConnectionError, 24 | ServerDisconnectedError, 25 | # ClientOSError: [Errno 104] Connection reset by peer 26 | ClientOSError, 27 | ], 28 | ): 29 | super().__init__(max_retry_count, interval_calculator) 30 | self.error_types_to_do_retries = error_types 31 | 32 | async def _can_retry_async( 33 | self, 34 | *, 35 | state: RetryState, 36 | request: HttpRequest, 37 | response: Optional[HttpResponse] = None, 38 | error: Optional[Exception] = None, 39 | ) -> bool: 40 | if error is None: 41 | return False 42 | 43 | for error_type in self.error_types_to_do_retries: 44 | if isinstance(error, error_type): 45 | return True 46 | return False 47 | 48 | 49 | class AsyncRateLimitErrorRetryHandler(AsyncRetryHandler): 50 | """RetryHandler that does retries for rate limited errors.""" 51 | 52 | async def _can_retry_async( 53 | self, 54 | *, 55 | state: RetryState, 56 | request: HttpRequest, 57 | response: Optional[HttpResponse], 58 | error: Optional[Exception], 59 | ) -> bool: 60 | return response is not None and response.status_code == 429 61 | 62 | async def prepare_for_next_attempt_async( 63 | self, 64 | *, 65 | state: RetryState, 66 | request: HttpRequest, 67 | response: Optional[HttpResponse] = None, 68 | error: Optional[Exception] = None, 69 | ) -> None: 70 | if response is None: 71 | raise error 72 | 73 | state.next_attempt_requested = True 74 | retry_after_header_name: Optional[str] = None 75 | for k in response.headers.keys(): 76 | if k.lower() == "retry-after": 77 | retry_after_header_name = k 78 | break 79 | duration = 1 80 | if retry_after_header_name is None: 81 | # This situation usually does not arise. Just in case. 82 | duration += random.random() 83 | else: 84 | duration = int(response.headers.get(retry_after_header_name)[0]) + random.random() 85 | await asyncio.sleep(duration) 86 | state.increment_current_attempt() 87 | 88 | 89 | def async_default_handlers() -> List[AsyncRetryHandler]: 90 | return [AsyncConnectionErrorRetryHandler()] 91 | -------------------------------------------------------------------------------- /sparkai/http_retry/builtin_handlers.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from http.client import RemoteDisconnected 4 | from typing import Optional, List, Type 5 | from urllib.error import URLError 6 | 7 | from sparkai.http_retry.interval_calculator import RetryIntervalCalculator 8 | from sparkai.http_retry.state import RetryState 9 | from sparkai.http_retry.request import HttpRequest 10 | from sparkai.http_retry.response import HttpResponse 11 | from sparkai.http_retry.handler import RetryHandler, default_interval_calculator 12 | 13 | 14 | class ConnectionErrorRetryHandler(RetryHandler): 15 | """RetryHandler that does retries for connectivity issues.""" 16 | 17 | def __init__( 18 | self, 19 | max_retry_count: int = 1, 20 | interval_calculator: RetryIntervalCalculator = default_interval_calculator, 21 | error_types: List[Type[Exception]] = [ 22 | # To cover URLError: 23 | URLError, 24 | ConnectionResetError, 25 | RemoteDisconnected, 26 | ], 27 | ): 28 | super().__init__(max_retry_count, interval_calculator) 29 | self.error_types_to_do_retries = error_types 30 | 31 | def _can_retry( 32 | self, 33 | *, 34 | state: RetryState, 35 | request: HttpRequest, 36 | response: Optional[HttpResponse] = None, 37 | error: Optional[Exception] = None, 38 | ) -> bool: 39 | if error is None: 40 | return False 41 | 42 | if isinstance(error, URLError): 43 | if response is not None: 44 | return False # status 40x 45 | 46 | for error_type in self.error_types_to_do_retries: 47 | if isinstance(error, error_type): 48 | return True 49 | return False 50 | 51 | 52 | class RateLimitErrorRetryHandler(RetryHandler): 53 | """RetryHandler that does retries for rate limited errors.""" 54 | 55 | def _can_retry( 56 | self, 57 | *, 58 | state: RetryState, 59 | request: HttpRequest, 60 | response: Optional[HttpResponse] = None, 61 | error: Optional[Exception] = None, 62 | ) -> bool: 63 | return response is not None and response.status_code == 429 64 | 65 | def prepare_for_next_attempt( 66 | self, 67 | *, 68 | state: RetryState, 69 | request: HttpRequest, 70 | response: Optional[HttpResponse] = None, 71 | error: Optional[Exception] = None, 72 | ) -> None: 73 | if response is None: 74 | raise error 75 | 76 | state.next_attempt_requested = True 77 | retry_after_header_name: Optional[str] = None 78 | for k in response.headers.keys(): 79 | if k.lower() == "retry-after": 80 | retry_after_header_name = k 81 | break 82 | duration = 1 83 | if retry_after_header_name is None: 84 | # This situation usually does not arise. Just in case. 85 | duration += random.random() 86 | else: 87 | duration = int(response.headers.get(retry_after_header_name)[0]) + random.random() 88 | time.sleep(duration) 89 | state.increment_current_attempt() 90 | -------------------------------------------------------------------------------- /sparkai/http_retry/builtin_interval_calculators.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from .jitter import Jitter, RandomJitter 3 | from .interval_calculator import RetryIntervalCalculator 4 | 5 | 6 | class FixedValueRetryIntervalCalculator(RetryIntervalCalculator): 7 | """Retry interval calculator that uses a fixed value.""" 8 | 9 | fixed_interval: float 10 | 11 | def __init__(self, fixed_internal: float = 0.5): 12 | """Retry interval calculator that uses a fixed value. 13 | 14 | Args: 15 | fixed_internal: The fixed interval seconds 16 | """ 17 | self.fixed_interval = fixed_internal 18 | 19 | def calculate_sleep_duration(self, current_attempt: int) -> float: 20 | return self.fixed_interval 21 | 22 | 23 | class BackoffRetryIntervalCalculator(RetryIntervalCalculator): 24 | """Retry interval calculator that calculates in the manner of Exponential Backoff And Jitter 25 | see also: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 26 | """ 27 | 28 | backoff_factor: float 29 | jitter: Jitter 30 | 31 | def __init__(self, backoff_factor: float = 0.5, jitter: Optional[Jitter] = None): 32 | """Retry interval calculator that calculates in the manner of Exponential Backoff And Jitter 33 | 34 | Args: 35 | backoff_factor: The factor for the backoff interval calculation 36 | jitter: The jitter logic implementation 37 | """ 38 | self.backoff_factor = backoff_factor 39 | self.jitter = jitter if jitter is not None else RandomJitter() 40 | 41 | def calculate_sleep_duration(self, current_attempt: int) -> float: 42 | interval = self.backoff_factor * (2 ** (current_attempt)) 43 | sleep_duration = self.jitter.recalculate(interval) 44 | return sleep_duration 45 | -------------------------------------------------------------------------------- /sparkai/http_retry/handler.py: -------------------------------------------------------------------------------- 1 | """RetryHandler interface. 2 | You can pass an array of handlers to customize retry logics in supported API clients. 3 | """ 4 | 5 | import time 6 | from typing import Optional 7 | 8 | from sparkai.http_retry.state import RetryState 9 | from sparkai.http_retry.request import HttpRequest 10 | from sparkai.http_retry.response import HttpResponse 11 | from sparkai.http_retry.interval_calculator import RetryIntervalCalculator 12 | from sparkai.http_retry.builtin_interval_calculators import ( 13 | BackoffRetryIntervalCalculator, 14 | ) 15 | 16 | default_interval_calculator = BackoffRetryIntervalCalculator() 17 | 18 | 19 | # Note that you cannot add aiohttp to this class as the external dependency is optional 20 | class RetryHandler: 21 | """RetryHandler interface. 22 | You can pass an array of handlers to customize retry logics in supported API clients. 23 | """ 24 | 25 | max_retry_count: int 26 | interval_calculator: RetryIntervalCalculator 27 | 28 | def __init__( 29 | self, 30 | max_retry_count: int = 1, 31 | interval_calculator: RetryIntervalCalculator = default_interval_calculator, 32 | ): 33 | """RetryHandler interface. 34 | 35 | Args: 36 | max_retry_count: The maximum times to do retries 37 | interval_calculator: Pass an interval calculator for customizing the logic 38 | """ 39 | self.max_retry_count = max_retry_count 40 | self.interval_calculator = interval_calculator 41 | 42 | def can_retry( 43 | self, 44 | *, 45 | state: RetryState, 46 | request: HttpRequest, 47 | response: Optional[HttpResponse] = None, 48 | error: Optional[Exception] = None, 49 | ) -> bool: 50 | if state.current_attempt >= self.max_retry_count: 51 | return False 52 | return self._can_retry( 53 | state=state, 54 | request=request, 55 | response=response, 56 | error=error, 57 | ) 58 | 59 | def _can_retry( 60 | self, 61 | *, 62 | state: RetryState, 63 | request: HttpRequest, 64 | response: Optional[HttpResponse] = None, 65 | error: Optional[Exception] = None, 66 | ) -> bool: 67 | raise NotImplementedError() 68 | 69 | def prepare_for_next_attempt( 70 | self, 71 | *, 72 | state: RetryState, 73 | request: HttpRequest, 74 | response: Optional[HttpResponse] = None, 75 | error: Optional[Exception] = None, 76 | ) -> None: 77 | state.next_attempt_requested = True 78 | duration = self.interval_calculator.calculate_sleep_duration(state.current_attempt) 79 | time.sleep(duration) 80 | state.increment_current_attempt() 81 | -------------------------------------------------------------------------------- /sparkai/http_retry/interval_calculator.py: -------------------------------------------------------------------------------- 1 | class RetryIntervalCalculator: 2 | """Retry interval calculator interface.""" 3 | 4 | def calculate_sleep_duration(self, current_attempt: int) -> float: 5 | """Calculates an interval duration in seconds. 6 | 7 | Args: 8 | current_attempt: the number of the current attempt (zero-origin; 0 means no retries are done so far) 9 | Returns: 10 | calculated interval duration in seconds 11 | """ 12 | raise NotImplementedError() 13 | -------------------------------------------------------------------------------- /sparkai/http_retry/jitter.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class Jitter: 5 | """Jitter interface""" 6 | 7 | def recalculate(self, duration: float) -> float: 8 | """Recalculate the given duration. 9 | see also: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 10 | 11 | Args: 12 | duration: the duration in seconds 13 | 14 | Returns: 15 | A new duration that the jitter amount is added 16 | """ 17 | raise NotImplementedError() 18 | 19 | 20 | class RandomJitter(Jitter): 21 | """Random jitter implementation""" 22 | 23 | def recalculate(self, duration: float) -> float: 24 | return duration + random.random() 25 | -------------------------------------------------------------------------------- /sparkai/http_retry/request.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List, Union, Any 2 | from urllib.request import Request 3 | 4 | 5 | class HttpRequest: 6 | """HTTP request representation""" 7 | 8 | method: str 9 | url: str 10 | headers: Dict[str, Union[str, List[str]]] 11 | body_params: Optional[Dict[str, Any]] 12 | data: Optional[bytes] 13 | 14 | def __init__( 15 | self, 16 | *, 17 | method: str, 18 | url: str, 19 | headers: Dict[str, Union[str, List[str]]], 20 | body_params: Optional[Dict[str, Any]] = None, 21 | data: Optional[bytes] = None, 22 | ): 23 | self.method = method 24 | self.url = url 25 | self.headers = {k: v if isinstance(v, list) else [v] for k, v in headers.items()} 26 | self.body_params = body_params 27 | self.data = data 28 | 29 | @classmethod 30 | def from_urllib_http_request(cls, req: Request) -> "HttpRequest": 31 | return HttpRequest( 32 | method=req.method, 33 | url=req.full_url, 34 | headers={k: v if isinstance(v, list) else [v] for k, v in req.headers.items()}, 35 | data=req.data, 36 | ) 37 | -------------------------------------------------------------------------------- /sparkai/http_retry/response.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List, Union, Any 2 | 3 | 4 | class HttpResponse: 5 | """HTTP response representation""" 6 | 7 | status_code: int 8 | headers: Dict[str, List[str]] 9 | body: Optional[Dict[str, Any]] 10 | data: Optional[bytes] 11 | 12 | def __init__( 13 | self, 14 | *, 15 | status_code: Union[int, str], 16 | headers: Dict[str, Union[str, List[str]]], 17 | body: Optional[Dict[str, Any]] = None, 18 | data: Optional[bytes] = None, 19 | ): 20 | self.status_code = int(status_code) 21 | self.headers = {k: v if isinstance(v, list) else [v] for k, v in headers.items()} 22 | self.body = body 23 | self.data = data 24 | -------------------------------------------------------------------------------- /sparkai/http_retry/state.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Dict 2 | 3 | 4 | class RetryState: 5 | next_attempt_requested: bool 6 | current_attempt: int # zero-origin 7 | custom_values: Optional[Dict[str, Any]] 8 | 9 | def __init__( 10 | self, 11 | *, 12 | current_attempt: int = 0, 13 | custom_values: Optional[Dict[str, Any]] = None, 14 | ): 15 | self.next_attempt_requested = False 16 | self.current_attempt = current_attempt 17 | self.custom_values = custom_values 18 | 19 | def increment_current_attempt(self) -> int: 20 | self.current_attempt += 1 21 | return self.current_attempt 22 | -------------------------------------------------------------------------------- /sparkai/llm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/02/02 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/log/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2023/07/23 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/log/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import inspect 3 | 4 | 5 | class CustomLogRecord(logging.LogRecord): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | frame = inspect.currentframe().f_back 10 | while frame: 11 | if frame.f_globals['__name__'] != __name__ and frame.f_globals['__name__'] != 'logging': 12 | break 13 | frame = frame.f_back 14 | 15 | if frame: 16 | self.filename = frame.f_code.co_filename 17 | self.lineno = frame.f_lineno 18 | else: 19 | self.filename = "unknown" 20 | self.lineno = 0 21 | 22 | 23 | class SingletonMeta(type): 24 | _instances = {} 25 | 26 | def __call__(cls, *args, **kwargs): 27 | if cls not in cls._instances: 28 | instance = super().__call__(*args, **kwargs) 29 | cls._instances[cls] = instance 30 | return cls._instances[cls] 31 | 32 | 33 | class Logger(metaclass=SingletonMeta): 34 | def __init__(self, logger_name='SparkPythonSDK', log_level=logging.ERROR): 35 | if not hasattr(self, 'logger'): 36 | self.logger = logging.getLogger(logger_name) 37 | self.logger.setLevel(log_level) 38 | self.logger.makeRecord = self._make_custom_log_record 39 | 40 | self.console_handler = logging.StreamHandler() 41 | self.console_handler.setLevel(log_level) 42 | 43 | formatter = logging.Formatter( 44 | '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', 45 | datefmt='%Y-%m-%d %H:%M:%S %Z') 46 | 47 | self.console_handler.setFormatter(formatter) 48 | self.logger.addHandler(self.console_handler) 49 | 50 | def setLevel(self, level): 51 | if level not in ["debug", "info", "warning", "error", "trace"]: 52 | level = logging.ERROR 53 | elif level == "debug": 54 | level = logging.DEBUG 55 | elif level == "info": 56 | level = logging.INFO 57 | elif level == "error": 58 | level = logging.ERROR 59 | elif level == "trace": 60 | level = logging.DEBUG 61 | elif level == "warning" or leve == "warn": 62 | level = logging.WARNING 63 | self.logger.setLevel(level) 64 | self.console_handler.setLevel(level) 65 | 66 | def _make_custom_log_record(self, name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None): 67 | return CustomLogRecord(name, level, fn, lno, msg, args, exc_info, func=func, extra=extra, sinfo=sinfo) 68 | 69 | def debug(self, message, *args): 70 | self.logger.debug(message) 71 | if args: 72 | self.logger.debug(*args) 73 | 74 | def info(self, message, *args): 75 | self.logger.info(message) 76 | if args: 77 | self.logger.info(*args) 78 | 79 | def warning(self, message, *args): 80 | self.logger.warning(message) 81 | if args: 82 | self.logger.warning(*args) 83 | 84 | def error(self, message, *args): 85 | self.logger.error(message) 86 | if args: 87 | self.logger.error(*args) 88 | 89 | def critical(self, message, *args): 90 | self.logger.critical(message) 91 | if args: 92 | self.logger.critical(*args) 93 | 94 | 95 | logger = Logger('SparkPythonSDK') 96 | -------------------------------------------------------------------------------- /sparkai/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from sparkai.memory.buffer import ( 2 | ConversationBufferMemory, 3 | ConversationStringBufferMemory, 4 | ) 5 | from sparkai.memory.buffer_window import ConversationBufferWindowMemory 6 | from sparkai.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory 7 | from sparkai.memory.chat_message_histories.in_memory import ChatMessageHistory 8 | from sparkai.memory.chat_message_histories.postgres import PostgresChatMessageHistory 9 | from sparkai.memory.chat_message_histories.redis import RedisChatMessageHistory 10 | from sparkai.memory.combined import CombinedMemory 11 | 12 | from sparkai.memory.readonly import ReadOnlySharedMemory 13 | from sparkai.memory.simple import SimpleMemory 14 | from sparkai.memory.token_buffer import ConversationTokenBufferMemory 15 | 16 | __all__ = [ 17 | "CombinedMemory", 18 | "ConversationBufferWindowMemory", 19 | "ConversationBufferMemory", 20 | "SimpleMemory", 21 | "ChatMessageHistory", 22 | "ConversationStringBufferMemory", 23 | "ReadOnlySharedMemory", 24 | "ConversationTokenBufferMemory", 25 | "RedisChatMessageHistory", 26 | "DynamoDBChatMessageHistory", 27 | "PostgresChatMessageHistory", 28 | ] 29 | -------------------------------------------------------------------------------- /sparkai/memory/buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from pydantic import root_validator 4 | 5 | from sparkai.memory.chat_memory import BaseChatMemory, BaseMemory 6 | from sparkai.memory.utils import get_prompt_input_key 7 | from sparkai.schema import get_buffer_string 8 | 9 | 10 | class ConversationBufferMemory(BaseChatMemory): 11 | """Buffer for storing conversation memory.""" 12 | 13 | human_prefix: str = "user" 14 | ai_prefix: str = "assistant" 15 | memory_key: str = "history" #: :meta private: 16 | 17 | @property 18 | def buffer(self) -> Any: 19 | """String buffer of memory.""" 20 | if self.return_messages: 21 | return self.chat_memory.messages 22 | else: 23 | return get_buffer_string( 24 | self.chat_memory.messages, 25 | human_prefix=self.human_prefix, 26 | ai_prefix=self.ai_prefix, 27 | ) 28 | 29 | @property 30 | def memory_variables(self) -> List[str]: 31 | """Will always return list of memory variables. 32 | 33 | :meta private: 34 | """ 35 | return [self.memory_key] 36 | 37 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 38 | """Return history buffer.""" 39 | return {self.memory_key: self.buffer} 40 | 41 | 42 | class ConversationStringBufferMemory(BaseMemory): 43 | """Buffer for storing conversation memory.""" 44 | 45 | human_prefix: str = "user" 46 | ai_prefix: str = "assistant" 47 | """Prefix to use for AI generated responses.""" 48 | buffer: str = "" 49 | output_key: Optional[str] = None 50 | input_key: Optional[str] = None 51 | memory_key: str = "history" #: :meta private: 52 | 53 | @root_validator() 54 | def validate_chains(cls, values: Dict) -> Dict: 55 | """Validate that return messages is not True.""" 56 | if values.get("return_messages", False): 57 | raise ValueError( 58 | "return_messages must be False for ConversationStringBufferMemory" 59 | ) 60 | return values 61 | 62 | @property 63 | def memory_variables(self) -> List[str]: 64 | """Will always return list of memory variables. 65 | :meta private: 66 | """ 67 | return [self.memory_key] 68 | 69 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 70 | """Return history buffer.""" 71 | return {self.memory_key: self.buffer} 72 | 73 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 74 | """Save context from this conversation to buffer.""" 75 | if self.input_key is None: 76 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 77 | else: 78 | prompt_input_key = self.input_key 79 | if self.output_key is None: 80 | if len(outputs) != 1: 81 | raise ValueError(f"One output key expected, got {outputs.keys()}") 82 | output_key = list(outputs.keys())[0] 83 | else: 84 | output_key = self.output_key 85 | human = f"{self.human_prefix}: " + inputs[prompt_input_key] 86 | ai = f"{self.ai_prefix}: " + outputs[output_key] 87 | self.buffer += "\n" + "\n".join([human, ai]) 88 | 89 | def clear(self) -> None: 90 | """Clear memory contents.""" 91 | self.buffer = "" 92 | -------------------------------------------------------------------------------- /sparkai/memory/buffer_window.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from sparkai.memory.chat_memory import BaseChatMemory 4 | from sparkai.schema import BaseMessage, get_buffer_string 5 | 6 | 7 | class ConversationBufferWindowMemory(BaseChatMemory): 8 | """Buffer for storing conversation memory.""" 9 | 10 | human_prefix: str = "user" 11 | ai_prefix: str = "assistant" 12 | memory_key: str = "history" #: :meta private: 13 | k: int = 5 14 | 15 | @property 16 | def buffer(self) -> List[BaseMessage]: 17 | """String buffer of memory.""" 18 | return self.chat_memory.messages 19 | 20 | @property 21 | def memory_variables(self) -> List[str]: 22 | """Will always return list of memory variables. 23 | 24 | :meta private: 25 | """ 26 | return [self.memory_key] 27 | 28 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 29 | """Return history buffer.""" 30 | 31 | buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else [] 32 | if not self.return_messages: 33 | buffer = get_buffer_string( 34 | buffer, 35 | human_prefix=self.human_prefix, 36 | ai_prefix=self.ai_prefix, 37 | ) 38 | return {self.memory_key: buffer} 39 | -------------------------------------------------------------------------------- /sparkai/memory/chat_memory.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, Optional, Tuple 3 | 4 | from pydantic import Field 5 | 6 | from sparkai.memory.chat_message_histories.in_memory import ChatMessageHistory 7 | from sparkai.memory.utils import get_prompt_input_key 8 | from sparkai.schema import BaseChatMessageHistory, BaseMemory 9 | 10 | 11 | class BaseChatMemory(BaseMemory, ABC): 12 | chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory) 13 | output_key: Optional[str] = None 14 | input_key: Optional[str] = None 15 | return_messages: bool = False 16 | 17 | def _get_input_output( 18 | self, inputs: Dict[str, Any], outputs: Dict[str, str] 19 | ) -> Tuple[str, str]: 20 | if self.input_key is None: 21 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 22 | else: 23 | prompt_input_key = self.input_key 24 | if self.output_key is None: 25 | if len(outputs) != 1: 26 | raise ValueError(f"One output key expected, got {outputs.keys()}") 27 | output_key = list(outputs.keys())[0] 28 | else: 29 | output_key = self.output_key 30 | return inputs[prompt_input_key], outputs[output_key] 31 | 32 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 33 | """Save context from this conversation to buffer.""" 34 | input_str, output_str = self._get_input_output(inputs, outputs) 35 | self.chat_memory.add_user_message(input_str) 36 | self.chat_memory.add_ai_message(output_str) 37 | 38 | def clear(self) -> None: 39 | """Clear memory contents.""" 40 | self.chat_memory.clear() 41 | -------------------------------------------------------------------------------- /sparkai/memory/chat_message_histories/__init__.py: -------------------------------------------------------------------------------- 1 | from sparkai.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory 2 | from sparkai.memory.chat_message_histories.file import FileChatMessageHistory 3 | from sparkai.memory.chat_message_histories.postgres import PostgresChatMessageHistory 4 | from sparkai.memory.chat_message_histories.redis import RedisChatMessageHistory 5 | 6 | __all__ = [ 7 | "DynamoDBChatMessageHistory", 8 | "RedisChatMessageHistory", 9 | "PostgresChatMessageHistory", 10 | "FileChatMessageHistory", 11 | ] 12 | -------------------------------------------------------------------------------- /sparkai/memory/chat_message_histories/dynamodb.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from sparkai.schema import ( 5 | AIMessage, 6 | BaseChatMessageHistory, 7 | BaseMessage, 8 | HumanMessage, 9 | _message_to_dict, 10 | messages_from_dict, 11 | messages_to_dict, 12 | ) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class DynamoDBChatMessageHistory(BaseChatMessageHistory): 18 | """Chat message history that stores history in AWS DynamoDB. 19 | This class expects that a DynamoDB table with name `table_name` 20 | and a partition Key of `SessionId` is present. 21 | 22 | Args: 23 | table_name: name of the DynamoDB table 24 | session_id: arbitrary key that is used to store the messages 25 | of a single chat session. 26 | """ 27 | 28 | def __init__(self, table_name: str, session_id: str): 29 | import boto3 30 | 31 | client = boto3.resource("dynamodb") 32 | self.table = client.Table(table_name) 33 | self.session_id = session_id 34 | 35 | @property 36 | def messages(self) -> List[BaseMessage]: # type: ignore 37 | """Retrieve the messages from DynamoDB""" 38 | from botocore.exceptions import ClientError 39 | 40 | try: 41 | response = self.table.get_item(Key={"SessionId": self.session_id}) 42 | except ClientError as error: 43 | if error.response["Error"]["Code"] == "ResourceNotFoundException": 44 | logger.warning("No record found with session id: %s", self.session_id) 45 | else: 46 | logger.error(error) 47 | 48 | if response and "Item" in response: 49 | items = response["Item"]["History"] 50 | else: 51 | items = [] 52 | 53 | messages = messages_from_dict(items) 54 | return messages 55 | 56 | def add_user_message(self, message: str) -> None: 57 | self.append(HumanMessage(content=message)) 58 | 59 | def add_ai_message(self, message: str) -> None: 60 | self.append(AIMessage(content=message)) 61 | 62 | def append(self, message: BaseMessage) -> None: 63 | """Append the message to the record in DynamoDB""" 64 | from botocore.exceptions import ClientError 65 | 66 | messages = messages_to_dict(self.messages) 67 | _message = _message_to_dict(message) 68 | messages.append(_message) 69 | 70 | try: 71 | self.table.put_item( 72 | Item={"SessionId": self.session_id, "History": messages} 73 | ) 74 | except ClientError as err: 75 | logger.error(err) 76 | 77 | def clear(self) -> None: 78 | """Clear session memory from DynamoDB""" 79 | from botocore.exceptions import ClientError 80 | 81 | try: 82 | self.table.delete_item(Key={"SessionId": self.session_id}) 83 | except ClientError as err: 84 | logger.error(err) 85 | -------------------------------------------------------------------------------- /sparkai/memory/chat_message_histories/file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | from typing import List 5 | 6 | from sparkai.schema import ( 7 | AIMessage, 8 | BaseChatMessageHistory, 9 | BaseMessage, 10 | HumanMessage, 11 | messages_from_dict, 12 | messages_to_dict, 13 | ) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class FileChatMessageHistory(BaseChatMessageHistory): 19 | """ 20 | Chat message history that stores history in a local file. 21 | 22 | Args: 23 | file_path: path of the local file to store the messages. 24 | """ 25 | 26 | def __init__(self, file_path: str): 27 | self.file_path = Path(file_path) 28 | if not self.file_path.exists(): 29 | self.file_path.touch() 30 | self.file_path.write_text(json.dumps([])) 31 | 32 | @property 33 | def messages(self) -> List[BaseMessage]: # type: ignore 34 | """Retrieve the messages from the local file""" 35 | items = json.loads(self.file_path.read_text()) 36 | messages = messages_from_dict(items) 37 | return messages 38 | 39 | def add_user_message(self, message: str) -> None: 40 | self.append(HumanMessage(content=message)) 41 | 42 | def add_ai_message(self, message: str) -> None: 43 | self.append(AIMessage(content=message)) 44 | 45 | def append(self, message: BaseMessage) -> None: 46 | """Append the message to the record in the local file""" 47 | messages = messages_to_dict(self.messages) 48 | messages.append(messages_to_dict([message])[0]) 49 | self.file_path.write_text(json.dumps(messages)) 50 | 51 | def clear(self) -> None: 52 | """Clear session memory from the local file""" 53 | self.file_path.write_text(json.dumps([])) 54 | -------------------------------------------------------------------------------- /sparkai/memory/chat_message_histories/in_memory.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | from sparkai.schema import ( 6 | AIMessage, 7 | BaseChatMessageHistory, 8 | BaseMessage, 9 | HumanMessage, 10 | ) 11 | 12 | 13 | class ChatMessageHistory(BaseChatMessageHistory, BaseModel): 14 | messages: List[BaseMessage] = [] 15 | 16 | def add_user_message(self, message: str) -> None: 17 | self.messages.append(HumanMessage(content=message)) 18 | 19 | def add_ai_message(self, message: str) -> None: 20 | self.messages.append(AIMessage(content=message)) 21 | 22 | def clear(self) -> None: 23 | self.messages = [] 24 | -------------------------------------------------------------------------------- /sparkai/memory/chat_message_histories/postgres.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import List 4 | 5 | from sparkai.schema import ( 6 | AIMessage, 7 | BaseChatMessageHistory, 8 | BaseMessage, 9 | HumanMessage, 10 | _message_to_dict, 11 | messages_from_dict, 12 | ) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_history" 17 | 18 | 19 | class PostgresChatMessageHistory(BaseChatMessageHistory): 20 | def __init__( 21 | self, 22 | session_id: str, 23 | connection_string: str = DEFAULT_CONNECTION_STRING, 24 | table_name: str = "message_store", 25 | ): 26 | import psycopg 27 | from psycopg.rows import dict_row 28 | 29 | try: 30 | self.connection = psycopg.connect(connection_string) 31 | self.cursor = self.connection.cursor(row_factory=dict_row) 32 | except psycopg.OperationalError as error: 33 | logger.error(error) 34 | 35 | self.session_id = session_id 36 | self.table_name = table_name 37 | 38 | self._create_table_if_not_exists() 39 | 40 | def _create_table_if_not_exists(self) -> None: 41 | create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( 42 | id SERIAL PRIMARY KEY, 43 | session_id TEXT NOT NULL, 44 | message JSONB NOT NULL 45 | );""" 46 | self.cursor.execute(create_table_query) 47 | self.connection.commit() 48 | 49 | @property 50 | def messages(self) -> List[BaseMessage]: # type: ignore 51 | """Retrieve the messages from PostgreSQL""" 52 | query = f"SELECT message FROM {self.table_name} WHERE session_id = %s;" 53 | self.cursor.execute(query, (self.session_id,)) 54 | items = [record["message"] for record in self.cursor.fetchall()] 55 | messages = messages_from_dict(items) 56 | return messages 57 | 58 | def add_user_message(self, message: str) -> None: 59 | self.append(HumanMessage(content=message)) 60 | 61 | def add_ai_message(self, message: str) -> None: 62 | self.append(AIMessage(content=message)) 63 | 64 | def append(self, message: BaseMessage) -> None: 65 | """Append the message to the record in PostgreSQL""" 66 | from psycopg import sql 67 | 68 | query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format( 69 | sql.Identifier(self.table_name) 70 | ) 71 | self.cursor.execute( 72 | query, (self.session_id, json.dumps(_message_to_dict(message))) 73 | ) 74 | self.connection.commit() 75 | 76 | def clear(self) -> None: 77 | """Clear session memory from PostgreSQL""" 78 | query = f"DELETE FROM {self.table_name} WHERE session_id = %s;" 79 | self.cursor.execute(query, (self.session_id,)) 80 | self.connection.commit() 81 | 82 | def __del__(self) -> None: 83 | if self.cursor: 84 | self.cursor.close() 85 | if self.connection: 86 | self.connection.close() 87 | -------------------------------------------------------------------------------- /sparkai/memory/chat_message_histories/redis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import List, Optional 4 | 5 | from sparkai.schema import ( 6 | AIMessage, 7 | BaseChatMessageHistory, 8 | BaseMessage, 9 | HumanMessage, 10 | _message_to_dict, 11 | messages_from_dict, 12 | ) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class RedisChatMessageHistory(BaseChatMessageHistory): 18 | def __init__( 19 | self, 20 | session_id: str, 21 | url: str = "redis://localhost:6379/0", 22 | key_prefix: str = "message_store:", 23 | ttl: Optional[int] = None, 24 | ): 25 | try: 26 | import redis 27 | except ImportError: 28 | raise ValueError( 29 | "Could not import redis python package. " 30 | "Please install it with `pip install redis`." 31 | ) 32 | 33 | try: 34 | self.redis_client = redis.Redis.from_url(url=url) 35 | except redis.exceptions.ConnectionError as error: 36 | logger.error(error) 37 | 38 | self.session_id = session_id 39 | self.key_prefix = key_prefix 40 | self.ttl = ttl 41 | 42 | @property 43 | def key(self) -> str: 44 | """Construct the record key to use""" 45 | return self.key_prefix + self.session_id 46 | 47 | @property 48 | def messages(self) -> List[BaseMessage]: # type: ignore 49 | """Retrieve the messages from Redis""" 50 | _items = self.redis_client.lrange(self.key, 0, -1) 51 | items = [json.loads(m.decode("utf-8")) for m in _items[::-1]] 52 | messages = messages_from_dict(items) 53 | return messages 54 | 55 | def add_user_message(self, message: str) -> None: 56 | self.append(HumanMessage(content=message)) 57 | 58 | def add_ai_message(self, message: str) -> None: 59 | self.append(AIMessage(content=message)) 60 | 61 | def append(self, message: BaseMessage) -> None: 62 | """Append the message to the record in Redis""" 63 | self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message))) 64 | if self.ttl: 65 | self.redis_client.expire(self.key, self.ttl) 66 | 67 | def clear(self) -> None: 68 | """Clear session memory from Redis""" 69 | self.redis_client.delete(self.key) 70 | -------------------------------------------------------------------------------- /sparkai/memory/combined.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from sparkai.schema import BaseMemory 4 | 5 | 6 | class CombinedMemory(BaseMemory): 7 | """Class for combining multiple memories' data together.""" 8 | 9 | memories: List[BaseMemory] 10 | """For tracking all the memories that should be accessed.""" 11 | 12 | @property 13 | def memory_variables(self) -> List[str]: 14 | """All the memory variables that this instance provides.""" 15 | """Collected from the all the linked memories.""" 16 | 17 | memory_variables = [] 18 | 19 | for memory in self.memories: 20 | memory_variables.extend(memory.memory_variables) 21 | 22 | return memory_variables 23 | 24 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 25 | """Load all vars from sub-memories.""" 26 | memory_data: Dict[str, Any] = {} 27 | 28 | # Collect vars from all sub-memories 29 | for memory in self.memories: 30 | data = memory.load_memory_variables(inputs) 31 | memory_data = { 32 | **memory_data, 33 | **data, 34 | } 35 | 36 | return memory_data 37 | 38 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 39 | """Save context from this session for every memory.""" 40 | # Save context for all sub-memories 41 | for memory in self.memories: 42 | memory.save_context(inputs, outputs) 43 | 44 | def clear(self) -> None: 45 | """Clear context from this session for every memory.""" 46 | for memory in self.memories: 47 | memory.clear() 48 | -------------------------------------------------------------------------------- /sparkai/memory/readonly.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from sparkai.schema import BaseMemory 4 | 5 | 6 | class ReadOnlySharedMemory(BaseMemory): 7 | """A memory wrapper that is read-only and cannot be changed.""" 8 | 9 | memory: BaseMemory 10 | 11 | @property 12 | def memory_variables(self) -> List[str]: 13 | """Return memory variables.""" 14 | return self.memory.memory_variables 15 | 16 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 17 | """Load memory variables from memory.""" 18 | return self.memory.load_memory_variables(inputs) 19 | 20 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 21 | """Nothing should be saved or changed""" 22 | pass 23 | 24 | def clear(self) -> None: 25 | """Nothing to clear, got a memory like a vault.""" 26 | pass 27 | -------------------------------------------------------------------------------- /sparkai/memory/simple.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from sparkai.schema import BaseMemory 4 | 5 | 6 | class SimpleMemory(BaseMemory): 7 | """Simple memory for storing context or other bits of information that shouldn't 8 | ever change between prompts. 9 | """ 10 | 11 | memories: Dict[str, Any] = dict() 12 | 13 | @property 14 | def memory_variables(self) -> List[str]: 15 | return list(self.memories.keys()) 16 | 17 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: 18 | return self.memories 19 | 20 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 21 | """Nothing should be saved or changed, my memory is set in stone.""" 22 | pass 23 | 24 | def clear(self) -> None: 25 | """Nothing to clear, got a memory like a vault.""" 26 | pass 27 | -------------------------------------------------------------------------------- /sparkai/memory/token_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from sparkai.memory.chat_memory import BaseChatMemory 4 | from sparkai.schema import BaseLanguageModel, BaseMessage, get_buffer_string 5 | 6 | 7 | class ConversationTokenBufferMemory(BaseChatMemory): 8 | """Buffer for storing conversation memory.""" 9 | 10 | human_prefix: str = "user" 11 | ai_prefix: str = "assistant" 12 | llm: BaseLanguageModel 13 | memory_key: str = "history" 14 | max_token_limit: int = 2000 15 | 16 | @property 17 | def buffer(self) -> List[BaseMessage]: 18 | """String buffer of memory.""" 19 | return self.chat_memory.messages 20 | 21 | @property 22 | def memory_variables(self) -> List[str]: 23 | """Will always return list of memory variables. 24 | 25 | :meta private: 26 | """ 27 | return [self.memory_key] 28 | 29 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 30 | """Return history buffer.""" 31 | buffer: Any = self.buffer 32 | if self.return_messages: 33 | final_buffer: Any = buffer 34 | else: 35 | final_buffer = get_buffer_string( 36 | buffer, 37 | human_prefix=self.human_prefix, 38 | ai_prefix=self.ai_prefix, 39 | ) 40 | return {self.memory_key: final_buffer} 41 | 42 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 43 | """Save context from this conversation to buffer. Pruned.""" 44 | super().save_context(inputs, outputs) 45 | # Prune buffer if it exceeds max token limit 46 | buffer = self.chat_memory.messages 47 | curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) 48 | if curr_buffer_length > self.max_token_limit: 49 | pruned_memory = [] 50 | while curr_buffer_length > self.max_token_limit: 51 | pruned_memory.append(buffer.pop(0)) 52 | curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) 53 | -------------------------------------------------------------------------------- /sparkai/memory/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from sparkai.schema import get_buffer_string # noqa: 401 4 | 5 | 6 | def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: 7 | # "stop" is a special key that can be passed as input but is not used to 8 | # format the prompt. 9 | prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"])) 10 | if len(prompt_input_keys) != 1: 11 | raise ValueError(f"One input key expected got {prompt_input_keys}") 12 | return prompt_input_keys[0] 13 | -------------------------------------------------------------------------------- /sparkai/messages.py: -------------------------------------------------------------------------------- 1 | from sparkai.core.messages import ( 2 | AIMessage, 3 | AIMessageChunk, 4 | AnyMessage, 5 | BaseMessage, 6 | BaseMessageChunk, 7 | ChatMessage, 8 | ChatMessageChunk, 9 | FunctionMessage, 10 | FunctionMessageChunk, 11 | HumanMessage, 12 | HumanMessageChunk, 13 | SystemMessage, 14 | SystemMessageChunk, 15 | ToolMessage, 16 | ToolMessageChunk, 17 | _message_from_dict, 18 | get_buffer_string, 19 | merge_content, 20 | message_to_dict, 21 | messages_from_dict, 22 | messages_to_dict, 23 | ) 24 | 25 | # Backwards compatibility. 26 | _message_to_dict = message_to_dict 27 | 28 | __all__ = [ 29 | "get_buffer_string", 30 | "BaseMessage", 31 | "merge_content", 32 | "BaseMessageChunk", 33 | "HumanMessage", 34 | "HumanMessageChunk", 35 | "AIMessage", 36 | "AIMessageChunk", 37 | "SystemMessage", 38 | "SystemMessageChunk", 39 | "FunctionMessage", 40 | "FunctionMessageChunk", 41 | "ToolMessage", 42 | "ToolMessageChunk", 43 | "ChatMessage", 44 | "ChatMessageChunk", 45 | "messages_to_dict", 46 | "messages_from_dict", 47 | "_message_to_dict", 48 | "_message_from_dict", 49 | "message_to_dict", 50 | "AnyMessage", 51 | ] 52 | -------------------------------------------------------------------------------- /sparkai/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Classes for constructing Slack-specific data structure""" 2 | 3 | import logging 4 | from typing import Union, Dict, Any, Sequence, List 5 | 6 | from .basic_objects import BaseObject 7 | from .basic_objects import EnumValidator 8 | from .basic_objects import JsonObject 9 | from .basic_objects import JsonValidator 10 | 11 | 12 | # NOTE: used only for legacy components - don't use this for Block Kit 13 | def extract_json( 14 | item_or_items: Union[JsonObject, Sequence[JsonObject]], *format_args 15 | ) -> Union[Dict[Any, Any], List[Dict[Any, Any]]]: # type: ignore 16 | """ 17 | Given a sequence (or single item), attempt to call the to_dict() method on each 18 | item and return a plain list. If item is not the expected type, return it 19 | unmodified, in case it's already a plain dict or some other user created class. 20 | 21 | Args: 22 | item_or_items: item(s) to go through 23 | format_args: Any formatting specifiers to pass into the object's to_dict 24 | method 25 | """ 26 | try: 27 | return [ # type: ignore 28 | elem.to_dict(*format_args) if isinstance(elem, JsonObject) else elem for elem in item_or_items 29 | ] 30 | except TypeError: # not iterable, so try returning it as a single item 31 | return ( # type: ignore 32 | item_or_items.to_dict(*format_args) if isinstance(item_or_items, JsonObject) else item_or_items 33 | ) 34 | 35 | 36 | def show_unknown_key_warning(name: Union[str, object], others: dict): 37 | if "type" in others: 38 | others.pop("type") 39 | if len(others) > 0: 40 | keys = ", ".join(others.keys()) 41 | logger = logging.getLogger(__name__) 42 | if isinstance(name, object): 43 | name = name.__class__.__name__ 44 | logger.debug( 45 | f"!!! {name}'s constructor args ({keys}) were ignored." 46 | f"If they should be supported by this library, report this issue to the project :bow: " 47 | f"https://github.com/iflytek/spark-ai-sdk/issues" 48 | ) 49 | 50 | 51 | __all__ = [ 52 | "BaseObject", 53 | "EnumValidator", 54 | "JsonObject", 55 | "JsonValidator", 56 | "extract_json", 57 | "show_unknown_key_warning", 58 | ] 59 | -------------------------------------------------------------------------------- /sparkai/prompts/classification/__init__.py: -------------------------------------------------------------------------------- 1 | PROMPTS = ''' 2 | 现在你是语义理解器,帮我理解用户的问题并从下列命中选择一个合适的命令 3 | 4 | 你必须遵守如下限制: 5 | 1. 结果响应只能包含json内容 6 | 2. 结果响应不能有markdown内容 7 | 3. 结果中json格式务必正确且能够被python json.loads 解析 8 | 4. 你的回答必须使用下列命令中 name部分 9 | 5. 你的回答格式必须为下述json格式: 10 | 11 | { 12 | "thoughts": { 13 | "text": "thought", 14 | "reasoning": "reasoning", 15 | "plan": "- short bulleted - list that conveys - long-term plan", 16 | "criticism": "constructive self-criticism", 17 | "speak": "thoughts summary to say to user" 18 | }, 19 | "command": {"name": "command name", "args": {"arg name": "value"}} 20 | } 21 | 必须保证结果能够被python的json.loads加载 22 | 23 | 命令: 24 | 1. Start GPT Agent: name: "start_agent" , args: "name": "", "task": "", "prompt": "" 25 | 2. Read Emails: name: "read_emails", args: "imap_folder": "", "imap_search_command": "" 26 | 3. Send Email: name: "send_email", args: "to": "", "subject": "", "body": "" 27 | 4. Send Email: name: "send_email_with_attachment", args: "to": "", "subject": "", "body": "", "attachment": "" 28 | 5. Query Weather: name "query_weather", args: "date": "", "city": 29 | 30 | 用户问题: 31 | 32 | ''' 33 | -------------------------------------------------------------------------------- /sparkai/proxy_env_variable_loader.py: -------------------------------------------------------------------------------- 1 | """Internal module for loading proxy-related env variables""" 2 | import logging 3 | import os 4 | from typing import Optional 5 | 6 | _default_logger = logging.getLogger(__name__) 7 | 8 | 9 | def load_http_proxy_from_env(logger: logging.Logger = _default_logger) -> Optional[str]: 10 | proxy_url = ( 11 | os.environ.get("HTTPS_PROXY") 12 | or os.environ.get("https_proxy") 13 | or os.environ.get("HTTP_PROXY") 14 | or os.environ.get("http_proxy") 15 | ) 16 | if proxy_url is None: 17 | return None 18 | if len(proxy_url.strip()) == 0: 19 | # If the value is an empty string, the intention should be unsetting it 20 | logger.debug("The Slack SDK ignored the proxy env variable as an empty value is set.") 21 | return None 22 | 23 | logger.debug(f"HTTP proxy URL has been loaded from an env variable: {proxy_url}") 24 | return proxy_url 25 | -------------------------------------------------------------------------------- /sparkai/socket_mode/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py.py 7 | @time: 2023/04/29 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/socket_mode/interval_runner.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from threading import Thread, Event 3 | from typing import Callable 4 | 5 | 6 | class IntervalRunner: 7 | event: Event 8 | thread: Thread 9 | 10 | def __init__(self, target: Callable[[], None], interval_seconds: float = 0.1): 11 | self.event = threading.Event() 12 | self.target = target 13 | self.interval_seconds = interval_seconds 14 | self.thread = threading.Thread(target=self._run) 15 | self.thread.daemon = True 16 | 17 | def _run(self) -> None: 18 | while not self.event.is_set(): 19 | self.target() 20 | self.event.wait(self.interval_seconds) 21 | 22 | def start(self) -> "IntervalRunner": 23 | self.thread.start() 24 | return self 25 | 26 | def is_alive(self) -> bool: 27 | return self.thread is not None and self.thread.is_alive() 28 | 29 | def shutdown(self): 30 | if self.is_alive(): 31 | self.event.set() 32 | self.thread.join() 33 | self.thread = None 34 | -------------------------------------------------------------------------------- /sparkai/socket_mode/listeners.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from sparkai.socket_mode.request import SocketModeRequest 4 | 5 | 6 | class WebSocketMessageListener: 7 | def __call__( # type: ignore 8 | client: "BaseSocketModeClient", # noqa: F821 9 | message: dict, 10 | raw_message: Optional[str] = None, 11 | ): # noqa: F821 12 | raise NotImplementedError() 13 | 14 | 15 | class SocketModeRequestListener: 16 | def __call__(client: "BaseSocketModeClient", request: SocketModeRequest): # type: ignore # noqa: F821 # noqa: F821 17 | raise NotImplementedError() 18 | -------------------------------------------------------------------------------- /sparkai/socket_mode/request.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | from sparkai.models import JsonObject 4 | 5 | 6 | class SocketModeRequest: 7 | type: str 8 | envelope_id: str 9 | payload: dict 10 | accepts_response_payload: bool 11 | retry_attempt: Optional[int] # events_api 12 | retry_reason: Optional[str] # events_api 13 | 14 | def __init__( 15 | self, 16 | type: str, 17 | envelope_id: str, 18 | payload: Union[dict, JsonObject, str], 19 | accepts_response_payload: Optional[bool] = None, 20 | retry_attempt: Optional[int] = None, 21 | retry_reason: Optional[str] = None, 22 | ): 23 | self.type = type 24 | self.envelope_id = envelope_id 25 | 26 | if isinstance(payload, JsonObject): 27 | self.payload = payload.to_dict() 28 | elif isinstance(payload, dict): 29 | self.payload = payload 30 | elif isinstance(payload, str): 31 | self.payload = {"text": payload} 32 | else: 33 | unexpected_payload_type = type(payload) # type: ignore 34 | raise ValueError(f"Unsupported payload data type ({unexpected_payload_type})") 35 | 36 | self.accepts_response_payload = accepts_response_payload or False 37 | self.retry_attempt = retry_attempt 38 | self.retry_reason = retry_reason 39 | 40 | @classmethod 41 | def from_dict(cls, message: dict) -> Optional["SocketModeRequest"]: 42 | if all(k in message for k in ("type", "envelope_id", "payload")): 43 | return SocketModeRequest( 44 | type=message.get("type"), 45 | envelope_id=message.get("envelope_id"), 46 | payload=message.get("payload"), 47 | accepts_response_payload=message.get("accepts_response_payload") or False, 48 | retry_attempt=message.get("retry_attempt"), 49 | retry_reason=message.get("retry_reason"), 50 | ) 51 | return None 52 | 53 | def to_dict(self) -> dict: # skipcq: PYL-W0221 54 | d = {"envelope_id": self.envelope_id} 55 | if self.payload is not None: 56 | d["payload"] = self.payload 57 | return d 58 | -------------------------------------------------------------------------------- /sparkai/socket_mode/response.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | from sparkai.models import JsonObject 4 | 5 | 6 | class SocketModeResponse: 7 | envelope_id: str 8 | payload: dict 9 | 10 | def __init__(self, envelope_id: str, payload: Optional[Union[dict, JsonObject, str]] = None): 11 | self.envelope_id = envelope_id 12 | 13 | if payload is None: 14 | self.payload = None 15 | elif isinstance(payload, JsonObject): 16 | self.payload = payload.to_dict() 17 | elif isinstance(payload, dict): 18 | self.payload = payload 19 | elif isinstance(payload, str): 20 | self.payload = {"text": payload} 21 | else: 22 | raise ValueError(f"Unsupported payload data type ({type(payload)})") 23 | 24 | def to_dict(self) -> dict: # skipcq: PYL-W0221 25 | d = {"envelope_id": self.envelope_id} 26 | if self.payload is not None: 27 | d["payload"] = self.payload 28 | return d 29 | -------------------------------------------------------------------------------- /sparkai/spark_proxy/generate_message.py: -------------------------------------------------------------------------------- 1 | from sparkai.spark_proxy.openai_types import ChatMessage, Function, FunctionCall, Tool, ToolCall 2 | from typing import List, Optional 3 | 4 | from sparkai.spark_proxy.spark_api import SparkAPI 5 | 6 | s_k = "key&secret&appid" 7 | 8 | 9 | def generate_message( 10 | *, 11 | key: str, 12 | messages: List[ChatMessage], 13 | functions: Optional[List[Function]] = None, 14 | tools: Optional[List[Tool]] = None, 15 | temperature: float = 0.7, 16 | model: str = None 17 | ) -> ChatMessage: 18 | s_api = SparkAPI(key, model=model, temperature=temperature) 19 | print('generate_message ~~~~~') 20 | print('messages:', messages) 21 | print('functions:', functions) 22 | print('tools:', tools) 23 | print('temperature:', temperature) 24 | 25 | function_list = [] 26 | if functions: 27 | for t in functions: 28 | function_list.append( 29 | { 30 | 'name': t.name, 31 | 'description': t.description, 32 | 'parameters': t.parameters 33 | } 34 | ) 35 | if tools: 36 | for t in tools: 37 | f = t.function 38 | function_list.append( 39 | { 40 | 'name': f.name, 41 | 'description': f.description, 42 | 'parameters': f.parameters 43 | } 44 | ) 45 | content, function_call = s_api.call(messages, function_list) 46 | print("resp------------") 47 | print('content::', content) 48 | print('function_call::', function_call) 49 | if function_call: 50 | f = FunctionCall(name=function_call['name'], arguments=function_call['arguments'], id="call_" + 'a' * 24) 51 | return ChatMessage( 52 | content=content, 53 | role='assistant', 54 | tool_calls=[ToolCall(function=f)] 55 | ) 56 | 57 | return ChatMessage( 58 | content=content, 59 | role='assistant') 60 | -------------------------------------------------------------------------------- /sparkai/spark_proxy/generate_stream.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Generator, Dict, Any 2 | from sparkai.spark_proxy.openai_types import ChatMessage, Function, Tool, FunctionCall, ToolCall 3 | 4 | from sparkai.spark_proxy.spark_api import SparkAPI 5 | 6 | s_k = "key&secret&appid" 7 | 8 | 9 | def generate_stream( 10 | *, 11 | key: str, 12 | model: str, 13 | messages: List[ChatMessage], 14 | functions: Optional[List[Function]] = None, 15 | tools: Optional[List[Tool]] = None, 16 | temperature: float = 0.7, 17 | stop: list[str] = None 18 | ) -> Generator[Dict, Any, Any]: 19 | 20 | s_api = SparkAPI(key, model=model, temperature=temperature) 21 | # print('generate_message ~~~~~') 22 | # print('messages:', messages) 23 | # print('functions:', functions) 24 | # print('tools:', tools) 25 | # print('temperature:', temperature) 26 | 27 | function_list = [] 28 | if functions: 29 | for t in functions: 30 | function_list.append( 31 | { 32 | 'name': t.name, 33 | 'description': t.description, 34 | 'parameters': t.parameters 35 | } 36 | ) 37 | if tools: 38 | for t in tools: 39 | f = t.function 40 | function_list.append( 41 | { 42 | 'name': f.name, 43 | 'description': f.description, 44 | 'parameters': f.parameters 45 | } 46 | ) 47 | return s_api.yield_call(messages, function_list) 48 | # content, function_call = s_api.call(messages, function_list) 49 | # print("resp------------") 50 | # print('content::', content) 51 | # print('function_call::', function_call) 52 | # if function_call: 53 | # f = FunctionCall(name=function_call['name'], arguments=function_call['arguments'], id="call_" + 'a' * 24) 54 | # yield ChatMessage( 55 | # content=content, 56 | # role='assistant', 57 | # tool_calls=[ToolCall(function=f)] 58 | # ) 59 | # 60 | # yield ChatMessage( 61 | # content=content, 62 | # role='assistant') 63 | -------------------------------------------------------------------------------- /sparkai/spark_proxy/main.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from sparkai.spark_proxy.server import app 3 | 4 | if __name__ == '__main__': 5 | uvicorn.run(app, host="0.0.0.0", port=8008) 6 | -------------------------------------------------------------------------------- /sparkai/spark_proxy/openai_types.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Literal, Optional 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class FunctionCall(BaseModel): 8 | name: Optional[str] = None 9 | arguments: str 10 | 11 | 12 | class ToolCall(BaseModel): 13 | index: Optional[int] = None 14 | id: Optional[str] = None 15 | function: FunctionCall 16 | type: Optional[str] = "function" 17 | 18 | 19 | class Function(BaseModel): 20 | name: str 21 | description: Optional[str] = Field(default="") 22 | parameters: Optional[dict] = None 23 | 24 | 25 | class Tool(BaseModel): 26 | type: Literal["function", "code_interpreter"] = "function" 27 | function: Optional[Function] = None 28 | 29 | 30 | class ChatMessage(BaseModel): 31 | role: Optional[str] = None 32 | tool_call_id: Optional[str] = None 33 | content: Optional[str] = None 34 | name: Optional[str] = None 35 | function_call: Optional[FunctionCall] = None 36 | tool_calls: Optional[List[ToolCall]] = None 37 | 38 | def __str__(self) -> str: 39 | if self.role == "system": 40 | return f"system:\n{self.content}\n" 41 | 42 | elif self.role == "function": 43 | return f"function name={self.name}:\n{self.content}\n" 44 | 45 | elif self.role == "user": 46 | if self.content is None: 47 | return "user:\n" 48 | else: 49 | return f"user:\n{self.content}\n" 50 | 51 | elif self.role == "assistant": 52 | if self.content is not None and self.function_call is not None: 53 | return f"assistant:\n{self.content}\nassistant to={self.function_call.name}:\n{self.function_call.arguments}" 54 | 55 | elif self.function_call is not None: 56 | return f"assistant to={self.function_call.name}:\n{self.function_call.arguments}" 57 | 58 | elif self.content is None: 59 | return "assistant" 60 | 61 | else: 62 | return f"assistant:\n{self.content}\n" 63 | 64 | else: 65 | raise ValueError(f"Unsupported role: {self.role}") 66 | 67 | 68 | class ChatInput(BaseModel): 69 | messages: List[ChatMessage] 70 | functions: Optional[List[Function]] = None 71 | tools: Optional[List[Tool]] = None 72 | temperature: float = 0.9 73 | stream: bool = False 74 | model: str = None 75 | stop: list[str] = None 76 | key: str = None 77 | 78 | 79 | class Choice(BaseModel): 80 | message: ChatMessage 81 | finish_reason: str = "stop" 82 | index: int = 0 83 | 84 | @classmethod 85 | def from_message(cls, message: ChatMessage, finish_reason: str): 86 | return cls(message=message, finish_reason=finish_reason) 87 | 88 | 89 | class ChatCompletion(BaseModel): 90 | id: str 91 | object: str = "chat.completion" 92 | created: float = Field(default_factory=time.time) 93 | choices: List[Choice] 94 | 95 | 96 | class StreamChoice(BaseModel): 97 | delta: ChatMessage 98 | finish_reason: Optional[str] = "stop" 99 | index: int = 0 100 | 101 | 102 | class ChatCompletionChunk(BaseModel): 103 | id: str 104 | object: str = "chat.completion.chunk" 105 | created: float = Field(default_factory=time.time) 106 | choices: List[StreamChoice] 107 | -------------------------------------------------------------------------------- /sparkai/spark_proxy/server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from fastapi import FastAPI, Header 4 | from fastapi.responses import JSONResponse, StreamingResponse 5 | from sparkai.spark_proxy.openai_types import ChatInput, ChatCompletion, Choice, StreamChoice, ChatCompletionChunk, ChatMessage 6 | from sparkai.spark_proxy.generate_message import generate_message 7 | from sparkai.spark_proxy.generate_stream import generate_stream 8 | 9 | from typing import Annotated 10 | from sparkai.log import logger 11 | 12 | app = FastAPI(title="Spark 2 OpenAI-Compatible API") 13 | 14 | 15 | @app.post("/v1/chat/completions") 16 | async def chat_endpoint(authorization: Annotated[str | None, Header()] = None, chat_input: ChatInput = None): 17 | key = authorization.split()[1] 18 | request_id = str(uuid.uuid4()) 19 | if not chat_input.stream: 20 | response_message = generate_message( 21 | key=key, 22 | messages=chat_input.messages, 23 | functions=chat_input.functions, 24 | tools=chat_input.tools, 25 | temperature=chat_input.temperature, 26 | model=chat_input.model 27 | ) 28 | print(chat_input.messages) 29 | finish_reason = "stop" 30 | if response_message.function_call is not None: 31 | finish_reason = "function_call" # need to add this to follow the format of openAI function calling 32 | result = ChatCompletion( 33 | id=request_id, 34 | choices=[Choice.from_message(response_message, finish_reason)], 35 | ) 36 | return result.dict(exclude_none=True) 37 | 38 | else: 39 | print(chat_input.messages) 40 | 41 | response_generator = generate_stream( 42 | key=key, 43 | messages=chat_input.messages, 44 | functions=chat_input.functions, 45 | tools=chat_input.tools, 46 | temperature=chat_input.temperature, 47 | model=chat_input.model, # type: ignore 48 | stop=chat_input.stop 49 | ) 50 | 51 | def get_response_stream(): 52 | i = 0 53 | r_str = """""" 54 | for response in response_generator: 55 | if 'function_call' in response['payload']['choices']['text'][0]: 56 | response = { 57 | 'delta': ChatMessage( 58 | content=response['payload']['choices']['text'][0]['function_call'] 59 | ), 60 | 'finish_reason': 'tool_calls', 61 | 'index': i 62 | } 63 | else: 64 | print(response['payload']['choices']['text'][0]['content'], end='') 65 | r_str += response['payload']['choices']['text'][0]['content'] 66 | for s in chat_input.stop or []: 67 | if s in r_str: 68 | yield "data: [DONE]\n\n" 69 | return 70 | response = { 71 | 'delta': ChatMessage( 72 | content=response['payload']['choices']['text'][0]['content'], 73 | role='assistant' 74 | ), 75 | 'finish_reason': 'stop', 76 | 'index': i 77 | } 78 | i += 1 79 | chunk = StreamChoice(**response) 80 | result = ChatCompletionChunk(id=request_id, choices=[chunk]) 81 | chunk_dic = result.dict(exclude_unset=True) 82 | chunk_data = json.dumps(chunk_dic, ensure_ascii=False) 83 | yield f"data: {chunk_data}\n\n" 84 | yield "data: [DONE]\n\n" 85 | 86 | return StreamingResponse(get_response_stream(), media_type="text/event-stream") 87 | -------------------------------------------------------------------------------- /sparkai/spark_proxy/spark_auth.py: -------------------------------------------------------------------------------- 1 | import time 2 | import hmac 3 | import base64 4 | import hashlib 5 | import datetime 6 | from urllib.parse import urlencode 7 | 8 | 9 | # 接口地址,接口地址表单形式,包含了token 10 | url = 'wss://spark-api.xf-yun.com/v3.1/chat' 11 | 12 | # Weekday and month names for HTTP date/time formatting; always English! 13 | _weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] 14 | _monthname = [None, # Dummy so we can use 1-based month numbers 15 | "Jan", "Feb", "Mar", "Apr", "May", "Jun", 16 | "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] 17 | 18 | 19 | def format_date_time(timestamp): 20 | year, month, day, hh, mm, ss, wd, y, z = time.gmtime(timestamp) 21 | return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 22 | _weekdayname[wd], day, _monthname[month], year, hh, mm, ss 23 | ) 24 | 25 | 26 | def create_url(host, path, api_key, api_secret, spark_url): 27 | # 生成RFC1123格式的时间戳 28 | now = datetime.datetime.now() 29 | date = format_date_time(time.mktime(now.timetuple())) 30 | 31 | # 拼接字符串 32 | signature_origin = "host: " + host + "\n" 33 | signature_origin += "date: " + date + "\n" 34 | signature_origin += "GET " + path + " HTTP/1.1" 35 | 36 | # 进行hmac-sha256进行加密 37 | signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), 38 | digestmod=hashlib.sha256).digest() 39 | 40 | signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') 41 | 42 | authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' 43 | 44 | authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') 45 | 46 | # 将请求的鉴权参数组合为字典 47 | v = { 48 | "authorization": authorization, 49 | "date": date, 50 | "host": host 51 | } 52 | # 拼接鉴权参数,生成url 53 | url = spark_url + '?' + urlencode(v) 54 | # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 55 | return url 56 | -------------------------------------------------------------------------------- /sparkai/v2/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/19 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/v2/client/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/19 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/v2/client/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/21 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/v2/client/common/consts.py: -------------------------------------------------------------------------------- 1 | IFLYTEK = "iflytek" 2 | DefaultDomain = "generalv3.5" -------------------------------------------------------------------------------- /sparkai/v2/client/http/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/19 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | import asyncio 28 | import queue 29 | import threading 30 | from abc import ABC 31 | from queue import Queue 32 | from typing import Optional, Dict 33 | 34 | import httpx 35 | 36 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 37 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 38 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 39 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 40 | # Vestibulum commodo. Ut rhoncus gravida arcu. 41 | 42 | from sparkai.v2.client.common.consts import * 43 | 44 | 45 | class HttpClient(ABC): 46 | def __init__( 47 | self, 48 | app_id: str, 49 | api_key: str, 50 | api_secret: str, 51 | result_q: queue.Queue, 52 | api_url: Optional[str] = None, 53 | spark_domain: Optional[str] = None, 54 | model_kwargs: Optional[dict] = None, 55 | user_agent: Optional[str] = None 56 | ): 57 | self.api_url = api_url 58 | self.domain = spark_domain 59 | self.app_id = app_id 60 | self.api_key = api_key 61 | self.api_secret = api_secret 62 | self.model_kwargs = model_kwargs 63 | self.queue = result_q 64 | self.blocking_message = {"content": "", "role": "assistant"} 65 | self.api_secret = api_secret 66 | self.extra_user_agent = user_agent 67 | 68 | async def a_request(self, params: dict, data:dict, method="GET", headers={}) -> httpx.Response: 69 | async with httpx.AsyncClient() as client: 70 | if method == "GET": 71 | response = await client.get(self.api_url, params=params, headers=headers) 72 | elif method == "POST": 73 | response = await client.post(self.api_url,params=params,data=data,headers=headers) 74 | self.queue.put(response) 75 | def request(self): 76 | pass 77 | 78 | async def a_start(self): 79 | await self.a_request(params={},data={}) 80 | while not self.queue.empty(): 81 | result = self.queue.get() 82 | print(result) 83 | print(result.content) 84 | 85 | def start(self): 86 | p = [] 87 | for i in range(10): 88 | t = threading.Thread(target=asyncio.run, args=(self.a_start(),)) 89 | t.start() 90 | p.append(t) 91 | for t in p: 92 | t.join() 93 | 94 | if __name__ == '__main__': 95 | c = HttpClient(api_key="",api_secret="", api_url="https://www.xx.com", app_id="", result_q=queue.Queue()) 96 | 97 | c.start() -------------------------------------------------------------------------------- /sparkai/v2/client/ws/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/19 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/v2/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/19 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/v2/llm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: __init__.py 7 | @time: 2024/04/19 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | -------------------------------------------------------------------------------- /sparkai/version.py: -------------------------------------------------------------------------------- 1 | """Check the latest version at https://pypi.org/project/spark-ai-sdk/""" 2 | __version__ = "0.3.4" 3 | -------------------------------------------------------------------------------- /sparkai/xf_util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import hmac 4 | 5 | from datetime import datetime 6 | from time import mktime 7 | from urllib import parse 8 | from urllib.parse import urlencode 9 | from wsgiref.handlers import format_date_time 10 | 11 | 12 | # 生成鉴权的url 13 | def build_auth_request_url(request_url, method="GET", api_key="", api_secret=""): 14 | url_result = parse.urlparse(request_url) 15 | date = format_date_time(mktime(datetime.now().timetuple())) 16 | signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(url_result.hostname, date, method, url_result.path) 17 | signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), 18 | digestmod=hashlib.sha256).digest() 19 | signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') 20 | authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( 21 | api_key, "hmac-sha256", "host date request-line", signature_sha) 22 | authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') 23 | values = { 24 | "host": url_result.hostname, 25 | "date": date, 26 | "authorization": authorization 27 | } 28 | return request_url + "?" + urlencode(values) 29 | -------------------------------------------------------------------------------- /tests/embedding_test/embedding_test.py: -------------------------------------------------------------------------------- 1 | from sparkai.embedding.spark_embedding import Embeddingmodel, SparkEmbeddingFunction 2 | import chromadb 3 | import os 4 | 5 | try: 6 | from dotenv import load_dotenv 7 | except ImportError: 8 | raise RuntimeError( 9 | 'Python environment for SPARK AI is not completely set up: required package "python-dotenv" is missing.') from None 10 | 11 | load_dotenv() 12 | 13 | 14 | def test_embedding(): 15 | model = Embeddingmodel( 16 | spark_embedding_app_id=os.environ['SPARK_Embedding_APP_ID'], 17 | spark_embedding_api_key=os.environ['SPARK_Embedding_API_KEY'], 18 | spark_embedding_api_secret=os.environ['SPARK_Embedding_API_SECRET'], 19 | spark_embedding_domain=os.environ['SPARK_Embedding_DOMAIN'], 20 | ) 21 | # desc = {"messages":[{"content":"cc","role":"user"}]} 22 | desc = {"content": "cc", "role": "user"} 23 | # 调用embedding方法 24 | a = model.embedding(text=desc, kind='text') 25 | # print(len(a)) 26 | print(a) 27 | 28 | 29 | def test_chroma_embedding(): 30 | chroma_client = chromadb.Client() 31 | sparkmodel = SparkEmbeddingFunction( 32 | spark_embedding_app_id=os.environ['SPARK_Embedding_APP_ID'], 33 | spark_embedding_api_key=os.environ['SPARK_Embedding_API_KEY'], 34 | spark_embedding_api_secret=os.environ['SPARK_Embedding_API_SECRET'], 35 | spark_embedding_domain=os.environ['SPARK_Embedding_DOMAIN'], 36 | ) 37 | a = sparkmodel(["This is a document", "This is another document"]) 38 | # print(type(a)) 39 | # print(a[0]) 40 | # print(a[0][1]) 41 | # 可以正确的生成embedding结果 42 | collection = chroma_client.get_or_create_collection(name="my_collection", embedding_function=sparkmodel) 43 | # 为什么是None 44 | collection.add( 45 | documents=["This is a document", "cc", "1122"], 46 | metadatas=[{"source": "my_source"}, {"source": "my_source"}, {"source": "my_source"}], 47 | ids=["id1", "id2", "id3"] 48 | ) 49 | # print(collection.peek()) #显示前五条数据 50 | print(collection.count()) # 数据库中数据量 51 | results = collection.query( 52 | query_texts=["ac", 'documents'], 53 | n_results=2 54 | ) 55 | print(results) # 查询结果 56 | 57 | 58 | if __name__ == "__main__": 59 | test_embedding() 60 | test_chroma_embedding() 61 | -------------------------------------------------------------------------------- /tests/embedding_test/test_llama.py: -------------------------------------------------------------------------------- 1 | # 自定义功能 2 | from sparkai.embedding.sparkai_base import SparkAiEmbeddingModel 3 | from llama_index.core import VectorStoreIndex, SimpleDirectoryReader 4 | from llama_index.vector_stores.chroma import ChromaVectorStore 5 | from llama_index.core import StorageContext 6 | import chromadb 7 | import os 8 | from sparkai.embedding.spark_embedding import SparkEmbeddingFunction 9 | from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings 10 | from sparkai.frameworks.llama_index import SparkAI 11 | 12 | try: 13 | from dotenv import load_dotenv 14 | except ImportError: 15 | raise RuntimeError( 16 | 'Python environment for SPARK AI is not completely set up: required package "python-dotenv" is missing.') from None 17 | 18 | load_dotenv() 19 | 20 | 21 | def llama_query(): 22 | chroma_client = chromadb.Client() 23 | chroma_collection = chroma_client.get_or_create_collection(name="spark") 24 | # define embedding function 25 | embed_model = SparkAiEmbeddingModel(spark_embedding_app_id=os.environ['SPARK_Embedding_APP_ID'], 26 | spark_embedding_api_key=os.environ['SPARK_Embedding_API_KEY'], 27 | spark_embedding_api_secret=os.environ['SPARK_Embedding_API_SECRET'], 28 | spark_embedding_domain=os.environ['SPARKAI_Embedding_DOMAIN'], 29 | qps=2) 30 | # define LLM Model 31 | sparkai = SparkAI( 32 | spark_api_url=os.environ["SPARKAI_URL"], 33 | spark_app_id=os.environ["SPARKAI_APP_ID"], 34 | spark_api_key=os.environ["SPARKAI_API_KEY"], 35 | spark_api_secret=os.environ["SPARKAI_API_SECRET"], 36 | spark_llm_domain=os.environ["SPARKAI_DOMAIN"], 37 | streaming=False, 38 | temperature=0.01, 39 | ) 40 | # load documents 41 | # Invoke-WebRequest -Uri 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -OutFile 'data\paul_graham\paul_graham_essay.txt' 42 | documents = SimpleDirectoryReader("D:\data\paul_graham").load_data() 43 | # set up ChromaVectorStore and load in data 44 | vector_store = ChromaVectorStore(chroma_collection=chroma_collection) 45 | storage_context = StorageContext.from_defaults(vector_store=vector_store) 46 | index = VectorStoreIndex.from_documents(documents, storage_context=storage_context, embed_model=embed_model) 47 | 48 | # query 49 | query_engine = index.as_query_engine(llm=sparkai, similarity_top_k=2) 50 | response = query_engine.query("What did the author do growing up?") 51 | print(response) 52 | 53 | 54 | if __name__ == "__main__": 55 | for i in range(10): 56 | llama_query() 57 | -------------------------------------------------------------------------------- /tests/examples/docs/agent_artchitect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/tests/examples/docs/agent_artchitect.png -------------------------------------------------------------------------------- /tests/examples/docs/agents.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/tests/examples/docs/agents.png -------------------------------------------------------------------------------- /tests/examples/docs/autogen_grouchat_with_graph.md: -------------------------------------------------------------------------------- 1 | # 智能体框架AutoGen深入体验及部分思考 2 | 3 | ## 背景 4 | 5 | 大模型爆发,AIAgent爆发,AIAgent Framework也在爆发,各家都提前在agent领域布局各自的框架类基础设施。 6 | 7 | 前有`AutoGPT`,`chatDev`, `Langchain`, `MetaGPT`, `SuperAGI`, `AgentGPT` 一众框架平台出圈,从23年下半年到2024年初, 8 | 9 | `AutoGen`, `Langflow` ,`LangGraph`, `OpenAgents`,`Gpt_Pilot`, `Devin`, `OpenDevin`, `AIOS`等又火速升温。 10 | 11 | ![img.png](agents.png) 12 | 13 | 这一切的一切都在体现着 `AI Agent`在行业成为新宠。 今天我们来从偏技术的角度剖析微软开源的`AutoGen框架`。 14 | 15 | 请耐心读完,读完你可能会发现,什么都没有。。。 16 | 17 | 18 | ## 何为 Agent? 何为 AI Agent 19 | 20 | ![1img.png](agent_artchitect.png) 21 | 22 | 一直以来大家对Agent的定义都是不太清晰的,究其本源,还是因为Agent的定义过于宽泛,仿佛万物皆可Agent。 23 | 24 | 再加上一个 AI 限定, 摇身一变成为`智能体`。 25 | 26 | ***在大模型的背景下, `Agent`一词已经逐渐默认为 `智能体` 和 `AI Agent`*** 27 | 28 | ### AI Agent定义 29 | 30 | 可以实现自主理解、长期记忆、规划决策、执行复杂任务的 `Agent` 31 | 32 | ### AI Agent和大模型的关系 33 | 34 | 破除大模型有脑无手的困局, 不仅可以教会用户如何做,更加会执行。 35 | 36 | 37 | 38 | 以上两个定义都是有些宽泛的,如果有更加明确标准的定义还请私信联系我交流探讨。 39 | 40 | 41 | ## 走近 AutoGen 42 | 43 | * AutoGEN的几个典型案例 44 | 45 | 46 | ## 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /tests/examples/docs/llama-index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/tests/examples/docs/llama-index.png -------------------------------------------------------------------------------- /tests/examples/docs/token_usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/tests/examples/docs/token_usage.png -------------------------------------------------------------------------------- /tests/examples/llama_index_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/tests/examples/llama_index_embedding.png -------------------------------------------------------------------------------- /tests/examples/spark_llama_index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/tests/examples/spark_llama_index.png -------------------------------------------------------------------------------- /tests/sparkai_test/wrapper_write.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | @author: nivic ybyang7 5 | @license: Apache Licence 6 | @file: test.py 7 | @time: 2023/07/23 8 | @contact: ybyang7@iflytek.com 9 | @site: 10 | @software: PyCharm 11 | 12 | # code is far away from bugs with the god animal protecting 13 | I love animals. They taste delicious. 14 | ┏┓ ┏┓ 15 | ┏┛┻━━━┛┻┓ 16 | ┃ ☃ ┃ 17 | ┃ ┳┛ ┗┳ ┃ 18 | ┃ ┻ ┃ 19 | ┗━┓ ┏━┛ 20 | ┃ ┗━━━┓ 21 | ┃ 神兽保佑 ┣┓ 22 | ┃ 永无BUG! ┏┛ 23 | ┗┓┓┏━┳┓┏┛ 24 | ┃┫┫ ┃┫┫ 25 | ┗┻┛ ┗┻┛ 26 | """ 27 | 28 | # Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. 29 | # Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. 30 | # Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. 31 | # Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. 32 | # Vestibulum commodo. Ut rhoncus gravida arcu. 33 | 34 | from sparkai.api_resources import * 35 | from sparkai.api_resources.chat_completion import * 36 | from sparkai.schema import ChatMessage 37 | from sparkai.models.chat import ChatBody, ChatResponse 38 | 39 | if __name__ == '__main__': 40 | c = SparkOnceWebsocket(api_key=api_key, api_secret=api_secret, app_id=app_id, api_base=api_base) 41 | content = open("prompts/wrapper_write_code_spark.txt",'r').read() 42 | messages = [{'role': 'user', 'content': content}] 43 | print(messages[0]['content']) 44 | c.send_messages(messages) 45 | -------------------------------------------------------------------------------- /weichat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/spark-ai-python/735eb3d56d6f0fffadb6c9cc0bb7d6530b788684/weichat.jpg --------------------------------------------------------------------------------