├── zhipuai
├── types
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── chat_completions_create_param.py
│ │ ├── agents_completion.py
│ │ └── agents_completion_chunk.py
│ ├── chat
│ │ ├── __init__.py
│ │ ├── chat_completions_create_param.py
│ │ ├── async_chat_completion.py
│ │ ├── chat_completion.py
│ │ ├── chat_completion_chunk.py
│ │ └── code_geex
│ │ │ └── code_geex_params.py
│ ├── file_parser
│ │ ├── __init__.py
│ │ ├── file_parser_resp.py
│ │ └── file_parser_create_params.py
│ ├── moderation
│ │ ├── __init__.py
│ │ └── moderation_completion.py
│ ├── fine_tuning
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── fine_tuned_models.py
│ │ ├── __init__.py
│ │ ├── job_create_params.py
│ │ ├── fine_tuning_job_event.py
│ │ └── fine_tuning_job.py
│ ├── assistant
│ │ ├── message
│ │ │ ├── __init__.py
│ │ │ ├── text_content_block.py
│ │ │ ├── message_content.py
│ │ │ ├── tools_delta_block.py
│ │ │ └── tools
│ │ │ │ ├── drawing_tool_delta_block.py
│ │ │ │ ├── function_delta_block.py
│ │ │ │ ├── tools_type.py
│ │ │ │ ├── code_interpreter_delta_block.py
│ │ │ │ ├── retrieval_delta_black.py
│ │ │ │ └── web_browser_delta_block.py
│ │ ├── __init__.py
│ │ ├── assistant_conversation_params.py
│ │ ├── assistant_support_resp.py
│ │ ├── assistant_conversation_resp.py
│ │ ├── assistant_create_params.py
│ │ └── assistant_completion.py
│ ├── video
│ │ ├── __init__.py
│ │ ├── video_object.py
│ │ └── video_create_params.py
│ ├── web_search
│ │ ├── __init__.py
│ │ ├── web_search_create_params.py
│ │ └── web_search_resp.py
│ ├── sensitive_word_check
│ │ ├── __init__.py
│ │ └── sensitive_word_check.py
│ ├── knowledge
│ │ ├── __init__.py
│ │ ├── document
│ │ │ ├── __init__.py
│ │ │ ├── document_list_resp.py
│ │ │ ├── document_list_params.py
│ │ │ ├── document_edit_params.py
│ │ │ └── document.py
│ │ ├── knowledge_list_resp.py
│ │ ├── knowledge_list_params.py
│ │ ├── knowledge_used.py
│ │ ├── knowledge.py
│ │ └── knowledge_create_params.py
│ ├── files
│ │ ├── file_deleted.py
│ │ ├── __init__.py
│ │ ├── upload_detail.py
│ │ ├── file_object.py
│ │ └── file_create_params.py
│ ├── batch_request_counts.py
│ ├── audio
│ │ ├── __init__.py
│ │ ├── transcriptions_create_param.py
│ │ ├── audio_speech_params.py
│ │ ├── audio_customization_param.py
│ │ └── audio_speech_chunk.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── web_search_chunk.py
│ │ ├── tools_web_search_params.py
│ │ └── web_search.py
│ ├── image.py
│ ├── batch_list_params.py
│ ├── batch_error.py
│ ├── embeddings.py
│ ├── batch_create_params.py
│ └── batch.py
├── __version__.py
├── api_resource
│ ├── tools
│ │ ├── __init__.py
│ │ └── tools.py
│ ├── knowledge
│ │ ├── __init__.py
│ │ └── document
│ │ │ └── __init__.py
│ ├── file_parser
│ │ ├── __init__.py
│ │ └── file_parser.py
│ ├── fine_tuning
│ │ ├── jobs
│ │ │ └── __init__.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── fine_tuned_models.py
│ │ ├── __init__.py
│ │ └── fine_tuning.py
│ ├── web_search
│ │ ├── __init__.py
│ │ └── web_search.py
│ ├── videos
│ │ ├── __init__.py
│ │ └── videos.py
│ ├── agents
│ │ ├── __init__.py
│ │ └── agents.py
│ ├── moderation
│ │ ├── __init__.py
│ │ └── moderations.py
│ ├── assistant
│ │ └── __init__.py
│ ├── audio
│ │ ├── __init__.py
│ │ ├── transcriptions.py
│ │ └── audio.py
│ ├── chat
│ │ ├── __init__.py
│ │ └── chat.py
│ ├── __init__.py
│ ├── embeddings.py
│ └── images.py
├── __init__.py
├── core
│ ├── _base_api.py
│ ├── _constants.py
│ ├── _jwt_token.py
│ ├── pagination.py
│ ├── _utils
│ │ ├── __init__.py
│ │ └── _typing.py
│ ├── _errors.py
│ ├── _files.py
│ ├── __init__.py
│ ├── logs.py
│ └── _request_opt.py
└── _client.py
├── tests
├── unit_tests
│ ├── response_model
│ │ └── __init__.py
│ ├── sse_client
│ │ ├── __init__.py
│ │ └── test_stream.py
│ ├── batchinput.jsonl
│ ├── test_sdk_import.py
│ ├── maybe
│ │ └── test_maybe_transform.py
│ ├── test_jwt.py
│ ├── test_audio.py
│ ├── test_request_opt.py
│ ├── test_agents.py
│ ├── test_streaming.py
│ └── test_response.py
├── integration_tests
│ ├── asr1.wav
│ ├── file.xlsx
│ ├── speech.wav
│ ├── img
│ │ └── MetaGLM.png
│ ├── batchinput.jsonl
│ ├── demo.jsonl
│ ├── test.py
│ ├── test_moderation.py
│ ├── test_web_search.py
│ ├── test_emohaa.py
│ ├── test_file_parser.py
│ ├── test_tools.py
│ ├── test_embedding.py
│ ├── test_images.py
│ ├── test_transcriptions.py
│ ├── test_videos.py
│ ├── test_file.py
│ ├── test_vlm_thinking.py
│ ├── test_agents.py
│ ├── test_charglm3.py
│ ├── test_audio.py
│ ├── test_assistant.py
│ ├── test_code_geex.py
│ └── test_finetuning.py
└── conftest.py
├── .env
├── poetry.toml
├── .gitignore
├── .github
├── workflows
│ ├── lint-pr.yaml
│ ├── _test.yml
│ ├── _integration_test.yml
│ └── _test_release.yml
├── PULL_REQUEST_TEMPLATE.md
├── ISSUE_TEMPLATE
│ ├── bug-report.yml
│ └── feature-request.yml
└── actions
│ └── poetry_setup
│ └── action.yml
├── LICENSE
├── CONTRIBUTING.md
├── Makefile
├── pyproject.toml
├── Release-Note.md
└── CODE_OF_CONDUCT.md
/zhipuai/types/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/zhipuai/types/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/zhipuai/types/chat/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/zhipuai/types/file_parser/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/zhipuai/types/moderation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit_tests/response_model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit_tests/sse_client/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/zhipuai/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = 'v2.1.5.20250725'
--------------------------------------------------------------------------------
/zhipuai/api_resource/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .tools import Tools
2 |
3 | __all__ = ['Tools']
--------------------------------------------------------------------------------
/.env:
--------------------------------------------------------------------------------
1 | PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
2 | ZHIPUAI_API_KEY={YOUR API KEY}
3 |
--------------------------------------------------------------------------------
/zhipuai/types/fine_tuning/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .fine_tuned_models import FineTunedModelsStatus
--------------------------------------------------------------------------------
/poetry.toml:
--------------------------------------------------------------------------------
1 | [virtualenvs]
2 | in-project = true
3 |
4 | [installer]
5 | modern-installation = false
6 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/knowledge/__init__.py:
--------------------------------------------------------------------------------
1 | from .knowledge import Knowledge
2 |
3 | __all__ = ['Knowledge']
--------------------------------------------------------------------------------
/zhipuai/api_resource/file_parser/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_parser import FileParser
2 |
3 | __all__ = ['FileParser']
--------------------------------------------------------------------------------
/zhipuai/api_resource/fine_tuning/jobs/__init__.py:
--------------------------------------------------------------------------------
1 | from .jobs import Jobs
2 |
3 | __all__ = [
4 | "Jobs"
5 | ]
--------------------------------------------------------------------------------
/zhipuai/api_resource/knowledge/document/__init__.py:
--------------------------------------------------------------------------------
1 | from .document import Document
2 |
3 |
4 | __all__ = ['Document']
--------------------------------------------------------------------------------
/zhipuai/api_resource/web_search/__init__.py:
--------------------------------------------------------------------------------
1 | from .web_search import WebSearchApi
2 |
3 | __all__ = ['WebSearchApi']
--------------------------------------------------------------------------------
/zhipuai/api_resource/videos/__init__.py:
--------------------------------------------------------------------------------
1 | from .videos import (
2 | Videos,
3 | )
4 | __all__ = [
5 | 'Videos',
6 |
7 | ]
--------------------------------------------------------------------------------
/tests/integration_tests/asr1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MetaGLM/zhipuai-sdk-python-v4/HEAD/tests/integration_tests/asr1.wav
--------------------------------------------------------------------------------
/tests/integration_tests/file.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MetaGLM/zhipuai-sdk-python-v4/HEAD/tests/integration_tests/file.xlsx
--------------------------------------------------------------------------------
/tests/integration_tests/speech.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MetaGLM/zhipuai-sdk-python-v4/HEAD/tests/integration_tests/speech.wav
--------------------------------------------------------------------------------
/zhipuai/api_resource/fine_tuning/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .fine_tuned_models import FineTunedModels
2 |
3 | __all__ = ['FineTunedModels']
--------------------------------------------------------------------------------
/tests/integration_tests/img/MetaGLM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MetaGLM/zhipuai-sdk-python-v4/HEAD/tests/integration_tests/img/MetaGLM.png
--------------------------------------------------------------------------------
/zhipuai/api_resource/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from zhipuai.api_resource.agents.agents import Agents
2 |
3 | __all__= [
4 | "Agents"
5 | ]
6 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/moderation/__init__.py:
--------------------------------------------------------------------------------
1 | from .moderations import (
2 | Moderations
3 | )
4 |
5 | __all__ = [
6 | 'Moderations'
7 | ]
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .message_content import MessageContent
3 |
4 | __all__ = [
5 | "MessageContent"
6 | ]
7 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from .assistant_completion import AssistantCompletion
4 |
5 | __all__ = [
6 | 'AssistantCompletion',
7 | ]
8 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/assistant/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from zhipuai.api_resource.assistant.assistant import Assistant
4 |
5 | __all__= [
6 | "Assistant"
7 | ]
8 |
--------------------------------------------------------------------------------
/zhipuai/types/video/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .video_object import (
3 | VideoObject,
4 | VideoResult
5 | )
6 |
7 | __all__ = ["VideoObject", "VideoResult"]
8 |
--------------------------------------------------------------------------------
/zhipuai/types/web_search/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .web_search_create_params import (
3 | WebSearchCreatParams
4 | )
5 |
6 | __all__ = ["WebSearchCreatParams"]
7 |
--------------------------------------------------------------------------------
/zhipuai/types/sensitive_word_check/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from .sensitive_word_check import SensitiveWordCheckRequest
4 |
5 |
6 | __all__ = [
7 | "SensitiveWordCheckRequest"
8 | ]
9 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/audio/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio import (
2 | Audio
3 | )
4 |
5 | from .transcriptions import (
6 | Transcriptions
7 | )
8 |
9 |
10 | __all__ = [
11 | 'Audio',
12 | 'Transcriptions'
13 | ]
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .knowledge import KnowledgeInfo
3 | from .knowledge_used import KnowledgeStatistics, KnowledgeUsed
4 | __all__ = [
5 | 'KnowledgeInfo',
6 |
7 | "KnowledgeStatistics",
8 | "KnowledgeUsed",
9 | ]
--------------------------------------------------------------------------------
/zhipuai/types/agents/chat_completions_create_param.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from typing_extensions import TypedDict
4 |
5 |
6 | class Reference(TypedDict, total=False):
7 | enable: Optional[bool]
8 | search_query: Optional[str]
9 |
--------------------------------------------------------------------------------
/zhipuai/types/chat/chat_completions_create_param.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from typing_extensions import TypedDict
4 |
5 |
6 | class Reference(TypedDict, total=False):
7 | enable: Optional[bool]
8 | search_query: Optional[str]
9 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/assistant_conversation_params.py:
--------------------------------------------------------------------------------
1 | from typing import TypedDict, List, Optional, Union
2 |
3 |
4 | class ConversationParameters(TypedDict, total=False):
5 | assistant_id: str # 智能体 ID
6 | page: int # 当前分页
7 | page_size: int # 分页数量
8 |
--------------------------------------------------------------------------------
/tests/unit_tests/batchinput.jsonl:
--------------------------------------------------------------------------------
1 | {"custom_id": "request-1", "method": "POST", "url": "/v4/chat/completions", "body": {"model": "glm-4", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
--------------------------------------------------------------------------------
/zhipuai/types/files/file_deleted.py:
--------------------------------------------------------------------------------
1 |
2 | from typing_extensions import Literal
3 |
4 | from ...core import BaseModel
5 | __all__ = ["FileDeleted"]
6 |
7 |
8 | class FileDeleted(BaseModel):
9 | id: str
10 |
11 | deleted: bool
12 |
13 | object: Literal["file"]
14 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/document/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .document import DocumentData, DocumentObject, DocumentSuccessinfo, DocumentFailedInfo
3 |
4 |
5 | __all__ = [
6 | "DocumentData",
7 | "DocumentObject",
8 | "DocumentSuccessinfo",
9 | "DocumentFailedInfo",
10 | ]
--------------------------------------------------------------------------------
/zhipuai/types/fine_tuning/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from .fine_tuning_job import FineTuningJob as FineTuningJob
4 | from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
5 | from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
6 |
--------------------------------------------------------------------------------
/zhipuai/types/files/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_object import FileObject, ListOfFileObject
2 | from .upload_detail import UploadDetail
3 | from .file_deleted import FileDeleted
4 |
5 | __all__ = [
6 | "FileObject",
7 | "ListOfFileObject",
8 | "UploadDetail",
9 | "FileDeleted"
10 | ]
11 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/chat/__init__.py:
--------------------------------------------------------------------------------
1 | from .async_completions import (
2 | AsyncCompletions
3 | )
4 |
5 | from .chat import (
6 | Chat
7 | )
8 |
9 | from .completions import (
10 | Completions
11 | )
12 |
13 | __all__ = [
14 | 'AsyncCompletions'
15 | 'Chat'
16 | 'Completions'
17 | ]
--------------------------------------------------------------------------------
/zhipuai/api_resource/fine_tuning/__init__.py:
--------------------------------------------------------------------------------
1 | from .jobs import (
2 | Jobs
3 | )
4 |
5 | from .models import (
6 | FineTunedModels
7 | )
8 |
9 | from .fine_tuning import (
10 | FineTuning
11 | )
12 |
13 |
14 | __all__ = [
15 | 'Jobs',
16 | 'FineTunedModels',
17 | 'FineTuning'
18 | ]
--------------------------------------------------------------------------------
/zhipuai/types/batch_request_counts.py:
--------------------------------------------------------------------------------
1 | from ..core import BaseModel
2 |
3 | __all__ = ["BatchRequestCounts"]
4 |
5 |
6 | class BatchRequestCounts(BaseModel):
7 | completed: int
8 | """这个数字表示已经完成的请求。"""
9 |
10 | failed: int
11 | """这个数字表示失败的请求。"""
12 |
13 | total: int
14 | """这个数字表示总的请求。"""
15 |
--------------------------------------------------------------------------------
/zhipuai/types/moderation/moderation_completion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union, Dict
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["Completion"]
6 |
7 | class Completion(BaseModel):
8 | model: Optional[str] = None
9 | input: Optional[Union[str, List[str], Dict]] = None # 新增 input 字段
10 |
11 |
12 |
--------------------------------------------------------------------------------
/zhipuai/types/audio/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio_speech_params import(
2 | AudioSpeechParams
3 | )
4 |
5 | from .audio_customization_param import(
6 | AudioCustomizationParam
7 | )
8 | from .transcriptions_create_param import(
9 | TranscriptionsParam
10 | )
11 |
12 | __all__ = ["AudioSpeechParams","AudioCustomizationParam","TranscriptionsParam"]
13 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/knowledge_list_resp.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Dict, Optional, List
4 |
5 | from . import KnowledgeInfo
6 | from ...core import BaseModel
7 |
8 | __all__ = [
9 | "KnowledgePage"
10 | ]
11 |
12 |
13 | class KnowledgePage(BaseModel):
14 | list: List[KnowledgeInfo]
15 | object: str
16 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/document/document_list_resp.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Dict, Optional, List
4 |
5 | from . import DocumentData
6 | from ....core import BaseModel
7 |
8 | __all__ = [
9 | "DocumentPage"
10 | ]
11 |
12 |
13 | class DocumentPage(BaseModel):
14 | list: List[DocumentData]
15 | object: str
16 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/text_content_block.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import Literal
2 |
3 | from ....core import BaseModel
4 |
5 | __all__ = ["TextContentBlock"]
6 |
7 |
8 | class TextContentBlock(BaseModel):
9 | content: str
10 |
11 | role: str = "assistant"
12 |
13 | type: Literal["content"] = "content"
14 | """Always `content`."""
15 |
--------------------------------------------------------------------------------
/zhipuai/types/file_parser/file_parser_resp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from zhipuai.core import BaseModel
4 |
5 | __all__ = [
6 | "FileParserTaskCreateResp"
7 | ]
8 |
9 |
10 | class FileParserTaskCreateResp(BaseModel):
11 | task_id: str
12 | # 任务id
13 | message: str
14 | # message
15 | success: bool
16 | # 是否成功
17 |
18 |
19 |
--------------------------------------------------------------------------------
/zhipuai/types/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .web_search import (
2 | WebSearch,
3 | SearchIntent,
4 | SearchResult,
5 | SearchRecommend,
6 | )
7 |
8 | from .web_search_chunk import (
9 | WebSearchChunk
10 | )
11 |
12 | __all__ = [
13 | 'WebSearch',
14 | 'SearchIntent',
15 | 'SearchResult',
16 | 'SearchRecommend',
17 | 'WebSearchChunk'
18 | ]
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .idea
3 | *.pyc
4 | __pycache__/
5 | .pytest_cache/
6 | .DS_Store/
7 |
8 | *.swp
9 | *~
10 |
11 | build/
12 | dist/
13 | eggs/
14 | .eggs/
15 | *.egg-info/
16 | test/
17 | examples/
18 | .pypirc
19 | /poetry.lock
20 | logs
21 | /tests/integration_tests/batchoutput.jsonl
22 | /tests/integration_tests/content_batchoutput.jsonl
23 | /tests/integration_tests/write_to_file_batchoutput.jsonl
24 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/knowledge_list_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Dict, Optional
4 | from typing_extensions import Literal, Required, TypedDict
5 |
6 | __all__ = ["KnowledgeListParams"]
7 |
8 |
9 | class KnowledgeListParams(TypedDict, total=False):
10 | page: int = 1
11 | """ 页码,默认 1,第一页
12 | """
13 |
14 | size: int = 10
15 | """每页数量 默认10
16 | """
17 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_sdk_import.py:
--------------------------------------------------------------------------------
1 | def test_sdk_import_unit():
2 | import zhipuai
3 |
4 | print(zhipuai.__version__)
5 |
6 |
7 | def test_os_import_unit():
8 | import os
9 |
10 | print(os)
11 |
12 |
13 | def test_sdk_import():
14 | from zhipuai import ZhipuAI
15 |
16 | client = ZhipuAI(api_key='empty') # 请填写您自己的APIKey
17 |
18 | if client is not None:
19 | print('SDK导入成功')
20 | else:
21 | print('SDK导入失败')
22 |
--------------------------------------------------------------------------------
/zhipuai/types/sensitive_word_check/sensitive_word_check.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from typing_extensions import TypedDict
4 |
5 |
6 | class SensitiveWordCheckRequest(TypedDict, total=False):
7 | type: Optional[str]
8 | """敏感词类型,当前仅支持ALL"""
9 | status: Optional[str]
10 | """敏感词启用禁用状态
11 | 启用:ENABLE
12 | 禁用:DISABLE
13 | 备注:默认开启敏感词校验,如果要关闭敏感词校验,需联系商务获取对应权限,否则敏感词禁用不生效。
14 | """
--------------------------------------------------------------------------------
/zhipuai/types/files/upload_detail.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Dict
2 |
3 | from ...core import BaseModel
4 |
5 |
6 | class UploadDetail(BaseModel):
7 | url: str
8 | knowledge_type: int
9 | file_name: Optional[str] = None
10 | sentence_size: Optional[int] = None
11 | custom_separator: Optional[List[str]] = None
12 | callback_url: Optional[str] = None
13 | callback_header: Optional[Dict[str,str]] = None
14 |
--------------------------------------------------------------------------------
/zhipuai/types/fine_tuning/job_create_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union
4 |
5 | from typing_extensions import Literal, TypedDict
6 |
7 | __all__ = ["Hyperparameters"]
8 |
9 |
10 | class Hyperparameters(TypedDict, total=False):
11 | batch_size: Union[Literal["auto"], int]
12 |
13 | learning_rate_multiplier: Union[Literal["auto"], float]
14 |
15 | n_epochs: Union[Literal["auto"], int]
16 |
--------------------------------------------------------------------------------
/.github/workflows/lint-pr.yaml:
--------------------------------------------------------------------------------
1 | name: "Lint PR"
2 |
3 | on:
4 | pull_request_target:
5 | types:
6 | - opened
7 | - edited
8 | - reopened
9 |
10 | jobs:
11 | lint-pr:
12 | name: Validate PR title
13 | runs-on: ubuntu-latest
14 | permissions:
15 | pull-requests: read
16 | steps:
17 | - uses: amannn/action-semantic-pull-request@v5
18 | env:
19 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
--------------------------------------------------------------------------------
/zhipuai/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from ._client import ZhipuAI
3 |
4 | from .core import (
5 | ZhipuAIError,
6 | APIStatusError,
7 | APIRequestFailedError,
8 | APIAuthenticationError,
9 | APIReachLimitError,
10 | APIInternalError,
11 | APIServerFlowExceedError,
12 | APIResponseError,
13 | APIResponseValidationError,
14 | APIConnectionError,
15 | APITimeoutError,
16 | )
17 |
18 | from .__version__ import __version__
19 |
--------------------------------------------------------------------------------
/zhipuai/types/image.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional, List
4 |
5 | from ..core import BaseModel
6 |
7 | __all__ = ["GeneratedImage", "ImagesResponded"]
8 |
9 |
10 | class GeneratedImage(BaseModel):
11 | b64_json: Optional[str] = None
12 | url: Optional[str] = None
13 | revised_prompt: Optional[str] = None
14 |
15 |
16 | class ImagesResponded(BaseModel):
17 | created: int
18 | data: List[GeneratedImage]
19 |
--------------------------------------------------------------------------------
/zhipuai/types/batch_list_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing_extensions import TypedDict
4 |
5 | __all__ = ["BatchListParams"]
6 |
7 |
8 | class BatchListParams(TypedDict, total=False):
9 | after: str
10 | """分页的游标,用于获取下一页的数据。
11 |
12 | `after` 是一个指向当前页面的游标,用于获取下一页的数据。如果没有提供 `after`,则返回第一页的数据。
13 | list.
14 | """
15 |
16 | limit: int
17 | """这个参数用于限制返回的结果数量。
18 |
19 | Limit 用于限制返回的结果数量。默认值为 10
20 | """
21 |
--------------------------------------------------------------------------------
/tests/unit_tests/maybe/test_maybe_transform.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from zhipuai.core import maybe_transform
3 | from zhipuai.types import batch_create_params
4 |
5 |
6 | def test_response_joblist_model_cast() -> None:
7 | params = maybe_transform(
8 | {
9 | 'completion_window': '/v1/chat/completions',
10 | 'endpoint': None,
11 | 'metadata': {'key': 'value'},
12 | },
13 | batch_create_params.BatchCreateParams,
14 | )
15 | assert isinstance(params, dict)
16 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/message_content.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Union
3 | from typing_extensions import Annotated, TypeAlias
4 |
5 | from ....core._utils import PropertyInfo
6 | from .tools_delta_block import ToolsDeltaBlock
7 | from .text_content_block import TextContentBlock
8 |
9 | __all__ = ["MessageContent"]
10 |
11 |
12 | MessageContent: TypeAlias = Annotated[
13 | Union[ToolsDeltaBlock, TextContentBlock],
14 | PropertyInfo(discriminator="type"),
15 | ]
--------------------------------------------------------------------------------
/tests/integration_tests/batchinput.jsonl:
--------------------------------------------------------------------------------
1 | {"custom_id": "request-1", "method": "POST", "url": "/v4/chat/completions", "body": {"model": "glm-4", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
2 | {"custom_id": "request-2", "method": "POST", "url": "/v4/chat/completions", "body": {"model": "glm-4", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
--------------------------------------------------------------------------------
/zhipuai/types/fine_tuning/models/fine_tuned_models.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Optional, ClassVar
2 |
3 | from ....core import BaseModel, PYDANTIC_V2, ConfigDict
4 |
5 | __all__ = ["FineTunedModelsStatus"]
6 |
7 |
8 | class FineTunedModelsStatus(BaseModel):
9 | if PYDANTIC_V2:
10 | model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", protected_namespaces=())
11 | request_id: str #请求id
12 | model_name: str #模型名称
13 | delete_status: str #删除状态 deleting(删除中), deleted (已删除)
14 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools_delta_block.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 | from typing_extensions import Literal
3 |
4 | from .tools.tools_type import ToolsType
5 | from ....core import BaseModel
6 |
7 | __all__ = ["ToolsDeltaBlock"]
8 |
9 |
10 | class ToolsDeltaBlock(BaseModel):
11 | tool_calls: List[ToolsType]
12 | """The index of the content part in the message."""
13 |
14 | role: str = "tool"
15 |
16 | type: Literal["tool_calls"] = "tool_calls"
17 | """Always `tool_calls`."""
18 |
--------------------------------------------------------------------------------
/zhipuai/types/batch_error.py:
--------------------------------------------------------------------------------
1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2 |
3 | from typing import Optional
4 |
5 | from ..core import BaseModel
6 |
7 | __all__ = ["BatchError"]
8 |
9 |
10 | class BatchError(BaseModel):
11 | code: Optional[str] = None
12 | """定义的业务错误码"""
13 |
14 | line: Optional[int] = None
15 | """文件中的行号"""
16 |
17 | message: Optional[str] = None
18 | """关于对话文件中的错误的描述"""
19 |
20 | param: Optional[str] = None
21 | """参数名称,如果有的话"""
22 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/fine_tuning/fine_tuning.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING
2 | from .jobs import Jobs
3 | from .models import FineTunedModels
4 | from ...core import BaseAPI, cached_property
5 |
6 | if TYPE_CHECKING:
7 | from ..._client import ZhipuAI
8 |
9 |
10 | class FineTuning(BaseAPI):
11 |
12 | @cached_property
13 | def jobs(self) -> Jobs:
14 | return Jobs(self._client)
15 |
16 | @cached_property
17 | def models(self) -> FineTunedModels:
18 | return FineTunedModels(self._client)
19 |
20 |
--------------------------------------------------------------------------------
/zhipuai/types/embeddings.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional, List
4 |
5 | from ..core import BaseModel
6 | from .chat.chat_completion import CompletionUsage
7 | __all__ = ["Embedding", "EmbeddingsResponded"]
8 |
9 |
10 | class Embedding(BaseModel):
11 | object: str
12 | index: Optional[int] = None
13 | embedding: List[float]
14 |
15 |
16 | class EmbeddingsResponded(BaseModel):
17 | object: str
18 | data: List[Embedding]
19 | model: str
20 | usage: CompletionUsage
21 |
--------------------------------------------------------------------------------
/zhipuai/core/_base_api.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING
3 |
4 | if TYPE_CHECKING:
5 | from .._client import ZhipuAI
6 |
7 |
8 | class BaseAPI:
9 | _client: ZhipuAI
10 |
11 | def __init__(self, client: ZhipuAI) -> None:
12 | self._client = client
13 | self._delete = client.delete
14 | self._get = client.get
15 | self._post = client.post
16 | self._put = client.put
17 | self._patch = client.patch
18 | self._get_api_list = client.get_api_list
19 |
--------------------------------------------------------------------------------
/zhipuai/core/_constants.py:
--------------------------------------------------------------------------------
1 | import httpx
2 |
3 | RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
4 | # 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0`
5 | ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
6 | # 通过 `retry` 参数控制重试次数,默认为3次
7 | ZHIPUAI_DEFAULT_MAX_RETRIES = 3
8 | # 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10`
9 | ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
10 |
11 | INITIAL_RETRY_DELAY = 0.5
12 | MAX_RETRY_DELAY = 8.0
13 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/chat/chat.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING
2 | from .completions import Completions
3 | from .async_completions import AsyncCompletions
4 | from ...core import BaseAPI, cached_property
5 |
6 | if TYPE_CHECKING:
7 | from ..._client import ZhipuAI
8 |
9 |
10 | class Chat(BaseAPI):
11 |
12 | @cached_property
13 | def completions(self) -> Completions:
14 | return Completions(self._client)
15 |
16 | @cached_property
17 | def asyncCompletions(self) -> AsyncCompletions:
18 | return AsyncCompletions(self._client)
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools/drawing_tool_delta_block.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from typing_extensions import Literal
4 |
5 | from .....core import BaseModel
6 |
7 | __all__ = ["DrawingToolBlock"]
8 |
9 |
10 | class DrawingToolOutput(BaseModel):
11 | image: str
12 |
13 |
14 | class DrawingTool(BaseModel):
15 | input: str
16 | outputs: List[DrawingToolOutput]
17 |
18 |
19 | class DrawingToolBlock(BaseModel):
20 | drawing_tool: DrawingTool
21 |
22 | type: Literal["drawing_tool"]
23 | """Always `drawing_tool`."""
24 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/knowledge_used.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from typing import Optional
4 |
5 | from ...core import BaseModel
6 |
7 | __all__ = [
8 | "KnowledgeStatistics",
9 | "KnowledgeUsed"
10 | ]
11 |
12 |
13 | class KnowledgeStatistics(BaseModel):
14 | """
15 | 使用量统计
16 | """
17 | word_num: Optional[int] = None
18 | length: Optional[int] = None
19 |
20 |
21 | class KnowledgeUsed(BaseModel):
22 |
23 | used: Optional[KnowledgeStatistics] = None
24 | """已使用量"""
25 | total: Optional[KnowledgeStatistics] = None
26 | """知识库总量"""
27 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Thank you for opening a Pull Request!
4 | Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
5 |
6 | - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/MetaGLM/zhipuai-sdk-python-v4/blob/main/CONTRIBUTING.md).
7 | - [ ] Make your Pull Request title in the specification.
8 | - [ ] Ensure the tests pass (Run `mvn clean test` from the repository root)
9 | - [ ] Appropriate docs were updated (if necessary)
10 |
11 | Fixes #
12 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools/function_delta_block.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from typing_extensions import Literal
4 |
5 | __all__ = ["FunctionToolBlock"]
6 |
7 | from .....core import BaseModel
8 |
9 |
10 | class FunctionToolOutput(BaseModel):
11 | content: str
12 |
13 |
14 | class FunctionTool(BaseModel):
15 | name: str
16 | arguments: Union[str,dict]
17 | outputs: List[FunctionToolOutput]
18 |
19 |
20 | class FunctionToolBlock(BaseModel):
21 | function: FunctionTool
22 |
23 | type: Literal["function"]
24 | """Always `drawing_tool`."""
--------------------------------------------------------------------------------
/tests/unit_tests/test_jwt.py:
--------------------------------------------------------------------------------
1 | import jwt
2 | import pytest
3 |
4 | from zhipuai.core._jwt_token import generate_token
5 |
6 |
7 | def test_token() -> None:
8 | # 生成token
9 | token = generate_token('12345678.abcdefg')
10 | assert token is not None
11 |
12 | # 解析token
13 | payload = jwt.decode(
14 | token,
15 | 'abcdefg',
16 | algorithms='HS256',
17 | options={'verify_signature': False},
18 | )
19 | assert payload is not None
20 | assert payload.get('api_key') == '12345678'
21 |
22 | apikey = 'invalid_api_key'
23 | with pytest.raises(Exception):
24 | generate_token(apikey)
25 |
--------------------------------------------------------------------------------
/zhipuai/types/audio/transcriptions_create_param.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import List, Optional
4 |
5 | from typing_extensions import Literal, Required, TypedDict
6 | __all__ = ["TranscriptionsParam"]
7 |
8 | from ..sensitive_word_check import SensitiveWordCheckRequest
9 |
10 | class TranscriptionsParam(TypedDict, total=False):
11 | model: str
12 | """模型编码"""
13 | temperature:float
14 | """采样温度"""
15 | stream: bool
16 | """是否流式输出"""
17 | sensitive_word_check: Optional[SensitiveWordCheckRequest]
18 | request_id: str
19 | """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
20 | user_id: str
21 | """用户端。"""
--------------------------------------------------------------------------------
/zhipuai/types/files/file_object.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["FileObject", "ListOfFileObject"]
6 |
7 |
8 | class FileObject(BaseModel):
9 |
10 | id: Optional[str] = None
11 | bytes: Optional[int] = None
12 | created_at: Optional[int] = None
13 | filename: Optional[str] = None
14 | object: Optional[str] = None
15 | purpose: Optional[str] = None
16 | status: Optional[str] = None
17 | status_details: Optional[str] = None
18 |
19 |
20 | class ListOfFileObject(BaseModel):
21 |
22 | object: Optional[str] = None
23 | data: List[FileObject]
24 | has_more: Optional[bool] = None
25 |
--------------------------------------------------------------------------------
/zhipuai/types/web_search/web_search_create_params.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import Union, Optional, List
5 | from typing_extensions import Literal, Required, TypedDict
6 | from ...types.sensitive_word_check import SensitiveWordCheckRequest
7 |
8 |
9 | __all__ = ["WebSearchCreatParams"]
10 |
11 |
12 | class WebSearchCreatParams(TypedDict):
13 |
14 | search_engine: str
15 | """搜索引擎"""
16 |
17 | search_query: str
18 | """搜索query文本"""
19 |
20 | request_id: str
21 | """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
22 |
23 | user_id: str
24 | """用户端。"""
25 |
26 | sensitive_word_check: Optional[SensitiveWordCheckRequest]
27 |
--------------------------------------------------------------------------------
/zhipuai/types/chat/async_chat_completion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from ...core import BaseModel
4 | from .chat_completion import CompletionChoice, CompletionUsage
5 |
6 | __all__ = ["AsyncTaskStatus", "AsyncCompletion"]
7 |
8 |
9 | class AsyncTaskStatus(BaseModel):
10 | id: Optional[str] = None
11 | request_id: Optional[str] = None
12 | model: Optional[str] = None
13 | task_status: Optional[str] = None
14 |
15 |
16 | class AsyncCompletion(BaseModel):
17 | id: Optional[str] = None
18 | request_id: Optional[str] = None
19 | model: Optional[str] = None
20 | task_status: str
21 | choices: List[CompletionChoice]
22 | usage: CompletionUsage
23 |
--------------------------------------------------------------------------------
/tests/integration_tests/demo.jsonl:
--------------------------------------------------------------------------------
1 | {"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]}
2 | {"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]}
3 | {"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]}
--------------------------------------------------------------------------------
/zhipuai/types/video/video_object.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["VideoObject", "VideoResult"]
6 |
7 |
8 | class VideoResult(BaseModel):
9 | url: str
10 | """视频url"""
11 | cover_image_url: str
12 | """预览图"""
13 |
14 |
15 | class VideoObject(BaseModel):
16 | id: Optional[str] = None
17 | """智谱 AI 开放平台生成的任务订单号,调用请求结果接口时请使用此订单号"""
18 |
19 | model: str
20 | """模型名称"""
21 |
22 | video_result: List[VideoResult]
23 | """视频生成结果"""
24 |
25 | task_status: str
26 | """处理状态,PROCESSING(处理中),SUCCESS(成功),FAIL(失败)
27 | 注:处理中状态需通过查询获取结果"""
28 |
29 | request_id: str
30 | """用户在客户端请求时提交的任务编号或者平台生成的任务编号"""
31 |
--------------------------------------------------------------------------------
/tests/integration_tests/test.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from zhipuai import ZhipuAI
4 |
5 | client = ZhipuAI() # 填写您自己的APIKey
6 |
7 | response = client.videos.generations(
8 | model='cogvideo',
9 | prompt='一个年轻的艺术家在一片彩虹上用调色板作画。',
10 | # prompt="一只卡通狐狸在森林里跳着欢快的爵士舞。"
11 | # prompt="这是一部汽车广告片,描述了一位30岁的汽车赛车手戴着红色头盔的赛车冒险。背景是蔚蓝的天空和苛刻的沙漠环境,电影风格使用35毫米胶片拍摄,色彩鲜艳夺目。"
12 | )
13 | print(response)
14 | task_id = response.id
15 | task_status = response.task_status
16 | get_cnt = 0
17 |
18 | while task_status == 'PROCESSING' and get_cnt <= 40:
19 | result_response = client.videos.retrieve_videos_result(id=task_id)
20 | print(result_response)
21 | task_status = result_response.task_status
22 |
23 | time.sleep(2)
24 | get_cnt += 1
25 |
--------------------------------------------------------------------------------
/zhipuai/types/file_parser/file_parser_create_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing_extensions import Literal, Required, TypedDict
4 | from ...core import NOT_GIVEN, Body, Headers, NotGiven, FileTypes
5 |
6 |
7 | __all__ = ["FileParserCreateParams", "FileParserDownloadParams"]
8 |
9 |
10 | class FileParserCreateParams(TypedDict):
11 | file: FileTypes
12 | """上传的文件"""
13 | file_type: str
14 | """文件类型"""
15 | tool_type: Literal["simple", "doc2x", "tencent", "zhipu-pro"]
16 | """工具类型"""
17 |
18 |
19 | class FileParserDownloadParams(TypedDict):
20 | task_id: str
21 | """解析任务id"""
22 | format_type: Literal["text", "download_link"]
23 | """结果返回类型"""
24 |
25 |
--------------------------------------------------------------------------------
/zhipuai/types/audio/audio_speech_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import List, Optional
4 |
5 | from typing_extensions import Literal, Required, TypedDict
6 |
7 | __all__ = ["AudioSpeechParams"]
8 |
9 | from ..sensitive_word_check import SensitiveWordCheckRequest
10 |
11 |
12 | class AudioSpeechParams(TypedDict, total=False):
13 | model: str
14 | """模型编码"""
15 | input: str
16 | """需要生成语音的文本"""
17 | voice: str
18 | """需要生成语音的音色"""
19 | response_format: str
20 | """需要生成语音文件的格式"""
21 | sensitive_word_check: Optional[SensitiveWordCheckRequest]
22 | request_id: str
23 | """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
24 | user_id: str
25 | """用户端。"""
26 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools/tools_type.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Union
3 | from typing_extensions import Annotated, TypeAlias
4 |
5 | from .code_interpreter_delta_block import CodeInterpreterToolBlock
6 | from .retrieval_delta_black import RetrievalToolBlock
7 | from .web_browser_delta_block import WebBrowserToolBlock
8 | from .....core._utils import PropertyInfo
9 | from .drawing_tool_delta_block import DrawingToolBlock
10 | from .function_delta_block import FunctionToolBlock
11 |
12 | __all__ = ["ToolsType"]
13 |
14 |
15 | ToolsType: TypeAlias = Annotated[
16 | Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock],
17 | PropertyInfo(discriminator="type"),
18 | ]
--------------------------------------------------------------------------------
/zhipuai/types/audio/audio_customization_param.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import List, Optional
4 |
5 | from typing_extensions import Literal, Required, TypedDict
6 | __all__ = ["AudioCustomizationParam"]
7 |
8 | from ..sensitive_word_check import SensitiveWordCheckRequest
9 |
10 | class AudioCustomizationParam(TypedDict, total=False):
11 | model: str
12 | """模型编码"""
13 | input: str
14 | """需要生成语音的文本"""
15 | voice_text: str
16 | """需要生成语音的音色"""
17 | response_format: str
18 | """需要生成语音文件的格式"""
19 | sensitive_word_check: Optional[SensitiveWordCheckRequest]
20 | request_id: str
21 | """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
22 | user_id: str
23 | """用户端。"""
24 |
25 |
26 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/assistant_support_resp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from .message import MessageContent
4 | from ...core import BaseModel
5 |
6 | __all__ = [
7 | "AssistantSupportResp"
8 | ]
9 |
10 |
11 | class AssistantSupport(BaseModel):
12 | assistant_id: str # 智能体的 Assistant id,用于智能体会话
13 | created_at: int # 创建时间
14 | updated_at: int # 更新时间
15 | name: str # 智能体名称
16 | avatar: str # 智能体头像
17 | description: str # 智能体描述
18 | status: str # 智能体状态,目前只有 publish
19 | tools: List[str] # 智能体支持的工具名
20 | starter_prompts: List[str] # 智能体启动推荐的 prompt
21 |
22 |
23 | class AssistantSupportResp(BaseModel):
24 | code: int
25 | msg: str
26 | data: List[AssistantSupport] # 智能体列表
27 |
--------------------------------------------------------------------------------
/zhipuai/types/video/video_create_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import List, Optional
4 |
5 | from typing_extensions import Literal, Required, TypedDict
6 |
7 | __all__ = ["VideoCreateParams"]
8 |
9 | from ..sensitive_word_check import SensitiveWordCheckRequest
10 |
11 |
12 | class VideoCreateParams(TypedDict, total=False):
13 | model: str
14 | """模型编码"""
15 | prompt: str
16 | """所需视频的文本描述"""
17 | image_url: str
18 | """所需视频的文本描述"""
19 | sensitive_word_check: Optional[SensitiveWordCheckRequest]
20 | """支持 URL 或者 Base64、传入 image 奖进行图生视频
21 | * 图片格式:
22 | * 图片大小:"""
23 | request_id: str
24 | """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
25 |
26 | user_id: str
27 | """用户端。"""
28 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools/code_interpreter_delta_block.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from typing_extensions import Literal
4 |
5 | __all__ = ["CodeInterpreterToolBlock"]
6 |
7 | from .....core import BaseModel
8 |
9 |
10 | class CodeInterpreterToolOutput(BaseModel):
11 | """代码工具输出结果"""
12 | type: str # 代码执行日志,目前只有 logs
13 | logs: str # 代码执行的日志结果
14 | error_msg: str # 错误信息
15 |
16 |
17 | class CodeInterpreter(BaseModel):
18 | """代码解释器"""
19 | input: str # 生成的代码片段,输入给代码沙盒
20 | outputs: List[CodeInterpreterToolOutput] # 代码执行后的输出结果
21 |
22 |
23 | class CodeInterpreterToolBlock(BaseModel):
24 | """代码工具块"""
25 | code_interpreter: CodeInterpreter # 代码解释器对象
26 | type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter`
27 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_moderation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | import time
4 |
5 | import zhipuai
6 | from zhipuai import ZhipuAI
7 |
8 |
9 | def test_completions_temp0(logging_conf):
10 | logging.config.dictConfig(logging_conf) # type: ignore
11 | client = ZhipuAI(disable_token_cache=False) # 填写您自己的APIKey
12 | try:
13 | # 生成request_id
14 | request_id = time.time()
15 | print(f'request_id:{request_id}')
16 | response = client.moderations.create(
17 | model='moderations', input={'type': 'text', 'text': 'hello world '}
18 | )
19 | print(response)
20 |
21 | except zhipuai.core._errors.APIRequestFailedError as err:
22 | print(err)
23 | except zhipuai.core._errors.APIInternalError as err:
24 | print(err)
25 | except zhipuai.core._errors.APIStatusError as err:
26 | print(err)
27 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/knowledge.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional
3 |
4 | from ...core import BaseModel
5 |
6 | __all__ = ["KnowledgeInfo"]
7 |
8 |
9 | class KnowledgeInfo(BaseModel):
10 | id: Optional[str] = None
11 | """知识库唯一 id"""
12 | embedding_id: Optional[str] = None # 知识库绑定的向量化模型 见模型列表 [内部服务开放接口文档](https://lslfd0slxc.feishu.cn/docx/YauWdbBiMopV0FxB7KncPWCEn8f#H15NduiQZo3ugmxnWQFcfAHpnQ4)
13 | name: Optional[str] = None # 知识库名称 100字限制
14 | customer_identifier: Optional[str] = None # 用户标识 长度32位以内
15 | description: Optional[str] = None # 知识库描述 500字限制
16 | background: Optional[str] = None # 背景颜色(给枚举)'blue', 'red', 'orange', 'purple', 'sky'
17 | icon: Optional[str] = None # 知识库图标(给枚举) question: 问号、book: 书籍、seal: 印章、wrench: 扳手、tag: 标签、horn: 喇叭、house: 房子
18 | bucket_id: Optional[str] = None # 桶id 限制32位
--------------------------------------------------------------------------------
/tests/integration_tests/test_web_search.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_web_search(logging_conf):
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 | client = ZhipuAI() # 填写您自己的APIKey
11 | try:
12 | response = client.web_search.web_search(
13 | search_engine='search-std',
14 | search_query='2025特朗普向中国加征了多少关税',
15 | count=50,
16 | search_domain_filter='finance.sina.com.cn',
17 | search_recency_filter='oneYear',
18 | content_size='high',
19 | search_intent=True,
20 | )
21 | print(response)
22 |
23 | except zhipuai.core._errors.APIRequestFailedError as err:
24 | print(err)
25 | except zhipuai.core._errors.APIInternalError as err:
26 | print(err)
27 | except zhipuai.core._errors.APIStatusError as err:
28 | print(err)
29 |
--------------------------------------------------------------------------------
/zhipuai/types/audio/audio_speech_chunk.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = [
6 | "AudioSpeechChunk",
7 | "AudioError",
8 | "AudioSpeechChoice",
9 | "AudioSpeechDelta"
10 | ]
11 |
12 |
13 | class AudioSpeechDelta(BaseModel):
14 | content: Optional[str] = None
15 | role: Optional[str] = None
16 |
17 |
18 | class AudioSpeechChoice(BaseModel):
19 | delta: AudioSpeechDelta
20 | finish_reason: Optional[str] = None
21 | index: int
22 |
23 | class AudioError(BaseModel):
24 | code: Optional[str] = None
25 | message: Optional[str] = None
26 |
27 |
28 | class AudioSpeechChunk(BaseModel):
29 | choices: List[AudioSpeechChoice]
30 | request_id: Optional[str] = None
31 | created: Optional[int] = None
32 | error: Optional[AudioError] = None
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/document/document_list_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Dict, Optional
4 | from typing_extensions import Literal, Required, TypedDict
5 |
6 |
7 | class DocumentListParams(TypedDict, total=False):
8 | """
9 | 文件查询参数类型定义
10 |
11 | Attributes:
12 | purpose (Optional[str]): 文件用途
13 | knowledge_id (Optional[str]): 当文件用途为 retrieval 时,需要提供查询的知识库ID
14 | page (Optional[int]): 页,默认1
15 | limit (Optional[int]): 查询文件列表数,默认10
16 | after (Optional[str]): 查询指定fileID之后的文件列表(当文件用途为 fine-tune 时需要)
17 | order (Optional[str]): 排序规则,可选值['desc', 'asc'],默认desc(当文件用途为 fine-tune 时需要)
18 | """
19 | purpose: Optional[str]
20 | knowledge_id: Optional[str]
21 | page: Optional[int]
22 | limit: Optional[int]
23 | after: Optional[str]
24 | order: Optional[str]
--------------------------------------------------------------------------------
/zhipuai/core/_jwt_token.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import time
3 |
4 | import cachetools.func
5 | import jwt
6 |
7 | # 缓存时间 3分钟
8 | CACHE_TTL_SECONDS = 3 * 60
9 |
10 | # token 有效期比缓存时间 多30秒
11 | API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30
12 |
13 |
14 | @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
15 | def generate_token(apikey: str):
16 | try:
17 | api_key, secret = apikey.split(".")
18 | except Exception as e:
19 | raise Exception("invalid api_key", e)
20 |
21 | payload = {
22 | "api_key": api_key,
23 | "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
24 | "timestamp": int(round(time.time() * 1000)),
25 | }
26 | ret = jwt.encode(
27 | payload,
28 | secret,
29 | algorithm="HS256",
30 | headers={"alg": "HS256", "sign_type": "SIGN"},
31 | )
32 | return ret
33 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/assistant_conversation_resp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = [
6 | "ConversationUsageListResp"
7 | ]
8 |
9 | class Usage(BaseModel):
10 | prompt_tokens: int # 用户输入的 tokens 数量
11 | completion_tokens: int # 模型输入的 tokens 数量
12 | total_tokens: int # 总 tokens 数量
13 |
14 |
15 | class ConversationUsage(BaseModel):
16 | id: str # 会话 id
17 | assistant_id: str # 智能体Assistant id
18 | create_time: int # 创建时间
19 | update_time: int # 更新时间
20 | usage: Usage # 会话中 tokens 数量统计
21 |
22 |
23 | class ConversationUsageList(BaseModel):
24 | assistant_id: str # 智能体id
25 | has_more: bool # 是否还有更多页
26 | conversation_list: List[ConversationUsage] # 返回的
27 |
28 |
29 | class ConversationUsageListResp(BaseModel):
30 | code: int
31 | msg: str
32 | data: ConversationUsageList
33 |
--------------------------------------------------------------------------------
/zhipuai/types/web_search/web_search_resp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = [
6 | "SearchIntentResp",
7 | "SearchResultResp",
8 | ]
9 |
10 |
11 | class SearchIntentResp(BaseModel):
12 | query: str
13 | # 搜索优化 query
14 | intent: str
15 | # 判断的意图类型
16 | keywords: str
17 | # 搜索关键词
18 |
19 |
20 | class SearchResultResp(BaseModel):
21 | title: str
22 | # 标题
23 | link: str
24 | # 链接
25 | content: str
26 | # 内容
27 | icon: str
28 | # 图标
29 | media: str
30 | # 来源媒体
31 | refer: str
32 | # 角标序号 [ref_1]
33 | publish_date: str
34 | # 发布时间
35 |
36 | class WebSearchResp(BaseModel):
37 | created: Optional[int] = None
38 | request_id: Optional[str] = None
39 | id: Optional[str] = None
40 | search_intent: Optional[SearchIntentResp]
41 | search_result: Optional[SearchResultResp]
42 |
43 |
--------------------------------------------------------------------------------
/zhipuai/types/tools/web_search_chunk.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from .web_search import SearchIntent, SearchResult, SearchRecommend
4 | from ...core import BaseModel
5 |
6 | __all__ = [
7 | "WebSearchChunk"
8 | ]
9 |
10 |
11 | class ChoiceDeltaToolCall(BaseModel):
12 | index: int
13 | id: Optional[str] = None
14 |
15 | search_intent: Optional[SearchIntent] = None
16 | search_result: Optional[SearchResult] = None
17 | search_recommend: Optional[SearchRecommend] = None
18 | type: Optional[str] = None
19 |
20 |
21 | class ChoiceDelta(BaseModel):
22 | role: Optional[str] = None
23 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
24 |
25 |
26 | class Choice(BaseModel):
27 | delta: ChoiceDelta
28 | finish_reason: Optional[str] = None
29 | index: int
30 |
31 |
32 | class WebSearchChunk(BaseModel):
33 | id: Optional[str] = None
34 | choices: List[Choice]
35 | created: Optional[int] = None
36 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/moderation/moderations.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, List, TYPE_CHECKING, Dict
4 |
5 | import logging
6 | from ...core import BaseAPI, deepcopy_minimal
7 | from ...types.moderation.moderation_completion import Completion
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | if TYPE_CHECKING:
13 | from ..._client import ZhipuAI
14 |
15 | __all__ = ["Moderations"]
16 | class Moderations(BaseAPI):
17 | def __init__(self, client: "ZhipuAI") -> None:
18 | super().__init__(client)
19 |
20 | def create(
21 | self,
22 | *,
23 | model: str,
24 | input: Union[str, List[str], Dict],
25 | ) -> Completion:
26 |
27 | body = deepcopy_minimal({
28 | "model": model,
29 | "input": input
30 | })
31 | return self._post(
32 | "/moderations",
33 | body=body,
34 | cast_type=Completion
35 | )
36 |
37 |
38 |
--------------------------------------------------------------------------------
/zhipuai/types/agents/agents_completion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["AgentsCompletion", "AgentsCompletionUsage"]
6 |
7 | class AgentsCompletionMessage(BaseModel):
8 | content: Optional[object] = None
9 | role: str
10 |
11 | class AgentsCompletionUsage(BaseModel):
12 | prompt_tokens: int
13 | completion_tokens: int
14 | total_tokens: int
15 |
16 |
17 | class AgentsCompletionChoice(BaseModel):
18 | index: int
19 | finish_reason: str
20 | message: AgentsCompletionMessage
21 |
22 | class AgentsError(BaseModel):
23 | code: Optional[str] = None
24 | message: Optional[str] = None
25 |
26 |
27 | class AgentsCompletion(BaseModel):
28 | agent_id: Optional[str] = None
29 | conversation_id: Optional[str] = None
30 | status: Optional[str] = None
31 | choices: List[AgentsCompletionChoice]
32 | request_id: Optional[str] = None
33 | id: Optional[str] = None
34 | usage: Optional[AgentsCompletionUsage] = None
35 | error: Optional[AgentsError] = None
36 |
37 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_audio.py:
--------------------------------------------------------------------------------
1 | def test_audio_error_field():
2 | from zhipuai.types.audio.audio_speech_chunk import AudioSpeechChunk, AudioError, AudioSpeechChoice, AudioSpeechDelta
3 |
4 | # 构造一个 AudioError
5 | error = AudioError(code="500", message="Internal Error")
6 |
7 | # 构造一个完整的 AudioSpeechChunk
8 | chunk = AudioSpeechChunk(
9 | choices=[
10 | AudioSpeechChoice(
11 | delta=AudioSpeechDelta(content="audio", role="system"),
12 | finish_reason="error",
13 | index=0
14 | )
15 | ],
16 | request_id="req_2",
17 | created=123456,
18 | error=error
19 | )
20 |
21 | # 检查 error 字段是否为 AudioError 实例
22 | assert isinstance(chunk.error, AudioError)
23 | assert chunk.error.code == "500"
24 | assert chunk.error.message == "Internal Error"
25 |
26 | # 检查序列化
27 | as_dict = chunk.model_dump()
28 | assert as_dict["error"]["code"] == "500"
29 | assert as_dict["error"]["message"] == "Internal Error"
30 | print("test_audio_error_field passed.")
31 |
--------------------------------------------------------------------------------
/zhipuai/types/agents/agents_completion_chunk.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = [
6 | "AgentsCompletionUsage",
7 | "AgentsCompletionChunk",
8 | "AgentsChoice",
9 | "AgentsChoiceDelta"
10 | ]
11 |
12 |
13 | class AgentsChoiceDelta(BaseModel):
14 | content: Optional[object] = None
15 | role: Optional[str] = None
16 |
17 |
18 | class AgentsChoice(BaseModel):
19 | delta: AgentsChoiceDelta
20 | finish_reason: Optional[str] = None
21 | index: int
22 |
23 |
24 | class AgentsCompletionUsage(BaseModel):
25 | prompt_tokens: int
26 | completion_tokens: int
27 | total_tokens: int
28 |
29 | class AgentsError(BaseModel):
30 | code: Optional[str] = None
31 | message: Optional[str] = None
32 |
33 |
34 | class AgentsCompletionChunk(BaseModel):
35 | agent_id: Optional[str] = None
36 | conversation_id: Optional[str] = None
37 | id: Optional[str] = None
38 | choices: List[AgentsChoice]
39 | usage: Optional[AgentsCompletionUsage] = None
40 | error: Optional[AgentsError] = None
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/document/document_edit_params.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Literal, TypedDict, Dict
2 |
3 |
4 | __all__ = ["DocumentEditParams"]
5 |
6 |
7 | class DocumentEditParams(TypedDict):
8 | """
9 | 知识参数类型定义
10 |
11 | Attributes:
12 | id (str): 知识ID
13 | knowledge_type (int): 知识类型:
14 | 1:文章知识: 支持pdf,url,docx
15 | 2.问答知识-文档: 支持pdf,url,docx
16 | 3.问答知识-表格: 支持xlsx
17 | 4.商品库-表格: 支持xlsx
18 | 5.自定义: 支持pdf,url,docx
19 | custom_separator (Optional[List[str]]): 当前知识类型为自定义(knowledge_type=5)时的切片规则,默认\n
20 | sentence_size (Optional[int]): 当前知识类型为自定义(knowledge_type=5)时的切片字数,取值范围: 20-2000,默认300
21 | callback_url (Optional[str]): 回调地址
22 | callback_header (Optional[dict]): 回调时携带的header
23 | """
24 | id: str
25 | knowledge_type: int
26 | custom_separator: Optional[List[str]]
27 | sentence_size: Optional[int]
28 | callback_url: Optional[str]
29 | callback_header: Optional[Dict[str, str]]
--------------------------------------------------------------------------------
/zhipuai/types/chat/chat_completion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["Completion", "CompletionUsage"]
6 |
7 |
8 | class Function(BaseModel):
9 | arguments: str
10 | name: str
11 |
12 |
13 | class CompletionMessageToolCall(BaseModel):
14 | id: str
15 | function: Function
16 | type: str
17 |
18 |
19 | class CompletionMessage(BaseModel):
20 | content: Optional[str] = None
21 | role: str
22 | reasoning_content: Optional[str] = None
23 | tool_calls: Optional[List[CompletionMessageToolCall]] = None
24 |
25 |
26 | class CompletionUsage(BaseModel):
27 | prompt_tokens: int
28 | completion_tokens: int
29 | total_tokens: int
30 |
31 |
32 | class CompletionChoice(BaseModel):
33 | index: int
34 | finish_reason: str
35 | message: CompletionMessage
36 |
37 |
38 | class Completion(BaseModel):
39 | model: Optional[str] = None
40 | created: Optional[int] = None
41 | choices: List[CompletionChoice]
42 | request_id: Optional[str] = None
43 | id: Optional[str] = None
44 | usage: CompletionUsage
45 |
46 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/knowledge_create_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Dict, Optional
4 | from typing_extensions import Literal, Required, TypedDict
5 |
6 | __all__ = ["KnowledgeBaseParams"]
7 |
8 |
9 | class KnowledgeBaseParams(TypedDict):
10 | """
11 | 知识库参数类型定义
12 |
13 | Attributes:
14 | embedding_id (int): 知识库绑定的向量化模型ID
15 | name (str): 知识库名称,限制100字
16 | customer_identifier (Optional[str]): 用户标识,长度32位以内
17 | description (Optional[str]): 知识库描述,限制500字
18 | background (Optional[Literal['blue', 'red', 'orange', 'purple', 'sky']]): 背景颜色
19 | icon (Optional[Literal['question', 'book', 'seal', 'wrench', 'tag', 'horn', 'house']]): 知识库图标
20 | bucket_id (Optional[str]): 桶ID,限制32位
21 | """
22 | embedding_id: int
23 | name: str
24 | customer_identifier: Optional[str]
25 | description: Optional[str]
26 | background: Optional[Literal['blue', 'red', 'orange', 'purple', 'sky']] = None
27 | icon: Optional[Literal['question', 'book', 'seal', 'wrench', 'tag', 'horn', 'house']] = None
28 | bucket_id: Optional[str]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Zhipu, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/assistant_create_params.py:
--------------------------------------------------------------------------------
1 | from typing import TypedDict, List, Optional, Union
2 |
3 |
4 | class AssistantAttachments:
5 | file_id: str
6 |
7 |
8 | class MessageTextContent:
9 | type: str # 目前支持 type = text
10 | text: str
11 |
12 |
13 | MessageContent = Union[MessageTextContent]
14 |
15 |
16 | class ConversationMessage(TypedDict):
17 | """会话消息体"""
18 | role: str # 用户的输入角色,例如 'user'
19 | content: List[MessageContent] # 会话消息体的内容
20 |
21 |
22 | class AssistantParameters(TypedDict, total=False):
23 | """智能体参数类"""
24 | assistant_id: str # 智能体 ID
25 | conversation_id: Optional[str] # 会话 ID,不传则创建新会话
26 | model: str # 模型名称,默认为 'GLM-4-Assistant'
27 | stream: bool # 是否支持流式 SSE,需要传入 True
28 | messages: List[ConversationMessage] # 会话消息体
29 | attachments: Optional[List[AssistantAttachments]] # 会话指定的文件,非必填
30 | metadata: Optional[dict] # 元信息,拓展字段,非必填
31 |
32 | class TranslateParameters(TypedDict, total=False):
33 | from_language: str
34 | to_language: str
35 |
36 | class ExtraParameters(TypedDict, total=False):
37 | translate: TranslateParameters
38 |
39 |
--------------------------------------------------------------------------------
/zhipuai/types/files/file_create_params.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import List, Optional
5 |
6 | from typing_extensions import Literal, Required, TypedDict
7 |
8 |
9 | __all__ = ["FileCreateParams"]
10 |
11 | from . import UploadDetail
12 | from ...core import FileTypes
13 |
14 |
15 | class FileCreateParams(TypedDict, total=False):
16 | file: FileTypes
17 | """file和 upload_detail二选一必填"""
18 |
19 | upload_detail: List[UploadDetail]
20 | """file和 upload_detail二选一必填"""
21 |
22 | purpose: Required[Literal["fine-tune", "retrieval", "batch"]]
23 | """
24 | 上传文件的用途,支持 "fine-tune和 "retrieval"
25 | retrieval支持上传Doc、Docx、PDF、Xlsx、URL类型文件,且单个文件的大小不超过 5MB。
26 | fine-tune支持上传.jsonl文件且当前单个文件的大小最大可为 100 MB ,文件中语料格式需满足微调指南中所描述的格式。
27 | """
28 | custom_separator: Optional[List[str]]
29 | """
30 | 当 purpose 为 retrieval 且文件类型为 pdf, url, docx 时上传,切片规则默认为 `\n`。
31 | """
32 | knowledge_id: str
33 | """
34 | 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。
35 | """
36 |
37 | sentence_size: int
38 | """
39 | 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。
40 | """
41 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_emohaa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import logging
4 | import logging.config
5 |
6 | import zhipuai
7 | from zhipuai import ZhipuAI
8 |
9 |
10 | def test_completions_emohaa(logging_conf):
11 | logging.config.dictConfig(logging_conf) # type: ignore
12 | client = ZhipuAI() # 请填写您自己的APIKey
13 | try:
14 | response = client.chat.completions.create(
15 | model='emohaa', # 填写需要调用的模型名称
16 | messages=[
17 | {
18 | 'role': 'assistant',
19 | 'content': '你好,我是Emohaa,很高兴见到你。请问有什么我可以帮忙的吗?',
20 | },
21 | {
22 | 'role': 'user',
23 | 'content': '今天我休息,决定去西安保密逛逛,心情很好地喷上了我最爱的烟熏木制香',
24 | },
25 | {
26 | 'role': 'assistant',
27 | 'content': '今天我休息,决定去西安保密逛逛,心情很好地喷上了我最爱的烟熏木制香',
28 | },
29 | ],
30 | meta={
31 | 'user_info': '30岁的男性软件工程师,兴趣包括阅读、徒步和编程',
32 | 'bot_info': 'Emohaa是一款基于Hill助人理论的情感支持AI,拥有专业的心理咨询话术能力',
33 | 'bot_name': 'Emohaa',
34 | 'user_name': '陆星辰',
35 | },
36 | )
37 | print(response)
38 |
39 | except zhipuai.core._errors.APIRequestFailedError as err:
40 | print(err)
41 | except zhipuai.core._errors.APIInternalError as err:
42 | print(err)
43 | except zhipuai.core._errors.APIStatusError as err:
44 | print(err)
45 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yml:
--------------------------------------------------------------------------------
1 | name: 🐞 Bug Report
2 | description: File a bug report
3 | title: "[Bug]: "
4 | type: "Bug"
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thanks for stopping by to let us know something could be better!
10 | - type: textarea
11 | id: what-happened
12 | attributes:
13 | label: What happened?
14 | description: Also tell us what you expected to happen and how to reproduce the issue.
15 | placeholder: Tell us what you see!
16 | value: "A bug happened!"
17 | validations:
18 | required: true
19 | - type: textarea
20 | id: logs
21 | attributes:
22 | label: Relevant log output
23 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
24 | render: shell
25 | - type: checkboxes
26 | id: terms
27 | attributes:
28 | label: Code of Conduct
29 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/MetaGLM/zhipuai-sdk-python-v4/blob/main/CODE_OF_CONDUCT.md)
30 | options:
31 | - label: I agree to follow this project's Code of Conduct
32 | required: true
33 |
--------------------------------------------------------------------------------
/zhipuai/types/fine_tuning/fine_tuning_job_event.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Optional
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
6 |
7 |
8 | class Metric(BaseModel):
9 | epoch: Optional[Union[str, int, float]] = None
10 | current_steps: Optional[int] = None
11 | total_steps: Optional[int] = None
12 | elapsed_time: Optional[str] = None
13 | remaining_time: Optional[str] = None
14 | trained_tokens: Optional[int] = None
15 | loss: Optional[Union[str, int, float]] = None
16 | eval_loss: Optional[Union[str, int, float]] = None
17 | acc: Optional[Union[str, int, float]] = None
18 | eval_acc: Optional[Union[str, int, float]] = None
19 | learning_rate: Optional[Union[str, int, float]] = None
20 |
21 |
22 | class JobEvent(BaseModel):
23 | object: Optional[str] = None
24 | id: Optional[str] = None
25 | type: Optional[str] = None
26 | created_at: Optional[int] = None
27 | level: Optional[str] = None
28 | message: Optional[str] = None
29 | data: Optional[Metric] = None
30 |
31 |
32 | class FineTuningJobEvent(BaseModel):
33 | object: Optional[str] = None
34 | data: List[JobEvent]
35 | has_more: Optional[bool] = None
36 |
--------------------------------------------------------------------------------
/zhipuai/types/fine_tuning/fine_tuning_job.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Optional, Dict, Any
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
6 |
7 |
8 | class Error(BaseModel):
9 | code: str
10 | message: str
11 | param: Optional[str] = None
12 |
13 |
14 | class Hyperparameters(BaseModel):
15 | n_epochs: Union[str, int, None] = None
16 |
17 |
18 | class FineTuningJob(BaseModel):
19 | id: Optional[str] = None
20 |
21 | request_id: Optional[str] = None
22 |
23 | created_at: Optional[int] = None
24 |
25 | error: Optional[Error] = None
26 |
27 | fine_tuned_model: Optional[str] = None
28 |
29 | finished_at: Optional[int] = None
30 |
31 | hyperparameters: Optional[Hyperparameters] = None
32 |
33 | model: Optional[str] = None
34 |
35 | object: Optional[str] = None
36 |
37 | result_files: List[str]
38 |
39 | status: str
40 |
41 | trained_tokens: Optional[int] = None
42 |
43 | training_file: str
44 |
45 | validation_file: Optional[str] = None
46 |
47 |
48 | class ListOfFineTuningJob(BaseModel):
49 | object: Optional[str] = None
50 | data: List[FineTuningJob]
51 | has_more: Optional[bool] = None
52 |
--------------------------------------------------------------------------------
/zhipuai/types/tools/tools_web_search_params.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import Union, Optional, List
5 | from typing_extensions import Literal, Required, TypedDict
6 |
7 | __all__ = ["WebSearchParams"]
8 |
9 |
10 | class WebSearchParams(TypedDict):
11 | """
12 | 工具名:web-search-pro参数类型定义
13 |
14 | Attributes:
15 | :param model: str, 模型名称
16 | :param request_id: Optional[str], 请求ID
17 | :param stream: Optional[bool], 是否流式
18 | :param messages: Union[str, List[str], List[int], object, None],
19 | 包含历史对话上下文的内容,按照 {"role": "user", "content": "你好"} 的json 数组形式进行传参
20 | 当前版本仅支持 User Message 单轮对话,工具会理解User Message并进行搜索,
21 | 请尽可能传入不带指令格式的用户原始提问,以提高搜索准确率。
22 | :param scope: Optional[str], 指定搜索范围,全网、学术等,默认全网
23 | :param location: Optional[str], 指定搜索用户地区 location 提高相关性
24 | :param recent_days: Optional[int],支持指定返回 N 天(1-30)更新的搜索结果
25 |
26 |
27 | """
28 | model: str
29 | request_id: Optional[str]
30 | stream: Optional[bool]
31 | messages: Union[str, List[str], List[int], object, None]
32 | scope: Optional[str] = None
33 | location: Optional[str] = None
34 | recent_days: Optional[int] = None
--------------------------------------------------------------------------------
/zhipuai/types/batch_create_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Dict, Optional
4 | from typing_extensions import Literal, Required, TypedDict
5 |
6 | __all__ = ["BatchCreateParams"]
7 |
8 |
9 | class BatchCreateParams(TypedDict, total=False):
10 | completion_window: Required[str]
11 | """The time frame within which the batch should be processed.
12 |
13 | Currently only `24h` is supported.
14 | """
15 |
16 | endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]]
17 | """The endpoint to be used for all requests in the batch.
18 |
19 | Currently `/v1/chat/completions` and `/v1/embeddings` are supported.
20 | """
21 |
22 | input_file_id: Required[str]
23 | """The ID of an uploaded file that contains requests for the new batch.
24 |
25 | See [upload file](https://platform.openai.com/docs/api-reference/files/create)
26 | for how to upload a file.
27 |
28 | Your input file must be formatted as a
29 | [JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput),
30 | and must be uploaded with the purpose `batch`.
31 | """
32 |
33 | metadata: Optional[Dict[str, str]]
34 | """Optional custom metadata for the batch."""
35 |
36 | auto_delete_input_file: Optional[bool]
37 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_file_parser.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import logging.config
5 | import os
6 |
7 | import pytest
8 |
9 | import zhipuai
10 | from zhipuai import ZhipuAI
11 |
12 |
13 | def test_file_parser_create(logging_conf):
14 | logging.config.dictConfig(logging_conf) # type: ignore
15 | client = ZhipuAI() # 填写您自己的APIKey
16 | try:
17 | response = client.file_parser.create(file=open('hitsuyoushorui-cn.pdf', 'rb'), file_type='pdf', tool_type='zhipu_pro')
18 | print(response)
19 |
20 | except zhipuai.core._errors.APIRequestFailedError as err:
21 | print(err)
22 | except zhipuai.core._errors.APIInternalError as err:
23 | print(err)
24 | except zhipuai.core._errors.APIStatusError as err:
25 | print(err)
26 |
27 | def test_file_parser_content(logging_conf):
28 | logging.config.dictConfig(logging_conf) # type: ignore
29 | client = ZhipuAI() # 填写您自己的APIKey
30 | try:
31 | response = client.file_parser.content(task_id="66e8f7ab884448c8b4190f251f6c2982-1", format_type="text")
32 | print(response.content.decode('utf-8'))
33 |
34 | except zhipuai.core._errors.APIRequestFailedError as err:
35 | print(err)
36 | except zhipuai.core._errors.APIInternalError as err:
37 | print(err)
38 | except zhipuai.core._errors.APIStatusError as err:
39 | print(err)
40 |
41 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/assistant_completion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from .message import MessageContent
4 | from ...core import BaseModel
5 |
6 | __all__ = ["AssistantCompletion", "CompletionUsage"]
7 |
8 |
9 | class ErrorInfo(BaseModel):
10 | code: str # 错误码
11 | message: str # 错误信息
12 |
13 |
14 | class AssistantChoice(BaseModel):
15 | index: int # 结果下标
16 | delta: MessageContent # 当前会话输出消息体
17 | finish_reason: str
18 | """
19 | # 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。
20 | # network_error 代表模型推理服务异常。
21 | """
22 | metadata: dict # 元信息,拓展字段
23 |
24 |
25 | class CompletionUsage(BaseModel):
26 | prompt_tokens: int # 输入的 tokens 数量
27 | completion_tokens: int # 输出的 tokens 数量
28 | total_tokens: int # 总 tokens 数量
29 |
30 |
31 | class AssistantCompletion(BaseModel):
32 | id: str # 请求 ID
33 | conversation_id: str # 会话 ID
34 | assistant_id: str # 智能体 ID
35 | created: int # 请求创建时间,Unix 时间戳
36 | status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常
37 | last_error: Optional[ErrorInfo] # 异常信息
38 | choices: List[AssistantChoice] # 增量返回的信息
39 | metadata: Optional[Dict[str, Any]] # 元信息,拓展字段
40 | usage: Optional[CompletionUsage] # tokens 数量统计
41 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/__init__.py:
--------------------------------------------------------------------------------
1 | from .chat import (
2 | AsyncCompletions,
3 | Chat,
4 | Completions,
5 | )
6 | from .images import (
7 | Images
8 | )
9 | from .embeddings import (
10 | Embeddings
11 | )
12 | from .files import (
13 | Files,
14 | FilesWithRawResponse
15 | )
16 | from .fine_tuning import (
17 | FineTuning
18 | )
19 |
20 | from .batches import (
21 | Batches
22 | )
23 |
24 | from .knowledge import (
25 | Knowledge
26 | )
27 | from .tools import (
28 | Tools
29 | )
30 | from .videos import (
31 | Videos,
32 | )
33 | from .assistant import (
34 | Assistant,
35 | )
36 | from .audio import (
37 | Audio
38 | )
39 |
40 | from .moderation import (
41 | Moderations
42 | )
43 |
44 | from .web_search import (
45 | WebSearchApi
46 | )
47 |
48 | from .agents import (
49 | Agents
50 | )
51 |
52 | from .audio import (
53 | Audio
54 | )
55 |
56 | from .file_parser import (
57 | FileParser
58 | )
59 |
60 | __all__ = [
61 | 'Videos',
62 | 'AsyncCompletions',
63 | 'Chat',
64 | 'Completions',
65 | 'Images',
66 | 'Embeddings',
67 | 'Files',
68 | 'FilesWithRawResponse',
69 | 'FineTuning',
70 | 'Batches',
71 | 'Knowledge',
72 | 'Tools',
73 | 'Assistant',
74 | 'Audio',
75 | 'Moderation',
76 | 'FileParser'
77 | ]
78 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_request_opt.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from zhipuai.core._request_opt import FinalRequestOptions, NotGiven
4 |
5 |
6 | # Test Initialization and Default Values
7 | def test_initialization():
8 | params = FinalRequestOptions.construct(
9 | method='GET', url='http://example.com'
10 | )
11 | assert isinstance(params.max_retries, NotGiven)
12 | assert isinstance(params.timeout, NotGiven)
13 | assert isinstance(params.headers, NotGiven)
14 | assert params.json_data is None
15 |
16 |
17 | # Test get_max_retries Method
18 | @pytest.mark.parametrize(
19 | 'max_retries_input, expected',
20 | [
21 | (NotGiven(), 5), # Default case
22 | (3, 3), # Specific number
23 | ],
24 | )
25 | def test_get_max_retries(max_retries_input, expected):
26 | params = FinalRequestOptions.construct(
27 | method='GET', url='http://example.com', max_retries=max_retries_input
28 | )
29 | assert params.get_max_retries(5) == expected
30 |
31 |
32 | # Test construct Method
33 | def test_construct():
34 | input_data = {
35 | 'max_retries': 3,
36 | 'timeout': 10.0,
37 | 'headers': {'Content-Type': 'application/json'},
38 | }
39 | params = FinalRequestOptions.construct(**input_data)
40 | assert params.max_retries == input_data['max_retries']
41 | assert params.timeout == input_data['timeout']
42 | assert params.headers == input_data['headers']
43 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_agents.py:
--------------------------------------------------------------------------------
1 | def test_agents_completion_error_field():
2 | from zhipuai.types.agents.agents_completion import AgentsCompletion, AgentsError, AgentsCompletionChoice, AgentsCompletionMessage, AgentsCompletionUsage
3 |
4 | # 构造一个 AgentsError
5 | error = AgentsError(code="404", message="Not Found")
6 |
7 | # 构造一个完整的 AgentsCompletion
8 | completion = AgentsCompletion(
9 | agent_id="test_agent",
10 | conversation_id="conv_1",
11 | status="failed",
12 | choices=[
13 | AgentsCompletionChoice(
14 | index=0,
15 | finish_reason="error",
16 | message=AgentsCompletionMessage(content="error", role="system")
17 | )
18 | ],
19 | request_id="req_1",
20 | id="id_1",
21 | usage=AgentsCompletionUsage(prompt_tokens=1, completion_tokens=1, total_tokens=2),
22 | error=error
23 | )
24 |
25 | # 检查 error 字段是否为 AgentsError 实例
26 | assert isinstance(completion.error, AgentsError)
27 | assert completion.error.code == "404"
28 | assert completion.error.message == "Not Found"
29 |
30 | # 检查序列化
31 | as_dict = completion.model_dump()
32 | assert as_dict["error"]["code"] == "404"
33 | assert as_dict["error"]["message"] == "Not Found"
34 | print("test_agents_completion_error_field passed.")
--------------------------------------------------------------------------------
/tests/integration_tests/test_tools.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_tools(logging_conf):
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 | client = ZhipuAI() # 填写您自己的APIKey
11 | try:
12 | response = client.tools.web_search(
13 | model='web-search-pro',
14 | messages=[
15 | {
16 | 'content': '你好',
17 | 'role': 'user',
18 | }
19 | ],
20 | stream=False,
21 | )
22 | print(response)
23 |
24 | except zhipuai.core._errors.APIRequestFailedError as err:
25 | print(err)
26 | except zhipuai.core._errors.APIInternalError as err:
27 | print(err)
28 | except zhipuai.core._errors.APIStatusError as err:
29 | print(err)
30 |
31 |
32 | def test_tools_stream(logging_conf):
33 | logging.config.dictConfig(logging_conf) # type: ignore
34 | client = ZhipuAI() # 填写您自己的APIKey
35 | try:
36 | response = client.tools.web_search(
37 | model='web-search-pro',
38 | messages=[
39 | {
40 | 'content': '你好',
41 | 'role': 'user',
42 | }
43 | ],
44 | stream=True,
45 | )
46 | for item in response:
47 | print(item)
48 |
49 | except zhipuai.core._errors.APIRequestFailedError as err:
50 | print(err)
51 | except zhipuai.core._errors.APIInternalError as err:
52 | print(err)
53 | except zhipuai.core._errors.APIStatusError as err:
54 | print(err)
55 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools/retrieval_delta_black.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from typing_extensions import Literal
4 |
5 | from .....core import BaseModel
6 |
7 |
8 | class RetrievalToolOutput(BaseModel):
9 | """
10 | This class represents the output of a retrieval tool.
11 |
12 | Attributes:
13 | - text (str): The text snippet retrieved from the knowledge base.
14 | - document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration.
15 | """
16 | text: str
17 | document: str
18 |
19 |
20 | class RetrievalTool(BaseModel):
21 | """
22 | This class represents the outputs of a retrieval tool.
23 |
24 | Attributes:
25 | - outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base.
26 | """
27 | outputs: List[RetrievalToolOutput]
28 |
29 |
30 | class RetrievalToolBlock(BaseModel):
31 | """
32 | This class represents a block for invoking the retrieval tool.
33 |
34 | Attributes:
35 | - retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs.
36 | - type (Literal["retrieval"]): The type of tool being used, always set to "retrieval".
37 | """
38 | retrieval: RetrievalTool
39 | type: Literal["retrieval"]
40 | """Always `retrieval`."""
41 |
42 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_embedding.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_embeddings(logging_conf):
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 |
11 | client = ZhipuAI()
12 | try:
13 | response = client.embeddings.create(
14 | model='embedding-2', # 填写需要调用的模型名称
15 | input='你好',
16 | extra_body={'model_version': 'v1'},
17 | )
18 | print(response)
19 |
20 | except zhipuai.core._errors.APIRequestFailedError as err:
21 | print(err)
22 | except zhipuai.core._errors.APIInternalError as err:
23 | print(err)
24 | except zhipuai.core._errors.APIStatusError as err:
25 | print(err)
26 |
27 |
28 | def test_embeddings_dimensions(logging_conf):
29 | logging.config.dictConfig(logging_conf) # type: ignore
30 |
31 | client = ZhipuAI()
32 | try:
33 | response = client.embeddings.create(
34 | model='embedding-3', # 填写需要调用的模型名称
35 | input='你好',
36 | dimensions=512,
37 | extra_body={'model_version': 'v1'},
38 | )
39 | assert response.data[0].object == 'embedding'
40 | assert len(response.data[0].embedding) == 512
41 | print(len(response.data[0].embedding))
42 |
43 | except zhipuai.core._errors.APIRequestFailedError as err:
44 | print(err)
45 | except zhipuai.core._errors.APIInternalError as err:
46 | print(err)
47 | except zhipuai.core._errors.APIStatusError as err:
48 | print(err)
49 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_images.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_images(logging_conf):
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 | client = ZhipuAI() # 填写您自己的APIKey
11 | try:
12 | response = client.images.generations(
13 | model='cogview-3', # 填写需要调用的模型名称
14 | prompt='一只可爱的小猫咪',
15 | extra_body={'user_id': '1222212'},
16 | user_id='12345678',
17 | )
18 | print(response)
19 |
20 | except zhipuai.core._errors.APIRequestFailedError as err:
21 | print(err)
22 | except zhipuai.core._errors.APIInternalError as err:
23 | print(err)
24 | except zhipuai.core._errors.APIStatusError as err:
25 | print(err)
26 |
27 |
28 | def test_images_sensitive_word_check(logging_conf):
29 | logging.config.dictConfig(logging_conf) # type: ignore
30 | client = ZhipuAI() # 填写您自己的APIKey
31 | try:
32 | response = client.images.generations(
33 | model='cogview-3', # 填写需要调用的模型名称
34 | prompt='一只可爱的小猫咪',
35 | sensitive_word_check={'type': 'ALL', 'status': 'DISABLE'},
36 | extra_body={'user_id': '1222212'},
37 | user_id='12345678',
38 | )
39 | print(response)
40 |
41 | except zhipuai.core._errors.APIRequestFailedError as err:
42 | print(err)
43 | except zhipuai.core._errors.APIInternalError as err:
44 | print(err)
45 | except zhipuai.core._errors.APIStatusError as err:
46 | print(err)
47 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_transcriptions.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | from pathlib import Path
4 |
5 | import zhipuai
6 | from zhipuai import ZhipuAI
7 |
8 |
9 | def test_transcriptions(logging_conf):
10 | logging.config.dictConfig(logging_conf) # type: ignore
11 | client = ZhipuAI() # 填写您自己的APIKey
12 | try:
13 | with open(Path(__file__).parent / 'asr1.wav', 'rb') as audio_file:
14 | transcriptResponse = client.audio.transcriptions.create(
15 | model='glm-asr', file=audio_file, stream=False
16 | )
17 | print(transcriptResponse)
18 | except zhipuai.core._errors.APIRequestFailedError as err:
19 | print(err)
20 | except zhipuai.core._errors.APIInternalError as err:
21 | print(err)
22 | except zhipuai.core._errors.APIStatusError as err:
23 | print(err)
24 |
25 |
26 | def test_transcriptions_stream(logging_conf):
27 | logging.config.dictConfig(logging_conf) # type: ignore
28 | client = ZhipuAI() # 填写您自己的APIKey
29 | try:
30 | with open(Path(__file__).parent / 'asr1.wav', 'rb') as audio_file:
31 | transcriptResponse = client.audio.transcriptions.create(
32 | model='glm-asr', file=audio_file, stream=True
33 | )
34 | for item in transcriptResponse:
35 | print(item)
36 | except zhipuai.core._errors.APIRequestFailedError as err:
37 | print(err)
38 | except zhipuai.core._errors.APIInternalError as err:
39 | print(err)
40 | except zhipuai.core._errors.APIStatusError as err:
41 | print(err)
42 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/fine_tuning/models/fine_tuned_models.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional, TYPE_CHECKING
4 |
5 | import httpx
6 |
7 | from ....core import BaseAPI
8 | from ....core import NOT_GIVEN, Headers, NotGiven, Body
9 | from ....core import (
10 | make_request_options,
11 | )
12 |
13 | from ....types.fine_tuning.models import (
14 | FineTunedModelsStatus
15 | )
16 |
17 | if TYPE_CHECKING:
18 | from ...._client import ZhipuAI
19 |
20 | __all__ = ["FineTunedModels"]
21 |
22 |
23 | class FineTunedModels(BaseAPI):
24 |
25 | def __init__(self, client: "ZhipuAI") -> None:
26 | super().__init__(client)
27 |
28 | def delete(
29 | self,
30 | fine_tuned_model: str,
31 | *,
32 | extra_headers: Headers | None = None,
33 | extra_body: Body | None = None,
34 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
35 | ) -> FineTunedModelsStatus:
36 |
37 | if not fine_tuned_model:
38 | raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
39 | return self._delete(
40 | f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
41 | options=make_request_options(
42 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
43 | ),
44 | cast_type=FineTunedModelsStatus,
45 | )
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | ## Contribution process
6 |
7 | ### Code reviews
8 |
9 | All submissions, including submissions by project members, require review. We
10 | use GitHub pull requests for this purpose. Consult
11 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
12 | information on using pull requests.
13 |
14 | ### Contributor Guide
15 |
16 | You may follow these steps to contribute:
17 |
18 | 1. **Fork the official repository.** This will create a copy of the official repository in your own account.
19 | 2. **Sync the branches.** This will ensure that your copy of the repository is up-to-date with the latest changes from the official repository.
20 | 3. **Work on your forked repository's feature branch.** This is where you will make your changes to the code.
21 | 4. **Commit your updates on your forked repository's feature branch.** This will save your changes to your copy of the repository.
22 | 5. **Submit a pull request to the official repository's main branch.** This will request that your changes be merged into the official repository.
23 | 6. **Resolve any linting errors.** This will ensure that your changes are formatted correctly.
24 |
25 | Here are some additional things to keep in mind during the process:
26 |
27 | - **Test your changes.** Before you submit a pull request, make sure that your changes work as expected.
28 | - **Be patient.** It may take some time for your pull request to be reviewed and merged.
29 |
30 |
31 |
32 |
33 |
34 | Have Fun!
35 | ---
36 |
--------------------------------------------------------------------------------
/zhipuai/types/assistant/message/tools/web_browser_delta_block.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from typing_extensions import Literal
4 |
5 | from .....core import BaseModel
6 | __all__ = ["WebBrowserToolBlock"]
7 |
8 |
9 | class WebBrowserOutput(BaseModel):
10 | """
11 | This class represents the output of a web browser search result.
12 |
13 | Attributes:
14 | - title (str): The title of the search result.
15 | - link (str): The URL link to the search result's webpage.
16 | - content (str): The textual content extracted from the search result.
17 | - error_msg (str): Any error message encountered during the search or retrieval process.
18 | """
19 | title: str
20 | link: str
21 | content: str
22 | error_msg: str
23 |
24 |
25 | class WebBrowser(BaseModel):
26 | """
27 | This class represents the input and outputs of a web browser search.
28 |
29 | Attributes:
30 | - input (str): The input query for the web browser search.
31 | - outputs (List[WebBrowserOutput]): A list of search results returned by the web browser.
32 | """
33 | input: str
34 | outputs: List[WebBrowserOutput]
35 |
36 |
37 | class WebBrowserToolBlock(BaseModel):
38 | """
39 | This class represents a block for invoking the web browser tool.
40 |
41 | Attributes:
42 | - web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs.
43 | - type (Literal["web_browser"]): The type of tool being used, always set to "web_browser".
44 | """
45 | web_browser: WebBrowser
46 | type: Literal["web_browser"]
47 |
--------------------------------------------------------------------------------
/zhipuai/types/knowledge/document/document.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | from ....core import BaseModel
4 |
5 | __all__ = ["DocumentData", "DocumentObject", "DocumentSuccessinfo", "DocumentFailedInfo"]
6 |
7 |
8 | class DocumentSuccessinfo(BaseModel):
9 | documentId: Optional[str] = None
10 | """文件id"""
11 | filename: Optional[str] = None
12 | """文件名称"""
13 |
14 |
15 | class DocumentFailedInfo(BaseModel):
16 | failReason: Optional[str] = None
17 | """上传失败的原因,包括:文件格式不支持、文件大小超出限制、知识库容量已满、容量上限为 50 万字。"""
18 | filename: Optional[str] = None
19 | """文件名称"""
20 | documentId: Optional[str] = None
21 | """知识库id"""
22 |
23 |
24 | class DocumentObject(BaseModel):
25 | """文档信息"""
26 |
27 | successInfos: Optional[List[DocumentSuccessinfo]] = None
28 | """上传成功的文件信息"""
29 | failedInfos: Optional[List[DocumentFailedInfo]] = None
30 | """上传失败的文件信息"""
31 |
32 |
33 | class DocumentDataFailInfo(BaseModel):
34 | """失败原因"""
35 |
36 | embedding_code: Optional[int] = None # 失败码 10001:知识不可用,知识库空间已达上限 10002:知识不可用,知识库空间已达上限(字数超出限制)
37 | embedding_msg: Optional[str] = None # 失败原因
38 |
39 |
40 | class DocumentData(BaseModel):
41 | id: str = None # 知识唯一id
42 | custom_separator: List[str] = None # 切片规则
43 | sentence_size: str = None # 切片大小
44 | length: int = None # 文件大小(字节)
45 | word_num: int = None # 文件字数
46 | name: str = None # 文件名
47 | url: str = None # 文件下载链接
48 | embedding_stat: int = None # 0:向量化中 1:向量化完成 2:向量化失败
49 | failInfo: Optional[DocumentDataFailInfo] = None # 失败原因 向量化失败embedding_stat=2的时候 会有此值
50 |
--------------------------------------------------------------------------------
/zhipuai/types/tools/web_search.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from ..chat.chat_completion import Function
4 | from ...core import BaseModel
5 |
6 | __all__ = [
7 | "WebSearch",
8 | "SearchIntent",
9 | "SearchResult",
10 | "SearchRecommend",
11 | ]
12 |
13 |
14 | class SearchIntent(BaseModel):
15 | index: int
16 | # 搜索轮次,默认为 0
17 | query: str
18 | # 搜索优化 query
19 | intent: str
20 | # 判断的意图类型
21 | keywords: str
22 | # 搜索关键词
23 |
24 |
25 | class SearchResult(BaseModel):
26 | index: int
27 | # 搜索轮次,默认为 0
28 | title: str
29 | # 标题
30 | link: str
31 | # 链接
32 | content: str
33 | # 内容
34 | icon: str
35 | # 图标
36 | media: str
37 | # 来源媒体
38 | refer: str
39 | # 角标序号 [ref_1]
40 |
41 |
42 | class SearchRecommend(BaseModel):
43 | index: int
44 | # 搜索轮次,默认为 0
45 | query: str
46 | # 推荐query
47 |
48 |
49 | class WebSearchMessageToolCall(BaseModel):
50 | id: str
51 | search_intent: Optional[SearchIntent]
52 | search_result: Optional[SearchResult]
53 | search_recommend: Optional[SearchRecommend]
54 | type: str
55 |
56 |
57 | class WebSearchMessage(BaseModel):
58 | role: str
59 | tool_calls: Optional[List[WebSearchMessageToolCall]] = None
60 |
61 |
62 | class WebSearchChoice(BaseModel):
63 | index: int
64 | finish_reason: str
65 | message: WebSearchMessage
66 |
67 |
68 | class WebSearch(BaseModel):
69 | created: Optional[int] = None
70 | choices: List[WebSearchChoice]
71 | request_id: Optional[str] = None
72 | id: Optional[str] = None
73 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: 💡 Feature Request
2 | description: Suggest an idea for this repository
3 | title: "[Feat]: "
4 | type: "Feature"
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thanks for stopping by to let us know something could be better!
10 | - type: textarea
11 | id: problem
12 | attributes:
13 | label: Is your feature request related to a problem? Please describe.
14 | description: A clear and concise description of what the problem is.
15 | placeholder: Ex. I'm always frustrated when [...]
16 | - type: textarea
17 | id: describe
18 | attributes:
19 | label: Describe the solution you'd like
20 | description: A clear and concise description of what you want to happen.
21 | validations:
22 | required: true
23 | - type: textarea
24 | id: alternatives
25 | attributes:
26 | label: Describe alternatives you've considered
27 | description: A clear and concise description of any alternative solutions or features you've considered.
28 | - type: textarea
29 | id: context
30 | attributes:
31 | label: Additional context
32 | description: Add any other context or screenshots about the feature request here.
33 | - type: checkboxes
34 | id: terms
35 | attributes:
36 | label: Code of Conduct
37 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/MetaGLM/zhipuai-sdk-python-v4/blob/main/CODE_OF_CONDUCT.md)
38 | options:
39 | - label: I agree to follow this project's Code of Conduct
40 | required: true
41 |
--------------------------------------------------------------------------------
/zhipuai/core/pagination.py:
--------------------------------------------------------------------------------
1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2 |
3 | from typing import Any, List, Generic, TypeVar, Optional, cast
4 | from typing_extensions import Protocol, override, runtime_checkable
5 |
6 | from ._http_client import BasePage, PageInfo, BaseSyncPage
7 |
8 | __all__ = ["SyncPage", "SyncCursorPage"]
9 |
10 | _T = TypeVar("_T")
11 |
12 |
13 | @runtime_checkable
14 | class CursorPageItem(Protocol):
15 | id: Optional[str]
16 |
17 |
18 | class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
19 | """Note: no pagination actually occurs yet, this is for forwards-compatibility."""
20 |
21 | data: List[_T]
22 | object: str
23 |
24 | @override
25 | def _get_page_items(self) -> List[_T]:
26 | data = self.data
27 | if not data:
28 | return []
29 | return data
30 |
31 | @override
32 | def next_page_info(self) -> None:
33 | """
34 | This page represents a response that isn't actually paginated at the API level
35 | so there will never be a next page.
36 | """
37 | return None
38 |
39 |
40 | class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
41 | data: List[_T]
42 |
43 | @override
44 | def _get_page_items(self) -> List[_T]:
45 | data = self.data
46 | if not data:
47 | return []
48 | return data
49 |
50 | @override
51 | def next_page_info(self) -> Optional[PageInfo]:
52 | data = self.data
53 | if not data:
54 | return None
55 |
56 | item = cast(Any, data[-1])
57 | if not isinstance(item, CursorPageItem) or item.id is None:
58 | # TODO emit warning log
59 | return None
60 |
61 | return PageInfo(params={"after": item.id})
62 |
--------------------------------------------------------------------------------
/.github/workflows/_test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | default: '.'
10 | description: "From which folder this pipeline executes"
11 | pull_request:
12 | branches:
13 | - main
14 | push:
15 | branches:
16 | - 'action*'
17 |
18 | env:
19 | POETRY_VERSION: "1.8.2"
20 |
21 | jobs:
22 | build:
23 | defaults:
24 | run:
25 | working-directory: ${{ inputs.working-directory || '.' }}
26 | runs-on: ubuntu-latest
27 | strategy:
28 | matrix:
29 | python-version:
30 | - "3.8"
31 | - "3.9"
32 | - "3.10"
33 | - "3.11"
34 | - "3.12"
35 | name: "make test #${{ matrix.python-version }}"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
40 | uses: "./.github/actions/poetry_setup"
41 | with:
42 | python-version: ${{ matrix.python-version }}
43 | poetry-version: ${{ env.POETRY_VERSION }}
44 | working-directory: ${{ inputs.working-directory || '.' }}
45 | cache-key: core
46 |
47 |
48 | - name: Import test dependencies
49 | run: poetry install --with test
50 | working-directory: ${{ inputs.working-directory || '.' }}
51 |
52 | - name: Run core tests
53 | shell: bash
54 | run: |
55 | make test
56 |
57 | - name: Ensure the tests did not create any additional files
58 | shell: bash
59 | run: |
60 | set -eu
61 |
62 | STATUS="$(git status)"
63 | echo "$STATUS"
64 |
65 | # grep will exit non-zero if the target message isn't found,
66 | # and `set -e` above will cause the step to fail.
67 | echo "$STATUS" | grep 'nothing to commit, working tree clean'
68 |
--------------------------------------------------------------------------------
/zhipuai/core/_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from ._utils import (
3 | remove_notgiven_indict as remove_notgiven_indict,
4 | flatten as flatten,
5 | is_dict as is_dict,
6 | is_list as is_list,
7 | is_given as is_given,
8 | is_tuple as is_tuple,
9 | is_mapping as is_mapping,
10 | is_tuple_t as is_tuple_t,
11 | parse_date as parse_date,
12 | is_iterable as is_iterable,
13 | is_sequence as is_sequence,
14 | coerce_float as coerce_float,
15 | is_mapping_t as is_mapping_t,
16 | removeprefix as removeprefix,
17 | removesuffix as removesuffix,
18 | extract_files as extract_files,
19 | is_sequence_t as is_sequence_t,
20 | required_args as required_args,
21 | coerce_boolean as coerce_boolean,
22 | coerce_integer as coerce_integer,
23 | file_from_path as file_from_path,
24 | parse_datetime as parse_datetime,
25 | strip_not_given as strip_not_given,
26 | deepcopy_minimal as deepcopy_minimal,
27 | get_async_library as get_async_library,
28 | maybe_coerce_float as maybe_coerce_float,
29 | get_required_header as get_required_header,
30 | maybe_coerce_boolean as maybe_coerce_boolean,
31 | maybe_coerce_integer as maybe_coerce_integer,
32 | drop_prefix_image_data as drop_prefix_image_data,
33 | )
34 |
35 |
36 | from ._typing import (
37 | is_list_type as is_list_type,
38 | is_union_type as is_union_type,
39 | extract_type_arg as extract_type_arg,
40 | is_iterable_type as is_iterable_type,
41 | is_required_type as is_required_type,
42 | is_annotated_type as is_annotated_type,
43 | strip_annotated_type as strip_annotated_type,
44 | extract_type_var_from_base as extract_type_var_from_base,
45 | )
46 |
47 | from ._transform import (
48 | PropertyInfo as PropertyInfo,
49 | transform as transform,
50 | async_transform as async_transform,
51 | maybe_transform as maybe_transform,
52 | async_maybe_transform as async_maybe_transform,
53 | )
54 |
--------------------------------------------------------------------------------
/tests/unit_tests/sse_client/test_stream.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Iterable, Type, cast
3 |
4 | import httpx
5 |
6 | from zhipuai.core import HttpClient, StreamResponse, get_args
7 | from zhipuai.core._base_type import ResponseT
8 | from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
9 |
10 |
11 | class MockClient:
12 | _strict_response_validation: bool = False
13 |
14 | def _process_response_data(
15 | self,
16 | *,
17 | data: object,
18 | cast_type: Type[ResponseT],
19 | response: httpx.Response,
20 | ) -> ResponseT:
21 | pass
22 |
23 |
24 | def test_stream_cls_chunk() -> None:
25 | MockClient._process_response_data = HttpClient._process_response_data
26 |
27 | def body() -> Iterable[bytes]:
28 | yield b'data: {"id":"8635243129834723621","created":1715329207,"model":\
29 | "glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"1"}}]}\n\n'
30 | yield b'data: {"id":"8635243129834723621","created":1715329207,"model":\
31 | "glm-4","choices":[{"index":0,"delta":{"role":"assistant","content":"2"}}]}\n\n'
32 |
33 | _stream_cls = StreamResponse[ChatCompletionChunk]
34 | http_response = httpx.Response(status_code=200, content=body())
35 |
36 | stream_cls = _stream_cls(
37 | cast_type=cast(type, get_args(_stream_cls)[0]),
38 | response=http_response,
39 | client=MockClient(),
40 | )
41 | chat_completion_chunk1 = next(stream_cls)
42 |
43 | assert chat_completion_chunk1.choices[0].delta.content == '1'
44 | assert chat_completion_chunk1.choices[0].delta.role == 'assistant'
45 | assert chat_completion_chunk1.choices[0].index == 0
46 | assert chat_completion_chunk1.model == 'glm-4'
47 | chat_completion_chunk2 = next(stream_cls)
48 | assert chat_completion_chunk2.choices[0].delta.content == '2'
49 | assert chat_completion_chunk2.choices[0].delta.role == 'assistant'
50 | assert chat_completion_chunk2.choices[0].index == 0
51 | assert chat_completion_chunk2.model == 'glm-4'
52 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
2 |
3 | # Default target executed when no arguments are given to make.
4 | all: help
5 |
6 | ######################
7 | # TESTING AND COVERAGE
8 | ######################
9 |
10 | # Define a variable for the test file path.
11 | TEST_FILE ?= tests/unit_tests/
12 |
13 |
14 | test tests:
15 | poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
16 | integration_tests:
17 | poetry run pytest tests/integration_tests
18 |
19 |
20 | ######################
21 | # LINTING AND FORMATTING
22 | ######################
23 |
24 | # Define a variable for Python and notebook files.
25 | PYTHON_FILES=.
26 | MYPY_CACHE=.mypy_cache
27 | lint format: PYTHON_FILES=.
28 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
29 | lint_package: PYTHON_FILES=zhipuai
30 | lint_tests: PYTHON_FILES=tests
31 | lint_tests: MYPY_CACHE=.mypy_cache_test
32 |
33 | lint lint_diff lint_package lint_tests:
34 | ./scripts/check_pydantic.sh .
35 | ./scripts/lint_imports.sh
36 | poetry run ruff .
37 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
38 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
39 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
40 |
41 | format format_diff:
42 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
43 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
44 |
45 |
46 | ######################
47 | # HELP
48 | ######################
49 |
50 | help:
51 | @echo '-- LINTING --'
52 | @echo 'format - run code formatters'
53 | @echo 'lint - run linters'
54 | @echo '-- TESTS --'
55 | @echo 'test - run unit tests'
56 | @echo 'tests - run unit tests (alias for "make test")'
57 | @echo 'test TEST_FILE= - run all tests in file'
58 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/embeddings.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, List, Optional, TYPE_CHECKING
4 |
5 | import httpx
6 |
7 | from ..core import BaseAPI, Body
8 | from ..core import NotGiven, NOT_GIVEN, Headers
9 | from ..core import make_request_options
10 | from ..types.embeddings import EmbeddingsResponded
11 |
12 | if TYPE_CHECKING:
13 | from .._client import ZhipuAI
14 |
15 |
16 | class Embeddings(BaseAPI):
17 | def __init__(self, client: "ZhipuAI") -> None:
18 | super().__init__(client)
19 |
20 | def create(
21 | self,
22 | *,
23 | input: Union[str, List[str], List[int], List[List[int]]],
24 | model: Union[str],
25 | dimensions: Union[int]| NotGiven = NOT_GIVEN,
26 | encoding_format: str | NotGiven = NOT_GIVEN,
27 | user: str | NotGiven = NOT_GIVEN,
28 | request_id: Optional[str] | NotGiven = NOT_GIVEN,
29 | sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
30 | extra_headers: Headers | None = None,
31 | extra_body: Body | None = None,
32 | disable_strict_validation: Optional[bool] | None = None,
33 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
34 | ) -> EmbeddingsResponded:
35 | _cast_type = EmbeddingsResponded
36 | if disable_strict_validation:
37 | _cast_type = object
38 | return self._post(
39 | "/embeddings",
40 | body={
41 | "input": input,
42 | "model": model,
43 | "dimensions": dimensions,
44 | "encoding_format": encoding_format,
45 | "user": user,
46 | "request_id": request_id,
47 | "sensitive_word_check": sensitive_word_check,
48 | },
49 | options=make_request_options(
50 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
51 | ),
52 | cast_type=_cast_type,
53 | stream=False,
54 | )
55 |
--------------------------------------------------------------------------------
/zhipuai/types/chat/chat_completion_chunk.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 |
3 | from ...core import BaseModel
4 |
5 | __all__ = [
6 | "CompletionUsage",
7 | "ChatCompletionChunk",
8 | "Choice",
9 | "ChoiceDelta",
10 | "ChoiceDeltaFunctionCall",
11 | "ChoiceDeltaToolCall",
12 | "ChoiceDeltaToolCallFunction",
13 | "AudioCompletionChunk"
14 | ]
15 |
16 |
17 | class ChoiceDeltaFunctionCall(BaseModel):
18 | arguments: Optional[str] = None
19 | name: Optional[str] = None
20 |
21 |
22 | class ChoiceDeltaToolCallFunction(BaseModel):
23 | arguments: Optional[str] = None
24 | name: Optional[str] = None
25 |
26 |
27 | class ChoiceDeltaToolCall(BaseModel):
28 | index: int
29 | id: Optional[str] = None
30 | function: Optional[ChoiceDeltaToolCallFunction] = None
31 | type: Optional[str] = None
32 |
33 | class AudioCompletionChunk(BaseModel):
34 | id: Optional[str] = None
35 | data: Optional[str] = None
36 | expires_at: Optional[int] = None
37 |
38 |
39 | class ChoiceDelta(BaseModel):
40 | content: Optional[str] = None
41 | role: Optional[str] = None
42 | reasoning_content: Optional[str] = None
43 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
44 | audio: Optional[AudioCompletionChunk] = None
45 |
46 |
47 | class Choice(BaseModel):
48 | delta: ChoiceDelta
49 | finish_reason: Optional[str] = None
50 | index: int
51 |
52 | class PromptTokensDetails(BaseModel):
53 | cached_tokens: int
54 |
55 | class CompletionTokensDetails(BaseModel):
56 | reasoning_tokens: int
57 |
58 | class CompletionUsage(BaseModel):
59 | prompt_tokens: int
60 | prompt_tokens_details: Optional[PromptTokensDetails] = None
61 | completion_tokens: int
62 | completion_tokens_details: Optional[CompletionTokensDetails] = None
63 | total_tokens: int
64 |
65 | class ChatCompletionChunk(BaseModel):
66 | id: Optional[str] = None
67 | choices: List[Choice]
68 | created: Optional[int] = None
69 | model: Optional[str] = None
70 | usage: Optional[CompletionUsage] = None
71 | extra_json: Dict[str, Any]
72 |
--------------------------------------------------------------------------------
/zhipuai/types/chat/code_geex/code_geex_params.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from typing_extensions import Literal, Required, TypedDict
3 |
4 | __all__ = [
5 | "CodeGeexTarget",
6 | "CodeGeexContext",
7 | "CodeGeexExtra",
8 | ]
9 |
10 |
11 | class CodeGeexTarget(TypedDict, total=False):
12 | """补全的内容参数"""
13 | path: Optional[str]
14 | """文件路径"""
15 | language: Required[Literal[
16 | "c", "c++", "cpp", "c#", "csharp", "c-sharp", "css", "cuda", "dart", "lua",
17 | "objectivec", "objective-c", "objective-c++", "python", "perl", "prolog",
18 | "swift", "lisp", "java", "scala", "tex", "jsx", "tsx", "vue", "markdown",
19 | "html", "php", "js", "javascript", "typescript", "go", "shell", "rust",
20 | "sql", "kotlin", "vb", "ruby", "pascal", "r", "fortran", "lean", "matlab",
21 | "delphi", "scheme", "basic", "assembly", "groovy", "abap", "gdscript",
22 | "haskell", "julia", "elixir", "excel", "clojure", "actionscript", "solidity",
23 | "powershell", "erlang", "cobol", "alloy", "awk", "thrift", "sparql", "augeas",
24 | "cmake", "f-sharp", "stan", "isabelle", "dockerfile", "rmarkdown",
25 | "literate-agda", "tcl", "glsl", "antlr", "verilog", "racket", "standard-ml",
26 | "elm", "yaml", "smalltalk", "ocaml", "idris", "visual-basic", "protocol-buffer",
27 | "bluespec", "applescript", "makefile", "tcsh", "maple", "systemverilog",
28 | "literate-coffeescript", "vhdl", "restructuredtext", "sas", "literate-haskell",
29 | "java-server-pages", "coffeescript", "emacs-lisp", "mathematica", "xslt",
30 | "zig", "common-lisp", "stata", "agda", "ada"
31 | ]]
32 | """代码语言类型,如python"""
33 | code_prefix: Required[str]
34 | """补全位置的前文"""
35 | code_suffix: Required[str]
36 | """补全位置的后文"""
37 |
38 |
39 | class CodeGeexContext(TypedDict, total=False):
40 | """附加代码"""
41 | path: Required[str]
42 | """附加代码文件的路径"""
43 | code: Required[str]
44 | """附加的代码内容"""
45 |
46 |
47 | class CodeGeexExtra(TypedDict, total=False):
48 | target: Required[CodeGeexTarget]
49 | """补全的内容参数"""
50 | contexts: Optional[List[CodeGeexContext]]
51 | """附加代码"""
52 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_streaming.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | import pytest
4 |
5 | from zhipuai.core._sse_client import SSELineParser
6 |
7 |
8 | def test_basic() -> None:
9 | def body() -> Iterator[str]:
10 | yield 'event: completion'
11 | yield 'data: {"foo":true}'
12 | yield ''
13 |
14 | it = SSELineParser().iter_lines(body())
15 | sse = next(it)
16 | assert sse.event == 'completion'
17 | assert sse.json_data() == {'foo': True}
18 |
19 | with pytest.raises(StopIteration):
20 | next(it)
21 |
22 |
23 | def test_data_missing_event() -> None:
24 | def body() -> Iterator[str]:
25 | yield 'data: {"foo":true}'
26 | yield ''
27 |
28 | it = SSELineParser().iter_lines(body())
29 | sse = next(it)
30 | assert sse.event is None
31 | assert sse.json_data() == {'foo': True}
32 |
33 | with pytest.raises(StopIteration):
34 | next(it)
35 |
36 |
37 | def test_event_missing_data() -> None:
38 | def body() -> Iterator[str]:
39 | yield 'event: ping'
40 | yield ''
41 |
42 | it = SSELineParser().iter_lines(body())
43 | sse = next(it)
44 | assert sse.event == 'ping'
45 | assert sse.data == ''
46 |
47 | with pytest.raises(StopIteration):
48 | next(it)
49 |
50 |
51 | def test_multiple_events() -> None:
52 | def body() -> Iterator[str]:
53 | yield 'event: ping'
54 | yield ''
55 | yield 'event: completion'
56 | yield ''
57 |
58 | it = SSELineParser().iter_lines(body())
59 |
60 | sse = next(it)
61 | assert sse.event == 'ping'
62 | assert sse.data == ''
63 |
64 | sse = next(it)
65 | assert sse.event == 'completion'
66 | assert sse.data == ''
67 |
68 | with pytest.raises(StopIteration):
69 | next(it)
70 |
71 |
72 | def test_multiple_events_with_data() -> None:
73 | def body() -> Iterator[str]:
74 | yield 'event: ping'
75 | yield 'data: {"foo":true}'
76 | yield ''
77 | yield 'event: completion'
78 | yield 'data: {"bar":false}'
79 | yield ''
80 |
81 | it = SSELineParser().iter_lines(body())
82 |
83 | sse = next(it)
84 | assert sse.event == 'ping'
85 | assert sse.json_data() == {'foo': True}
86 |
87 | sse = next(it)
88 | assert sse.event == 'completion'
89 | assert sse.json_data() == {'bar': False}
90 |
91 | with pytest.raises(StopIteration):
92 | next(it)
93 |
--------------------------------------------------------------------------------
/.github/workflows/_integration_test.yml:
--------------------------------------------------------------------------------
1 | name: integration_test
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | default: '.'
10 | description: "From which folder this pipeline executes"
11 |
12 | env:
13 | POETRY_VERSION: "1.8.2"
14 |
15 | jobs:
16 | build:
17 | if: github.ref == 'refs/heads/main'
18 | runs-on: ubuntu-latest
19 |
20 | environment: Scheduled testing publish
21 | outputs:
22 | pkg-name: ${{ steps.check-version.outputs.pkg-name }}
23 | version: ${{ steps.check-version.outputs.version }}
24 | strategy:
25 | matrix:
26 | python-version:
27 | - "3.8"
28 | - "3.9"
29 | - "3.10"
30 | - "3.11"
31 | - "3.12"
32 | name: "make integration_test #${{ matrix.python-version }}"
33 | steps:
34 | - uses: actions/checkout@v4
35 |
36 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
37 | uses: "./.github/actions/poetry_setup"
38 | with:
39 | python-version: ${{ matrix.python-version }}
40 | poetry-version: ${{ env.POETRY_VERSION }}
41 | working-directory: ${{ inputs.working-directory }}
42 | cache-key: core
43 |
44 | - name: Import test dependencies
45 | run: poetry install --with test
46 | working-directory: ${{ inputs.working-directory }}
47 |
48 | - name: Run integration tests
49 | shell: bash
50 | env:
51 | ZHIPUAI_API_KEY: ${{ secrets.ZHIPUAI_API_KEY }}
52 | ZHIPUAI_BASE_URL: ${{ secrets.ZHIPUAI_BASE_URL }}
53 | run: |
54 | make integration_tests
55 |
56 | - name: Ensure the tests did not create any additional files
57 | shell: bash
58 | run: |
59 | set -eu
60 |
61 | STATUS="$(git status)"
62 | echo "$STATUS"
63 |
64 | # grep will exit non-zero if the target message isn't found,
65 | # and `set -e` above will cause the step to fail.
66 | echo "$STATUS" | grep 'nothing to commit, working tree clean'
67 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/tools/tools.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, List, Union, Dict, Optional
4 | from typing_extensions import Literal
5 |
6 | from ...core import NOT_GIVEN, Body, Headers, NotGiven, BaseAPI, maybe_transform, StreamResponse, deepcopy_minimal
7 |
8 | import httpx
9 |
10 | from ...core import (
11 | make_request_options,
12 | )
13 | import logging
14 |
15 | from ...types.tools import tools_web_search_params, WebSearch, WebSearchChunk
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | if TYPE_CHECKING:
20 | from ..._client import ZhipuAI
21 |
22 | __all__ = ["Tools"]
23 |
24 |
25 | class Tools(BaseAPI):
26 |
27 | def __init__(self, client: "ZhipuAI") -> None:
28 | super().__init__(client)
29 |
30 | def web_search(
31 | self,
32 | *,
33 | model: str,
34 | request_id: Optional[str] | NotGiven = NOT_GIVEN,
35 | stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
36 | messages: Union[str, List[str], List[int], object, None],
37 | scope: Optional[str] | NotGiven = NOT_GIVEN,
38 | location: Optional[str] | NotGiven = NOT_GIVEN,
39 | recent_days: Optional[int] | NotGiven = NOT_GIVEN,
40 | extra_headers: Headers | None = None,
41 | extra_body: Body | None = None,
42 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
43 | ) -> WebSearch | StreamResponse[WebSearchChunk]:
44 |
45 | body = deepcopy_minimal(
46 | {
47 | "model": model,
48 | "request_id": request_id,
49 | "messages": messages,
50 | "stream": stream,
51 | "scope": scope,
52 | "location": location,
53 | "recent_days": recent_days,
54 | })
55 | return self._post(
56 | "/tools",
57 | body= maybe_transform(body, tools_web_search_params.WebSearchParams),
58 | options=make_request_options(
59 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
60 | ),
61 | cast_type=WebSearch,
62 | stream=stream or False,
63 | stream_cls=StreamResponse[WebSearchChunk],
64 | )
65 |
--------------------------------------------------------------------------------
/zhipuai/types/batch.py:
--------------------------------------------------------------------------------
1 | # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2 |
3 | import builtins
4 | from typing import List, Optional
5 | from typing_extensions import Literal
6 |
7 | from ..core import BaseModel
8 | from .batch_error import BatchError
9 | from .batch_request_counts import BatchRequestCounts
10 |
11 | __all__ = ["Batch", "Errors"]
12 |
13 |
14 | class Errors(BaseModel):
15 | data: Optional[List[BatchError]] = None
16 |
17 | object: Optional[str] = None
18 | """这个类型,一直是`list`。"""
19 |
20 |
21 | class Batch(BaseModel):
22 | id: str
23 |
24 | completion_window: str
25 | """用于执行请求的地址信息。"""
26 |
27 | created_at: int
28 | """这是 Unix timestamp (in seconds) 表示的创建时间。"""
29 |
30 | endpoint: str
31 | """这是ZhipuAI endpoint的地址。"""
32 |
33 | input_file_id: str
34 | """标记为batch的输入文件的ID。"""
35 |
36 | object: Literal["batch"]
37 | """这个类型,一直是`batch`."""
38 |
39 | status: Literal[
40 | "validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled"
41 | ]
42 | """batch 的状态。"""
43 |
44 | cancelled_at: Optional[int] = None
45 | """Unix timestamp (in seconds) 表示的取消时间。"""
46 |
47 | cancelling_at: Optional[int] = None
48 | """Unix timestamp (in seconds) 表示发起取消的请求时间 """
49 |
50 | completed_at: Optional[int] = None
51 | """Unix timestamp (in seconds) 表示的完成时间。"""
52 |
53 | error_file_id: Optional[str] = None
54 | """这个文件id包含了执行请求失败的请求的输出。"""
55 |
56 | errors: Optional[Errors] = None
57 |
58 | expired_at: Optional[int] = None
59 | """Unix timestamp (in seconds) 表示的将在过期时间。"""
60 |
61 | expires_at: Optional[int] = None
62 | """Unix timestamp (in seconds) 触发过期"""
63 |
64 | failed_at: Optional[int] = None
65 | """Unix timestamp (in seconds) 表示的失败时间。"""
66 |
67 | finalizing_at: Optional[int] = None
68 | """Unix timestamp (in seconds) 表示的最终时间。"""
69 |
70 | in_progress_at: Optional[int] = None
71 | """Unix timestamp (in seconds) 表示的开始处理时间。"""
72 |
73 | metadata: Optional[builtins.object] = None
74 | """
75 | key:value形式的元数据,以便将信息存储
76 | 结构化格式。键的长度是64个字符,值最长512个字符
77 | """
78 |
79 | output_file_id: Optional[str] = None
80 | """完成请求的输出文件的ID。"""
81 |
82 | request_counts: Optional[BatchRequestCounts] = None
83 | """批次中不同状态的请求计数"""
84 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_videos.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_videos(logging_conf):
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 | client = ZhipuAI() # 填写您自己的APIKey
11 | try:
12 | response = client.videos.generations(
13 | model='cogvideox', prompt='一个开船的人', user_id='1212222'
14 | )
15 | print(response)
16 |
17 | except zhipuai.core._errors.APIRequestFailedError as err:
18 | print(err)
19 | except zhipuai.core._errors.APIInternalError as err:
20 | print(err)
21 | except zhipuai.core._errors.APIStatusError as err:
22 | print(err)
23 |
24 |
25 | def test_videos_sensitive_word_check(logging_conf):
26 | logging.config.dictConfig(logging_conf) # type: ignore
27 | client = ZhipuAI() # 填写您自己的APIKey
28 | try:
29 | response = client.videos.generations(
30 | model='cogvideo',
31 | prompt='一个开船的人',
32 | sensitive_word_check={'type': 'ALL', 'status': 'DISABLE'},
33 | user_id='1212222',
34 | )
35 | print(response)
36 |
37 | except zhipuai.core._errors.APIRequestFailedError as err:
38 | print(err)
39 | except zhipuai.core._errors.APIInternalError as err:
40 | print(err)
41 | except zhipuai.core._errors.APIStatusError as err:
42 | print(err)
43 |
44 |
45 | def test_videos_image_url(logging_conf):
46 | logging.config.dictConfig(logging_conf) # type: ignore
47 | client = ZhipuAI() # 填写您自己的APIKey
48 | try:
49 | response = client.videos.generations(
50 | model='cogvideo',
51 | image_url='https://cdn.bigmodel.cn/static/platform/images/solutions/car/empowerment/icon-metric.png',
52 | prompt='一些相信光的人,举着奥特曼',
53 | user_id='12222211',
54 | )
55 | print(response)
56 |
57 | except zhipuai.core._errors.APIRequestFailedError as err:
58 | print(err)
59 | except zhipuai.core._errors.APIInternalError as err:
60 | print(err)
61 | except zhipuai.core._errors.APIStatusError as err:
62 | print(err)
63 |
64 |
65 | def test_retrieve_videos_result(logging_conf):
66 | logging.config.dictConfig(logging_conf) # type: ignore
67 | client = ZhipuAI() # 填写您自己的APIKey
68 | try:
69 | response = client.videos.retrieve_videos_result(id='1014908869548405238276203')
70 | print(response)
71 |
72 | except zhipuai.core._errors.APIRequestFailedError as err:
73 | print(err)
74 | except zhipuai.core._errors.APIInternalError as err:
75 | print(err)
76 | except zhipuai.core._errors.APIStatusError as err:
77 | print(err)
78 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_file.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import logging.config
5 | import os
6 |
7 | import pytest
8 |
9 | import zhipuai
10 | from zhipuai import ZhipuAI
11 |
12 |
13 | @pytest.fixture(scope='class')
14 | def test_server():
15 | class SharedData:
16 | client = ZhipuAI()
17 | file_id1 = None
18 | file_id2 = None
19 |
20 | return SharedData()
21 |
22 |
23 | class TestZhipuAIFileServer:
24 | def test_logs(self, logging_conf):
25 | logging.config.dictConfig(logging_conf) # type: ignore
26 |
27 | def test_files(self, test_server, test_file_path):
28 | try:
29 | result = test_server.client.files.create(
30 | file=open(os.path.join(test_file_path, 'demo.jsonl'), 'rb'),
31 | purpose='fine-tune',
32 | )
33 | print(result)
34 | test_server.file_id1 = result.id
35 |
36 | except zhipuai.core._errors.APIRequestFailedError as err:
37 | print(err)
38 | except zhipuai.core._errors.APIInternalError as err:
39 | print(err)
40 | except zhipuai.core._errors.APIStatusError as err:
41 | print(err)
42 |
43 | def test_files_validation(self, test_server, test_file_path):
44 | try:
45 | result = test_server.client.files.create(
46 | file=open(os.path.join(test_file_path, 'demo.jsonl'), 'rb'),
47 | purpose='fine-tune',
48 | )
49 | print(result)
50 |
51 | test_server.file_id2 = result.id
52 |
53 | except zhipuai.core._errors.APIRequestFailedError as err:
54 | print(err)
55 | except zhipuai.core._errors.APIInternalError as err:
56 | print(err)
57 | except zhipuai.core._errors.APIStatusError as err:
58 | print(err)
59 |
60 | def test_files_list(self, test_server):
61 | try:
62 | list = test_server.client.files.list()
63 | print(list)
64 |
65 | except zhipuai.core._errors.APIRequestFailedError as err:
66 | print(err)
67 | except zhipuai.core._errors.APIInternalError as err:
68 | print(err)
69 | except zhipuai.core._errors.APIStatusError as err:
70 | print(err)
71 |
72 | def test_delete_files(self, test_server):
73 | try:
74 | delete1 = test_server.client.files.delete(file_id=test_server.file_id1)
75 | print(delete1)
76 |
77 | delete2 = test_server.client.files.delete(file_id=test_server.file_id2)
78 | print(delete2)
79 |
80 | except zhipuai.core._errors.APIRequestFailedError as err:
81 | print(err)
82 | except zhipuai.core._errors.APIInternalError as err:
83 | print(err)
84 | except zhipuai.core._errors.APIStatusError as err:
85 | print(err)
86 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_vlm_thinking.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | import time
4 |
5 | import zhipuai
6 | from zhipuai import ZhipuAI
7 |
8 |
9 | def test_completions_vlm_thinking(logging_conf):
10 | logging.config.dictConfig(logging_conf) # type: ignore
11 | client = ZhipuAI() # 填写您自己的APIKey
12 | try:
13 | # 生成request_id
14 | request_id = time.time()
15 | print(f'request_id:{request_id}')
16 | response = client.chat.completions.create(
17 | request_id=request_id,
18 | model='glm-4.1v-thinking-flash', # 填写需要调用的模型名称
19 | messages=[
20 | {
21 | 'role': 'user',
22 | 'content': [
23 | {'type': 'text', 'text': '图里有什么'},
24 | {
25 | 'type': 'image_url',
26 | 'image_url': {
27 | 'url': 'https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f'
28 | },
29 | },
30 | ],
31 | }
32 | ],
33 | temperature=0.5,
34 | max_tokens=1024,
35 | user_id='12345678',
36 | )
37 | print(response)
38 |
39 | except zhipuai.core._errors.APIRequestFailedError as err:
40 | print(err)
41 | except zhipuai.core._errors.APIInternalError as err:
42 | print(err)
43 | except zhipuai.core._errors.APIStatusError as err:
44 | print(err)
45 |
46 |
47 | def test_completions_vlm_thinking_stream(logging_conf):
48 | logging.config.dictConfig(logging_conf) # type: ignore
49 | client = ZhipuAI() # 填写您自己的APIKey
50 | try:
51 | # 生成request_id
52 | request_id = time.time()
53 | print(f'request_id:{request_id}')
54 | response = client.chat.completions.create(
55 | request_id=request_id,
56 | model='glm-4.1v-thinking-flash', # 填写需要调用的模型名称
57 | messages=[
58 | {
59 | 'role': 'user',
60 | 'content': [
61 | {'type': 'text', 'text': '图里有什么'},
62 | {
63 | 'type': 'image_url',
64 | 'image_url': {
65 | 'url': 'https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f'
66 | },
67 | },
68 | ],
69 | }
70 | ],
71 | temperature=0.5,
72 | max_tokens=1024,
73 | user_id='12345678',
74 | stream=True,
75 | )
76 | for item in response:
77 | print(item)
78 |
79 | except zhipuai.core._errors.APIRequestFailedError as err:
80 | print(err)
81 | except zhipuai.core._errors.APIInternalError as err:
82 | print(err)
83 | except zhipuai.core._errors.APIStatusError as err:
84 | print(err)
85 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/images.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, List, Optional, TYPE_CHECKING
4 |
5 | import httpx
6 |
7 | from ..core import BaseAPI
8 | from ..core import NotGiven, NOT_GIVEN, Headers, Body
9 | from ..core import make_request_options
10 | from ..types.image import ImagesResponded
11 | from ..types.sensitive_word_check import SensitiveWordCheckRequest
12 |
13 | if TYPE_CHECKING:
14 | from .._client import ZhipuAI
15 |
16 |
17 | class Images(BaseAPI):
18 | def __init__(self, client: "ZhipuAI") -> None:
19 | super().__init__(client)
20 |
21 | def generations(
22 | self,
23 | *,
24 | prompt: str,
25 | model: str | NotGiven = NOT_GIVEN,
26 | n: Optional[int] | NotGiven = NOT_GIVEN,
27 | quality: Optional[str] | NotGiven = NOT_GIVEN,
28 | response_format: Optional[str] | NotGiven = NOT_GIVEN,
29 | size: Optional[str] | NotGiven = NOT_GIVEN,
30 | style: Optional[str] | NotGiven = NOT_GIVEN,
31 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
32 | user: str | NotGiven = NOT_GIVEN,
33 | request_id: Optional[str] | NotGiven = NOT_GIVEN,
34 | user_id: Optional[str] | NotGiven = NOT_GIVEN,
35 | extra_headers: Headers | None = None,
36 | extra_body: Body | None = None,
37 | disable_strict_validation: Optional[bool] | None = None,
38 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
39 | ) -> ImagesResponded:
40 | _cast_type = ImagesResponded
41 | if disable_strict_validation:
42 | _cast_type = object
43 | return self._post(
44 | "/images/generations",
45 | body={
46 | "prompt": prompt,
47 | "model": model,
48 | "n": n,
49 | "quality": quality,
50 | "response_format": response_format,
51 | "sensitive_word_check": sensitive_word_check,
52 | "size": size,
53 | "style": style,
54 | "user": user,
55 | "user_id": user_id,
56 | "request_id": request_id,
57 | },
58 | options=make_request_options(
59 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
60 | ),
61 | cast_type=_cast_type,
62 | stream=False,
63 | )
64 |
--------------------------------------------------------------------------------
/zhipuai/core/_errors.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import httpx
4 |
5 | __all__ = [
6 | "ZhipuAIError",
7 | "APIStatusError",
8 | "APIRequestFailedError",
9 | "APIAuthenticationError",
10 | "APIReachLimitError",
11 | "APIInternalError",
12 | "APIServerFlowExceedError",
13 | "APIResponseError",
14 | "APIResponseValidationError",
15 | "APITimeoutError",
16 | "APIConnectionError",
17 | ]
18 |
19 |
20 | class ZhipuAIError(Exception):
21 | def __init__(self, message: str, ) -> None:
22 | super().__init__(message)
23 |
24 |
25 | class APIStatusError(ZhipuAIError):
26 | response: httpx.Response
27 | status_code: int
28 |
29 | def __init__(self, message: str, *, response: httpx.Response) -> None:
30 | super().__init__(message)
31 | self.response = response
32 | self.status_code = response.status_code
33 |
34 |
35 | class APIRequestFailedError(APIStatusError):
36 | ...
37 |
38 |
39 | class APIAuthenticationError(APIStatusError):
40 | ...
41 |
42 |
43 | class APIReachLimitError(APIStatusError):
44 | ...
45 |
46 |
47 | class APIInternalError(APIStatusError):
48 | ...
49 |
50 |
51 | class APIServerFlowExceedError(APIStatusError):
52 | ...
53 |
54 |
55 | class APIResponseError(ZhipuAIError):
56 | message: str
57 | request: httpx.Request
58 | json_data: object
59 |
60 | def __init__(self, message: str, request: httpx.Request, json_data: object):
61 | self.message = message
62 | self.request = request
63 | self.json_data = json_data
64 | super().__init__(message)
65 |
66 |
67 | class APIResponseValidationError(APIResponseError):
68 | status_code: int
69 | response: httpx.Response
70 |
71 | def __init__(
72 | self,
73 | response: httpx.Response,
74 | json_data: object | None, *,
75 | message: str | None = None
76 | ) -> None:
77 | super().__init__(
78 | message=message or "Data returned by API invalid for expected schema.",
79 | request=response.request,
80 | json_data=json_data
81 | )
82 | self.response = response
83 | self.status_code = response.status_code
84 |
85 |
86 | class APIConnectionError(APIResponseError):
87 | def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
88 | super().__init__(message, request, json_data=None)
89 |
90 |
91 | class APITimeoutError(APIConnectionError):
92 | def __init__(self, request: httpx.Request) -> None:
93 | super().__init__(message="Request timed out.", request=request)
94 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_agents.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | import time
4 |
5 | import zhipuai
6 | from zhipuai import ZhipuAI
7 |
8 |
9 | def test_completions_sync(logging_conf):
10 | logging.config.dictConfig(logging_conf) # type: ignore
11 | client = ZhipuAI() # 填写您自己的APIKey
12 | try:
13 | # 生成request_id
14 | request_id = time.time()
15 | print(f'request_id:{request_id}')
16 | response = client.agents.invoke(
17 | request_id=request_id,
18 | agent_id='general_translation',
19 | messages=[{'role': 'user', 'content': 'tell me a joke'}],
20 | user_id='12345678',
21 | )
22 | print(response)
23 |
24 | except zhipuai.core._errors.APIRequestFailedError as err:
25 | print(err)
26 | except zhipuai.core._errors.APIInternalError as err:
27 | print(err)
28 | except zhipuai.core._errors.APIStatusError as err:
29 | print(err)
30 |
31 |
32 | def test_completions_stream(logging_conf):
33 | logging.config.dictConfig(logging_conf) # type: ignore
34 | client = ZhipuAI() # 填写您自己的APIKey
35 | try:
36 | # 生成request_id
37 | request_id = time.time()
38 | print(f'request_id:{request_id}')
39 | response = client.agents.invoke(
40 | request_id=request_id,
41 | agent_id='general_translation',
42 | messages=[{'role': 'user', 'content': 'tell me a joke'}],
43 | user_id='12345678',
44 | stream=True,
45 | )
46 | for item in response:
47 | print(item)
48 |
49 | except zhipuai.core._errors.APIRequestFailedError as err:
50 | print(err)
51 | except zhipuai.core._errors.APIInternalError as err:
52 | print(err)
53 | except zhipuai.core._errors.APIStatusError as err:
54 | print(err)
55 |
56 | def test_correction():
57 | client = ZhipuAI() # 请替换为实际API密钥
58 |
59 | response = client.agents.invoke(
60 | agent_id="intelligent_education_correction_agent",
61 | messages=[
62 | {
63 | "role": "user",
64 | "content": [
65 | {
66 | "type": "image_url",
67 | "image_url": "https://b0.bdstatic.com/e24937f1f6b9c0ff6895e1012c981515.jpg"
68 | }
69 | ]
70 | }
71 | ]
72 | )
73 | print(response)
74 |
75 | def test_correction_result(image_id,uuids,trace_id):
76 | client = ZhipuAI()
77 |
78 | response = client.agents.async_result(
79 | agent_id="intelligent_education_correction_polling",
80 | custom_variables={
81 | "images": [
82 | {
83 | "image_id": image_id,
84 | "uuids": uuids
85 | }
86 | ],
87 | "trace_id": trace_id
88 | }
89 | )
90 | print(response)
91 |
92 | def main():
93 | test_completions_sync()
94 | test_completions_stream()
95 | # test_correction()
96 | # test_correction_result(image_id,uuids,trace_id)
97 |
98 | if __name__ == "__main__":
99 | main()
--------------------------------------------------------------------------------
/tests/integration_tests/test_charglm3.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import logging
4 | import logging.config
5 |
6 | import zhipuai
7 | from zhipuai import ZhipuAI
8 |
9 |
10 | def test_completions_charglm(logging_conf):
11 | logging.config.dictConfig(logging_conf) # type: ignore
12 | client = ZhipuAI() # 请填写您自己的APIKey
13 | try:
14 | response = client.chat.completions.create(
15 | model='charglm-3', # 填写需要调用的模型名称
16 | messages=[{'role': 'user', 'content': '请问你在做什么'}],
17 | meta={
18 | 'user_info': '我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。',
19 | 'bot_info': '苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。',
20 | 'bot_name': '苏梦远',
21 | 'user_name': '陆星辰',
22 | },
23 | user_id='12345678',
24 | )
25 | print(response)
26 |
27 | except zhipuai.core._errors.APIRequestFailedError as err:
28 | print(err)
29 | except zhipuai.core._errors.APIInternalError as err:
30 | print(err)
31 | except zhipuai.core._errors.APIStatusError as err:
32 | print(err)
33 |
34 |
35 | def test_async_completions():
36 | client = ZhipuAI() # 请填写您自己的APIKey
37 | try:
38 | response = client.chat.asyncCompletions.create(
39 | model='charglm', # 填写需要调用的模型名称
40 | messages=[{'role': 'user', 'content': '请问你在做什么'}],
41 | meta={
42 | 'user_info': '我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。',
43 | 'bot_info': '苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。',
44 | 'bot_name': '苏梦远',
45 | 'user_name': '陆星辰',
46 | },
47 | )
48 | print(response)
49 |
50 | except zhipuai.core._errors.APIRequestFailedError as err:
51 | print(err)
52 | except zhipuai.core._errors.APIInternalError as err:
53 | print(err)
54 | except zhipuai.core._errors.APIStatusError as err:
55 | print(err)
56 |
57 |
58 | # def test_retrieve_completion_result():
59 | # client = ZhipuAI() # 请填写您自己的APIKey
60 | # try:
61 | # response = client.chat.asyncCompletions.retrieve_completion_result(id="1014908592669352541650991")
62 | # print(response)
63 | #
64 | #
65 | # except zhipuai.core._errors.APIRequestFailedError as err:
66 | # print(err)
67 | # except zhipuai.core._errors.APIInternalError as err:
68 | # print(err)
69 |
70 | # if __name__ == "__main__":
71 | # test_retrieve_completion_result()
72 |
--------------------------------------------------------------------------------
/zhipuai/core/_files.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 | import os
5 | import pathlib
6 | from typing import Mapping, Sequence, overload
7 | from typing_extensions import TypeGuard
8 |
9 | from ._base_type import (
10 | FileTypes,
11 | HttpxFileTypes,
12 | HttpxRequestFiles,
13 | RequestFiles,
14 | Base64FileInput, FileContent, HttpxFileContent,
15 | )
16 | from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
17 |
18 |
19 | def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
20 | return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
21 |
22 |
23 | def is_file_content(obj: object) -> TypeGuard[FileContent]:
24 | return (
25 | isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj,
26 | os.PathLike)
27 | )
28 |
29 |
30 | def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
31 | if not is_file_content(obj):
32 | prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
33 | raise RuntimeError(
34 | f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
35 | ) from None
36 |
37 |
38 | @overload
39 | def to_httpx_files(files: None) -> None:
40 | ...
41 |
42 |
43 | @overload
44 | def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles:
45 | ...
46 |
47 |
48 | def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
49 | if files is None:
50 | return None
51 |
52 | if is_mapping_t(files):
53 | files = {key: _transform_file(file) for key, file in files.items()}
54 | elif is_sequence_t(files):
55 | files = [(key, _transform_file(file)) for key, file in files]
56 | else:
57 | raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
58 |
59 | return files
60 |
61 |
62 | def _transform_file(file: FileTypes) -> HttpxFileTypes:
63 | if is_file_content(file):
64 | if isinstance(file, os.PathLike):
65 | path = pathlib.Path(file)
66 | return (path.name, path.read_bytes())
67 |
68 | return file
69 |
70 | if is_tuple_t(file):
71 | return (file[0], _read_file_content(file[1]), *file[2:])
72 |
73 | raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
74 |
75 |
76 | def _read_file_content(file: FileContent) -> HttpxFileContent:
77 | if isinstance(file, os.PathLike):
78 | return pathlib.Path(file).read_bytes()
79 | return file
80 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_audio.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import logging
4 | import logging.config
5 | from pathlib import Path
6 |
7 | import zhipuai
8 | from zhipuai import ZhipuAI
9 |
10 |
11 | def test_audio_speech(logging_conf):
12 | logging.config.dictConfig(logging_conf) # type: ignore
13 | client = ZhipuAI() # 填写您自己的APIKey
14 | try:
15 | speech_file_path = Path(__file__).parent / 'speech.wav'
16 | response = client.audio.speech(
17 | model='cogtts',
18 | input='你好呀,欢迎来到智谱开放平台',
19 | voice='tongtong',
20 | stream=False,
21 | response_format='wav',
22 | )
23 | response.stream_to_file(speech_file_path)
24 |
25 | except zhipuai.core._errors.APIRequestFailedError as err:
26 | print(err)
27 | except zhipuai.core._errors.APIInternalError as err:
28 | print(err)
29 | except zhipuai.core._errors.APIStatusError as err:
30 | print(err)
31 |
32 | def test_audio_speech_streaming(logging_conf):
33 | logging.config.dictConfig(logging_conf) # type: ignore
34 | client = ZhipuAI() # 填写您自己的APIKey
35 | try:
36 | response = client.audio.speech(
37 | model='cogtts',
38 | input='你好呀,欢迎来到智谱开放平台',
39 | voice='tongtong',
40 | stream=True,
41 | response_format='mp3',
42 | encode_format='hex'
43 | )
44 | with open("output.mp3", "wb") as f:
45 | for item in response:
46 | choice = item.choices[0]
47 | index = choice.index
48 | finish_reason = choice.finish_reason
49 | audio_delta = choice.delta.content
50 | if finish_reason is not None:
51 | break
52 | f.write(bytes.fromhex(audio_delta))
53 | print(f"audio delta: {audio_delta[:64]}..., 长度:{len(audio_delta)}")
54 | except zhipuai.core._errors.APIRequestFailedError as err:
55 | print(err)
56 | except zhipuai.core._errors.APIInternalError as err:
57 | print(err)
58 | except zhipuai.core._errors.APIStatusError as err:
59 | print(err)
60 | except Exception as e:
61 | print(e)
62 |
63 |
64 | def test_audio_customization(logging_conf):
65 | logging.config.dictConfig(logging_conf)
66 | client = ZhipuAI() # 填写您自己的APIKey
67 | with open(Path(__file__).parent / 'speech.wav', 'rb') as file:
68 | try:
69 | speech_file_path = Path(__file__).parent / 'speech.wav'
70 | response = client.audio.customization(
71 | model='cogtts',
72 | input='你好呀,欢迎来到智谱开放平台',
73 | voice_text='这是一条测试用例',
74 | voice_data=file,
75 | response_format='wav',
76 | )
77 | response.stream_to_file(speech_file_path)
78 |
79 | except zhipuai.core._errors.APIRequestFailedError as err:
80 | print(err)
81 | except zhipuai.core._errors.APIInternalError as err:
82 | print(err)
83 | except zhipuai.core._errors.APIStatusError as err:
84 | print(err)
85 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/web_search/web_search.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, List, Union, Dict, Optional
4 | from ...types.sensitive_word_check import SensitiveWordCheckRequest
5 | from ...core import NOT_GIVEN, Body, Headers, NotGiven, BaseAPI, maybe_transform, StreamResponse, deepcopy_minimal
6 |
7 | import httpx
8 |
9 | from ...core import (
10 | make_request_options,
11 | )
12 | import logging
13 |
14 | from ...types.web_search import web_search_create_params
15 | from ...types.web_search.web_search_resp import WebSearchResp
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | if TYPE_CHECKING:
20 | from ..._client import ZhipuAI
21 |
22 | __all__ = ["WebSearchApi"]
23 |
24 |
25 | class WebSearchApi(BaseAPI):
26 | def __init__(self, client: "ZhipuAI") -> None:
27 | super().__init__(client)
28 |
29 | def web_search(
30 | self,
31 | *,
32 | request_id: Optional[str] | NotGiven = NOT_GIVEN,
33 | search_engine: Optional[str] | NotGiven = NOT_GIVEN,
34 | search_query: Optional[str] | NotGiven = NOT_GIVEN,
35 | user_id: Optional[str] | NotGiven = NOT_GIVEN,
36 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
37 | count: Optional[int] | NotGiven = NOT_GIVEN,
38 | search_domain_filter: Optional[str] | NotGiven = NOT_GIVEN,
39 | search_recency_filter: Optional[str] | NotGiven = NOT_GIVEN,
40 | content_size: Optional[str] | NotGiven = NOT_GIVEN,
41 | search_intent: Optional[bool] | NotGiven = NOT_GIVEN,
42 | location: Optional[str] | NotGiven = NOT_GIVEN,
43 | extra_headers: Headers | None = None,
44 | extra_body: Body | None = None,
45 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
46 | ) -> WebSearchResp:
47 |
48 | body = deepcopy_minimal(
49 | {
50 | "request_id": request_id,
51 | "search_engine": search_engine,
52 | "search_query": search_query,
53 | "user_id": user_id,
54 | "sensitive_word_check": sensitive_word_check,
55 | "count":count,
56 | "search_domain_filter": search_domain_filter,
57 | "search_recency_filter": search_recency_filter,
58 | "content_size": content_size,
59 | "search_intent": search_intent,
60 | "location": location
61 | })
62 | return self._post(
63 | "/web_search",
64 | body= maybe_transform(body, web_search_create_params.WebSearchCreatParams),
65 | options=make_request_options(
66 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
67 | ),
68 | cast_type=WebSearchResp
69 | )
70 |
--------------------------------------------------------------------------------
/zhipuai/core/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base_models import (
2 | BaseModel,
3 | construct_type
4 | )
5 | from ._base_api import BaseAPI
6 | from ._base_type import (
7 | NOT_GIVEN,
8 | Headers,
9 | NotGiven,
10 | Body,
11 | IncEx,
12 | ModelT,
13 | Query,
14 | FileTypes,
15 |
16 | )
17 | from ._base_compat import (
18 | PYDANTIC_V2,
19 | ConfigDict,
20 | GenericModel,
21 | get_args,
22 | is_union,
23 | parse_obj,
24 | get_origin,
25 | is_literal_type,
26 | get_model_config,
27 | get_model_fields,
28 | field_get_default,
29 | cached_property,
30 | )
31 | from ._files import (
32 | is_file_content
33 | )
34 | from ._errors import (
35 | ZhipuAIError,
36 | APIStatusError,
37 | APIRequestFailedError,
38 | APIAuthenticationError,
39 | APIReachLimitError,
40 | APIInternalError,
41 | APIServerFlowExceedError,
42 | APIResponseError,
43 | APIResponseValidationError,
44 | APIConnectionError,
45 | APITimeoutError,
46 | )
47 | from ._http_client import (
48 | make_request_options,
49 | HttpClient
50 |
51 | )
52 | from ._utils import (
53 | is_list,
54 | is_mapping,
55 | parse_date,
56 | parse_datetime,
57 | is_given,
58 | maybe_transform,
59 | deepcopy_minimal,
60 | extract_files,
61 | drop_prefix_image_data,
62 | )
63 |
64 | from ._sse_client import StreamResponse
65 |
66 | from ._constants import (
67 |
68 | ZHIPUAI_DEFAULT_TIMEOUT,
69 | ZHIPUAI_DEFAULT_MAX_RETRIES,
70 | ZHIPUAI_DEFAULT_LIMITS,
71 | )
72 | __all__ = [
73 | "BaseModel",
74 | "construct_type",
75 | "BaseAPI",
76 | "NOT_GIVEN",
77 | "Headers",
78 | "NotGiven",
79 | "Body",
80 | "IncEx",
81 | "ModelT",
82 | "Query",
83 | "FileTypes",
84 |
85 | "PYDANTIC_V2",
86 | "ConfigDict",
87 | "GenericModel",
88 | "get_args",
89 | "is_union",
90 | "parse_obj",
91 | "get_origin",
92 | "is_literal_type",
93 | "get_model_config",
94 | "get_model_fields",
95 | "field_get_default",
96 |
97 | "is_file_content",
98 |
99 | "ZhipuAIError",
100 | "APIStatusError",
101 | "APIRequestFailedError",
102 | "APIAuthenticationError",
103 | "APIReachLimitError",
104 | "APIInternalError",
105 | "APIServerFlowExceedError",
106 | "APIResponseError",
107 | "APIResponseValidationError",
108 | "APITimeoutError",
109 |
110 | "make_request_options",
111 | "HttpClient",
112 | "ZHIPUAI_DEFAULT_TIMEOUT",
113 | "ZHIPUAI_DEFAULT_MAX_RETRIES",
114 | "ZHIPUAI_DEFAULT_LIMITS",
115 |
116 | "is_list",
117 | "is_mapping",
118 | "parse_date",
119 | "parse_datetime",
120 | "is_given",
121 | "maybe_transform",
122 |
123 | "deepcopy_minimal",
124 | "extract_files",
125 |
126 | "StreamResponse",
127 |
128 | ]
129 |
--------------------------------------------------------------------------------
/zhipuai/core/logs.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | class LoggerNameFilter(logging.Filter):
9 | def filter(self, record):
10 | # return record.name.startswith("loom_core") or record.name in "ERROR" or (
11 | # record.name.startswith("uvicorn.error")
12 | # and record.getMessage().startswith("Uvicorn running on")
13 | # )
14 | return True
15 |
16 |
17 | def get_log_file(log_path: str, sub_dir: str):
18 | """
19 | sub_dir should contain a timestamp.
20 | """
21 | log_dir = os.path.join(log_path, sub_dir)
22 | # Here should be creating a new directory each time, so `exist_ok=False`
23 | os.makedirs(log_dir, exist_ok=False)
24 | return os.path.join(log_dir, "zhipuai.log")
25 |
26 |
27 | def get_config_dict(
28 | log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
29 | ) -> dict:
30 | # for windows, the path should be a raw string.
31 | log_file_path = (
32 | log_file_path.encode("unicode-escape").decode()
33 | if os.name == "nt"
34 | else log_file_path
35 | )
36 | log_level = log_level.upper()
37 | config_dict = {
38 | "version": 1,
39 | "disable_existing_loggers": False,
40 | "formatters": {
41 | "formatter": {
42 | "format": (
43 | "%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s"
44 | )
45 | },
46 | },
47 | "filters": {
48 | "logger_name_filter": {
49 | "()": __name__ + ".LoggerNameFilter",
50 | },
51 | },
52 | "handlers": {
53 | "stream_handler": {
54 | "class": "logging.StreamHandler",
55 | "formatter": "formatter",
56 | "level": log_level,
57 | # "stream": "ext://sys.stdout",
58 | # "filters": ["logger_name_filter"],
59 | },
60 | "file_handler": {
61 | "class": "logging.handlers.RotatingFileHandler",
62 | "formatter": "formatter",
63 | "level": log_level,
64 | "filename": log_file_path,
65 | "mode": "a",
66 | "maxBytes": log_max_bytes,
67 | "backupCount": log_backup_count,
68 | "encoding": "utf8",
69 | },
70 | },
71 | "loggers": {
72 | "loom_core": {
73 | "handlers": ["stream_handler", "file_handler"],
74 | "level": log_level,
75 | "propagate": False,
76 | }
77 | },
78 | "root": {
79 | "level": log_level,
80 | "handlers": ["stream_handler", "file_handler"],
81 | },
82 | }
83 | return config_dict
84 |
85 |
86 | def get_timestamp_ms():
87 | t = time.time()
88 | return int(round(t * 1000))
89 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/audio/transcriptions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, List, Mapping, cast, Optional, Dict
4 | from ...types.audio import transcriptions_create_param
5 |
6 | import httpx
7 | import logging
8 | from typing_extensions import Literal
9 |
10 | from ...core import BaseAPI, deepcopy_minimal, maybe_transform, drop_prefix_image_data
11 | from ...core import make_request_options
12 | from ...core import StreamResponse
13 | from ...types.chat.chat_completion import Completion
14 | from ...types.chat.chat_completion_chunk import ChatCompletionChunk
15 | from ...types.sensitive_word_check import SensitiveWordCheckRequest
16 | from ...core import NOT_GIVEN, Body, Headers, NotGiven, FileTypes
17 | from zhipuai.core._utils import extract_files
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 | if TYPE_CHECKING:
23 | from ..._client import ZhipuAI
24 |
25 |
26 | __all__ = ["Transcriptions"]
27 |
28 | class Transcriptions(BaseAPI):
29 | def __init__(self, client: "ZhipuAI") -> None:
30 | super().__init__(client)
31 |
32 | def create(
33 | self,
34 | *,
35 | file: FileTypes,
36 | model: str,
37 | request_id: Optional[str] | NotGiven = NOT_GIVEN,
38 | user_id: Optional[str] | NotGiven = NOT_GIVEN,
39 | stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
40 | temperature: Optional[float] | NotGiven = NOT_GIVEN,
41 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
42 | extra_headers: Headers | None = None,
43 | extra_body: Body | None = None,
44 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
45 | ) -> Completion | StreamResponse[ChatCompletionChunk]:
46 | if temperature is not None and temperature != NOT_GIVEN:
47 | if temperature <= 0:
48 | temperature = 0.01
49 | if temperature >= 1:
50 | temperature = 0.99
51 |
52 | body = deepcopy_minimal({
53 | "model": model,
54 | "file": file,
55 | "request_id": request_id,
56 | "user_id": user_id,
57 | "temperature": temperature,
58 | "sensitive_word_check": sensitive_word_check,
59 | "stream": stream
60 | })
61 | files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
62 | if files:
63 | extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
64 | return self._post(
65 | "/audio/transcriptions",
66 | body=maybe_transform(body, transcriptions_create_param.TranscriptionsParam),
67 | files=files,
68 | options=make_request_options(
69 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
70 | ),
71 | cast_type=Completion,
72 | stream=stream or False,
73 | stream_cls=StreamResponse[ChatCompletionChunk],
74 | )
75 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_assistant.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_assistant(logging_conf) -> None:
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 | client = ZhipuAI() # 填写您自己的APIKey
11 | try:
12 | generate = client.assistant.conversation(
13 | assistant_id='659e54b1b8006379b4b2abd6',
14 | messages=[
15 | {
16 | 'role': 'user',
17 | 'content': [{'type': 'text', 'text': '帮我搜索下智谱的cogvideox发布时间'}],
18 | }
19 | ],
20 | stream=True,
21 | attachments=None,
22 | metadata=None,
23 | request_id='request_1790291013237211136',
24 | user_id='12345678',
25 | )
26 | for assistant in generate:
27 | print(assistant)
28 |
29 | except zhipuai.core._errors.APIRequestFailedError as err:
30 | print(err)
31 | except zhipuai.core._errors.APIInternalError as err:
32 | print(err)
33 | except zhipuai.core._errors.APIStatusError as err:
34 | print(err)
35 |
36 |
37 | def test_assistant_query_support(logging_conf) -> None:
38 | logging.config.dictConfig(logging_conf) # type: ignore
39 | client = ZhipuAI() # 填写您自己的APIKey
40 | try:
41 | response = client.assistant.query_support(
42 | assistant_id_list=[],
43 | request_id='request_1790291013237211136',
44 | user_id='12345678',
45 | )
46 | print(response)
47 |
48 | except zhipuai.core._errors.APIRequestFailedError as err:
49 | print(err)
50 | except zhipuai.core._errors.APIInternalError as err:
51 | print(err)
52 | except zhipuai.core._errors.APIStatusError as err:
53 | print(err)
54 |
55 |
56 | def test_assistant_query_conversation_usage(logging_conf) -> None:
57 | logging.config.dictConfig(logging_conf) # type: ignore
58 | client = ZhipuAI() # 填写您自己的APIKey
59 | try:
60 | response = client.assistant.query_conversation_usage(
61 | assistant_id='659e54b1b8006379b4b2abd6',
62 | request_id='request_1790291013237211136',
63 | user_id='12345678',
64 | )
65 | print(response)
66 | except zhipuai.core._errors.APIRequestFailedError as err:
67 | print(err)
68 | except zhipuai.core._errors.APIInternalError as err:
69 | print(err)
70 | except zhipuai.core._errors.APIStatusError as err:
71 | print(err)
72 |
73 |
74 | def test_translate_api(logging_conf) -> None:
75 | logging.config.dictConfig(logging_conf) # type: ignore
76 | client = ZhipuAI() # 填写您自己的APIKey
77 | try:
78 | translate_response = client.assistant.conversation(
79 | assistant_id='9996ijk789lmn012o345p999',
80 | messages=[{'role': 'user', 'content': [{'type': 'text', 'text': '你好呀'}]}],
81 | stream=True,
82 | attachments=None,
83 | metadata=None,
84 | request_id='request_1790291013237211136',
85 | user_id='12345678',
86 | extra_parameters={'translate': {'from': 'zh', 'to': 'en'}},
87 | )
88 | for chunk in translate_response:
89 | print(chunk.choices[0].delta)
90 | # print(translate_response)
91 | except zhipuai.core._errors.APIRequestFailedError as err:
92 | print(err)
93 | except zhipuai.core._errors.APIInternalError as err:
94 | print(err)
95 | except zhipuai.core._errors.APIStatusError as err:
96 | print(err)
97 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/videos/videos.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 |
4 | from typing import TYPE_CHECKING, List, Mapping, cast, Optional, Dict
5 | from typing_extensions import Literal
6 |
7 | from ...types.sensitive_word_check import SensitiveWordCheckRequest
8 | from ...types.video import video_create_params
9 | from ...types.video import VideoObject
10 | from ...core import BaseAPI, maybe_transform
11 | from ...core import NOT_GIVEN, Body, Headers, NotGiven
12 |
13 | import httpx
14 |
15 | from ...core import (
16 | make_request_options,
17 | )
18 | from ...core import deepcopy_minimal, extract_files
19 |
20 | if TYPE_CHECKING:
21 | from ..._client import ZhipuAI
22 |
23 | __all__ = ["Videos"]
24 |
25 |
26 | class Videos(BaseAPI):
27 |
28 | def __init__(self, client: "ZhipuAI") -> None:
29 | super().__init__(client)
30 |
31 | def generations(
32 | self,
33 | model: str,
34 | *,
35 | prompt: str = None,
36 | image_url: object = None,
37 | quality: str = None,
38 | with_audio: bool = None,
39 | size: str = None,
40 | duration: int = None,
41 | fps: int = None,
42 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
43 | request_id: str = None,
44 | user_id: str = None,
45 | extra_headers: Headers | None = None,
46 | extra_body: Body | None = None,
47 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
48 | ) -> VideoObject:
49 |
50 | if not model and not model:
51 | raise ValueError("At least one of `model` and `prompt` must be provided.")
52 | body = deepcopy_minimal(
53 | {
54 | "model": model,
55 | "prompt": prompt,
56 | "image_url": image_url,
57 | "sensitive_word_check": sensitive_word_check,
58 | "request_id": request_id,
59 | "user_id": user_id,
60 | "quality": quality,
61 | "with_audio": with_audio,
62 | "size": size,
63 | "duration": duration,
64 | "fps": fps
65 | }
66 | )
67 | return self._post(
68 | "/videos/generations",
69 | body=maybe_transform(body, video_create_params.VideoCreateParams),
70 | options=make_request_options(
71 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
72 | ),
73 | cast_type=VideoObject,
74 | )
75 |
76 | def retrieve_videos_result(
77 | self,
78 | id: str,
79 | *,
80 | extra_headers: Headers | None = None,
81 | extra_body: Body | None = None,
82 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
83 | ) -> VideoObject:
84 |
85 | if not id:
86 | raise ValueError("At least one of `id` must be provided.")
87 |
88 | return self._get(
89 | f"/async-result/{id}",
90 | options=make_request_options(
91 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
92 | ),
93 | cast_type=VideoObject,
94 | )
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "zhipuai"
3 | version = "2.1.5.20250825"
4 | description = "A SDK library for accessing big model apis from ZhipuAI"
5 | authors = ["Zhipu AI"]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = ">=3.8,<4.0.0,!=3.9.7 "
10 | httpx = ">=0.23.0"
11 | pydantic = ">=1.9.0,<3.0"
12 | pydantic-core = ">=2.14.6"
13 | cachetools = ">=4.2.2"
14 | pyjwt = "~=2.8.0"
15 |
16 |
17 | [tool.poetry.group.test.dependencies]
18 | # The only dependencies that should be added are
19 | # dependencies used for running tests (e.g., pytest, freezegun, response).
20 | # Any dependencies that do not meet that criteria will be removed.
21 | pytest = "^7.3.0"
22 | pytest-cov = "^4.0.0"
23 | pytest-dotenv = "^0.5.2"
24 | duckdb-engine = "^0.9.2"
25 | pytest-watcher = "^0.2.6"
26 | freezegun = "^1.2.2"
27 | responses = "^0.22.0"
28 | pytest-asyncio = { version = "^0.23.2", python = "^3.8" }
29 | lark = "^1.1.5"
30 | pytest-mock = "^3.10.0"
31 | pytest-socket = { version = "^0.6.0", python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0" }
32 | syrupy = { version = "^4.0.2", python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0" }
33 | requests-mock = "^1.11.0"
34 | respx = "0.21.1"
35 |
36 |
37 | [tool.poetry.group.lint]
38 | optional = true
39 |
40 | [tool.poetry.group.lint.dependencies]
41 | ruff = "^0.1.5"
42 |
43 | [tool.poetry.extras]
44 | cli = ["typer"]
45 | # An extra used to be able to add extended testing.
46 | # Please use new-line on formatting to make it easier to add new packages without
47 | # merge-conflicts
48 | extended_testing = [
49 | "langchain",
50 | ]
51 |
52 | [tool.ruff.lint]
53 | select = [
54 | "E", # pycodestyle
55 | "F", # pyflakes
56 | "I", # isort
57 | "T201", # print
58 | ]
59 |
60 | [tool.ruff]
61 | line-length = 88
62 |
63 | [tool.ruff.format]
64 | quote-style = "single"
65 | indent-style = "tab"
66 | docstring-code-format = true
67 |
68 | [tool.coverage.run]
69 | omit = [
70 | "tests/*",
71 | ]
72 |
73 | [build-system]
74 | requires = ["poetry-core>=1.0.0", "poetry-plugin-pypi-mirror==0.4.2"]
75 | build-backend = "poetry.core.masonry.api"
76 |
77 | [tool.pytest.ini_options]
78 | # --strict-markers will raise errors on unknown marks.
79 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
80 | #
81 | # https://docs.pytest.org/en/7.1.x/reference/reference.html
82 | # --strict-config any warnings encountered while parsing the `pytest`
83 | # section of the configuration file raise errors.
84 | #
85 | # https://github.com/tophat/syrupy
86 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
87 | addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
88 | # Registering custom markers.
89 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
90 | markers = [
91 | "requires: mark tests as requiring a specific library",
92 | "scheduled: mark tests to run in scheduled testing",
93 | "compile: mark placeholder test used to compile integration tests without running them"
94 | ]
95 | asyncio_mode = "auto"
96 |
97 |
98 | # https://python-poetry.org/docs/repositories/
99 | #[[tool.poetry.source]]
100 | #name = "tsinghua"
101 | #url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
102 | #priority = "default"
103 |
--------------------------------------------------------------------------------
/.github/workflows/_test_release.yml:
--------------------------------------------------------------------------------
1 | name: test-release
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 |
11 | env:
12 | POETRY_VERSION: "1.8.2"
13 | PYTHON_VERSION: "3.9"
14 |
15 | jobs:
16 | build:
17 | if: github.ref == 'refs/heads/main'
18 | runs-on: ubuntu-latest
19 |
20 | outputs:
21 | pkg-name: ${{ steps.check-version.outputs.pkg-name }}
22 | version: ${{ steps.check-version.outputs.version }}
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 |
27 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }}
28 | uses: "./.github/actions/poetry_setup"
29 | with:
30 | python-version: ${{ env.PYTHON_VERSION }}
31 | poetry-version: ${{ env.POETRY_VERSION }}
32 | working-directory: ${{ inputs.working-directory }}
33 | cache-key: release
34 |
35 | # We want to keep this build stage *separate* from the release stage,
36 | # so that there's no sharing of permissions between them.
37 | # The release stage has trusted publishing and GitHub repo contents write access,
38 | # and we want to keep the scope of that access limited just to the release job.
39 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency)
40 | # could get access to our GitHub or PyPI credentials.
41 | #
42 | # Per the trusted publishing GitHub Action:
43 | # > It is strongly advised to separate jobs for building [...]
44 | # > from the publish job.
45 | # https://github.com/pypa/gh-action-pypi-publish#non-goals
46 | - name: Build project for distribution
47 | run: poetry build
48 | working-directory: ${{ inputs.working-directory }}
49 |
50 | - name: Upload build
51 | uses: actions/upload-artifact@v4
52 | with:
53 | name: test-dist
54 | path: ${{ inputs.working-directory }}/dist/
55 |
56 | - name: Check Version
57 | id: check-version
58 | shell: bash
59 | working-directory: ${{ inputs.working-directory }}
60 | run: |
61 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT
62 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT
63 |
64 | publish:
65 | needs:
66 | - build
67 | runs-on: ubuntu-latest
68 | environment: Scheduled testing publish
69 | # permissions:
70 | # id-token: none # This is required for requesting the JWT
71 |
72 | steps:
73 | - uses: actions/checkout@v4
74 |
75 | - uses: actions/download-artifact@v4
76 | with:
77 | name: test-dist
78 | path: ${{ inputs.working-directory }}/dist/
79 |
80 | - name: Publish to test PyPI
81 | uses: pypa/gh-action-pypi-publish@release/v1
82 | with:
83 | user: __token__
84 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
85 | packages-dir: ${{ inputs.working-directory }}/dist/
86 | verbose: true
87 | print-hash: true
88 | repository-url: https://test.pypi.org/legacy/
89 | # We overwrite any existing distributions with the same name and version.
90 | # This is *only for CI use* and is *extremely dangerous* otherwise!
91 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates
92 | skip-existing: true
93 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/agents/agents.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, List, Optional, TYPE_CHECKING, Dict
4 |
5 | import httpx
6 | import logging
7 | from typing_extensions import Literal
8 |
9 | from ...core import BaseAPI, deepcopy_minimal
10 | from ...core import NotGiven, NOT_GIVEN, Headers, Query, Body
11 | from ...core import make_request_options
12 | from ...core import StreamResponse
13 | from ...types.agents.agents_completion import AgentsCompletion
14 | from ...types.agents.agents_completion_chunk import AgentsCompletionChunk
15 | from ...types.sensitive_word_check import SensitiveWordCheckRequest
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | if TYPE_CHECKING:
20 | from ..._client import ZhipuAI
21 |
22 |
23 | class Agents(BaseAPI):
24 |
25 | def __init__(self, client: "ZhipuAI") -> None:
26 | super().__init__(client)
27 |
28 | def invoke(
29 | self,
30 | agent_id: Optional[str] | NotGiven = NOT_GIVEN,
31 | request_id: Optional[str] | NotGiven = NOT_GIVEN,
32 | stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
33 | messages: Union[str, List[str], List[int], object, None] | NotGiven = NOT_GIVEN,
34 | user_id: Optional[str] | NotGiven = NOT_GIVEN,
35 | custom_variables: object = NOT_GIVEN,
36 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
37 | extra_headers: Headers | None = None,
38 | extra_body: Body | None = None,
39 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
40 | ) -> AgentsCompletion | StreamResponse[AgentsCompletionChunk]:
41 | body = deepcopy_minimal({
42 | "agent_id": agent_id,
43 | "request_id": request_id,
44 | "user_id": user_id,
45 | "messages": messages,
46 | "sensitive_word_check": sensitive_word_check,
47 | "stream": stream,
48 | "custom_variables": custom_variables
49 | })
50 |
51 | return self._post(
52 | "/v1/agents",
53 | body=body,
54 | options=make_request_options(
55 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
56 | ),
57 | cast_type=AgentsCompletion,
58 | stream=stream or False,
59 | stream_cls=StreamResponse[AgentsCompletionChunk],
60 | )
61 |
62 | def async_result(
63 | self,
64 | agent_id: Optional[str] | NotGiven = NOT_GIVEN,
65 | async_id: Optional[str] | NotGiven = NOT_GIVEN,
66 | conversation_id: Optional[str] | NotGiven = NOT_GIVEN,
67 | custom_variables: object = NOT_GIVEN,
68 | extra_headers: Headers | None = None,
69 | extra_body: Body | None = None,
70 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
71 | ) -> AgentsCompletion:
72 | body = deepcopy_minimal({
73 | "agent_id": agent_id,
74 | "async_id": async_id,
75 | "conversation_id": conversation_id,
76 | "custom_variables": custom_variables
77 | })
78 | return self._post(
79 | "/v1/agents/async-result",
80 | body=body,
81 | options=make_request_options(
82 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
83 | ),
84 | cast_type=AgentsCompletion,
85 | )
86 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Configuration for unit tests."""
2 |
3 | from importlib import util
4 | from pathlib import Path
5 | from typing import Dict, Sequence
6 |
7 | import pytest
8 | from pytest import Config, Function, Parser
9 |
10 | from zhipuai.core.logs import (
11 | get_config_dict,
12 | get_log_file,
13 | get_timestamp_ms,
14 | )
15 |
16 |
17 | def pytest_addoption(parser: Parser) -> None:
18 | """Add custom command line options to pytest."""
19 | parser.addoption(
20 | '--only-extended',
21 | action='store_true',
22 | help='Only run extended tests. Does not allow skipping any extended tests.',
23 | )
24 | parser.addoption(
25 | '--only-core',
26 | action='store_true',
27 | help='Only run core tests. Never runs any extended tests.',
28 | )
29 |
30 |
31 | def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
32 | """Add implementations for handling custom markers.
33 |
34 | At the moment, this adds support for a custom `requires` marker.
35 |
36 | The `requires` marker is used to denote tests that require one or more packages
37 | to be installed to run. If the package is not installed, the test is skipped.
38 |
39 | The `requires` marker syntax is:
40 |
41 | .. code-block:: python
42 |
43 | @pytest.mark.requires('package1', 'package2')
44 | def test_something(): ...
45 | """
46 | # Mapping from the name of a package to whether it is installed or not.
47 | # Used to avoid repeated calls to `util.find_spec`
48 | required_pkgs_info: Dict[str, bool] = {}
49 |
50 | only_extended = config.getoption('--only-extended') or False
51 | only_core = config.getoption('--only-core') or False
52 |
53 | if only_extended and only_core:
54 | raise ValueError('Cannot specify both `--only-extended` and `--only-core`.')
55 |
56 | for item in items:
57 | requires_marker = item.get_closest_marker('requires')
58 | if requires_marker is not None:
59 | if only_core:
60 | item.add_marker(pytest.mark.skip(reason='Skipping not a core test.'))
61 | continue
62 |
63 | # Iterate through the list of required packages
64 | required_pkgs = requires_marker.args
65 | for pkg in required_pkgs:
66 | # If we haven't yet checked whether the pkg is installed
67 | # let's check it and store the result.
68 | if pkg not in required_pkgs_info:
69 | try:
70 | installed = util.find_spec(pkg) is not None
71 | except Exception:
72 | installed = False
73 | required_pkgs_info[pkg] = installed
74 |
75 | if not required_pkgs_info[pkg]:
76 | if only_extended:
77 | pytest.fail(
78 | f'Package `{pkg}` is not installed but is required for '
79 | f'extended tests. Please install the given package and '
80 | f'try again.',
81 | )
82 |
83 | else:
84 | # If the package is not installed, we immediately break
85 | # and mark the test as skipped.
86 | item.add_marker(pytest.mark.skip(reason=f'Requires pkg: `{pkg}`'))
87 | break
88 | else:
89 | if only_extended:
90 | item.add_marker(pytest.mark.skip(reason='Skipping not an extended test.'))
91 |
92 |
93 | @pytest.fixture
94 | def logging_conf() -> dict:
95 | return get_config_dict(
96 | 'info',
97 | get_log_file(log_path='logs', sub_dir=f'local_{get_timestamp_ms()}'),
98 | 1024 * 1024,
99 | 1024 * 1024 * 1024,
100 | )
101 |
102 |
103 | @pytest.fixture
104 | def test_file_path(request) -> Path:
105 | from pathlib import Path
106 |
107 | # 当前执行目录
108 | # 获取当前测试文件的路径
109 | test_file_path = Path(str(request.fspath)).parent
110 | print('test_file_path:', test_file_path)
111 | return test_file_path
112 |
--------------------------------------------------------------------------------
/.github/actions/poetry_setup/action.yml:
--------------------------------------------------------------------------------
1 | # An action for setting up poetry install with caching.
2 | # Using a custom action since the default action does not
3 | # take poetry install groups into account.
4 | # Action code from:
5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236
6 | name: poetry-install-with-caching
7 | description: Poetry install with support for caching of dependency groups.
8 |
9 | inputs:
10 | python-version:
11 | description: Python version, supporting MAJOR.MINOR only
12 | required: true
13 |
14 | poetry-version:
15 | description: Poetry version
16 | required: true
17 |
18 | cache-key:
19 | description: Cache key to use for manual handling of caching
20 | required: true
21 |
22 | working-directory:
23 | description: Directory whose poetry.lock file should be cached
24 | required: true
25 |
26 | runs:
27 | using: composite
28 | steps:
29 | - uses: actions/setup-python@v5
30 | name: Setup python ${{ inputs.python-version }}
31 | id: setup-python
32 | with:
33 | python-version: ${{ inputs.python-version }}
34 |
35 | # - uses: actions/cache@v4
36 | # id: cache-bin-poetry
37 | # name: Cache Poetry binary - Python ${{ inputs.python-version }}
38 | # env:
39 | # SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1"
40 | # with:
41 | # path: |
42 | # /opt/pipx/venvs/poetry
43 | # # This step caches the poetry installation, so make sure it's keyed on the poetry version as well.
44 | # key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }}
45 |
46 | - name: Refresh shell hashtable and fixup softlinks
47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true'
48 | shell: bash
49 | env:
50 | POETRY_VERSION: ${{ inputs.poetry-version }}
51 | PYTHON_VERSION: ${{ inputs.python-version }}
52 | run: |
53 | set -eux
54 |
55 | # Refresh the shell hashtable, to ensure correct `which` output.
56 | hash -r
57 |
58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks.
59 | # Delete and recreate the softlinks pipx expects to have.
60 | rm /opt/pipx/venvs/poetry/bin/python
61 | cd /opt/pipx/venvs/poetry/bin
62 | ln -s "$(which "python$PYTHON_VERSION")" python
63 | chmod +x python
64 | cd /opt/pipx_bin/
65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry
66 | chmod +x poetry
67 |
68 | # Ensure everything got set up correctly.
69 | /opt/pipx/venvs/poetry/bin/python --version
70 | /opt/pipx_bin/poetry --version
71 |
72 | - name: Install poetry via pip
73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true'
74 | shell: bash
75 | env:
76 | POETRY_VERSION: ${{ inputs.poetry-version }}
77 | run: |
78 | python -m pip install --upgrade pip
79 | python -m pip install "poetry==$POETRY_VERSION"
80 | poetry --version
81 |
82 | - name: Restore pip and poetry cached dependencies
83 | uses: actions/cache@v4
84 | env:
85 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4"
86 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
87 | with:
88 | path: |
89 | ~/.cache/pip
90 | ~/.cache/pypoetry/virtualenvs
91 | ~/.cache/pypoetry/cache
92 | ~/.cache/pypoetry/artifacts
93 | ${{ env.WORKDIR }}/.venv
94 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }}
95 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/file_parser/file_parser.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Mapping, cast
4 |
5 | import httpx
6 | from typing_extensions import Literal
7 |
8 | from ...core import BaseAPI, maybe_transform
9 | from ...core import NOT_GIVEN, Body, Headers, NotGiven, FileTypes
10 | from ...core import _legacy_binary_response
11 | from ...core import _legacy_response
12 | from ...core import deepcopy_minimal, extract_files
13 | from ...core import (
14 | make_request_options,
15 | )
16 | from ...types.file_parser.file_parser_create_params import FileParserCreateParams
17 | from ...types.file_parser.file_parser_resp import FileParserTaskCreateResp
18 |
19 | if TYPE_CHECKING:
20 | from ..._client import ZhipuAI
21 |
22 | __all__ = ["FileParser"]
23 |
24 |
25 | class FileParser(BaseAPI):
26 |
27 | def __init__(self, client: "ZhipuAI") -> None:
28 | super().__init__(client)
29 |
30 | def create(
31 | self,
32 | *,
33 | file: FileTypes = None,
34 | file_type: str = None,
35 | tool_type: Literal["simple", "doc2x", "tencent", "zhipu-pro"],
36 | extra_headers: Headers | None = None,
37 | extra_body: Body | None = None,
38 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
39 | ) -> FileParserTaskCreateResp:
40 |
41 | if not file:
42 | raise ValueError("At least one `file` must be provided.")
43 | body = deepcopy_minimal(
44 | {
45 | "file": file,
46 | "file_type": file_type,
47 | "tool_type": tool_type,
48 | }
49 | )
50 |
51 | files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
52 | if files:
53 | # It should be noted that the actual Content-Type header that will be
54 | # sent to the server will contain a `boundary` parameter, e.g.
55 | # multipart/form-data; boundary=---abc--
56 | extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
57 | return self._post(
58 | "/files/parser/create",
59 | body=maybe_transform(body, FileParserCreateParams),
60 | files=files,
61 | options=make_request_options(
62 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
63 | ),
64 | cast_type=FileParserTaskCreateResp,
65 | )
66 |
67 | def content(
68 | self,
69 | task_id: str,
70 | *,
71 | format_type: Literal["text", "download_link"],
72 | extra_headers: Headers | None = None,
73 | extra_body: Body | None = None,
74 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
75 | ) -> _legacy_response.HttpxBinaryResponseContent:
76 | """
77 | Returns the contents of the specified file.
78 |
79 | Args:
80 | extra_headers: Send extra headers
81 |
82 | extra_body: Add additional JSON properties to the request
83 |
84 | timeout: Override the client-level default timeout for this request, in seconds
85 | """
86 | if not task_id:
87 | raise ValueError(f"Expected a non-empty value for `task_id` but received {task_id!r}")
88 | extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
89 | return self._get(
90 | f"/files/parser/result/{task_id}/{format_type}",
91 | options=make_request_options(
92 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
93 | ),
94 | cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
95 | )
96 |
--------------------------------------------------------------------------------
/Release-Note.md:
--------------------------------------------------------------------------------
1 | ## Release Notes
2 |
3 | ### 2024-08-12
4 |
5 | **New Features:**
6 | - Modified video prompt to be optional, added file deletion functionality
7 | - Added Assistant business logic
8 | - Fixed embedding 3 dimensions
9 |
10 | ### 2024-07-25
11 |
12 | **Bug Fixes:**
13 | - Fixed cogvideo related issues
14 |
15 | ### 2024-07-12
16 |
17 | **New Features:**
18 | - Added advanced search tool Web search business logic
19 | - Specified Python versions support (3.8, 3.9, 3.10, 3.11, 3.12)
20 | - Integrated cogvideo business functionality
21 |
22 | ### 2024-05-20
23 |
24 | **Improvements:**
25 | - Fixed some `python3.12` dependency issues
26 | - Added pagination processing code, rewrote instantiation rules for some response classes
27 | - Added type conversion validation
28 | - Added batch task related APIs
29 | - Added file stream response wrapper
30 |
31 | ### 2024-04-29
32 |
33 | **Improvements:**
34 | - Fixed some `python3.7` code compatibility issues
35 | - Added interface failure retry mechanism, controlled by `retry` parameter with default of 3 retries
36 | - Adjusted interface timeout strategy, controlled by `Timeout` for interface `connect` and `read` timeout, default `timeout=300.0, connect=8.0`
37 | - Added support for super-humanoid large model parameters in conversation module, `model="charglm-3"`, `meta` parameter support
38 |
39 | ### 2024-04-23
40 |
41 | **Improvements:**
42 | - Fixed some compatibility issues with `pydantic<3,>=1.9.0`
43 | - Message processing business request and response parameters can be extended through configuration
44 | - Compatible with some parameters `top_p:1`, `temperature:0` (do_sample rewritten to false, parameters top_p temperature do not take effect)
45 | - Image understanding part, image_url parameter base64 content containing `data:image/jpeg;base64` compatibility
46 | - Removed JWT authentication logic
47 |
48 | ---
49 |
50 | ## Migration Guide
51 |
52 | For users upgrading from older versions, please note the following breaking changes:
53 |
54 | ### From v3.x to v4.x
55 |
56 | - API key configuration has been simplified
57 | - Some method signatures have changed for better type safety
58 | - Error handling has been improved with more specific exception types
59 |
60 | ## Support
61 |
62 | For questions about specific versions or upgrade assistance, please visit our [documentation](https://open.bigmodel.cn/) or contact our support team.
63 |
64 | ---
65 |
66 | ## 版本更新
67 |
68 | ### 2024-08-12
69 |
70 | **新功能:**
71 | - ✅ 视频提示词设为可选,新增文件删除功能
72 | - ✅ 智能助手业务逻辑
73 | - 🔧 修复 embedding 3 维度问题
74 |
75 | ### 2024-07-25
76 |
77 | **问题修复:**
78 | - 🔧 修复 cogvideo 相关问题
79 |
80 | ### 2024-07-12
81 |
82 | **新功能:**
83 | - ✅ 高级搜索工具 Web search 业务逻辑
84 | - ✅ 指定 Python 版本支持 (3.8, 3.9, 3.10, 3.11, 3.12)
85 | - ✅ 集成 cogvideo 业务功能
86 |
87 | ### 2024-05-20
88 |
89 | **改进优化:**
90 | - 🔧 修复部分 `python3.12` 依赖问题
91 | - ✅ 新增分页处理代码,重写部分响应类实例化规则
92 | - ✅ 新增类型转换校验
93 | - ✅ 批处理任务相关 API
94 | - ✅ 文件流响应包装器
95 |
96 | ### 2024-04-29
97 |
98 | **改进优化:**
99 | - 🔧 修复部分 `python3.7` 代码兼容性问题
100 | - ✅ 接口失败重试机制,通过 `retry` 参数控制重试次数,默认 3 次
101 | - ⏱️ 调整接口超时策略,通过 `Timeout` 控制接口 `connect` 和 `read` 超时时间,默认 `timeout=300.0, connect=8.0`
102 | - ✅ 对话模块新增超拟人大模型参数支持,`model="charglm-3"`,`meta` 参数支持
103 |
104 | ### 2024-04-23
105 |
106 | **改进优化:**
107 | - 🔧 修复部分 `pydantic<3,>=1.9.0` 兼容性问题
108 | - ✅ 报文处理的业务请求参数和响应参数可通过配置扩充
109 | - ✅ 兼容部分参数 `top_p:1`,`temperature:0`(do_sample 重写为 false,参数 top_p temperature 不生效)
110 | - ✅ 图像理解部分,image_url 参数 base64 内容包含 `data:image/jpeg;base64` 兼容性
111 | - 🔄 删除 JWT 认证逻辑
112 |
113 | ---
114 |
115 | ## 迁移指南
116 |
117 | 对于从旧版本升级的用户,请注意以下重大变更:
118 |
119 | ### 从 v3.x 到 v4.x
120 |
121 | - API 密钥配置已简化
122 | - 部分方法签名已更改以提供更好的类型安全性
123 | - 错误处理已改进,提供更具体的异常类型
124 |
125 | ## 技术支持
126 |
127 | 如有特定版本问题或升级协助需求,请访问我们的[文档](https://open.bigmodel.cn/)或联系我们的支持团队。
--------------------------------------------------------------------------------
/zhipuai/_client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, Mapping
4 |
5 | from typing_extensions import override
6 |
7 | from .core import _jwt_token
8 | from .core import ZhipuAIError
9 | from .core import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
10 | from .core import NotGiven, NOT_GIVEN
11 | from . import api_resource
12 | import os
13 | import httpx
14 | from httpx import Timeout
15 |
16 |
17 | class ZhipuAI(HttpClient):
18 | chat: api_resource.chat.Chat
19 | api_key: str
20 | _disable_token_cache: bool = True
21 | source_channel: str
22 |
23 | def __init__(
24 | self,
25 | *,
26 | api_key: str | None = None,
27 | base_url: str | httpx.URL | None = None,
28 | timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
29 | max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
30 | http_client: httpx.Client | None = None,
31 | custom_headers: Mapping[str, str] | None = None,
32 | disable_token_cache: bool = True,
33 | _strict_response_validation: bool = False,
34 | source_channel: str | None = None
35 | ) -> None:
36 | if api_key is None:
37 | api_key = os.environ.get("ZHIPUAI_API_KEY")
38 | if api_key is None:
39 | raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
40 | self.api_key = api_key
41 | self.source_channel = source_channel
42 | self._disable_token_cache = disable_token_cache
43 |
44 | if base_url is None:
45 | base_url = os.environ.get("ZHIPUAI_BASE_URL")
46 | if base_url is None:
47 | base_url = f"https://open.bigmodel.cn/api/paas/v4"
48 | from .__version__ import __version__
49 | super().__init__(
50 | version=__version__,
51 | base_url=base_url,
52 | max_retries=max_retries,
53 | timeout=timeout,
54 | custom_httpx_client=http_client,
55 | custom_headers=custom_headers,
56 | _strict_response_validation=_strict_response_validation,
57 | )
58 | self.chat = api_resource.chat.Chat(self)
59 | self.images = api_resource.images.Images(self)
60 | self.embeddings = api_resource.embeddings.Embeddings(self)
61 | self.files = api_resource.files.Files(self)
62 | self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
63 | self.batches = api_resource.Batches(self)
64 | self.knowledge = api_resource.Knowledge(self)
65 | self.tools = api_resource.Tools(self)
66 | self.videos = api_resource.Videos(self)
67 | self.assistant = api_resource.Assistant(self)
68 | self.web_search = api_resource.WebSearchApi(self)
69 | self.audio = api_resource.audio.Audio(self)
70 | self.moderations = api_resource.moderation.Moderations(self)
71 | self.agents = api_resource.agents.Agents(self)
72 | self.file_parser = api_resource.file_parser.FileParser(self)
73 |
74 | @property
75 | @override
76 | def auth_headers(self) -> dict[str, str]:
77 | api_key = self.api_key
78 | source_channel = self.source_channel or "python-sdk"
79 | if self._disable_token_cache:
80 | return {"Authorization": f"Bearer {api_key}","x-source-channel": source_channel}
81 | else:
82 | return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}","x-source-channel": source_channel}
83 |
84 | def __del__(self) -> None:
85 | if (not hasattr(self, "_has_custom_http_client")
86 | or not hasattr(self, "close")
87 | or not hasattr(self, "_client")):
88 | # if the '__init__' method raised an error, self would not have client attr
89 | return
90 |
91 | if self._has_custom_http_client:
92 | return
93 |
94 | self.close()
95 |
--------------------------------------------------------------------------------
/zhipuai/core/_request_opt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, Any, cast, TYPE_CHECKING
4 |
5 | from ._constants import RAW_RESPONSE_HEADER
6 | from ._utils import is_given
7 | from ._base_compat import ConfigDict, PYDANTIC_V2
8 | import pydantic.generics
9 | from httpx import Timeout
10 | from typing_extensions import (
11 | final, Unpack, ClassVar, TypedDict, Required, Callable
12 |
13 | )
14 |
15 | from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query, AnyMapping
16 | from ._utils import remove_notgiven_indict, strip_not_given
17 |
18 |
19 | class UserRequestInput(TypedDict, total=False):
20 | headers: Headers
21 | max_retries: int
22 | timeout: float | Timeout | None
23 | params: Query
24 | extra_json: AnyMapping
25 |
26 |
27 | class FinalRequestOptionsInput(TypedDict, total=False):
28 | method: Required[str]
29 | url: Required[str]
30 | params: Query
31 | headers: Headers
32 | max_retries: int
33 | timeout: float | Timeout | None
34 | files: HttpxRequestFiles | None
35 | json_data: Body
36 | extra_json: AnyMapping
37 |
38 |
39 | @final
40 | class FinalRequestOptions(pydantic.BaseModel):
41 | method: str
42 | url: str
43 | params: Query = {}
44 | headers: Union[Headers, NotGiven] = NotGiven()
45 | max_retries: Union[int, NotGiven] = NotGiven()
46 | timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
47 | files: Union[HttpxRequestFiles, None] = None
48 | idempotency_key: Union[str, None] = None
49 | post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
50 |
51 | # It should be noted that we cannot use `json` here as that would override
52 | # a BaseModel method in an incompatible fashion.
53 | json_data: Union[Body, None] = None
54 | extra_json: Union[AnyMapping, None] = None
55 |
56 | if PYDANTIC_V2:
57 | model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
58 | else:
59 |
60 | class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
61 | arbitrary_types_allowed: bool = True
62 |
63 | def get_max_retries(self, max_retries: int) -> int:
64 | if isinstance(self.max_retries, NotGiven):
65 | return max_retries
66 | return self.max_retries
67 |
68 | def _strip_raw_response_header(self) -> None:
69 | if not is_given(self.headers):
70 | return
71 |
72 | if self.headers.get(RAW_RESPONSE_HEADER):
73 | self.headers = {**self.headers}
74 | self.headers.pop(RAW_RESPONSE_HEADER)
75 |
76 | # override the `construct` method so that we can run custom transformations.
77 | # this is necessary as we don't want to do any actual runtime type checking
78 | # (which means we can't use validators) but we do want to ensure that `NotGiven`
79 | # values are not present
80 | #
81 | # type ignore required because we're adding explicit types to `**values`
82 | @classmethod
83 | def construct( # type: ignore
84 | cls,
85 | _fields_set: set[str] | None = None,
86 | **values: Unpack[UserRequestInput],
87 | ) -> FinalRequestOptions:
88 | kwargs: dict[str, Any] = {
89 | # we unconditionally call `strip_not_given` on any value
90 | # as it will just ignore any non-mapping types
91 | key: strip_not_given(value)
92 | for key, value in values.items()
93 | }
94 | if PYDANTIC_V2:
95 | return super().model_construct(_fields_set, **kwargs)
96 | return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
97 |
98 | if not TYPE_CHECKING:
99 | # type checkers incorrectly complain about this assignment
100 | model_construct = construct
101 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_response.py:
--------------------------------------------------------------------------------
1 | import httpx
2 | from httpx import URL, ByteStream, Headers, Request, Response
3 | from typing_extensions import Dict, Type
4 |
5 | from zhipuai.core import StreamResponse
6 | from zhipuai.core._base_type import ResponseT
7 | from zhipuai.core._http_client import HttpClient
8 | from zhipuai.core._request_opt import FinalRequestOptions
9 | from zhipuai.core._response import APIResponse
10 |
11 |
12 | # Mock objects for HttpClient and StreamResponse if necessary
13 | class MockHttpClient:
14 | _strict_response_validation: bool = False
15 |
16 | # Implement necessary mock methods or attributes
17 | def _process_response_data(
18 | self,
19 | *,
20 | data: object,
21 | cast_type: Type[ResponseT],
22 | response: httpx.Response,
23 | ) -> ResponseT:
24 | return data
25 |
26 |
27 | class MockStreamResponse(StreamResponse[ResponseT]):
28 | # Implement necessary mock methods or attributes
29 | def __init__(
30 | self,
31 | *,
32 | cast_type: Type[ResponseT],
33 | response: httpx.Response,
34 | client: HttpClient,
35 | ) -> None:
36 | super().__init__(cast_type=cast_type, response=response, client=client)
37 | self.response = response
38 | self._cast_type = cast_type
39 | # self._data_process_func = client._process_response_data
40 | # self._strem_chunks = self.__stream__()
41 |
42 | def __iter__(self):
43 | for item in self.response.iter_lines():
44 | yield item
45 |
46 |
47 | # Test Initialization
48 | def test_http_response_initialization():
49 | raw_response = Response(200)
50 | opts = FinalRequestOptions.construct(method='get', url='path')
51 | http_response = APIResponse(
52 | raw=raw_response,
53 | cast_type=str,
54 | client=MockHttpClient(),
55 | stream=False,
56 | options=opts,
57 | )
58 | assert http_response.http_response == raw_response
59 |
60 |
61 | # Test parse Method
62 | def test_parse_method():
63 | raw_response = Response(
64 | 200,
65 | headers=Headers({'content-type': 'application/json'}),
66 | content=b'{"key": "value"}',
67 | )
68 | opts = FinalRequestOptions.construct(method='get', url='path')
69 |
70 | http_response = APIResponse(
71 | raw=raw_response,
72 | cast_type=Dict[str, object],
73 | client=MockHttpClient(),
74 | stream=False,
75 | options=opts,
76 | )
77 | parsed_data = http_response.parse()
78 | assert parsed_data == {'key': 'value'}
79 | http_response = APIResponse(
80 | raw=raw_response,
81 | cast_type=str,
82 | client=MockHttpClient(),
83 | stream=False,
84 | options=opts,
85 | )
86 | parsed_data = http_response.parse()
87 | assert parsed_data == '{"key": "value"}'
88 |
89 | raw_response = Response(
90 | 200,
91 | content=b'{"key": "value"}',
92 | stream=ByteStream(b'{"key": "value"}\n"foo"\n"boo"\n'),
93 | )
94 | http_response = APIResponse(
95 | raw=raw_response,
96 | cast_type=str,
97 | client=MockHttpClient(),
98 | stream=True,
99 | options=opts,
100 | stream_cls=MockStreamResponse[str],
101 | )
102 | parsed_data = http_response.parse()
103 | excepted_data = ['{"key": "value"}', '"foo"', '"boo"']
104 | data = [chunk.strip() for chunk in parsed_data]
105 | assert data == excepted_data
106 |
107 |
108 | # Test properties
109 | def test_properties():
110 | opts = FinalRequestOptions.construct(method='get', url='path')
111 | headers = Headers({'content-type': 'application/json'})
112 | request = Request(method='GET', url='http://example.com')
113 | raw_response = Response(200, headers=headers, request=request)
114 | http_response = APIResponse(
115 | raw=raw_response,
116 | cast_type=str,
117 | client=MockHttpClient(),
118 | stream=False,
119 | options=opts,
120 | )
121 |
122 | assert http_response.headers == headers
123 | assert http_response.http_request == request
124 | assert http_response.status_code == 200
125 | assert http_response.url == URL('http://example.com')
126 | assert http_response.method == 'GET'
127 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_code_geex.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | import time
4 |
5 | import zhipuai
6 | from zhipuai import ZhipuAI
7 |
8 |
9 | def test_code_geex(logging_conf):
10 | logging.config.dictConfig(logging_conf) # type: ignore
11 | client = ZhipuAI() # 填写您自己的APIKey
12 | try:
13 | # 生成request_id
14 | request_id = time.time()
15 | print(f'request_id:{request_id}')
16 | response = client.chat.completions.create(
17 | request_id=request_id,
18 | model='codegeex-4',
19 | messages=[
20 | {
21 | 'role': 'system',
22 | 'content': """你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。
23 | 任务:请为输入代码提供格式规范的注释,包含多行注释和单行注释,请注意不要改动原始代码,只需要添加注释。
24 | 请用中文回答。""",
25 | },
26 | {'role': 'user', 'content': """写一个快速排序函数"""},
27 | ],
28 | top_p=0.7,
29 | temperature=0.9,
30 | max_tokens=2000,
31 | stop=['<|endoftext|>', '<|user|>', '<|assistant|>', '<|observation|>'],
32 | extra={
33 | 'target': {
34 | 'path': '11111',
35 | 'language': 'Python',
36 | 'code_prefix': 'EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getInstance());',
37 | 'code_suffix': 'TaskMonitorLocal taskMonitorLocal = getTaskMonitorLocal(algoMqReq);',
38 | },
39 | 'contexts': [
40 | {
41 | 'path': '/1/2',
42 | 'code': 'if(!sensitiveUser){ZpTraceUtils.addAsyncAttribute(algoMqReq.getTaskOrderNo(), ApiTraceProperty.request_params.getCode(), modelSendMap);',
43 | }
44 | ],
45 | },
46 | )
47 | print(response)
48 |
49 | except zhipuai.core._errors.APIRequestFailedError as err:
50 | print(err)
51 | except zhipuai.core._errors.APIInternalError as err:
52 | print(err)
53 | except zhipuai.core._errors.APIStatusError as err:
54 | print(err)
55 |
56 |
57 | def test_code_geex_async(logging_conf):
58 | logging.config.dictConfig(logging_conf) # type: ignore
59 | client = ZhipuAI() # 填写您自己的APIKey
60 | try:
61 | # 生成request_id
62 | request_id = time.time()
63 | print(f'request_id:{request_id}')
64 | response = client.chat.asyncCompletions.create(
65 | request_id=request_id,
66 | model='codegeex-4',
67 | messages=[
68 | {
69 | 'role': 'system',
70 | 'content': """你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。
71 | 任务:请为输入代码提供格式规范的注释,包含多行注释和单行注释,请注意不要改动原始代码,只需要添加注释。
72 | 请用中文回答。""",
73 | },
74 | {'role': 'user', 'content': """写一个快速排序函数"""},
75 | ],
76 | top_p=0.7,
77 | temperature=0.9,
78 | max_tokens=2000,
79 | stop=['<|endoftext|>', '<|user|>', '<|assistant|>', '<|observation|>'],
80 | extra={
81 | 'target': {
82 | 'path': '11111',
83 | 'language': 'Python',
84 | 'code_prefix': 'EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getInstance());',
85 | 'code_suffix': 'TaskMonitorLocal taskMonitorLocal = getTaskMonitorLocal(algoMqReq);',
86 | },
87 | 'contexts': [
88 | {
89 | 'path': '/1/2',
90 | 'code': 'if(!sensitiveUser){ZpTraceUtils.addAsyncAttribute(algoMqReq.getTaskOrderNo(), ApiTraceProperty.request_params.getCode(), modelSendMap);',
91 | }
92 | ],
93 | },
94 | )
95 | print(response)
96 |
97 | except zhipuai.core._errors.APIRequestFailedError as err:
98 | print(err)
99 | except zhipuai.core._errors.APIInternalError as err:
100 | print(err)
101 | except zhipuai.core._errors.APIStatusError as err:
102 | print(err)
103 |
104 |
105 | def test_geex_result(logging_conf):
106 | logging.config.dictConfig(logging_conf) # type: ignore
107 | client = ZhipuAI() # 请填写您自己的APIKey
108 | try:
109 | response = client.chat.asyncCompletions.retrieve_completion_result(
110 | id='1014908807577524653187108'
111 | )
112 | print(response)
113 |
114 | except zhipuai.core._errors.APIRequestFailedError as err:
115 | print(err)
116 | except zhipuai.core._errors.APIInternalError as err:
117 | print(err)
118 | except zhipuai.core._errors.APIStatusError as err:
119 | print(err)
120 |
--------------------------------------------------------------------------------
/zhipuai/core/_utils/_typing.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, TypeVar, Iterable, cast
4 | from collections import abc as _c_abc
5 | from typing_extensions import Required, Annotated, get_args, get_origin
6 |
7 | from .._base_type import InheritsGeneric
8 | from zhipuai.core._base_compat import is_union as _is_union
9 |
10 |
11 | def is_annotated_type(typ: type) -> bool:
12 | return get_origin(typ) == Annotated
13 |
14 |
15 | def is_list_type(typ: type) -> bool:
16 | return (get_origin(typ) or typ) == list
17 |
18 |
19 | def is_iterable_type(typ: type) -> bool:
20 | """If the given type is `typing.Iterable[T]`"""
21 | origin = get_origin(typ) or typ
22 | return origin == Iterable or origin == _c_abc.Iterable
23 |
24 |
25 | def is_union_type(typ: type) -> bool:
26 | return _is_union(get_origin(typ))
27 |
28 |
29 | def is_required_type(typ: type) -> bool:
30 | return get_origin(typ) == Required
31 |
32 |
33 | def is_typevar(typ: type) -> bool:
34 | # type ignore is required because type checkers
35 | # think this expression will always return False
36 | return type(typ) == TypeVar # type: ignore
37 |
38 |
39 | # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
40 | def strip_annotated_type(typ: type) -> type:
41 | if is_required_type(typ) or is_annotated_type(typ):
42 | return strip_annotated_type(cast(type, get_args(typ)[0]))
43 |
44 | return typ
45 |
46 |
47 | def extract_type_arg(typ: type, index: int) -> type:
48 | args = get_args(typ)
49 | try:
50 | return cast(type, args[index])
51 | except IndexError as err:
52 | raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
53 |
54 |
55 | def extract_type_var_from_base(
56 | typ: type,
57 | *,
58 | generic_bases: tuple[type, ...],
59 | index: int,
60 | failure_message: str | None = None,
61 | ) -> type:
62 | """Given a type like `Foo[T]`, returns the generic type variable `T`.
63 |
64 | This also handles the case where a concrete subclass is given, e.g.
65 | ```py
66 | class MyResponse(Foo[bytes]):
67 | ...
68 |
69 | extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
70 | ```
71 |
72 | And where a generic subclass is given:
73 | ```py
74 | _T = TypeVar('_T')
75 | class MyResponse(Foo[_T]):
76 | ...
77 |
78 | extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
79 | ```
80 | """
81 | cls = cast(object, get_origin(typ) or typ)
82 | if cls in generic_bases:
83 | # we're given the class directly
84 | return extract_type_arg(typ, index)
85 |
86 | # if a subclass is given
87 | # ---
88 | # this is needed as __orig_bases__ is not present in the typeshed stubs
89 | # because it is intended to be for internal use only, however there does
90 | # not seem to be a way to resolve generic TypeVars for inherited subclasses
91 | # without using it.
92 | if isinstance(cls, InheritsGeneric):
93 | target_base_class: Any | None = None
94 | for base in cls.__orig_bases__:
95 | if base.__origin__ in generic_bases:
96 | target_base_class = base
97 | break
98 |
99 | if target_base_class is None:
100 | raise RuntimeError(
101 | "Could not find the generic base class;\n"
102 | "This should never happen;\n"
103 | f"Does {cls} inherit from one of {generic_bases} ?"
104 | )
105 |
106 | extracted = extract_type_arg(target_base_class, index)
107 | if is_typevar(extracted):
108 | # If the extracted type argument is itself a type variable
109 | # then that means the subclass itself is generic, so we have
110 | # to resolve the type argument from the class itself, not
111 | # the base class.
112 | #
113 | # Note: if there is more than 1 type argument, the subclass could
114 | # change the ordering of the type arguments, this is not currently
115 | # supported.
116 | return extract_type_arg(typ, index)
117 |
118 | return extracted
119 |
120 | raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
121 |
--------------------------------------------------------------------------------
/tests/integration_tests/test_finetuning.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 |
4 | import zhipuai
5 | from zhipuai import ZhipuAI
6 |
7 |
8 | def test_finetuning_create(logging_conf):
9 | logging.config.dictConfig(logging_conf) # type: ignore
10 | client = ZhipuAI() # 请填写您自己的APIKey
11 | try:
12 | job = client.fine_tuning.jobs.create(
13 | model='chatglm3-6b',
14 | training_file='file-20240428021923715-xjng4', # 请填写已成功上传的文件id
15 | validation_file='file-20240428021923715-xjng4', # 请填写已成功上传的文件id
16 | suffix='demo_test',
17 | )
18 | job_id = job.id
19 | print(job_id)
20 | fine_tuning_job = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id)
21 | print(fine_tuning_job)
22 | # ftjob-20240418110039323-j8lh2
23 |
24 | except zhipuai.core._errors.APIRequestFailedError as err:
25 | print(err)
26 | except zhipuai.core._errors.APIInternalError as err:
27 | print(err)
28 | except zhipuai.core._errors.APIStatusError as err:
29 | print(err)
30 |
31 |
32 | def test_finetuning_retrieve(logging_conf):
33 | logging.config.dictConfig(logging_conf) # type: ignore
34 | client = ZhipuAI() # 请填写您自己的APIKey
35 | try:
36 | fine_tuning_job = client.fine_tuning.jobs.retrieve(
37 | fine_tuning_job_id='ftjob-20240429112551154-48vq7'
38 | )
39 | print(fine_tuning_job)
40 |
41 | except zhipuai.core._errors.APIRequestFailedError as err:
42 | print(err)
43 | except zhipuai.core._errors.APIInternalError as err:
44 | print(err)
45 | except zhipuai.core._errors.APIStatusError as err:
46 | print(err)
47 |
48 |
49 | def test_finetuning_job_list(logging_conf):
50 | logging.config.dictConfig(logging_conf) # type: ignore
51 | client = ZhipuAI() # 请填写您自己的APIKey
52 | try:
53 | job_list = client.fine_tuning.jobs.list()
54 |
55 | print(job_list)
56 |
57 | except zhipuai.core._errors.APIRequestFailedError as err:
58 | print(err)
59 | except zhipuai.core._errors.APIInternalError as err:
60 | print(err)
61 | except zhipuai.core._errors.APIStatusError as err:
62 | print(err)
63 |
64 |
65 | def test_finetuning_job_cancel(logging_conf):
66 | logging.config.dictConfig(logging_conf) # type: ignore
67 | client = ZhipuAI() # 请填写您自己的APIKey
68 | try:
69 | cancel = client.fine_tuning.jobs.cancel(fine_tuning_job_id='ftjob-20240429112551154-48vq7')
70 |
71 | print(cancel)
72 |
73 | except zhipuai.core._errors.APIRequestFailedError as err:
74 | print(err)
75 | except zhipuai.core._errors.APIInternalError as err:
76 | print(err)
77 | except zhipuai.core._errors.APIStatusError as err:
78 | print(err)
79 |
80 |
81 | def test_finetuning_job_delete(logging_conf):
82 | logging.config.dictConfig(logging_conf) # type: ignore
83 | client = ZhipuAI() # 请填写您自己的APIKey
84 | try:
85 | delete = client.fine_tuning.jobs.delete(fine_tuning_job_id='ftjob-20240126113041678-cs6s9')
86 |
87 | print(delete)
88 |
89 | except zhipuai.core._errors.APIRequestFailedError as err:
90 | print(err)
91 | except zhipuai.core._errors.APIInternalError as err:
92 | print(err)
93 | except zhipuai.core._errors.APIStatusError as err:
94 | print(err)
95 |
96 |
97 | def test_model_check(logging_conf):
98 | logging.config.dictConfig(logging_conf) # type: ignore
99 | client = ZhipuAI() # 填写您自己的APIKey
100 | try:
101 | response = client.chat.completions.create(
102 | model='chatglm3-6b-8572905046912426020-demo_test', # 填写需要调用的模型名称
103 | messages=[
104 | {'role': 'user', 'content': '你是一位乐于助人,知识渊博的全能AI助手。'},
105 | {'role': 'user', 'content': '创造一个更精准、吸引人的slogan'},
106 | ],
107 | extra_body={'temperature': 0.5, 'max_tokens': 50},
108 | )
109 | print(response.choices[0].message)
110 |
111 | except zhipuai.core._errors.APIRequestFailedError as err:
112 | print(err)
113 | except zhipuai.core._errors.APIInternalError as err:
114 | print(err)
115 | except zhipuai.core._errors.APIStatusError as err:
116 | print(err)
117 |
118 |
119 | def test_model_delete(logging_conf):
120 | logging.config.dictConfig(logging_conf) # type: ignore
121 | client = ZhipuAI() # 填写您自己的APIKey
122 | try:
123 | delete = client.fine_tuning.models.delete(
124 | fine_tuned_model='chatglm3-6b-8572905046912426020-demo_test'
125 | )
126 |
127 | print(delete)
128 |
129 | except zhipuai.core._errors.APIRequestFailedError as err:
130 | print(err)
131 | except zhipuai.core._errors.APIInternalError as err:
132 | print(err)
133 | except zhipuai.core._errors.APIStatusError as err:
134 | print(err)
135 |
136 |
137 | if __name__ == '__main__':
138 | test_finetuning_create()
139 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, gender identity and expression, level of
9 | experience, education, socio-economic status, nationality, personal appearance,
10 | race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or reject
41 | comments, commits, code, wiki edits, issues, and other contributions that are
42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any
43 | contributor for other behaviors that they deem inappropriate, threatening,
44 | offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when the Project
56 | Steward has a reasonable belief that an individual's behavior may have a
57 | negative impact on the project or its community.
58 |
59 | ## Conflict Resolution
60 |
61 | We do not believe that all conflict is bad; healthy debate and disagreement
62 | often yield positive results. However, it is never okay to be disrespectful or
63 | to engage in behavior that violates the project’s code of conduct.
64 |
65 | If you see someone violating the code of conduct, you are encouraged to address
66 | the behavior directly with those involved. Many issues can be resolved quickly
67 | and easily, and this gives people more control over the outcome of their
68 | dispute. If you are unable to resolve the matter for any reason, or if the
69 | behavior is threatening or harassing, report it. We are dedicated to providing
70 | an environment where participants feel welcome and safe.
71 |
72 | Reports should be directed to *Weijun Zheng (weijun.zheng@aminer.cn)*, the
73 | Project Steward(s) for *zhipuai-sdk-python-v4*. It is the Project Steward’s duty to
74 | receive and address reported violations of the code of conduct. They will then
75 | work with a committee consisting of representatives from the Open Source
76 | Programs Office and the Z.ai Open Source Strategy team.
77 |
78 | We will investigate every complaint, but you may not receive a direct response.
79 | We will use our discretion in determining when and how to follow up on reported
80 | incidents, which may range from not taking action to permanent expulsion from
81 | the project and project-sponsored spaces. We will notify the accused of the
82 | report and provide them an opportunity to discuss it before any action is taken.
83 | The identity of the reporter will be omitted from the details of the report
84 | supplied to the accused. In potentially harmful situations, such as ongoing
85 | harassment or threats to anyone's safety, we may take action without notice.
86 |
87 | ## Attribution
88 |
89 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
90 | available at
91 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
92 |
--------------------------------------------------------------------------------
/zhipuai/api_resource/audio/audio.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, List, Mapping, cast, Optional, Dict
4 | from .transcriptions import Transcriptions
5 |
6 | from zhipuai.core._utils import extract_files
7 |
8 | from zhipuai.types.sensitive_word_check import SensitiveWordCheckRequest
9 | from zhipuai.types.audio import AudioSpeechParams
10 | from ...types.audio import audio_customization_param
11 |
12 | from zhipuai.core import BaseAPI, maybe_transform, StreamResponse
13 | from zhipuai.core import NOT_GIVEN, Body, Headers, NotGiven, FileTypes
14 | from zhipuai.core import _legacy_response
15 |
16 | import httpx
17 | from ...core import BaseAPI, cached_property
18 |
19 | from zhipuai.core import (
20 | make_request_options,
21 | )
22 | from zhipuai.core import deepcopy_minimal
23 | from ...types.audio.audio_speech_chunk import AudioSpeechChunk
24 |
25 | if TYPE_CHECKING:
26 | from zhipuai._client import ZhipuAI
27 |
28 | __all__ = ["Audio"]
29 |
30 |
31 | class Audio(BaseAPI):
32 |
33 | @cached_property
34 | def transcriptions(self) -> Transcriptions:
35 | return Transcriptions(self._client)
36 |
37 | def __init__(self, client: "ZhipuAI") -> None:
38 | super().__init__(client)
39 |
40 | def speech(
41 | self,
42 | *,
43 | model: str,
44 | input: str = None,
45 | voice: str = None,
46 | response_format: str = None,
47 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
48 | request_id: str = None,
49 | user_id: str = None,
50 | stream: bool = False,
51 | extra_headers: Headers | None = None,
52 | extra_body: Body | None = None,
53 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
54 | encode_format: str,
55 | ) -> _legacy_response.HttpxBinaryResponseContent | StreamResponse[AudioSpeechChunk]:
56 | body = deepcopy_minimal(
57 | {
58 | "model": model,
59 | "input": input,
60 | "voice": voice,
61 | "stream": stream,
62 | "response_format": response_format,
63 | "sensitive_word_check": sensitive_word_check,
64 | "request_id": request_id,
65 | "user_id": user_id,
66 | "encode_format": encode_format
67 | }
68 | )
69 | return self._post(
70 | "/audio/speech",
71 | body=body,
72 | options=make_request_options(
73 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
74 | ),
75 | cast_type=_legacy_response.HttpxBinaryResponseContent,
76 | stream= stream or False,
77 | stream_cls=StreamResponse[AudioSpeechChunk]
78 | )
79 |
80 | def customization(
81 | self,
82 | *,
83 | model: str,
84 | input: str = None,
85 | voice_text: str = None,
86 | voice_data: FileTypes = None,
87 | response_format: str = None,
88 | sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
89 | request_id: str = None,
90 | user_id: str = None,
91 | extra_headers: Headers | None = None,
92 | extra_body: Body | None = None,
93 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
94 | ) -> _legacy_response.HttpxBinaryResponseContent:
95 | body = deepcopy_minimal(
96 | {
97 | "model": model,
98 | "input": input,
99 | "voice_text": voice_text,
100 | "voice_data": voice_data,
101 | "response_format": response_format,
102 | "sensitive_word_check": sensitive_word_check,
103 | "request_id": request_id,
104 | "user_id": user_id
105 | }
106 | )
107 | files = extract_files(cast(Mapping[str, object], body), paths=[["voice_data"]])
108 |
109 | if files:
110 | extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
111 | return self._post(
112 | "/audio/customization",
113 | body=maybe_transform(body, audio_customization_param.AudioCustomizationParam),
114 | files=files,
115 | options=make_request_options(
116 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
117 | ),
118 | cast_type=_legacy_response.HttpxBinaryResponseContent
119 | )
--------------------------------------------------------------------------------