├── README.md ├── computer1.py ├── computer2.py ├── filesfolder ├── output │ └── result.txt └── screenshot.png ├── pp_weights ├── ch_PP-OCRv4_det_infer │ ├── inference.pdiparams │ ├── inference.pdiparams.info │ └── inference.pdmodel ├── ch_PP-OCRv4_rec_infer │ ├── inference.pdiparams │ ├── inference.pdiparams.info │ └── inference.pdmodel └── ch_ppocr_mobile_v2.0_cls_slim_infer │ ├── inference.pdiparams │ ├── inference.pdiparams.info │ ├── inference.pdmodel │ └── paddle_infer.log ├── requirements_computer1.txt ├── requirements_computer2.txt ├── utils ├── __pycache__ │ ├── call_llm.cpython-310.pyc │ └── ocr_func.cpython-310.pyc ├── call_llm.py └── ocr_func.py └── zhipuai_sdk ├── .github ├── actions │ └── poetry_setup │ │ └── action.yml └── workflows │ ├── _release.yml │ └── _test_release.yml ├── .gitignore ├── Makefile ├── README.md ├── poetry.toml ├── pyproject.toml ├── setup.py ├── tests ├── agent │ ├── glm3_agent.py │ ├── test_agent.py │ ├── test_create_structured_funcation_call.py │ └── test_prompt_parser_agent.py ├── conftest.py ├── integration_tests │ ├── demo.jsonl │ ├── img │ │ └── MetaGLM.png │ ├── test_charglm3.py │ ├── test_chat.py │ ├── test_embedding.py │ ├── test_file.py │ ├── test_finetuning.py │ └── test_images.py └── unit_tests │ └── test_sdk_import.py └── zhipuai ├── __init__.py ├── __version__.py ├── _client.py ├── api_resource ├── __init__.py ├── chat │ ├── __init__.py │ ├── async_completions.py │ ├── chat.py │ └── completions.py ├── embeddings.py ├── files.py ├── fine_tuning │ ├── __init__.py │ ├── fine_tuning.py │ ├── jobs │ │ ├── __init__.py │ │ └── jobs.py │ └── models │ │ ├── __init__.py │ │ └── fine_tuned_models.py └── images.py ├── core ├── __init__.py ├── _base_api.py ├── _base_compat.py ├── _base_models.py ├── _base_type.py ├── _errors.py ├── _files.py ├── _http_client.py ├── _jwt_token.py ├── _request_opt.py ├── _response.py ├── _sse_client.py ├── _utils │ ├── __init__.py │ ├── _typing.py │ └── _utils.py └── logs.py └── types ├── __init__.py ├── chat ├── __init__.py ├── async_chat_completion.py ├── chat_completion.py ├── chat_completion_chunk.py └── chat_completions_create_param.py ├── embeddings.py ├── file_object.py ├── fine_tuning ├── __init__.py ├── fine_tuning_job.py ├── fine_tuning_job_event.py ├── job_create_params.py └── models │ ├── __init__.py │ └── fine_tuned_models.py └── image.py /README.md: -------------------------------------------------------------------------------- 1 | # 面试助手 2 | 面试助手, 使用ocr+大模型反击八股文, 使用本代码,你需要有两台电脑,一台面试用电脑用于获取屏幕截屏,一台用于请求大模型获取code,便于照着打字你输入(上一版本是自动输入,但是速度太快,不能模拟人的手速) 3 | 4 | ## 快捷键信息 5 | 6 | ### 电脑2,获取答案用电脑 7 | pip install -r requirements_computer2.txt 8 | 9 | 填上api key 和 大模型类型(如果是openai,填对应的gpt的模型名) 10 | key = "xxxxx" 11 | llm_type = "glm-4" 12 | 13 | 运行 python computer2.py 14 | 获取当前的请求url, 一般是192.168.x.x:9210, 这里我设的端口是9210 15 | 16 | 17 | ### 电脑1,面试用电脑 18 | pip install -r requirements_computer1.txt 19 | 20 | 填入 电脑2运行后的url = "http://xxxx:9210/get_code" 21 | 22 | 运行 python computer1.py 23 | 24 | 快捷键 'Shift+Z' 25 | 截图并请求答案, 这边需要先把鼠标移到要截取的区域右下角,会从屏幕左上角(0,0)截到该位置 26 | 27 | 快捷键 'Shift+X' 28 | 重新识别上一次的图片,重新请求大模型获取答案, 如果要再次截图识别获取答案,请再次按快捷键 'Shift+Z' , 注意鼠标位置!!!!! 29 | 30 | ## 本项目旨在反击无脑八股文面试,请勿滥用哟 31 | -------------------------------------------------------------------------------- /computer1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pyautogui 3 | import keyboard 4 | import requests 5 | 6 | 7 | url = "http://XXXXX:9210/get_code" 8 | 9 | savefolder = './filesfolder' 10 | 11 | def get_screen_and_rquest_llm(): 12 | print('start screenshot !') 13 | x, y = pyautogui.position() 14 | img = pyautogui.screenshot(region=[0,0,x,y]) # x,y,w,h 15 | img_save_path = os.path.join(savefolder, 'screenshot.png') 16 | img.save(img_save_path) 17 | files={'input_file': open(img_save_path,'rb')} 18 | response = requests.request("POST", url, files=files) 19 | if response.json()['msg'] == 'success': 20 | print('please see the computer2 screen and get code !') 21 | else: 22 | print('please "Shift+Z" get new img or use "Shift+R" get code again !') 23 | 24 | ## don't screenshot, continue ocr the old picture 25 | def get_code_again(): 26 | img_save_path = os.path.join(savefolder, 'screenshot.png') 27 | files={'input_file': open(img_save_path,'rb')} 28 | response = requests.request("POST", url, files=files) 29 | if response.json()['msg'] == 'success': 30 | print('please see the computer2 screen and get code !') 31 | else: 32 | print('please "Shift+Z" get new img or use "Shift+R" get code again !') 33 | 34 | 35 | keyboard.add_hotkey('Shift+Z', get_screen_and_rquest_llm) 36 | keyboard.add_hotkey('Shift+X', get_code_again) 37 | 38 | 39 | # 进入监听状态 40 | try: 41 | keyboard.wait('ctrl+c') 42 | except KeyboardInterrupt: 43 | pass -------------------------------------------------------------------------------- /computer2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from flask import Flask, request, make_response 5 | from utils.ocr_func import ocr_init, ocr_infer 6 | from utils.call_llm import query_llm 7 | 8 | 9 | ########## flask func ########## 10 | app = Flask(__name__) 11 | 12 | class JsonEncoder(json.JSONEncoder): 13 | """Convert numpy classes to JSON serializable objects.""" 14 | 15 | def default(self, obj): 16 | if isinstance(obj, (np.integer, np.floating, np.bool_)): 17 | return obj.item() 18 | elif isinstance(obj, np.ndarray): 19 | return obj.tolist() 20 | else: 21 | return super(JsonEncoder, self).default(obj) 22 | 23 | 24 | def json_dumps(data): 25 | return json.dumps(data, ensure_ascii=False, cls=JsonEncoder) 26 | 27 | def wrap_resp(res, status=400, headers={'Content-Type': 'application/json'}): 28 | resp = make_response(json_dumps(res), status) 29 | resp.headers = headers 30 | return resp 31 | 32 | ########## ocr init ######### 33 | paddle_ocr = ocr_init() 34 | 35 | ######### llm api key ######### 36 | key = "XXXXXX" 37 | llm_type = "glm-4" 38 | 39 | ######## mian func ########### 40 | savefolder = './filesfolder' 41 | 42 | def ocr_screen_and_get_answer(image, out_folder): 43 | print('start ocr !') 44 | texts = ocr_infer(image, paddle_ocr) 45 | print('ocr down, start request llm !') 46 | command = query_llm(key, llm_type, texts) 47 | if '```python' in command: 48 | command = command.split('```python')[1].split('```')[0] 49 | elif '```' in command: 50 | command = command.split('```')[1].split('```')[0] 51 | else: 52 | print('the result is wrong, please query llm again !') 53 | print("~~温馨提示,代码已生成~~") 54 | with open(os.path.join(out_folder, 'result.txt'), 'w', encoding='utf-8') as fw: 55 | fw.write(command) 56 | print(command) 57 | return command 58 | 59 | @app.route('/get_code', methods = ['POST']) 60 | def get_code(): 61 | inf = request.files.get('input_file') 62 | out_folder = os.path.join(savefolder, 'output') 63 | os.makedirs(out_folder, exist_ok=True) 64 | os.makedirs(savefolder, exist_ok=True) 65 | input_file = os.path.join(savefolder, inf.filename) 66 | inf.save(input_file) 67 | try: 68 | out = ocr_screen_and_get_answer(input_file, out_folder) 69 | res = { 70 | 'msg': 'success', 71 | 'code': 0, 72 | 'data': out, 73 | } 74 | return wrap_resp(res, 200) 75 | except Exception as e: 76 | res = { 77 | 'msg': 'fail', 78 | 'code': -1, 79 | 'data': str(e) 80 | } 81 | return wrap_resp(res) 82 | 83 | 84 | if __name__ == '__main__': 85 | app.run('0.0.0.0', 9210, debug=False) 86 | 87 | 88 | -------------------------------------------------------------------------------- /filesfolder/output/result.txt: -------------------------------------------------------------------------------- 1 | 2 | def trap(heights): 3 | if not heights: 4 | return 0 5 | 6 | n = len(heights) 7 | left_max = [0] * n 8 | right_max = [0] * n 9 | water_trapped = 0 10 | 11 | left_max[0] = heights[0] 12 | for i in range(1, n): 13 | left_max[i] = max(heights[i], left_max[i - 1]) 14 | 15 | right_max[n - 1] = heights[n - 1] 16 | for i in range(n - 2, -1, -1): 17 | right_max[i] = max(heights[i], right_max[i + 1]) 18 | 19 | for i in range(1, n - 1): 20 | water_trapped += min(left_max[i], right_max[i]) - heights[i] 21 | 22 | return water_trapped 23 | -------------------------------------------------------------------------------- /filesfolder/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/filesfolder/screenshot.png -------------------------------------------------------------------------------- /pp_weights/ch_PP-OCRv4_det_infer/inference.pdiparams: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_PP-OCRv4_det_infer/inference.pdiparams -------------------------------------------------------------------------------- /pp_weights/ch_PP-OCRv4_det_infer/inference.pdiparams.info: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_PP-OCRv4_det_infer/inference.pdiparams.info -------------------------------------------------------------------------------- /pp_weights/ch_PP-OCRv4_det_infer/inference.pdmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_PP-OCRv4_det_infer/inference.pdmodel -------------------------------------------------------------------------------- /pp_weights/ch_PP-OCRv4_rec_infer/inference.pdiparams: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_PP-OCRv4_rec_infer/inference.pdiparams -------------------------------------------------------------------------------- /pp_weights/ch_PP-OCRv4_rec_infer/inference.pdiparams.info: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_PP-OCRv4_rec_infer/inference.pdiparams.info -------------------------------------------------------------------------------- /pp_weights/ch_PP-OCRv4_rec_infer/inference.pdmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_PP-OCRv4_rec_infer/inference.pdmodel -------------------------------------------------------------------------------- /pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams -------------------------------------------------------------------------------- /pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams.info: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams.info -------------------------------------------------------------------------------- /pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel -------------------------------------------------------------------------------- /pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer/paddle_infer.log: -------------------------------------------------------------------------------- 1 | ➜ PaddleOCR git:(dygraph) ✗ python deploy/slim/quantization/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=cls_qat/best_accuracy Global.save_model_dir=./cls_qat_inference_model 2 | [2021/01/13 21:11:21] root INFO: Architecture : 3 | [2021/01/13 21:11:21] root INFO: Backbone : 4 | [2021/01/13 21:11:21] root INFO: model_name : small 5 | [2021/01/13 21:11:21] root INFO: name : MobileNetV3 6 | [2021/01/13 21:11:21] root INFO: scale : 0.35 7 | [2021/01/13 21:11:21] root INFO: Head : 8 | [2021/01/13 21:11:21] root INFO: class_dim : 2 9 | [2021/01/13 21:11:21] root INFO: name : ClsHead 10 | [2021/01/13 21:11:21] root INFO: Neck : None 11 | [2021/01/13 21:11:21] root INFO: Transform : None 12 | [2021/01/13 21:11:21] root INFO: algorithm : CLS 13 | [2021/01/13 21:11:21] root INFO: model_type : cls 14 | [2021/01/13 21:11:21] root INFO: Eval : 15 | [2021/01/13 21:11:21] root INFO: dataset : 16 | [2021/01/13 21:11:21] root INFO: data_dir : ./train_data/cls 17 | [2021/01/13 21:11:21] root INFO: label_file_list : ['./train_data/cls/test.txt'] 18 | [2021/01/13 21:11:21] root INFO: name : SimpleDataSet 19 | [2021/01/13 21:11:21] root INFO: transforms : 20 | [2021/01/13 21:11:21] root INFO: DecodeImage : 21 | [2021/01/13 21:11:21] root INFO: channel_first : False 22 | [2021/01/13 21:11:21] root INFO: img_mode : BGR 23 | [2021/01/13 21:11:21] root INFO: ClsLabelEncode : None 24 | [2021/01/13 21:11:21] root INFO: ClsResizeImg : 25 | [2021/01/13 21:11:21] root INFO: image_shape : [3, 48, 192] 26 | [2021/01/13 21:11:21] root INFO: KeepKeys : 27 | [2021/01/13 21:11:21] root INFO: keep_keys : ['image', 'label'] 28 | [2021/01/13 21:11:21] root INFO: loader : 29 | [2021/01/13 21:11:21] root INFO: batch_size_per_card : 512 30 | [2021/01/13 21:11:21] root INFO: drop_last : False 31 | [2021/01/13 21:11:21] root INFO: num_workers : 4 32 | [2021/01/13 21:11:21] root INFO: shuffle : False 33 | [2021/01/13 21:11:21] root INFO: Global : 34 | [2021/01/13 21:11:21] root INFO: cal_metric_during_train : True 35 | [2021/01/13 21:11:21] root INFO: checkpoints : cls_qat/best_accuracy 36 | [2021/01/13 21:11:21] root INFO: debug : False 37 | [2021/01/13 21:11:21] root INFO: distributed : False 38 | [2021/01/13 21:11:21] root INFO: epoch_num : 100 39 | [2021/01/13 21:11:21] root INFO: eval_batch_step : [0, 1000] 40 | [2021/01/13 21:11:21] root INFO: infer_img : doc/imgs_words_en/word_10.png 41 | [2021/01/13 21:11:21] root INFO: label_list : ['0', '180'] 42 | [2021/01/13 21:11:21] root INFO: log_smooth_window : 20 43 | [2021/01/13 21:11:21] root INFO: pretrained_model : None 44 | [2021/01/13 21:11:21] root INFO: print_batch_step : 10 45 | [2021/01/13 21:11:21] root INFO: save_epoch_step : 3 46 | [2021/01/13 21:11:21] root INFO: save_inference_dir : None 47 | [2021/01/13 21:11:21] root INFO: save_model_dir : ./cls_qat_inference_model 48 | [2021/01/13 21:11:21] root INFO: use_gpu : True 49 | [2021/01/13 21:11:21] root INFO: use_visualdl : False 50 | [2021/01/13 21:11:21] root INFO: Loss : 51 | [2021/01/13 21:11:21] root INFO: name : ClsLoss 52 | [2021/01/13 21:11:21] root INFO: Metric : 53 | [2021/01/13 21:11:21] root INFO: main_indicator : acc 54 | [2021/01/13 21:11:21] root INFO: name : ClsMetric 55 | [2021/01/13 21:11:21] root INFO: Optimizer : 56 | [2021/01/13 21:11:21] root INFO: beta1 : 0.9 57 | [2021/01/13 21:11:21] root INFO: beta2 : 0.999 58 | [2021/01/13 21:11:21] root INFO: lr : 59 | [2021/01/13 21:11:21] root INFO: learning_rate : 0.001 60 | [2021/01/13 21:11:21] root INFO: name : Cosine 61 | [2021/01/13 21:11:21] root INFO: name : Adam 62 | [2021/01/13 21:11:21] root INFO: regularizer : 63 | [2021/01/13 21:11:21] root INFO: factor : 0 64 | [2021/01/13 21:11:21] root INFO: name : L2 65 | [2021/01/13 21:11:21] root INFO: PostProcess : 66 | [2021/01/13 21:11:21] root INFO: name : ClsPostProcess 67 | [2021/01/13 21:11:21] root INFO: Train : 68 | [2021/01/13 21:11:21] root INFO: dataset : 69 | [2021/01/13 21:11:21] root INFO: data_dir : ./train_data/cls 70 | [2021/01/13 21:11:21] root INFO: label_file_list : ['./train_data/cls/train.txt'] 71 | [2021/01/13 21:11:21] root INFO: name : SimpleDataSet 72 | [2021/01/13 21:11:21] root INFO: transforms : 73 | [2021/01/13 21:11:21] root INFO: DecodeImage : 74 | [2021/01/13 21:11:21] root INFO: channel_first : False 75 | [2021/01/13 21:11:21] root INFO: img_mode : BGR 76 | [2021/01/13 21:11:21] root INFO: ClsLabelEncode : None 77 | [2021/01/13 21:11:21] root INFO: RecAug : 78 | [2021/01/13 21:11:21] root INFO: use_tia : False 79 | [2021/01/13 21:11:21] root INFO: RandAugment : None 80 | [2021/01/13 21:11:21] root INFO: ClsResizeImg : 81 | [2021/01/13 21:11:21] root INFO: image_shape : [3, 48, 192] 82 | [2021/01/13 21:11:21] root INFO: KeepKeys : 83 | [2021/01/13 21:11:21] root INFO: keep_keys : ['image', 'label'] 84 | [2021/01/13 21:11:21] root INFO: loader : 85 | [2021/01/13 21:11:21] root INFO: batch_size_per_card : 512 86 | [2021/01/13 21:11:21] root INFO: drop_last : True 87 | [2021/01/13 21:11:21] root INFO: num_workers : 8 88 | [2021/01/13 21:11:21] root INFO: shuffle : True 89 | [2021/01/13 21:11:21] root INFO: train with paddle 0.0.0 and device CUDAPlace(0) 90 | W0113 21:11:21.374989 83577 device_context.cc:320] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.0 91 | W0113 21:11:21.399595 83577 device_context.cc:330] device: 0, cuDNN Version: 7.6. 92 | [2021/01/13 21:11:32] root INFO: resume from cls_qat/best_accuracy 93 | [2021/01/13 21:11:32] root INFO: Initialize indexs of datasets:['./train_data/cls/test.txt'] 94 | eval model:: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 803/803 [04:26<00:00, 3.02it/s] 95 | [2021/01/13 21:16:00] root INFO: metric eval *************** 96 | [2021/01/13 21:16:00] root INFO: acc:0.9449500187460135 97 | [2021/01/13 21:16:00] root INFO: fps:1979.0335715470349 98 | -------------------------------------------------------------------------------- /requirements_computer1.txt: -------------------------------------------------------------------------------- 1 | pyautogui 2 | keyboard 3 | -------------------------------------------------------------------------------- /requirements_computer2.txt: -------------------------------------------------------------------------------- 1 | flask 2 | paddlepaddle 3 | paddleocr 4 | openai 5 | zhipuai -------------------------------------------------------------------------------- /utils/__pycache__/call_llm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/utils/__pycache__/call_llm.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ocr_func.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/utils/__pycache__/ocr_func.cpython-310.pyc -------------------------------------------------------------------------------- /utils/call_llm.py: -------------------------------------------------------------------------------- 1 | from zhipuai_sdk.zhipuai import ZhipuAI 2 | from openai import OpenAI 3 | 4 | 5 | def query_llm(api_key, model_name, user_input): 6 | if 'glm' in model_name.lower(): 7 | client = ZhipuAI(api_key=api_key) 8 | elif 'gpt' in model_name.lower(): 9 | client = OpenAI(api_key=api_key) 10 | else: 11 | return None 12 | 13 | chat_completion = client.chat.completions.create( 14 | messages=[ 15 | { 16 | "role": "system", 17 | "content": "你是一个撰写代码的专家,请依据用户输入信息,来撰写代码,直接输出代码函数,代码中不要包含注释。" 18 | }, 19 | { 20 | "role": "user", 21 | "content": user_input, 22 | } 23 | ], 24 | model=model_name, 25 | top_p=0.7, 26 | temperature=0.9 27 | ) 28 | 29 | return chat_completion.choices[0].message.content 30 | 31 | if __name__ == '__main__': 32 | key = "" 33 | result = query_llm(key, "glm-4", "冒泡排序") 34 | print(result) 35 | -------------------------------------------------------------------------------- /utils/ocr_func.py: -------------------------------------------------------------------------------- 1 | from paddleocr import PaddleOCR 2 | import numpy as np 3 | 4 | # paddle weights 5 | det_model_path = "./pp_weights/ch_PP-OCRv4_det_infer" 6 | rec_model_path = "./pp_weights/ch_PP-OCRv4_rec_infer" 7 | cls_model_path = "./pp_weights/ch_ppocr_mobile_v2.0_cls_slim_infer" 8 | 9 | def ocr_init(): 10 | paddle_ocr = PaddleOCR(use_angle_cls=True, lang='ch', 11 | det_model_dir=det_model_path, 12 | rec_model_dir=rec_model_path, 13 | cls_model_dir=cls_model_path, 14 | det_limit_type='max', 15 | drop_score=0.4) 16 | print('ocr init down !') 17 | return paddle_ocr 18 | 19 | def check_merge(text1, text2): 20 | x1, y1, x2, y2, x3, y3, x4, y4 = text1[:8] 21 | p1, q1, p2, q2, p3, q3, p4, q4 = text2[:8] 22 | t1 = np.array(text1[:8]).reshape((-1, 2)) 23 | t2 = np.array(text2[:8]).reshape((-1, 2)) 24 | if min(t1[:, 0]) <= min(t2[:, 0]): 25 | top_border, bottom_border = max(y2, q1), min(y3, q4) 26 | else: 27 | top_border, bottom_border = max(q2, y1), min(q3, y4) 28 | 29 | center1 = (max(t1[:, 1]) + min(t1[:, 1])) / 2 30 | center2 = (max(t2[:, 1]) + min(t2[:, 1])) / 2 31 | 32 | if max(bottom_border,top_border) - min(bottom_border, top_border) >= \ 33 | max(center1, center2) - min(center1, center2): 34 | merge = True 35 | else: 36 | merge = False 37 | return merge 38 | 39 | def get_line_dict(coord_lists): 40 | for i in range(len(coord_lists)): 41 | t_ = np.array(coord_lists[i][:8]).reshape((-1, 2)) 42 | centerx, centery = (max(t_[:, 0]) + min(t_[:, 0])) / 2, (max(t_[:, 1]) + min(t_[:, 1])) / 2 43 | coord_lists[i].extend([centerx, centery]) 44 | coord_lists = sorted(coord_lists, key=lambda s: (s[-1], s[-2])) 45 | rec_dict = dict() 46 | rec_dict[0] = [coord_lists[0]] 47 | n = 0 48 | for i in range(1, len(coord_lists)): 49 | rec_dict[n] = sorted(rec_dict[n], key=lambda s: s[-2]) 50 | merge = check_merge(rec_dict[n][-1], coord_lists[i]) 51 | if merge: 52 | rec_dict[n].append(coord_lists[i]) 53 | else: 54 | n += 1 55 | rec_dict[n] = [coord_lists[i]] 56 | return rec_dict 57 | 58 | def ocr_infer(img_path, paddle_ocr): 59 | total_text = '' 60 | result = paddle_ocr.ocr(img_path, cls=True) 61 | texts = [] 62 | for line_ in result: 63 | for line in line_: 64 | text = [] 65 | bbox = np.array(line[0]) 66 | bbox = bbox.reshape((-1)).tolist() 67 | t_ = line[1][0] 68 | text.extend(bbox) 69 | text.append(t_) 70 | texts.append(text) 71 | rec_dict = get_line_dict(texts) 72 | for key in rec_dict.keys(): 73 | rec_dict[key] = sorted(rec_dict[key], key=lambda s: s[-2]) 74 | lines = rec_dict[key] 75 | t_ = '' 76 | for item in lines: 77 | t_ += item[8] 78 | total_text += t_ + '\n' 79 | return total_text 80 | 81 | if __name__ == '__main__': 82 | im = r'C:\Users\xiaoyuli\Desktop\task\screen_imgs\237d6db5-86ba-4f7b-9930-36936352c7ab.png' 83 | paddle_ocr = ocr_init() 84 | ocr_infer(im, paddle_ocr) -------------------------------------------------------------------------------- /zhipuai_sdk/.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} -------------------------------------------------------------------------------- /zhipuai_sdk/.github/workflows/_release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }} 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | required: true 14 | type: string 15 | default: '.' 16 | description: "From which folder this pipeline executes" 17 | env: 18 | PYTHON_VERSION: "3.8" 19 | POETRY_VERSION: "1.7.1" 20 | 21 | jobs: 22 | build: 23 | if: github.ref == 'refs/heads/main' 24 | environment: Scheduled testing publish 25 | runs-on: ubuntu-latest 26 | 27 | outputs: 28 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 29 | version: ${{ steps.check-version.outputs.version }} 30 | 31 | steps: 32 | - uses: actions/checkout@v4 33 | 34 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 35 | uses: "./.github/actions/poetry_setup" 36 | with: 37 | python-version: ${{ env.PYTHON_VERSION }} 38 | poetry-version: ${{ env.POETRY_VERSION }} 39 | working-directory: ${{ inputs.working-directory }} 40 | cache-key: release 41 | 42 | # We want to keep this build stage *separate* from the release stage, 43 | # so that there's no sharing of permissions between them. 44 | # The release stage has trusted publishing and GitHub repo contents write access, 45 | # and we want to keep the scope of that access limited just to the release job. 46 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 47 | # could get access to our GitHub or PyPI credentials. 48 | # 49 | # Per the trusted publishing GitHub Action: 50 | # > It is strongly advised to separate jobs for building [...] 51 | # > from the publish job. 52 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 53 | - name: Build project for distribution 54 | run: poetry build 55 | working-directory: ${{ inputs.working-directory }} 56 | 57 | - name: Upload build 58 | uses: actions/upload-artifact@v4 59 | with: 60 | name: dist 61 | path: ${{ inputs.working-directory }}/dist/ 62 | 63 | - name: Check Version 64 | id: check-version 65 | shell: bash 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 69 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 70 | 71 | test-pypi-publish: 72 | needs: 73 | - build 74 | uses: 75 | ./.github/workflows/_test_release.yml 76 | with: 77 | working-directory: ${{ inputs.working-directory }} 78 | secrets: inherit 79 | 80 | pre-release-checks: 81 | needs: 82 | - build 83 | - test-pypi-publish 84 | environment: Scheduled testing publish 85 | runs-on: ubuntu-latest 86 | steps: 87 | - uses: actions/checkout@v4 88 | 89 | # We explicitly *don't* set up caching here. This ensures our tests are 90 | # maximally sensitive to catching breakage. 91 | # 92 | # For example, here's a way that caching can cause a falsely-passing test: 93 | # - Make the langchain package manifest no longer list a dependency package 94 | # as a requirement. This means it won't be installed by `pip install`, 95 | # and attempting to use it would cause a crash. 96 | # - That dependency used to be required, so it may have been cached. 97 | # When restoring the venv packages from cache, that dependency gets included. 98 | # - Tests pass, because the dependency is present even though it wasn't specified. 99 | # - The package is published, and it breaks on the missing dependency when 100 | # used in the real world. 101 | 102 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 103 | uses: "./.github/actions/poetry_setup" 104 | with: 105 | python-version: ${{ env.PYTHON_VERSION }} 106 | poetry-version: ${{ env.POETRY_VERSION }} 107 | working-directory: ${{ inputs.working-directory }} 108 | 109 | - name: Import published package 110 | shell: bash 111 | working-directory: ${{ inputs.working-directory }} 112 | env: 113 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 114 | VERSION: ${{ needs.build.outputs.version }} 115 | # Here we use: 116 | # - The default regular PyPI index as the *primary* index, meaning 117 | # that it takes priority (https://pypi.org/simple) 118 | # - The test PyPI index as an extra index, so that any dependencies that 119 | # are not found on test PyPI can be resolved and installed anyway. 120 | # (https://test.pypi.org/simple). This will include the PKG_NAME==VERSION 121 | # package because VERSION will not have been uploaded to regular PyPI yet. 122 | # - attempt install again after 5 seconds if it fails because there is 123 | # sometimes a delay in availability on test pypi 124 | run: | 125 | poetry run pip install \ 126 | --extra-index-url https://test.pypi.org/simple/ \ 127 | "$PKG_NAME==$VERSION" || \ 128 | ( \ 129 | sleep 5 && \ 130 | poetry run pip install \ 131 | --extra-index-url https://test.pypi.org/simple/ \ 132 | "$PKG_NAME==$VERSION" \ 133 | ) 134 | 135 | # Replace all dashes in the package name with underscores, 136 | # since that's how Python imports packages with dashes in the name. 137 | IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)" 138 | 139 | poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))" 140 | 141 | - name: Import test dependencies 142 | run: poetry install --with test 143 | working-directory: ${{ inputs.working-directory }} 144 | 145 | # Overwrite the local version of the package with the test PyPI version. 146 | - name: Import published package (again) 147 | working-directory: ${{ inputs.working-directory }} 148 | shell: bash 149 | env: 150 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 151 | VERSION: ${{ needs.build.outputs.version }} 152 | run: | 153 | poetry run pip install \ 154 | --extra-index-url https://test.pypi.org/simple/ \ 155 | "$PKG_NAME==$VERSION" 156 | 157 | - name: Run unit tests 158 | run: make tests 159 | env: 160 | ZHIPUAI_API_KEY: ${{ secrets.ZHIPUAI_API_KEY }} 161 | ZHIPUAI_BASE_URL: ${{ secrets.ZHIPUAI_BASE_URL }} 162 | working-directory: ${{ inputs.working-directory }} 163 | 164 | - name: Run integration tests 165 | env: 166 | ZHIPUAI_API_KEY: ${{ secrets.ZHIPUAI_API_KEY }} 167 | ZHIPUAI_BASE_URL: ${{ secrets.ZHIPUAI_BASE_URL }} 168 | run: make integration_tests 169 | working-directory: ${{ inputs.working-directory }} 170 | 171 | publish: 172 | needs: 173 | - build 174 | - test-pypi-publish 175 | - pre-release-checks 176 | environment: Scheduled testing publish 177 | runs-on: ubuntu-latest 178 | 179 | defaults: 180 | run: 181 | working-directory: ${{ inputs.working-directory }} 182 | 183 | steps: 184 | - uses: actions/checkout@v4 185 | 186 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 187 | uses: "./.github/actions/poetry_setup" 188 | with: 189 | python-version: ${{ env.PYTHON_VERSION }} 190 | poetry-version: ${{ env.POETRY_VERSION }} 191 | working-directory: ${{ inputs.working-directory }} 192 | cache-key: release 193 | 194 | - uses: actions/download-artifact@v4 195 | with: 196 | name: dist 197 | path: ${{ inputs.working-directory }}/dist/ 198 | 199 | - name: Publish package distributions to PyPI 200 | uses: pypa/gh-action-pypi-publish@release/v1 201 | 202 | with: 203 | packages-dir: ${{ inputs.working-directory }}/dist/ 204 | verbose: true 205 | print-hash: true 206 | user: __token__ 207 | password: ${{ secrets.PYPI_API_TOKEN }} 208 | # We overwrite any existing distributions with the same name and version. 209 | # This is *only for CI use* and is *extremely dangerous* otherwise! 210 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 211 | skip-existing: true 212 | 213 | mark-release: 214 | needs: 215 | - build 216 | - test-pypi-publish 217 | - pre-release-checks 218 | - publish 219 | environment: Scheduled testing publish 220 | runs-on: ubuntu-latest 221 | permissions: 222 | # This permission is needed by `ncipollo/release-action` to 223 | # create the GitHub release. 224 | contents: write 225 | id-token: none 226 | 227 | defaults: 228 | run: 229 | working-directory: ${{ inputs.working-directory }} 230 | 231 | steps: 232 | - uses: actions/checkout@v4 233 | 234 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 235 | uses: "./.github/actions/poetry_setup" 236 | with: 237 | python-version: ${{ env.PYTHON_VERSION }} 238 | poetry-version: ${{ env.POETRY_VERSION }} 239 | working-directory: ${{ inputs.working-directory }} 240 | cache-key: release 241 | 242 | - uses: actions/download-artifact@v4 243 | with: 244 | name: dist 245 | path: ${{ inputs.working-directory }}/dist/ 246 | 247 | - name: Create Release 248 | uses: ncipollo/release-action@v1 249 | if: ${{ inputs.working-directory == '.' }} 250 | with: 251 | artifacts: "dist/*" 252 | token: ${{ secrets.GITHUB_TOKEN }} 253 | draft: false 254 | generateReleaseNotes: true 255 | tag: v${{ needs.build.outputs.version }} 256 | commit: main 257 | -------------------------------------------------------------------------------- /zhipuai_sdk/.github/workflows/_test_release.yml: -------------------------------------------------------------------------------- 1 | name: test-release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | PYTHON_VERSION: "3.8" 14 | 15 | jobs: 16 | build: 17 | if: github.ref == 'refs/heads/main' 18 | runs-on: ubuntu-latest 19 | 20 | outputs: 21 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 22 | version: ${{ steps.check-version.outputs.version }} 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 28 | uses: "./.github/actions/poetry_setup" 29 | with: 30 | python-version: ${{ env.PYTHON_VERSION }} 31 | poetry-version: ${{ env.POETRY_VERSION }} 32 | working-directory: ${{ inputs.working-directory }} 33 | cache-key: release 34 | 35 | # We want to keep this build stage *separate* from the release stage, 36 | # so that there's no sharing of permissions between them. 37 | # The release stage has trusted publishing and GitHub repo contents write access, 38 | # and we want to keep the scope of that access limited just to the release job. 39 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 40 | # could get access to our GitHub or PyPI credentials. 41 | # 42 | # Per the trusted publishing GitHub Action: 43 | # > It is strongly advised to separate jobs for building [...] 44 | # > from the publish job. 45 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 46 | - name: Build project for distribution 47 | run: poetry build 48 | working-directory: ${{ inputs.working-directory }} 49 | 50 | - name: Upload build 51 | uses: actions/upload-artifact@v4 52 | with: 53 | name: test-dist 54 | path: ${{ inputs.working-directory }}/dist/ 55 | 56 | - name: Check Version 57 | id: check-version 58 | shell: bash 59 | working-directory: ${{ inputs.working-directory }} 60 | run: | 61 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 62 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 63 | 64 | publish: 65 | needs: 66 | - build 67 | runs-on: ubuntu-latest 68 | environment: Scheduled testing publish 69 | # permissions: 70 | # id-token: none # This is required for requesting the JWT 71 | 72 | steps: 73 | - uses: actions/checkout@v4 74 | 75 | - uses: actions/download-artifact@v4 76 | with: 77 | name: test-dist 78 | path: ${{ inputs.working-directory }}/dist/ 79 | 80 | - name: Publish to test PyPI 81 | uses: pypa/gh-action-pypi-publish@release/v1 82 | with: 83 | user: __token__ 84 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 85 | packages-dir: ${{ inputs.working-directory }}/dist/ 86 | verbose: true 87 | print-hash: true 88 | repository-url: https://test.pypi.org/legacy/ 89 | # We overwrite any existing distributions with the same name and version. 90 | # This is *only for CI use* and is *extremely dangerous* otherwise! 91 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 92 | skip-existing: true 93 | -------------------------------------------------------------------------------- /zhipuai_sdk/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | *.pyc 4 | __pycache__/ 5 | .pytest_cache/ 6 | .DS_Store/ 7 | 8 | *.swp 9 | *~ 10 | 11 | build/ 12 | dist/ 13 | eggs/ 14 | .eggs/ 15 | *.egg-info/ 16 | test/ 17 | examples/ 18 | .pypirc 19 | /poetry.lock 20 | /logs/local_1714381518619/zhipuai.log 21 | /logs/local_1714381521241/zhipuai.log 22 | /logs/local_1714381523244/zhipuai.log 23 | /logs/local_1714381534486/zhipuai.log 24 | /logs/local_1714381536161/zhipuai.log 25 | /logs/local_1714381537762/zhipuai.log 26 | /logs/local_1714381540228/zhipuai.log 27 | /logs/local_1714381541675/zhipuai.log 28 | /logs/local_1714381544183/zhipuai.log 29 | /logs/local_1714381554442/zhipuai.log 30 | /logs/local_1714381557457/zhipuai.log 31 | /logs/local_1714381560830/zhipuai.log 32 | /logs/local_1714381561389/zhipuai.log 33 | /logs/local_1714381570760/zhipuai.log 34 | /logs/local_1714381572227/zhipuai.log 35 | /logs/local_1714381574060/zhipuai.log 36 | /logs/local_1714381574645/zhipuai.log 37 | /logs/local_1714381586122/zhipuai.log 38 | /logs/local_1714381596219/zhipuai.log 39 | /logs/local_1714381598857/zhipuai.log 40 | /logs/local_1714381599471/zhipuai.log 41 | /logs/local_1714381609065/zhipuai.log 42 | /logs/local_1714381621513/zhipuai.log 43 | /logs/local_1714381622121/zhipuai.log 44 | /logs/local_1714381833346/zhipuai.log 45 | /logs/local_1714381918707/zhipuai.log 46 | /logs/local_1714382871154/zhipuai.log 47 | /logs/local_1714382884150/zhipuai.log 48 | /logs/local_1714382885846/zhipuai.log 49 | /logs/local_1714382887557/zhipuai.log 50 | /logs/local_1714382889632/zhipuai.log 51 | /logs/local_1714382891424/zhipuai.log 52 | /logs/local_1714382893632/zhipuai.log 53 | /logs/local_1714382904235/zhipuai.log 54 | /logs/local_1714382907134/zhipuai.log 55 | /logs/local_1714382908915/zhipuai.log 56 | /logs/local_1714382912290/zhipuai.log 57 | /logs/local_1714382915919/zhipuai.log 58 | /logs/local_1714382935679/zhipuai.log 59 | /logs/local_1714382945693/zhipuai.log 60 | /logs/local_1714382951968/zhipuai.log 61 | /logs/local_1714382953236/zhipuai.log 62 | /logs/local_1714382954392/zhipuai.log 63 | /logs/local_1714382957260/zhipuai.log 64 | /logs/local_1714382958299/zhipuai.log 65 | /logs/local_1714382959178/zhipuai.log 66 | /logs/local_1714382963052/zhipuai.log 67 | /logs/local_1714382964289/zhipuai.log 68 | /logs/local_1714382975167/zhipuai.log 69 | /logs/local_1714382979034/zhipuai.log 70 | -------------------------------------------------------------------------------- /zhipuai_sdk/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | ###################### 7 | # TESTING AND COVERAGE 8 | ###################### 9 | 10 | # Define a variable for the test file path. 11 | TEST_FILE ?= tests/unit_tests/ 12 | 13 | 14 | test tests: 15 | poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) 16 | integration_tests: 17 | poetry run pytest tests/integration_tests 18 | 19 | 20 | ###################### 21 | # LINTING AND FORMATTING 22 | ###################### 23 | 24 | # Define a variable for Python and notebook files. 25 | PYTHON_FILES=. 26 | MYPY_CACHE=.mypy_cache 27 | lint format: PYTHON_FILES=. 28 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 29 | lint_package: PYTHON_FILES=zhipuai 30 | lint_tests: PYTHON_FILES=tests 31 | lint_tests: MYPY_CACHE=.mypy_cache_test 32 | 33 | lint lint_diff lint_package lint_tests: 34 | ./scripts/check_pydantic.sh . 35 | ./scripts/lint_imports.sh 36 | poetry run ruff . 37 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff 38 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES) 39 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 40 | 41 | format format_diff: 42 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) 43 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES) 44 | 45 | 46 | ###################### 47 | # HELP 48 | ###################### 49 | 50 | help: 51 | @echo '-- LINTING --' 52 | @echo 'format - run code formatters' 53 | @echo 'lint - run linters' 54 | @echo '-- TESTS --' 55 | @echo 'test - run unit tests' 56 | @echo 'tests - run unit tests (alias for "make test")' 57 | @echo 'test TEST_FILE= - run all tests in file' 58 | -------------------------------------------------------------------------------- /zhipuai_sdk/README.md: -------------------------------------------------------------------------------- 1 | # 智谱大模型开放接口SDK 2 | 3 | [![PyPI version](https://img.shields.io/pypi/v/zhipuai.svg)](https://pypi.org/project/zhipuai/) 4 | 5 | 智谱[开放平台](https://open.bigmodel.cn/dev/api)大模型接口 Python SDK(Big Model API SDK in Python),让开发者更便捷的调用智谱开放API 6 | 7 | 8 | ## 简介 9 | - 对所有接口进行了类型封装。 10 | - 初始化client并调用成员函数,无需关注http调用过程的各种细节,所见即所得。 11 | - 默认缓存token。 12 | 13 | ## 安装 14 | 15 | 16 | - 运行环境: [**Python>=3.7**](https://www.python.org/) 17 | 18 | - 使用 pip 安装 `zhipuai` 软件包及其依赖 19 | 20 | ```sh 21 | pip install zhipuai 22 | ``` 23 | 24 | ## 使用 25 | 26 | - 调用流程: 27 | 1. 使用 APISecretKey 创建 Client 28 | 2. 调用 Client 对应的成员方法 29 | - 开放平台[接口文档](https://open.bigmodel.cn/dev/api)以及[使用指南](https://open.bigmodel.cn/dev/howuse/)中有更多的 demo 示例,请在 demo 中使用自己的 ApiKey 进行测试。 30 | 31 | ### 创建Client 32 | sdk支持通过环境变量配置APIKey 33 | - env 34 | 35 | `ZHIPUAI_API_KEY`: 您的APIKey 36 | 37 | `ZHIPUAI_BASE_URL`: 您的API地址 38 | 39 | - 也支持通过代码传入APIKey 40 | ```python 41 | from zhipuai import ZhipuAI 42 | 43 | client = ZhipuAI( 44 | api_key="", # 填写您的 APIKey 45 | ) 46 | ``` 47 | ### 客户端网络链接配置 48 | 在`core/_http_client.py`中,可以配置网络链接的超时时间,重试次数,限制等参数 49 | ```python 50 | # 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0` 51 | ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) 52 | # 通过 `retry` 参数控制重试次数,默认为3次 53 | ZHIPUAI_DEFAULT_MAX_RETRIES = 3 54 | # 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10` 55 | ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) 56 | 57 | ``` 58 | 同样在`ZhipuAI`入参中可以配置 59 | ```python 60 | client = ZhipuAI( 61 | timeout= httpx.Timeout(timeout=300.0, connect=8.0), 62 | max_retries=3, 63 | ) 64 | ``` 65 | 66 | 67 | ### 同步调用 68 | 69 | ```python 70 | from zhipuai import ZhipuAI 71 | 72 | client = ZhipuAI() # 填写您自己的APIKey 73 | response = client.chat.completions.create( 74 | model="glm-4", # 填写需要调用的模型名称 75 | messages=[ 76 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 77 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 78 | {"role": "user", "content": "智谱AI开放平台"}, 79 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 80 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 81 | ], 82 | tools=[ 83 | { 84 | "type": "web_search", 85 | "web_search": { 86 | "search_query": "帮我看看清华的升学率", 87 | "search_result": True, 88 | } 89 | } 90 | ], 91 | # 拓展参数 92 | extra_body={"temperature": 0.5, "max_tokens": 50}, 93 | ) 94 | print(response) 95 | ``` 96 | 97 | ### SSE 调用 98 | 99 | ```python 100 | from zhipuai import ZhipuAI 101 | client = ZhipuAI(api_key="") # 请填写您自己的APIKey 102 | response = client.chat.completions.create( 103 | model="", # 填写需要调用的模型名称 104 | messages=[ 105 | {"role": "system", "content": "你是一个人工智能助手,你叫叫chatGLM"}, 106 | {"role": "user", "content": "你好!你叫什么名字"}, 107 | ], 108 | stream=True, 109 | ) 110 | for chunk in response: 111 | print(chunk.choices[0].delta) 112 | ``` 113 | 114 | ### 多模态 115 | ```python 116 | 117 | 118 | # Function to encode the image 119 | def encode_image(image_path): 120 | import base64 121 | with open(image_path, "rb") as image_file: 122 | return base64.b64encode(image_file.read()).decode('utf-8') 123 | 124 | 125 | def test_completions_vis(): 126 | client = ZhipuAI() # 填写您自己的APIKey 127 | base64_image = encode_image("img/MetaGLM.png") 128 | response = client.chat.completions.create( 129 | model="glm-4v", # 填写需要调用的模型名称 130 | extra_body={"temperature": 0.5, "max_tokens": 50}, 131 | messages=[ 132 | { 133 | "role": "user", 134 | "content": [ 135 | { 136 | "type": "text", 137 | "text": "图里有什么" 138 | }, 139 | 140 | # { 141 | # "type": "image_url", 142 | # "image_url": { 143 | # "url": "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f" 144 | # } 145 | # }, 146 | { 147 | "type": "image_url", 148 | "image_url": { 149 | "url": f"data:image/jpeg;base64,{base64_image}" 150 | } 151 | } 152 | ] 153 | } 154 | ] 155 | ) 156 | print(response) 157 | 158 | test_completions_vis() 159 | ``` 160 | 161 | ### 角色扮演 162 | > 提供能力的模型名称,请从官网获取 163 | ```python 164 | 165 | def test_completions_charglm(): 166 | client = ZhipuAI() # 请填写您自己的APIKey 167 | response = client.chat.completions.create( 168 | model="charglm-3", # 填写需要调用的模型名称 169 | messages=[ 170 | { 171 | "role": "user", 172 | "content": "请问你在做什么" 173 | } 174 | ], 175 | meta={ 176 | "user_info": "我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。", 177 | "bot_info": "苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。", 178 | "bot_name": "苏梦远", 179 | "user_name": "陆星辰" 180 | }, 181 | ) 182 | print(response) 183 | test_completions_charglm() 184 | ``` 185 | 186 | 187 | 188 | ### 异常处理 189 | 190 | 模块定义了一些统一的参数返回(例如:响应错误,网络超时错误) 191 | 192 | 业务定义了http错误的响应类 (在接口返回,40x或者50x), 会抛出 `zhipuai.APIStatusError` ,包含 `status_code` 和 `response` 属性. 它们都是继承 `zhipuai.APIStatusError`. 193 | 其它Exception,属于不可预知的错误 194 | ```python 195 | from zhipuai import ZhipuAI 196 | import zhipuai 197 | client = ZhipuAI() # 填写您自己的APIKey 198 | try: 199 | response = client.chat.completions.create( 200 | model="glm-4", # 填写需要调用的模型名称 201 | messages=[ 202 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 203 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 204 | {"role": "user", "content": "智谱AI开放平台"}, 205 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 206 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 207 | ] 208 | ) 209 | print(response) 210 | 211 | except zhipuai.APIStatusError as err: 212 | print(err) 213 | except zhipuai.APITimeoutError as err: 214 | print(err) 215 | ``` 216 | 217 | Error codes are as followed: 218 | 219 | | Status Code | Error Type | 220 | |-------------| -------------------------- | 221 | | 400 | `APIRequestFailedError` | 222 | | 401 | `APIAuthenticationError` | 223 | | 429 | `APIReachLimitError` | 224 | | 500 | `APIInternalError` | 225 | | 503 | `APIServerFlowExceedError` | 226 | | N/A | `APIStatusError` | 227 | 228 | 229 | 230 | ### 更新日志 231 | 232 | `2024-4-29` 233 | - 一些 `python3.7` 的代码适配问题, 234 | - 接口失败重试机制,通过 `retry` 参数控制重试次数,默认为3次 235 | - 接口超时策略调整,通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0` 236 | - 对话模块增加超拟人大模型参数支持,`model="charglm-3"`, `meta`参数支持 237 | 238 | `2024-4-23` 239 | - 一些兼容 `pydantic<3,>=1.9.0 ` 的代码, 240 | - 报文处理的业务请求参数和响应参数可通过配置扩充 241 | - 兼容了一些参数 `top_p:1`,`temperture:0`(do_sample重写false,参数top_p temperture不生效) 242 | - 图像理解部分, image_url参数base64内容包含 `data:image/jpeg;base64`兼容 243 | - 删除jwt认证逻辑 244 | -------------------------------------------------------------------------------- /zhipuai_sdk/poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | 4 | [installer] 5 | modern-installation = false 6 | -------------------------------------------------------------------------------- /zhipuai_sdk/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "zhipuai" 3 | version = "2.0.1.20240429" 4 | description = "A SDK library for accessing big model apis from ZhipuAI" 5 | authors = ["Zhipu AI"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.7.1,<3.12,!=3.9.7 " 10 | httpx = ">=0.23.0" 11 | pydantic = ">=1.9.0,<3.0" 12 | cachetools = ">=4.2.2" 13 | pyjwt = "~=2.8.0" 14 | 15 | 16 | [tool.poetry.group.test.dependencies] 17 | # The only dependencies that should be added are 18 | # dependencies used for running tests (e.g., pytest, freezegun, response). 19 | # Any dependencies that do not meet that criteria will be removed. 20 | pytest = "^7.3.0" 21 | pytest-cov = "^4.0.0" 22 | pytest-dotenv = "^0.5.2" 23 | duckdb-engine = "^0.9.2" 24 | pytest-watcher = "^0.2.6" 25 | freezegun = "^1.2.2" 26 | responses = "^0.22.0" 27 | pytest-asyncio = { version = "^0.23.2", python = "^3.8" } 28 | lark = "^1.1.5" 29 | pandas = { version = "^2.0.0", python = "^3.8" } 30 | pytest-mock = "^3.10.0" 31 | pytest-socket = "^0.6.0" 32 | syrupy = { version = "^4.0.2", python = ">=3.8.1,<3.9.7 || >3.9.7,<3.12" } 33 | requests-mock = "^1.11.0" 34 | 35 | [tool.poetry.group.test_langchain.dependencies] 36 | pytest = "^7.3.0" 37 | pytest-cov = "^4.0.0" 38 | pytest-dotenv = "^0.5.2" 39 | duckdb-engine = "^0.9.2" 40 | pytest-watcher = "^0.2.6" 41 | freezegun = "^1.2.2" 42 | responses = "^0.22.0" 43 | pytest-asyncio = { version = "^0.23.2", python = "^3.8" } 44 | lark = "^1.1.5" 45 | pandas = { version = "^2.0.0", python = "^3.8" } 46 | pytest-mock = "^3.10.0" 47 | pytest-socket = "^0.6.0" 48 | syrupy = { version = "^4.0.2", python = ">=3.8.1,<3.9.7 || >3.9.7,<3.12" } 49 | requests-mock = "^1.11.0" 50 | langchain = { version = "0.1.5", python = ">=3.8.1,<3.9.7 || >3.9.7,<3.12" } 51 | 52 | 53 | [tool.poetry.group.lint] 54 | optional = true 55 | 56 | [tool.poetry.group.lint.dependencies] 57 | ruff = "^0.1.5" 58 | 59 | [tool.poetry.extras] 60 | cli = ["typer"] 61 | # An extra used to be able to add extended testing. 62 | # Please use new-line on formatting to make it easier to add new packages without 63 | # merge-conflicts 64 | extended_testing = [ 65 | "langchain", 66 | ] 67 | 68 | [tool.ruff.lint] 69 | select = [ 70 | "E", # pycodestyle 71 | "F", # pyflakes 72 | "I", # isort 73 | "T201", # print 74 | ] 75 | 76 | 77 | [tool.coverage.run] 78 | omit = [ 79 | "tests/*", 80 | ] 81 | 82 | [build-system] 83 | requires = ["poetry-core>=1.0.0", "poetry-plugin-pypi-mirror==0.4.2"] 84 | build-backend = "poetry.core.masonry.api" 85 | 86 | [tool.pytest.ini_options] 87 | # --strict-markers will raise errors on unknown marks. 88 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks 89 | # 90 | # https://docs.pytest.org/en/7.1.x/reference/reference.html 91 | # --strict-config any warnings encountered while parsing the `pytest` 92 | # section of the configuration file raise errors. 93 | # 94 | # https://github.com/tophat/syrupy 95 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. 96 | addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv" 97 | # Registering custom markers. 98 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers 99 | markers = [ 100 | "requires: mark tests as requiring a specific library", 101 | "scheduled: mark tests to run in scheduled testing", 102 | "compile: mark placeholder test used to compile integration tests without running them" 103 | ] 104 | asyncio_mode = "auto" 105 | 106 | # https://python-poetry.org/docs/repositories/ 107 | [[tool.poetry.source]] 108 | name = "tsinghua" 109 | url = "https://pypi.tuna.tsinghua.edu.cn/simple/" 110 | priority = "default" 111 | -------------------------------------------------------------------------------- /zhipuai_sdk/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", encoding="utf-8") as readme_file: 4 | readme = readme_file.read() 5 | 6 | with open("requirements.txt") as requirements_file: 7 | requirements = requirements_file.read().splitlines() 8 | 9 | setup( 10 | name="zhipuai", 11 | version="v2.0.1", 12 | description="A SDK library for accessing big model apis from ZhipuAI", 13 | long_description=readme, 14 | long_description_content_type="text/markdown", 15 | author="Zhipu AI", 16 | url="https://open.bigmodel.cn/", 17 | packages=find_packages(exclude=['test', 'examples']), 18 | include_package_data=True, 19 | install_requires=requirements, 20 | python_requires=">=3.7", 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "Topic :: Software Development :: Libraries :: Python Modules", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/agent/glm3_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. 3 | """ 4 | 5 | import json 6 | import logging 7 | import typing 8 | from typing import Sequence, Optional, Union 9 | 10 | import langchain_core.prompts 11 | import langchain_core.messages 12 | from langchain_core.runnables import Runnable, RunnablePassthrough 13 | from langchain.agents.agent import AgentOutputParser 14 | from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser 15 | from langchain.prompts.chat import ChatPromptTemplate 16 | from langchain.output_parsers import OutputFixingParser 17 | from langchain.schema import AgentAction, AgentFinish, OutputParserException 18 | from langchain.schema.language_model import BaseLanguageModel 19 | from langchain.tools.base import BaseTool 20 | from pydantic.v1 import Field 21 | 22 | from pydantic.v1.schema import model_schema 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}" 27 | HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}" 28 | 29 | 30 | class StructuredGLM3ChatOutputParser(AgentOutputParser): 31 | """ 32 | Output parser with retries for the structured chat agent. 33 | """ 34 | base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) 35 | output_fixing_parser: Optional[OutputFixingParser] = None 36 | 37 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]: 38 | print(text) 39 | 40 | special_tokens = ["Action:", "<|observation|>"] 41 | first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) 42 | text = text[:first_index] 43 | 44 | if "tool_call" in text: 45 | action_end = text.find("```") 46 | action = text[:action_end].strip() 47 | params_str_start = text.find("(") + 1 48 | params_str_end = text.rfind(")") 49 | params_str = text[params_str_start:params_str_end] 50 | 51 | params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] 52 | params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs} 53 | 54 | action_json = { 55 | "action": action, 56 | "action_input": params 57 | } 58 | else: 59 | action_json = { 60 | "action": "Final Answer", 61 | "action_input": text 62 | } 63 | action_str = f""" 64 | Action: 65 | ``` 66 | {json.dumps(action_json, ensure_ascii=False)} 67 | ```""" 68 | try: 69 | if self.output_fixing_parser is not None: 70 | parsed_obj: Union[ 71 | AgentAction, AgentFinish 72 | ] = self.output_fixing_parser.parse(action_str) 73 | else: 74 | parsed_obj = self.base_parser.parse(action_str) 75 | return parsed_obj 76 | except Exception as e: 77 | raise OutputParserException(f"Could not parse LLM output: {text}") from e 78 | 79 | @property 80 | def _type(self) -> str: 81 | return "StructuredGLM3ChatOutputParser" 82 | 83 | 84 | def create_structured_glm3_chat_agent( 85 | llm: BaseLanguageModel, tools: Sequence[BaseTool] 86 | ) -> Runnable: 87 | tools_json = [] 88 | for tool in tools: 89 | tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} 90 | description = tool.description.split(" - ")[ 91 | 1].strip() if tool.description and " - " in tool.description else tool.description 92 | parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in 93 | tool_schema.get("properties", {}).items()} 94 | simplified_config_langchain = { 95 | "name": tool.name, 96 | "description": description, 97 | "parameters": parameters 98 | } 99 | tools_json.append(simplified_config_langchain) 100 | tools = "\n".join([str(tool) for tool in tools_json]) 101 | 102 | prompt = ChatPromptTemplate( 103 | input_variables=["input", "agent_scratchpad"], 104 | input_types={'chat_history': typing.List[typing.Union[ 105 | langchain_core.messages.ai.AIMessage, 106 | langchain_core.messages.human.HumanMessage, 107 | langchain_core.messages.chat.ChatMessage, 108 | langchain_core.messages.system.SystemMessage, 109 | langchain_core.messages.function.FunctionMessage, 110 | langchain_core.messages.tool.ToolMessage]] 111 | }, 112 | messages=[ 113 | langchain_core.prompts.SystemMessagePromptTemplate( 114 | prompt=langchain_core.prompts.PromptTemplate( 115 | input_variables=['tools'], 116 | template=SYSTEM_PROMPT) 117 | ), 118 | langchain_core.prompts.MessagesPlaceholder( 119 | variable_name='chat_history', 120 | optional=True 121 | ), 122 | langchain_core.prompts.HumanMessagePromptTemplate( 123 | prompt=langchain_core.prompts.PromptTemplate( 124 | input_variables=['agent_scratchpad', 'input'], 125 | template=HUMAN_MESSAGE 126 | ) 127 | ) 128 | ] 129 | 130 | ).partial(tools=tools) 131 | 132 | llm_with_stop = llm.bind(stop=["<|observation|>"]) 133 | agent = ( 134 | RunnablePassthrough.assign( 135 | agent_scratchpad=lambda x: x["intermediate_steps"], 136 | ) 137 | | prompt 138 | | llm_with_stop 139 | | StructuredGLM3ChatOutputParser() 140 | ) 141 | return agent 142 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/agent/test_agent.py: -------------------------------------------------------------------------------- 1 | from langchain.agents import AgentExecutor, create_tool_calling_agent 2 | from langchain_community.chat_models import ChatZhipuAI 3 | from langchain_community.tools import ShellTool 4 | from langchain_community.tools.tavily_search import TavilySearchResults 5 | from langchain_core.prompts import ChatPromptTemplate 6 | 7 | if __name__ == "__main__": 8 | tools = [ShellTool()] 9 | 10 | prompt = ChatPromptTemplate.from_messages( 11 | [ 12 | ( 13 | "system", 14 | "You are a helpful assistant. Make sure to use the terminal tool for information.", 15 | ), 16 | ("placeholder", "{chat_history}"), 17 | ("human", "{input}"), 18 | ("placeholder", "{agent_scratchpad}"), 19 | ] 20 | ) 21 | 22 | llm = ChatZhipuAI() 23 | # Construct the Tools agent 24 | agent = create_tool_calling_agent(llm, tools, prompt) 25 | 26 | # Create an agent executor by passing in the agent and tools 27 | agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) 28 | agent_executor.invoke({"input": "查看本地目录"}) 29 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/agent/test_create_structured_funcation_call.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from langchain.chains.openai_functions import create_structured_output_runnable 4 | from langchain_community.chat_models import ChatOpenAI, ChatZhipuAI 5 | from langchain_core.prompts import ChatPromptTemplate 6 | from langchain_core.pydantic_v1 import BaseModel, Field 7 | import logging 8 | import langchain 9 | 10 | langchain.verbose = True 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.DEBUG) 13 | logging.basicConfig(level=logging.DEBUG) 14 | # 控制台打印 15 | handler = logging.StreamHandler() 16 | handler.setLevel(logging.DEBUG) 17 | 18 | logger.addHandler(handler) 19 | 20 | 21 | class Personality(BaseModel): 22 | """性格信息.""" 23 | 24 | personality: str = Field(..., description="性格评价") 25 | 26 | 27 | def test_create_structured_funcation_call() -> None: 28 | """Test create_structured_output_runnable. 测试创建结构化输出可运行对象。""" 29 | llm = ChatZhipuAI() 30 | prompt = ChatPromptTemplate.from_messages( 31 | [ 32 | ("system", """根据您提供的信息,您的性格特点可以总结如下: 33 | 34 | 1. 热情和温柔:您在描述天气和气氛时使用了"温柔长裙风"这样的形容词,表现出您对温暖和舒适的情感。 35 | 36 | 2. 情感表达:您在文本中表达了对一个叫"宝宝"的角色的期待和关心,这显示了您的感性和情感表达能力。 37 | 38 | 3. 好奇心和幽默感:您提到了要做大胆的事情,并且以"嘻嘻"结束,这暗示了您对新奇事物的好奇心和幽默感。 39 | 40 | 4. 关心家人和亲情:您提到了弟弟给了三颗糖,表现出您关心家人的情感。 41 | 42 | 5. 乐于分享和帮助:您提到要给宝宝剥虾并询问宝宝是否想知道小鱼在说什么,显示出您愿意分享和帮助他人的特点。 43 | 44 | 6. 可能有一些难以理解的部分:在文本中也出现了一些不太清楚的情节,如呼救情节和提到"小肚小肚",这可能表现出您的思维有时候会有些混乱或不太连贯。 45 | 46 | 总的来说,您的性格特点包括热情、情感表达能力、好奇心、幽默感、亲情关怀以及乐于分享和帮助他人。 47 | 48 | """), 49 | ("ai", """personality 50 | 热情、情感表达能力、好奇心、幽默感、亲情关怀以及乐于分享和帮助他人"""), 51 | ("human", "{input}"), 52 | ] 53 | ) 54 | chain = create_structured_output_runnable(Personality, llm, prompt) 55 | out = chain.invoke({"input": """根据您提供的信息,您的性格特点可以总结如下: 56 | 57 | 1. 热情和温柔:您在描述天气和气氛时使用了"温柔长裙风"这样的形容词,表现出您对温暖和舒适的情感。 58 | 59 | 2. 情感表达:您在文本中表达了对一个叫"宝宝"的角色的期待和关心,这显示了您的感性和情感表达能力。 60 | 61 | 3. 好奇心和幽默感:您提到了要做大胆的事情,并且以"嘻嘻"结束,这暗示了您对新奇事物的好奇心和幽默感。 62 | 63 | 4. 关心家人和亲情:您提到了弟弟给了三颗糖,表现出您关心家人的情感。 64 | 65 | 5. 乐于分享和帮助:您提到要给宝宝剥虾并询问宝宝是否想知道小鱼在说什么,显示出您愿意分享和帮助他人的特点。 66 | 67 | 6. 可能有一些难以理解的部分:在文本中也出现了一些不太清楚的情节,如呼救情节和提到"小肚小肚",这可能表现出您的思维有时候会有些混乱或不太连贯。 68 | 69 | 总的来说,您的性格特点包括热情、情感表达能力、好奇心、幽默感、亲情关怀以及乐于分享和帮助他人。"""}) 70 | 71 | logger.info(out) 72 | 73 | 74 | if __name__ == "__main__": 75 | test_create_structured_funcation_call() -------------------------------------------------------------------------------- /zhipuai_sdk/tests/agent/test_prompt_parser_agent.py: -------------------------------------------------------------------------------- 1 | from langchain.agents import AgentExecutor 2 | from langchain_community.chat_models import ChatZhipuAI 3 | from langchain_community.tools import ShellTool 4 | 5 | from tests.agent.glm3_agent import create_structured_glm3_chat_agent 6 | 7 | if __name__ == "__main__": 8 | tools = [ShellTool()] 9 | 10 | llm = ChatZhipuAI() 11 | # Construct the Tools agent 12 | agent = create_structured_glm3_chat_agent(llm, tools) 13 | 14 | # Create an agent executor by passing in the agent and tools 15 | agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) 16 | out = agent_executor.invoke({"input": "查看本地目录"}) 17 | print(out) -------------------------------------------------------------------------------- /zhipuai_sdk/tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Configuration for unit tests.""" 2 | import logging 3 | from importlib import util 4 | from typing import Dict, List, Sequence 5 | 6 | import pytest 7 | from pytest import Config, Function, Parser 8 | from zhipuai.core.logs import ( 9 | get_config_dict, 10 | get_log_file, 11 | get_timestamp_ms, 12 | ) 13 | 14 | 15 | 16 | def pytest_addoption(parser: Parser) -> None: 17 | """Add custom command line options to pytest.""" 18 | parser.addoption( 19 | "--only-extended", 20 | action="store_true", 21 | help="Only run extended tests. Does not allow skipping any extended tests.", 22 | ) 23 | parser.addoption( 24 | "--only-core", 25 | action="store_true", 26 | help="Only run core tests. Never runs any extended tests.", 27 | ) 28 | 29 | 30 | def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: 31 | """Add implementations for handling custom markers. 32 | 33 | At the moment, this adds support for a custom `requires` marker. 34 | 35 | The `requires` marker is used to denote tests that require one or more packages 36 | to be installed to run. If the package is not installed, the test is skipped. 37 | 38 | The `requires` marker syntax is: 39 | 40 | .. code-block:: python 41 | 42 | @pytest.mark.requires("package1", "package2") 43 | def test_something(): 44 | ... 45 | """ 46 | # Mapping from the name of a package to whether it is installed or not. 47 | # Used to avoid repeated calls to `util.find_spec` 48 | required_pkgs_info: Dict[str, bool] = {} 49 | 50 | only_extended = config.getoption("--only-extended") or False 51 | only_core = config.getoption("--only-core") or False 52 | 53 | if only_extended and only_core: 54 | raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") 55 | 56 | for item in items: 57 | requires_marker = item.get_closest_marker("requires") 58 | if requires_marker is not None: 59 | if only_core: 60 | item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) 61 | continue 62 | 63 | # Iterate through the list of required packages 64 | required_pkgs = requires_marker.args 65 | for pkg in required_pkgs: 66 | # If we haven't yet checked whether the pkg is installed 67 | # let's check it and store the result. 68 | if pkg not in required_pkgs_info: 69 | try: 70 | installed = util.find_spec(pkg) is not None 71 | except Exception: 72 | installed = False 73 | required_pkgs_info[pkg] = installed 74 | 75 | if not required_pkgs_info[pkg]: 76 | if only_extended: 77 | pytest.fail( 78 | f"Package `{pkg}` is not installed but is required for " 79 | f"extended tests. Please install the given package and " 80 | f"try again.", 81 | ) 82 | 83 | else: 84 | # If the package is not installed, we immediately break 85 | # and mark the test as skipped. 86 | item.add_marker( 87 | pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") 88 | ) 89 | break 90 | else: 91 | if only_extended: 92 | item.add_marker( 93 | pytest.mark.skip(reason="Skipping not an extended test.") 94 | ) 95 | 96 | 97 | @pytest.fixture 98 | def logging_conf() -> dict: 99 | return get_config_dict( 100 | "DEBUG", 101 | get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), 102 | 1024*1024, 103 | 1024*1024*1024, 104 | ) 105 | 106 | @pytest.fixture 107 | def test_file_path(request) -> str: 108 | from pathlib import Path 109 | import os 110 | # 当前执行目录 111 | # 获取当前测试文件的路径 112 | test_file_path = Path(str(request.fspath)).parent 113 | print("test_file_path:",test_file_path) 114 | return test_file_path 115 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/demo.jsonl: -------------------------------------------------------------------------------- 1 | {"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]} 2 | {"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]} 3 | {"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]} -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/img/MetaGLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/zhipuai_sdk/tests/integration_tests/img/MetaGLM.png -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/test_charglm3.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | import zhipuai 3 | 4 | import logging 5 | import logging.config 6 | 7 | 8 | def test_completions_charglm(logging_conf): 9 | logging.config.dictConfig(logging_conf) # type: ignore 10 | client = ZhipuAI() # 请填写您自己的APIKey 11 | try: 12 | response = client.chat.completions.create( 13 | model="charglm-3", # 填写需要调用的模型名称 14 | messages=[ 15 | { 16 | "role": "user", 17 | "content": "请问你在做什么" 18 | } 19 | ], 20 | meta={ 21 | "user_info": "我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。", 22 | "bot_info": "苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。", 23 | "bot_name": "苏梦远", 24 | "user_name": "陆星辰" 25 | }, 26 | ) 27 | print(response) 28 | 29 | 30 | except zhipuai.core._errors.APIRequestFailedError as err: 31 | print(err) 32 | except zhipuai.core._errors.APIInternalError as err: 33 | print(err) 34 | except zhipuai.core._errors.APIStatusError as err: 35 | print(err) 36 | 37 | 38 | 39 | def test_async_completions(): 40 | client = ZhipuAI() # 请填写您自己的APIKey 41 | try: 42 | response = client.chat.asyncCompletions.create( 43 | model="charglm", # 填写需要调用的模型名称 44 | messages=[ 45 | { 46 | "role": "user", 47 | "content": "请问你在做什么" 48 | } 49 | ], 50 | meta={ 51 | "user_info": "我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。", 52 | "bot_info": "苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。", 53 | "bot_name": "苏梦远", 54 | "user_name": "陆星辰" 55 | }, 56 | ) 57 | print(response) 58 | 59 | 60 | 61 | except zhipuai.core._errors.APIRequestFailedError as err: 62 | print(err) 63 | except zhipuai.core._errors.APIInternalError as err: 64 | print(err) 65 | except zhipuai.core._errors.APIStatusError as err: 66 | print(err) 67 | 68 | def test_retrieve_completion_result(): 69 | client = ZhipuAI() # 请填写您自己的APIKey 70 | try: 71 | response = client.chat.asyncCompletions.retrieve_completion_result(id="1014908592669352541650991") 72 | print(response) 73 | 74 | 75 | except zhipuai.core._errors.APIRequestFailedError as err: 76 | print(err) 77 | except zhipuai.core._errors.APIInternalError as err: 78 | print(err) 79 | 80 | if __name__ == "__main__": 81 | test_retrieve_completion_result() -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/test_chat.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from zhipuai import ZhipuAI 4 | import zhipuai 5 | import time 6 | 7 | import logging 8 | import logging.config 9 | 10 | 11 | def test_completions_temp0(logging_conf): 12 | logging.config.dictConfig(logging_conf) # type: ignore 13 | client = ZhipuAI() # 填写您自己的APIKey 14 | try: 15 | # 生成request_id 16 | request_id = time.time() 17 | print(f"request_id:{request_id}") 18 | response = client.chat.completions.create( 19 | request_id=request_id, 20 | model="glm-4", 21 | messages=[ 22 | { 23 | "role": "user", 24 | "content": "tell me a joke" 25 | } 26 | ], 27 | top_p=0.7, 28 | temperature=0, 29 | max_tokens=2000, 30 | ) 31 | print(response) 32 | 33 | except zhipuai.core._errors.APIRequestFailedError as err: 34 | print(err) 35 | except zhipuai.core._errors.APIInternalError as err: 36 | print(err) 37 | except zhipuai.core._errors.APIStatusError as err: 38 | print(err) 39 | 40 | 41 | def test_completions_temp1(logging_conf): 42 | logging.config.dictConfig(logging_conf) # type: ignore 43 | client = ZhipuAI() # 填写您自己的APIKey 44 | try: 45 | # 生成request_id 46 | request_id = time.time() 47 | print(f"request_id:{request_id}") 48 | response = client.chat.completions.create( 49 | request_id=request_id, 50 | model="glm-4", 51 | messages=[ 52 | { 53 | "role": "user", 54 | "content": "tell me a joke" 55 | } 56 | ], 57 | top_p=0.7, 58 | temperature=1, 59 | max_tokens=2000, 60 | ) 61 | print(response) 62 | 63 | 64 | 65 | except zhipuai.core._errors.APIRequestFailedError as err: 66 | print(err) 67 | except zhipuai.core._errors.APIInternalError as err: 68 | print(err) 69 | except zhipuai.core._errors.APIStatusError as err: 70 | print(err) 71 | 72 | 73 | def test_completions_top0(logging_conf): 74 | logging.config.dictConfig(logging_conf) # type: ignore 75 | client = ZhipuAI() # 填写您自己的APIKey 76 | try: 77 | # 生成request_id 78 | request_id = time.time() 79 | print(f"request_id:{request_id}") 80 | response = client.chat.completions.create( 81 | request_id=request_id, 82 | model="glm-4", 83 | messages=[ 84 | { 85 | "role": "user", 86 | "content": "tell me a joke" 87 | } 88 | ], 89 | top_p=0, 90 | temperature=0.9, 91 | max_tokens=2000, 92 | ) 93 | print(response) 94 | 95 | 96 | 97 | except zhipuai.core._errors.APIRequestFailedError as err: 98 | print(err) 99 | except zhipuai.core._errors.APIInternalError as err: 100 | print(err) 101 | except zhipuai.core._errors.APIStatusError as err: 102 | print(err) 103 | 104 | 105 | def test_completions_top1(logging_conf): 106 | logging.config.dictConfig(logging_conf) # type: ignore 107 | client = ZhipuAI() # 填写您自己的APIKey 108 | try: 109 | # 生成request_id 110 | request_id = time.time() 111 | print(f"request_id:{request_id}") 112 | response = client.chat.completions.create( 113 | request_id=request_id, 114 | model="glm-4", 115 | messages=[ 116 | { 117 | "role": "user", 118 | "content": "tell me a joke" 119 | } 120 | ], 121 | top_p=1, 122 | temperature=0.9, 123 | max_tokens=2000, 124 | ) 125 | print(response) 126 | 127 | 128 | 129 | except zhipuai.core._errors.APIRequestFailedError as err: 130 | print(err) 131 | except zhipuai.core._errors.APIInternalError as err: 132 | print(err) 133 | except zhipuai.core._errors.APIStatusError as err: 134 | print(err) 135 | 136 | 137 | def test_completions(logging_conf): 138 | logging.config.dictConfig(logging_conf) # type: ignore 139 | client = ZhipuAI() # 填写您自己的APIKey 140 | try: 141 | # 生成request_id 142 | request_id = time.time() 143 | print(f"request_id:{request_id}") 144 | response = client.chat.completions.create( 145 | request_id=request_id, 146 | model="glm-4", # 填写需要调用的模型名称 147 | messages=[ 148 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 149 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 150 | {"role": "user", "content": "智谱AI开放平台"}, 151 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 152 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 153 | ], 154 | tools=[ 155 | { 156 | "type": "web_search", 157 | "web_search": { 158 | "search_query": "帮我看看清华的升学率", 159 | "search_result": True, 160 | } 161 | } 162 | ], 163 | extra_body={"temperature": 0.5, "max_tokens": 50}, 164 | ) 165 | print(response) 166 | 167 | 168 | 169 | except zhipuai.core._errors.APIRequestFailedError as err: 170 | print(err) 171 | except zhipuai.core._errors.APIInternalError as err: 172 | print(err) 173 | except zhipuai.core._errors.APIStatusError as err: 174 | print(err) 175 | 176 | 177 | def test_completions_disenable_web_search(logging_conf): 178 | logging.config.dictConfig(logging_conf) # type: ignore 179 | client = ZhipuAI() # 填写您自己的APIKey 180 | try: 181 | # 生成request_id 182 | request_id = time.time() 183 | print(f"request_id:{request_id}") 184 | response = client.chat.completions.create( 185 | request_id=request_id, 186 | model="glm-4", # 填写需要调用的模型名称 187 | messages=[ 188 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 189 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 190 | {"role": "user", "content": "智谱AI开放平台"}, 191 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 192 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 193 | ], 194 | tools=[ 195 | { 196 | "type": "web_search", 197 | "web_search": { 198 | "search_query": "帮我看看清华的升学率", 199 | "search_result": True, 200 | "enable": False, 201 | } 202 | } 203 | ], 204 | extra_body={"temperature": 0.5, "max_tokens": 50}, 205 | ) 206 | print(response) 207 | 208 | 209 | 210 | except zhipuai.core._errors.APIRequestFailedError as err: 211 | print(err) 212 | except zhipuai.core._errors.APIInternalError as err: 213 | print(err) 214 | except zhipuai.core._errors.APIStatusError as err: 215 | print(err) 216 | 217 | 218 | def test_completions_enable_web_search(logging_conf): 219 | logging.config.dictConfig(logging_conf) # type: ignore 220 | client = ZhipuAI() # 填写您自己的APIKey 221 | try: 222 | # 生成request_id 223 | request_id = time.time() 224 | print(f"request_id:{request_id}") 225 | response = client.chat.completions.create( 226 | request_id=request_id, 227 | model="glm-4", # 填写需要调用的模型名称 228 | messages=[ 229 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 230 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 231 | {"role": "user", "content": "智谱AI开放平台"}, 232 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 233 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 234 | ], 235 | tools=[ 236 | { 237 | "type": "web_search", 238 | "web_search": { 239 | "search_query": "帮我看看清华的升学率", 240 | "search_result": True, 241 | "enable": True, 242 | } 243 | } 244 | ], 245 | extra_body={"temperature": 0.5, "max_tokens": 50}, 246 | ) 247 | print(response) 248 | 249 | 250 | 251 | except zhipuai.core._errors.APIRequestFailedError as err: 252 | print(err) 253 | except zhipuai.core._errors.APIInternalError as err: 254 | print(err) 255 | except zhipuai.core._errors.APIStatusError as err: 256 | print(err) 257 | 258 | 259 | def test_completions_stream(logging_conf): 260 | logging.config.dictConfig(logging_conf) # type: ignore 261 | client = ZhipuAI() # 填写您自己的APIKey 262 | try: 263 | # 生成request_id 264 | request_id = time.time() 265 | print(f"request_id:{request_id}") 266 | response = client.chat.completions.create( 267 | request_id=request_id, 268 | model="glm-4", # 填写需要调用的模型名称 269 | stream=True, 270 | messages=[ 271 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 272 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 273 | {"role": "user", "content": "智谱AI开放平台"}, 274 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 275 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 276 | ], 277 | extra_body={"temperature": 0.5, "max_tokens": 50}, 278 | ) 279 | for item in response: 280 | print(item) 281 | 282 | 283 | except zhipuai.core._errors.APIRequestFailedError as err: 284 | print(err) 285 | except zhipuai.core._errors.APIInternalError as err: 286 | print(err) 287 | except zhipuai.core._errors.APIStatusError as err: 288 | print(err) 289 | 290 | 291 | # Function to encode the image 292 | def encode_image(image_path): 293 | import base64 294 | with open(image_path, "rb") as image_file: 295 | return base64.b64encode(image_file.read()).decode('utf-8') 296 | 297 | 298 | def test_completions_vis(logging_conf): 299 | logging.config.dictConfig(logging_conf) # type: ignore 300 | client = ZhipuAI() # 填写您自己的APIKey 301 | try: 302 | # 生成request_id 303 | request_id = time.time() 304 | print(f"request_id:{request_id}") 305 | response = client.chat.completions.create( 306 | request_id=request_id, 307 | model="glm-4v", # 填写需要调用的模型名称 308 | extra_body={"temperature": 0.5, "max_tokens": 50}, 309 | messages=[ 310 | { 311 | "role": "user", 312 | "content": [ 313 | { 314 | "type": "text", 315 | "text": "图里有什么" 316 | }, 317 | 318 | { 319 | "type": "image_url", 320 | "image_url": { 321 | "url": "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f" 322 | } 323 | } 324 | ] 325 | } 326 | ] 327 | ) 328 | print(response) 329 | 330 | 331 | except zhipuai.core._errors.APIRequestFailedError as err: 332 | print(err) 333 | except zhipuai.core._errors.APIInternalError as err: 334 | print(err) 335 | except zhipuai.core._errors.APIStatusError as err: 336 | print(err) 337 | 338 | 339 | def test_completions_vis_base64(test_file_path, logging_conf): 340 | logging.config.dictConfig(logging_conf) # type: ignore 341 | client = ZhipuAI() # 填写您自己的APIKey 342 | try: 343 | base64_image = encode_image(os.path.join(test_file_path, "img/MetaGLM.png")) 344 | # 生成request_id 345 | request_id = time.time() 346 | print(f"request_id:{request_id}") 347 | response = client.chat.completions.create( 348 | request_id=request_id, 349 | model="glm-4v", # 填写需要调用的模型名称 350 | extra_body={"temperature": 0.5, "max_tokens": 50}, 351 | messages=[ 352 | { 353 | "role": "user", 354 | "content": [ 355 | { 356 | "type": "text", 357 | "text": "图里有什么" 358 | }, 359 | 360 | # { 361 | # "type": "image_url", 362 | # "image_url": { 363 | # "url": "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f" 364 | # } 365 | # }, 366 | { 367 | "type": "image_url", 368 | "image_url": { 369 | "url": f"data:image/jpeg;base64,{base64_image}" 370 | } 371 | } 372 | ] 373 | } 374 | ] 375 | ) 376 | print(response) 377 | 378 | 379 | except zhipuai.core._errors.APIRequestFailedError as err: 380 | print(err) 381 | except zhipuai.core._errors.APIInternalError as err: 382 | print(err) 383 | except zhipuai.core._errors.APIStatusError as err: 384 | print(err) 385 | 386 | 387 | def test_async_completions(logging_conf): 388 | logging.config.dictConfig(logging_conf) # type: ignore 389 | client = ZhipuAI() # 请填写您自己的APIKey 390 | try: 391 | # 生成request_id 392 | request_id = time.time() 393 | print(f"request_id:{request_id}") 394 | response = client.chat.asyncCompletions.create( 395 | request_id=request_id, 396 | model="glm-4", # 填写需要调用的模型名称 397 | messages=[ 398 | {"role": "user", "content": "作为一名营销专家,请为我的产品创作一个吸引人的slogan"}, 399 | {"role": "assistant", "content": "当然,为了创作一个吸引人的slogan,请告诉我一些关于您产品的信息"}, 400 | {"role": "user", "content": "智谱AI开放平台"}, 401 | {"role": "assistant", "content": "智启未来,谱绘无限一智谱AI,让创新触手可及!"}, 402 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 403 | ], 404 | tools=[ 405 | { 406 | "type": "web_search", 407 | "web_search": { 408 | "search_query": "帮我看看清华的升学率", 409 | "search_result": True, 410 | } 411 | } 412 | ], 413 | extra_body={"temperature": 0.5, "max_tokens": 50}, 414 | ) 415 | print(response) 416 | 417 | 418 | 419 | except zhipuai.core._errors.APIRequestFailedError as err: 420 | print(err) 421 | except zhipuai.core._errors.APIInternalError as err: 422 | print(err) 423 | except zhipuai.core._errors.APIStatusError as err: 424 | print(err) 425 | 426 | 427 | def test_retrieve_completion_result(logging_conf): 428 | logging.config.dictConfig(logging_conf) # type: ignore 429 | client = ZhipuAI() # 请填写您自己的APIKey 430 | try: 431 | response = client.chat.asyncCompletions.retrieve_completion_result(id="1014908592669352541651237") 432 | print(response) 433 | 434 | 435 | except zhipuai.core._errors.APIRequestFailedError as err: 436 | print(err) 437 | except zhipuai.core._errors.APIInternalError as err: 438 | print(err) 439 | except zhipuai.core._errors.APIStatusError as err: 440 | print(err) 441 | 442 | 443 | if __name__ == '__main__': 444 | test_completions_temp0() 445 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | import zhipuai 3 | 4 | import logging 5 | import logging.config 6 | 7 | 8 | def test_embeddings(logging_conf): 9 | logging.config.dictConfig(logging_conf) # type: ignore 10 | 11 | client = ZhipuAI() 12 | try: 13 | response = client.embeddings.create( 14 | model="embedding-2", #填写需要调用的模型名称 15 | input="你好", 16 | extra_body={"model_version": "v1"} 17 | ) 18 | print(response) 19 | 20 | 21 | except zhipuai.core._errors.APIRequestFailedError as err: 22 | print(err) 23 | except zhipuai.core._errors.APIInternalError as err: 24 | print(err) 25 | except zhipuai.core._errors.APIStatusError as err: 26 | print(err) 27 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/test_file.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | import zhipuai 3 | import os 4 | 5 | import logging 6 | import logging.config 7 | 8 | 9 | def test_files(test_file_path, logging_conf): 10 | logging.config.dictConfig(logging_conf) # type: ignore 11 | client = ZhipuAI() 12 | try: 13 | result = client.files.create( 14 | file=open(os.path.join(test_file_path,"demo.jsonl"), "rb"), 15 | purpose="fine-tune" 16 | ) 17 | print(result) 18 | # "file-20240418025911536-6dqgr" 19 | 20 | 21 | except zhipuai.core._errors.APIRequestFailedError as err: 22 | print(err) 23 | except zhipuai.core._errors.APIInternalError as err: 24 | print(err) 25 | except zhipuai.core._errors.APIStatusError as err: 26 | print(err) 27 | 28 | 29 | def test_files_validation(test_file_path, logging_conf): 30 | logging.config.dictConfig(logging_conf) # type: ignore 31 | client = ZhipuAI() 32 | try: 33 | result = client.files.create( 34 | file=open(os.path.join(test_file_path,"demo.jsonl"), "rb"), 35 | purpose="fine-tune" 36 | ) 37 | print(result) 38 | # "file-20240418025931214-c87tj" 39 | 40 | 41 | 42 | except zhipuai.core._errors.APIRequestFailedError as err: 43 | print(err) 44 | except zhipuai.core._errors.APIInternalError as err: 45 | print(err) 46 | except zhipuai.core._errors.APIStatusError as err: 47 | print(err) 48 | 49 | def test_files_list(logging_conf): 50 | logging.config.dictConfig(logging_conf) # type: ignore 51 | client = ZhipuAI() 52 | try: 53 | list = client.files.list() 54 | print(list) 55 | 56 | 57 | 58 | except zhipuai.core._errors.APIRequestFailedError as err: 59 | print(err) 60 | except zhipuai.core._errors.APIInternalError as err: 61 | print(err) 62 | except zhipuai.core._errors.APIStatusError as err: 63 | print(err) 64 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/test_finetuning.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | import zhipuai 3 | 4 | import logging 5 | import logging.config 6 | 7 | 8 | def test_finetuning_create(logging_conf): 9 | logging.config.dictConfig(logging_conf) # type: ignore 10 | client = ZhipuAI() # 请填写您自己的APIKey 11 | try: 12 | job = client.fine_tuning.jobs.create( 13 | model="chatglm3-6b", 14 | training_file="file-20240428021923715-xjng4", # 请填写已成功上传的文件id 15 | validation_file="file-20240428021923715-xjng4", # 请填写已成功上传的文件id 16 | suffix="demo_test", 17 | 18 | ) 19 | job_id = job.id 20 | print(job_id) 21 | fine_tuning_job = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id) 22 | print(fine_tuning_job) 23 | # ftjob-20240418110039323-j8lh2 24 | 25 | 26 | 27 | except zhipuai.core._errors.APIRequestFailedError as err: 28 | print(err) 29 | except zhipuai.core._errors.APIInternalError as err: 30 | print(err) 31 | except zhipuai.core._errors.APIStatusError as err: 32 | print(err) 33 | 34 | def test_finetuning_retrieve(logging_conf): 35 | logging.config.dictConfig(logging_conf) # type: ignore 36 | client = ZhipuAI() # 请填写您自己的APIKey 37 | try: 38 | fine_tuning_job = client.fine_tuning.jobs.retrieve(fine_tuning_job_id="ftjob-20240429112551154-48vq7") 39 | print(fine_tuning_job) 40 | 41 | 42 | 43 | except zhipuai.core._errors.APIRequestFailedError as err: 44 | print(err) 45 | except zhipuai.core._errors.APIInternalError as err: 46 | print(err) 47 | except zhipuai.core._errors.APIStatusError as err: 48 | print(err) 49 | 50 | 51 | def test_finetuning_job_list(logging_conf): 52 | logging.config.dictConfig(logging_conf) # type: ignore 53 | client = ZhipuAI() # 请填写您自己的APIKey 54 | try: 55 | job_list = client.fine_tuning.jobs.list() 56 | 57 | print(job_list) 58 | 59 | 60 | 61 | except zhipuai.core._errors.APIRequestFailedError as err: 62 | print(err) 63 | except zhipuai.core._errors.APIInternalError as err: 64 | print(err) 65 | except zhipuai.core._errors.APIStatusError as err: 66 | print(err) 67 | 68 | def test_finetuning_job_cancel(logging_conf): 69 | logging.config.dictConfig(logging_conf) # type: ignore 70 | client = ZhipuAI() # 请填写您自己的APIKey 71 | try: 72 | cancel = client.fine_tuning.jobs.cancel(fine_tuning_job_id="ftjob-20240429112551154-48vq7") 73 | 74 | print(cancel) 75 | 76 | 77 | 78 | except zhipuai.core._errors.APIRequestFailedError as err: 79 | print(err) 80 | except zhipuai.core._errors.APIInternalError as err: 81 | print(err) 82 | except zhipuai.core._errors.APIStatusError as err: 83 | print(err) 84 | 85 | def test_finetuning_job_delete(logging_conf): 86 | logging.config.dictConfig(logging_conf) # type: ignore 87 | client = ZhipuAI() # 请填写您自己的APIKey 88 | try: 89 | delete = client.fine_tuning.jobs.delete(fine_tuning_job_id="ftjob-20240126113041678-cs6s9") 90 | 91 | print(delete) 92 | 93 | 94 | 95 | except zhipuai.core._errors.APIRequestFailedError as err: 96 | print(err) 97 | except zhipuai.core._errors.APIInternalError as err: 98 | print(err) 99 | except zhipuai.core._errors.APIStatusError as err: 100 | print(err) 101 | 102 | def test_model_check(logging_conf): 103 | logging.config.dictConfig(logging_conf) # type: ignore 104 | client = ZhipuAI() # 填写您自己的APIKey 105 | try: 106 | response = client.chat.completions.create( 107 | model="chatglm3-6b-8572905046912426020-demo_test", # 填写需要调用的模型名称 108 | messages=[ 109 | {"role": "user", "content": "你是一位乐于助人,知识渊博的全能AI助手。"}, 110 | {"role": "user", "content": "创造一个更精准、吸引人的slogan"} 111 | ], 112 | extra_body={"temperature": 0.5, "max_tokens": 50}, 113 | ) 114 | print(response.choices[0].message) 115 | 116 | 117 | 118 | except zhipuai.core._errors.APIRequestFailedError as err: 119 | print(err) 120 | except zhipuai.core._errors.APIInternalError as err: 121 | print(err) 122 | except zhipuai.core._errors.APIStatusError as err: 123 | print(err) 124 | 125 | if __name__ == "__main__": 126 | test_finetuning_create() -------------------------------------------------------------------------------- /zhipuai_sdk/tests/integration_tests/test_images.py: -------------------------------------------------------------------------------- 1 | from zhipuai import ZhipuAI 2 | import zhipuai 3 | 4 | import logging 5 | import logging.config 6 | 7 | 8 | def test_images(logging_conf): 9 | logging.config.dictConfig(logging_conf) # type: ignore 10 | client = ZhipuAI() # 填写您自己的APIKey 11 | try: 12 | response = client.images.generations( 13 | model="cogview-3", #填写需要调用的模型名称 14 | prompt="一只可爱的小猫咪", 15 | extra_body={"user_id": "1222212"}, 16 | ) 17 | print(response) 18 | 19 | 20 | 21 | except zhipuai.core._errors.APIRequestFailedError as err: 22 | print(err) 23 | except zhipuai.core._errors.APIInternalError as err: 24 | print(err) 25 | except zhipuai.core._errors.APIStatusError as err: 26 | print(err) 27 | -------------------------------------------------------------------------------- /zhipuai_sdk/tests/unit_tests/test_sdk_import.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def test_sdk_import_unit(): 4 | import zhipuai 5 | 6 | print(zhipuai.__version__) 7 | 8 | 9 | def test_os_import_unit(): 10 | import os 11 | print(os) 12 | 13 | def test_sdk_import(): 14 | from zhipuai import ZhipuAI 15 | 16 | client = ZhipuAI() # 请填写您自己的APIKey 17 | 18 | if client is not None: 19 | print("SDK导入成功") 20 | else: 21 | print("SDK导入失败") 22 | 23 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._client import ZhipuAI 3 | 4 | from .core import ( 5 | ZhipuAIError, 6 | APIStatusError, 7 | APIRequestFailedError, 8 | APIAuthenticationError, 9 | APIReachLimitError, 10 | APIInternalError, 11 | APIServerFlowExceedError, 12 | APIResponseError, 13 | APIResponseValidationError, 14 | APIConnectionError, 15 | APITimeoutError, 16 | ) 17 | 18 | from .__version__ import __version__ 19 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/__version__.py: -------------------------------------------------------------------------------- 1 | 2 | __version__ = 'v2.0.1' -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/_client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, Mapping 4 | 5 | from typing_extensions import override 6 | 7 | from .core import _jwt_token 8 | from .core import ZhipuAIError 9 | from .core import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES 10 | from .core import NotGiven, NOT_GIVEN 11 | from . import api_resource 12 | import os 13 | import httpx 14 | from httpx import Timeout 15 | 16 | 17 | class ZhipuAI(HttpClient): 18 | chat: api_resource.chat.Chat 19 | api_key: str 20 | _disable_token_cache: bool = True 21 | 22 | def __init__( 23 | self, 24 | *, 25 | api_key: str | None = None, 26 | base_url: str | httpx.URL | None = None, 27 | timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, 28 | max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, 29 | http_client: httpx.Client | None = None, 30 | custom_headers: Mapping[str, str] | None = None, 31 | disable_token_cache: bool = True 32 | ) -> None: 33 | if api_key is None: 34 | api_key = os.environ.get("ZHIPUAI_API_KEY") 35 | if api_key is None: 36 | raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") 37 | self.api_key = api_key 38 | self._disable_token_cache = disable_token_cache 39 | 40 | if base_url is None: 41 | base_url = os.environ.get("ZHIPUAI_BASE_URL") 42 | if base_url is None: 43 | base_url = f"https://open.bigmodel.cn/api/paas/v4" 44 | from .__version__ import __version__ 45 | super().__init__( 46 | version=__version__, 47 | base_url=base_url, 48 | max_retries=max_retries, 49 | timeout=timeout, 50 | custom_httpx_client=http_client, 51 | custom_headers=custom_headers, 52 | ) 53 | self.chat = api_resource.chat.Chat(self) 54 | self.images = api_resource.images.Images(self) 55 | self.embeddings = api_resource.embeddings.Embeddings(self) 56 | self.files = api_resource.files.Files(self) 57 | self.fine_tuning = api_resource.fine_tuning.FineTuning(self) 58 | 59 | @property 60 | @override 61 | def auth_headers(self) -> dict[str, str]: 62 | api_key = self.api_key 63 | if self._disable_token_cache: 64 | return {"Authorization": f"{api_key}"} 65 | else: 66 | return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} 67 | 68 | def __del__(self) -> None: 69 | if (not hasattr(self, "_has_custom_http_client") 70 | or not hasattr(self, "close") 71 | or not hasattr(self, "_client")): 72 | # if the '__init__' method raised an error, self would not have client attr 73 | return 74 | 75 | if self._has_custom_http_client: 76 | return 77 | 78 | self.close() 79 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/__init__.py: -------------------------------------------------------------------------------- 1 | from .chat import ( 2 | AsyncCompletions, 3 | Chat, 4 | Completions, 5 | ) 6 | from .images import ( 7 | Images 8 | ) 9 | from .embeddings import ( 10 | Embeddings 11 | ) 12 | from .files import ( 13 | Files 14 | ) 15 | from .fine_tuning import ( 16 | FineTuning 17 | ) 18 | 19 | __all__ = [ 20 | 'AsyncCompletions', 21 | 'Chat', 22 | 'Completions', 23 | 'Images', 24 | 'Embeddings', 25 | 'Files', 26 | 'FineTuning' 27 | 28 | ] -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from .async_completions import ( 2 | AsyncCompletions 3 | ) 4 | 5 | from .chat import ( 6 | Chat 7 | ) 8 | 9 | from .completions import ( 10 | Completions 11 | ) 12 | 13 | __all__ = [ 14 | 'AsyncCompletions' 15 | 'Chat' 16 | 'Completions' 17 | ] -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/chat/async_completions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, List, Optional, TYPE_CHECKING 4 | 5 | import httpx 6 | from typing_extensions import Literal 7 | 8 | from ...core import BaseAPI 9 | from ...core import NotGiven, NOT_GIVEN, Headers, Body 10 | from ...core import make_request_options 11 | from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion 12 | 13 | if TYPE_CHECKING: 14 | from ..._client import ZhipuAI 15 | 16 | 17 | class AsyncCompletions(BaseAPI): 18 | def __init__(self, client: "ZhipuAI") -> None: 19 | super().__init__(client) 20 | 21 | 22 | def create( 23 | self, 24 | *, 25 | model: str, 26 | request_id: Optional[str] | NotGiven = NOT_GIVEN, 27 | do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, 28 | temperature: Optional[float] | NotGiven = NOT_GIVEN, 29 | top_p: Optional[float] | NotGiven = NOT_GIVEN, 30 | max_tokens: int | NotGiven = NOT_GIVEN, 31 | seed: int | NotGiven = NOT_GIVEN, 32 | messages: Union[str, List[str], List[int], List[List[int]], None], 33 | stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, 34 | sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, 35 | tools: Optional[object] | NotGiven = NOT_GIVEN, 36 | tool_choice: str | NotGiven = NOT_GIVEN, 37 | meta: Optional[Dict[str,str]] | NotGiven = NOT_GIVEN, 38 | extra_headers: Headers | None = None, 39 | extra_body: Body | None = None, 40 | disable_strict_validation: Optional[bool] | None = None, 41 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 42 | ) -> AsyncTaskStatus: 43 | _cast_type = AsyncTaskStatus 44 | 45 | if disable_strict_validation: 46 | _cast_type = object 47 | return self._post( 48 | "/async/chat/completions", 49 | body={ 50 | "model": model, 51 | "request_id": request_id, 52 | "temperature": temperature, 53 | "top_p": top_p, 54 | "do_sample": do_sample, 55 | "max_tokens": max_tokens, 56 | "seed": seed, 57 | "messages": messages, 58 | "stop": stop, 59 | "sensitive_word_check": sensitive_word_check, 60 | "tools": tools, 61 | "tool_choice": tool_choice, 62 | "meta": meta, 63 | }, 64 | options=make_request_options( 65 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 66 | ), 67 | cast_type=_cast_type, 68 | stream=False, 69 | ) 70 | 71 | def retrieve_completion_result( 72 | self, 73 | id: str, 74 | extra_headers: Headers | None = None, 75 | extra_body: Body | None = None, 76 | disable_strict_validation: Optional[bool] | None = None, 77 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 78 | ) -> Union[AsyncCompletion, AsyncTaskStatus]: 79 | _cast_type = Union[AsyncCompletion,AsyncTaskStatus] 80 | if disable_strict_validation: 81 | _cast_type = object 82 | return self._get( 83 | path=f"/async-result/{id}", 84 | cast_type=_cast_type, 85 | options=make_request_options( 86 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 87 | ), 88 | ) 89 | 90 | 91 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/chat/chat.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from .completions import Completions 3 | from .async_completions import AsyncCompletions 4 | from ...core import BaseAPI, cached_property 5 | 6 | if TYPE_CHECKING: 7 | from ..._client import ZhipuAI 8 | 9 | 10 | class Chat(BaseAPI): 11 | 12 | @cached_property 13 | def completions(self) -> Completions: 14 | return Completions(self._client) 15 | 16 | @cached_property 17 | def asyncCompletions(self) -> AsyncCompletions: 18 | return AsyncCompletions(self._client) -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/chat/completions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, List, Optional, TYPE_CHECKING 4 | 5 | import httpx 6 | import logging 7 | from typing_extensions import Literal 8 | 9 | from ...core import BaseAPI 10 | from ...core import NotGiven, NOT_GIVEN, Headers, Query, Body 11 | from ...core import make_request_options 12 | from ...core import StreamResponse 13 | from ...types.chat.chat_completion import Completion 14 | from ...types.chat.chat_completion_chunk import ChatCompletionChunk 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | if TYPE_CHECKING: 19 | from ..._client import ZhipuAI 20 | 21 | 22 | class Completions(BaseAPI): 23 | def __init__(self, client: "ZhipuAI") -> None: 24 | super().__init__(client) 25 | 26 | def create( 27 | self, 28 | *, 29 | model: str, 30 | request_id: Optional[str] | NotGiven = NOT_GIVEN, 31 | do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, 32 | stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, 33 | temperature: Optional[float] | NotGiven = NOT_GIVEN, 34 | top_p: Optional[float] | NotGiven = NOT_GIVEN, 35 | max_tokens: int | NotGiven = NOT_GIVEN, 36 | seed: int | NotGiven = NOT_GIVEN, 37 | messages: Union[str, List[str], List[int], object, None], 38 | stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, 39 | sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, 40 | tools: Optional[object] | NotGiven = NOT_GIVEN, 41 | tool_choice: str | NotGiven = NOT_GIVEN, 42 | meta: Optional[Dict[str,str]] | NotGiven = NOT_GIVEN, 43 | extra_headers: Headers | None = None, 44 | extra_body: Body | None = None, 45 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 46 | ) -> Completion | StreamResponse[ChatCompletionChunk]: 47 | logger.info(f"temperature:{temperature}, top_p:{top_p}") 48 | if temperature is not None and temperature != NOT_GIVEN: 49 | 50 | if temperature <= 0: 51 | do_sample = False 52 | temperature = 0.01 53 | logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)") 54 | if temperature >= 1: 55 | do_sample = False 56 | temperature = 0.99 57 | logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)") 58 | if top_p is not None and top_p != NOT_GIVEN: 59 | 60 | if top_p >= 1: 61 | top_p = 0.99 62 | logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") 63 | if top_p <= 0: 64 | top_p = 0.01 65 | logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") 66 | 67 | logger.info(f"temperature:{temperature}, top_p:{top_p}") 68 | if isinstance(messages, List): 69 | for item in messages: 70 | if item.get('content'): 71 | item['content'] = self._drop_prefix_image_data(item['content']) 72 | 73 | return self._post( 74 | "/chat/completions", 75 | body={ 76 | "model": model, 77 | "request_id": request_id, 78 | "temperature": temperature, 79 | "top_p": top_p, 80 | "do_sample": do_sample, 81 | "max_tokens": max_tokens, 82 | "seed": seed, 83 | "messages": messages, 84 | "stop": stop, 85 | "sensitive_word_check": sensitive_word_check, 86 | "stream": stream, 87 | "tools": tools, 88 | "tool_choice": tool_choice, 89 | "meta": meta, 90 | }, 91 | options=make_request_options( 92 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 93 | ), 94 | cast_type=Completion, 95 | stream=stream or False, 96 | stream_cls=StreamResponse[ChatCompletionChunk], 97 | ) 98 | 99 | def _drop_prefix_image_data(self, content: Union[str,List[dict]]) -> Union[str,List[dict]]: 100 | """ 101 | 删除 ;base64, 前缀 102 | :param image_data: 103 | :return: 104 | """ 105 | if isinstance(content, List): 106 | for data in content: 107 | if data.get('type') == 'image_url': 108 | image_data = data.get("image_url").get("url") 109 | if image_data.startswith("data:image/"): 110 | image_data = image_data.split("base64,")[-1] 111 | data["image_url"]["url"] = image_data 112 | 113 | return content 114 | 115 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, List, Optional, TYPE_CHECKING 4 | 5 | import httpx 6 | 7 | from ..core import BaseAPI 8 | from ..core import NotGiven, NOT_GIVEN, Headers 9 | from ..core import make_request_options 10 | from ..types.embeddings import EmbeddingsResponded 11 | 12 | if TYPE_CHECKING: 13 | from .._client import ZhipuAI 14 | 15 | 16 | class Embeddings(BaseAPI): 17 | def __init__(self, client: "ZhipuAI") -> None: 18 | super().__init__(client) 19 | 20 | def create( 21 | self, 22 | *, 23 | input: Union[str, List[str], List[int], List[List[int]]], 24 | model: Union[str], 25 | encoding_format: str | NotGiven = NOT_GIVEN, 26 | user: str | NotGiven = NOT_GIVEN, 27 | request_id: Optional[str] | NotGiven = NOT_GIVEN, 28 | sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, 29 | extra_headers: Headers | None = None, 30 | extra_body: Body | None = None, 31 | disable_strict_validation: Optional[bool] | None = None, 32 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 33 | ) -> EmbeddingsResponded: 34 | _cast_type = EmbeddingsResponded 35 | if disable_strict_validation: 36 | _cast_type = object 37 | return self._post( 38 | "/embeddings", 39 | body={ 40 | "input": input, 41 | "model": model, 42 | "encoding_format": encoding_format, 43 | "user": user, 44 | "request_id": request_id, 45 | "sensitive_word_check": sensitive_word_check, 46 | }, 47 | options=make_request_options( 48 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 49 | ), 50 | cast_type=_cast_type, 51 | stream=False, 52 | ) 53 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/files.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import httpx 6 | 7 | from ..core import BaseAPI 8 | from ..core import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes 9 | from ..core import is_file_content 10 | from ..core import ( 11 | make_request_options, 12 | ) 13 | from ..types.file_object import FileObject, ListOfFileObject 14 | 15 | if TYPE_CHECKING: 16 | from .._client import ZhipuAI 17 | 18 | __all__ = ["Files"] 19 | 20 | 21 | class Files(BaseAPI): 22 | 23 | def __init__(self, client: "ZhipuAI") -> None: 24 | super().__init__(client) 25 | 26 | def create( 27 | self, 28 | *, 29 | file: FileTypes, 30 | purpose: str, 31 | extra_headers: Headers | None = None, 32 | extra_body: Body | None = None, 33 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 34 | ) -> FileObject: 35 | if not is_file_content(file): 36 | prefix = f"Expected file input `{file!r}`" 37 | raise RuntimeError( 38 | f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead." 39 | ) from None 40 | files = [("file", file)] 41 | 42 | extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} 43 | 44 | return self._post( 45 | "/files", 46 | body={ 47 | "purpose": purpose, 48 | }, 49 | files=files, 50 | options=make_request_options( 51 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 52 | ), 53 | cast_type=FileObject, 54 | ) 55 | 56 | def list( 57 | self, 58 | *, 59 | purpose: str | NotGiven = NOT_GIVEN, 60 | limit: int | NotGiven = NOT_GIVEN, 61 | after: str | NotGiven = NOT_GIVEN, 62 | order: str | NotGiven = NOT_GIVEN, 63 | extra_headers: Headers | None = None, 64 | extra_body: Body | None = None, 65 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 66 | ) -> ListOfFileObject: 67 | return self._get( 68 | "/files", 69 | cast_type=ListOfFileObject, 70 | options=make_request_options( 71 | extra_headers=extra_headers, 72 | extra_body=extra_body, 73 | timeout=timeout, 74 | query={ 75 | "purpose": purpose, 76 | "limit": limit, 77 | "after": after, 78 | "order": order, 79 | }, 80 | ), 81 | ) 82 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/fine_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | from .jobs import ( 2 | Jobs 3 | ) 4 | 5 | from .models import ( 6 | FineTunedModels 7 | ) 8 | 9 | from .fine_tuning import ( 10 | FineTuning 11 | ) 12 | 13 | 14 | __all__ = [ 15 | 'Jobs', 16 | 'FineTunedModels', 17 | 'FineTuning' 18 | ] -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/fine_tuning/fine_tuning.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from .jobs import Jobs 3 | from .models import FineTunedModels 4 | from ...core import BaseAPI, cached_property 5 | 6 | if TYPE_CHECKING: 7 | from ..._client import ZhipuAI 8 | 9 | 10 | class FineTuning(BaseAPI): 11 | 12 | @cached_property 13 | def jobs(self) -> Jobs: 14 | return Jobs(self._client) 15 | 16 | @cached_property 17 | def models(self) -> FineTunedModels: 18 | return FineTunedModels(self._client) 19 | 20 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/fine_tuning/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | from .jobs import Jobs 2 | 3 | __all__ = [ 4 | "Jobs" 5 | ] -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/fine_tuning/jobs/jobs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, TYPE_CHECKING 4 | 5 | import httpx 6 | 7 | from ....core import BaseAPI 8 | from ....core import NOT_GIVEN, Headers, NotGiven, Body 9 | from ....core import ( 10 | make_request_options, 11 | ) 12 | from ....types.fine_tuning import ( 13 | FineTuningJob, 14 | job_create_params, 15 | ListOfFineTuningJob, 16 | FineTuningJobEvent, 17 | ) 18 | 19 | if TYPE_CHECKING: 20 | from ...._client import ZhipuAI 21 | 22 | __all__ = ["Jobs"] 23 | 24 | 25 | class Jobs(BaseAPI): 26 | 27 | def __init__(self, client: "ZhipuAI") -> None: 28 | super().__init__(client) 29 | 30 | def create( 31 | self, 32 | *, 33 | model: str, 34 | training_file: str, 35 | hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, 36 | suffix: Optional[str] | NotGiven = NOT_GIVEN, 37 | request_id: Optional[str] | NotGiven = NOT_GIVEN, 38 | validation_file: Optional[str] | NotGiven = NOT_GIVEN, 39 | extra_headers: Headers | None = None, 40 | extra_body: Body | None = None, 41 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 42 | ) -> FineTuningJob: 43 | return self._post( 44 | "/fine_tuning/jobs", 45 | body={ 46 | "model": model, 47 | "training_file": training_file, 48 | "hyperparameters": hyperparameters, 49 | "suffix": suffix, 50 | "validation_file": validation_file, 51 | "request_id": request_id, 52 | }, 53 | options=make_request_options( 54 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 55 | ), 56 | cast_type=FineTuningJob, 57 | ) 58 | 59 | def retrieve( 60 | self, 61 | fine_tuning_job_id: str, 62 | *, 63 | extra_headers: Headers | None = None, 64 | extra_body: Body | None = None, 65 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 66 | ) -> FineTuningJob: 67 | return self._get( 68 | f"/fine_tuning/jobs/{fine_tuning_job_id}", 69 | options=make_request_options( 70 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 71 | ), 72 | cast_type=FineTuningJob, 73 | ) 74 | 75 | def list( 76 | self, 77 | *, 78 | after: str | NotGiven = NOT_GIVEN, 79 | limit: int | NotGiven = NOT_GIVEN, 80 | extra_headers: Headers | None = None, 81 | extra_body: Body | None = None, 82 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 83 | ) -> ListOfFineTuningJob: 84 | return self._get( 85 | "/fine_tuning/jobs", 86 | cast_type=ListOfFineTuningJob, 87 | options=make_request_options( 88 | extra_headers=extra_headers, 89 | extra_body=extra_body, 90 | timeout=timeout, 91 | query={ 92 | "after": after, 93 | "limit": limit, 94 | }, 95 | ), 96 | ) 97 | 98 | def cancel( 99 | self, 100 | fine_tuning_job_id: str, 101 | *, 102 | # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. 103 | # The extra values given here take precedence over values defined on the client or passed to this method. 104 | extra_headers: Headers | None = None, 105 | extra_query: Query | None = None, 106 | extra_body: Body | None = None, 107 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 108 | ) -> FineTuningJob: 109 | 110 | if not fine_tuning_job_id: 111 | raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") 112 | return self._post( 113 | f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel", 114 | options=make_request_options( 115 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 116 | ), 117 | cast_type=FineTuningJob, 118 | ) 119 | 120 | def list_events( 121 | self, 122 | fine_tuning_job_id: str, 123 | *, 124 | after: str | NotGiven = NOT_GIVEN, 125 | limit: int | NotGiven = NOT_GIVEN, 126 | extra_headers: Headers | None = None, 127 | extra_body: Body | None = None, 128 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 129 | ) -> FineTuningJobEvent: 130 | 131 | return self._get( 132 | f"/fine_tuning/jobs/{fine_tuning_job_id}/events", 133 | cast_type=FineTuningJobEvent, 134 | options=make_request_options( 135 | extra_headers=extra_headers, 136 | extra_body=extra_body, 137 | timeout=timeout, 138 | query={ 139 | "after": after, 140 | "limit": limit, 141 | }, 142 | ), 143 | ) 144 | 145 | def delete( 146 | self, 147 | fine_tuning_job_id: str, 148 | *, 149 | extra_headers: Headers | None = None, 150 | extra_query: Query | None = None, 151 | extra_body: Body | None = None, 152 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 153 | ) -> FineTuningJob: 154 | 155 | if not fine_tuning_job_id: 156 | raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") 157 | return self._delete( 158 | f"/fine_tuning/jobs/{fine_tuning_job_id}", 159 | options=make_request_options( 160 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 161 | ), 162 | cast_type=FineTuningJob, 163 | ) -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/fine_tuning/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fine_tuned_models import FineTunedModels 2 | 3 | __all__ = ['FineTunedModels'] -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/fine_tuning/models/fine_tuned_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, TYPE_CHECKING 4 | 5 | import httpx 6 | 7 | from ....core import BaseAPI 8 | from ....core import NOT_GIVEN, Headers, NotGiven, Body 9 | from ....core import ( 10 | make_request_options, 11 | ) 12 | 13 | from ....types.fine_tuning.models import ( 14 | FineTunedModelsStatus 15 | ) 16 | 17 | if TYPE_CHECKING: 18 | from ...._client import ZhipuAI 19 | 20 | __all__ = ["FineTunedModels"] 21 | 22 | 23 | class FineTunedModels(BaseAPI): 24 | 25 | def __init__(self, client: "ZhipuAI") -> None: 26 | super().__init__(client) 27 | 28 | def delete( 29 | self, 30 | fine_tuned_model: str, 31 | *, 32 | extra_headers: Headers | None = None, 33 | extra_query: Query | None = None, 34 | extra_body: Body | None = None, 35 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 36 | ) -> FineTunedModelsStatus: 37 | 38 | if not fine_tuning_job_id: 39 | raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") 40 | return self._delete( 41 | f"fine_tuning/fine_tuned_models/{fine_tuned_model}", 42 | options=make_request_options( 43 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 44 | ), 45 | cast_type=FineTunedModelsStatus, 46 | ) -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/api_resource/images.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, List, Optional, TYPE_CHECKING 4 | 5 | import httpx 6 | 7 | from ..core import BaseAPI 8 | from ..core import NotGiven, NOT_GIVEN, Headers, Body 9 | from ..core import make_request_options 10 | from ..types.image import ImagesResponded 11 | 12 | if TYPE_CHECKING: 13 | from .._client import ZhipuAI 14 | 15 | 16 | class Images(BaseAPI): 17 | def __init__(self, client: "ZhipuAI") -> None: 18 | super().__init__(client) 19 | 20 | def generations( 21 | self, 22 | *, 23 | prompt: str, 24 | model: str | NotGiven = NOT_GIVEN, 25 | n: Optional[int] | NotGiven = NOT_GIVEN, 26 | quality: Optional[str] | NotGiven = NOT_GIVEN, 27 | response_format: Optional[str] | NotGiven = NOT_GIVEN, 28 | size: Optional[str] | NotGiven = NOT_GIVEN, 29 | style: Optional[str] | NotGiven = NOT_GIVEN, 30 | user: str | NotGiven = NOT_GIVEN, 31 | request_id: Optional[str] | NotGiven = NOT_GIVEN, 32 | extra_headers: Headers | None = None, 33 | extra_body: Body | None = None, 34 | disable_strict_validation: Optional[bool] | None = None, 35 | timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 36 | ) -> ImagesResponded: 37 | _cast_type = ImagesResponded 38 | if disable_strict_validation: 39 | _cast_type = object 40 | return self._post( 41 | "/images/generations", 42 | body={ 43 | "prompt": prompt, 44 | "model": model, 45 | "n": n, 46 | "quality": quality, 47 | "response_format": response_format, 48 | "size": size, 49 | "style": style, 50 | "user": user, 51 | "request_id": request_id, 52 | }, 53 | options=make_request_options( 54 | extra_headers=extra_headers, extra_body=extra_body, timeout=timeout 55 | ), 56 | cast_type=_cast_type, 57 | stream=False, 58 | ) 59 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base_models import ( 2 | BaseModel, 3 | construct_type 4 | ) 5 | from ._base_api import BaseAPI 6 | from ._base_type import ( 7 | NOT_GIVEN, 8 | Headers, 9 | NotGiven, 10 | Body, 11 | IncEx, 12 | ModelT, 13 | Query, 14 | FileTypes, 15 | 16 | ) 17 | from ._base_compat import ( 18 | PYDANTIC_V2, 19 | ConfigDict, 20 | GenericModel, 21 | get_args, 22 | is_union, 23 | parse_obj, 24 | get_origin, 25 | is_literal_type, 26 | get_model_config, 27 | get_model_fields, 28 | field_get_default, 29 | cached_property, 30 | ) 31 | from ._files import is_file_content 32 | from ._errors import ( 33 | ZhipuAIError, 34 | APIStatusError, 35 | APIRequestFailedError, 36 | APIAuthenticationError, 37 | APIReachLimitError, 38 | APIInternalError, 39 | APIServerFlowExceedError, 40 | APIResponseError, 41 | APIResponseValidationError, 42 | APIConnectionError, 43 | APITimeoutError, 44 | ) 45 | from ._http_client import ( 46 | make_request_options, 47 | HttpClient, 48 | ZHIPUAI_DEFAULT_TIMEOUT, 49 | ZHIPUAI_DEFAULT_MAX_RETRIES, 50 | ZHIPUAI_DEFAULT_LIMITS, 51 | 52 | ) 53 | from ._utils import (is_list, is_mapping, parse_date, parse_datetime, is_given) 54 | 55 | from ._sse_client import StreamResponse 56 | 57 | __all__ = [ 58 | "BaseModel", 59 | "construct_type", 60 | "BaseAPI", 61 | "NOT_GIVEN", 62 | "Headers", 63 | "NotGiven", 64 | "Body", 65 | "IncEx", 66 | "ModelT", 67 | "Query", 68 | "FileTypes", 69 | 70 | "PYDANTIC_V2", 71 | "ConfigDict", 72 | "GenericModel", 73 | "get_args", 74 | "is_union", 75 | "parse_obj", 76 | "get_origin", 77 | "is_literal_type", 78 | "get_model_config", 79 | "get_model_fields", 80 | "field_get_default", 81 | 82 | "is_file_content", 83 | 84 | "ZhipuAIError", 85 | "APIStatusError", 86 | "APIRequestFailedError", 87 | "APIAuthenticationError", 88 | "APIReachLimitError", 89 | "APIInternalError", 90 | "APIServerFlowExceedError", 91 | "APIResponseError", 92 | "APIResponseValidationError", 93 | "APITimeoutError", 94 | 95 | "make_request_options", 96 | "HttpClient", 97 | "ZHIPUAI_DEFAULT_TIMEOUT", 98 | "ZHIPUAI_DEFAULT_MAX_RETRIES", 99 | "ZHIPUAI_DEFAULT_LIMITS", 100 | 101 | "is_list", 102 | "is_mapping", 103 | "parse_date", 104 | "parse_datetime", 105 | "is_given", 106 | 107 | "StreamResponse", 108 | 109 | ] 110 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_base_api.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from .._client import ZhipuAI 6 | 7 | 8 | class BaseAPI: 9 | _client: ZhipuAI 10 | 11 | def __init__(self, client: ZhipuAI) -> None: 12 | self._client = client 13 | self._delete = client.delete 14 | self._get = client.get 15 | self._post = client.post 16 | self._put = client.put 17 | self._patch = client.patch 18 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_base_compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload 4 | from datetime import date, datetime 5 | from typing_extensions import Self 6 | 7 | import pydantic 8 | from pydantic.fields import FieldInfo 9 | 10 | from zhipuai.core._base_type import StrBytesIntFloat 11 | 12 | _T = TypeVar("_T") 13 | _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) 14 | 15 | # --------------- Pydantic v2 compatibility --------------- 16 | 17 | # Pyright incorrectly reports some of our functions as overriding a method when they don't 18 | # pyright: reportIncompatibleMethodOverride=false 19 | 20 | PYDANTIC_V2 = pydantic.VERSION.startswith("2.") 21 | 22 | # v1 re-exports 23 | if TYPE_CHECKING: 24 | 25 | def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 26 | ... 27 | 28 | def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001 29 | ... 30 | 31 | def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001 32 | ... 33 | 34 | def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001 35 | ... 36 | 37 | def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001 38 | ... 39 | 40 | def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001 41 | ... 42 | 43 | def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 44 | ... 45 | 46 | else: 47 | if PYDANTIC_V2: 48 | from pydantic.v1.typing import ( 49 | get_args as get_args, 50 | is_union as is_union, 51 | get_origin as get_origin, 52 | is_typeddict as is_typeddict, 53 | is_literal_type as is_literal_type, 54 | ) 55 | from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime 56 | else: 57 | from pydantic.typing import ( 58 | get_args as get_args, 59 | is_union as is_union, 60 | get_origin as get_origin, 61 | is_typeddict as is_typeddict, 62 | is_literal_type as is_literal_type, 63 | ) 64 | from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime 65 | 66 | 67 | # refactored config 68 | if TYPE_CHECKING: 69 | from pydantic import ConfigDict as ConfigDict 70 | else: 71 | if PYDANTIC_V2: 72 | from pydantic import ConfigDict 73 | else: 74 | # TODO: provide an error message here? 75 | ConfigDict = None 76 | 77 | 78 | # renamed methods / properties 79 | def parse_obj(model: type[_ModelT], value: object) -> _ModelT: 80 | if PYDANTIC_V2: 81 | return model.model_validate(value) 82 | else: 83 | return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] 84 | 85 | 86 | def field_is_required(field: FieldInfo) -> bool: 87 | if PYDANTIC_V2: 88 | return field.is_required() 89 | return field.required # type: ignore 90 | 91 | 92 | def field_get_default(field: FieldInfo) -> Any: 93 | value = field.get_default() 94 | if PYDANTIC_V2: 95 | from pydantic_core import PydanticUndefined 96 | 97 | if value == PydanticUndefined: 98 | return None 99 | return value 100 | return value 101 | 102 | 103 | def field_outer_type(field: FieldInfo) -> Any: 104 | if PYDANTIC_V2: 105 | return field.annotation 106 | return field.outer_type_ # type: ignore 107 | 108 | 109 | def get_model_config(model: type[pydantic.BaseModel]) -> Any: 110 | if PYDANTIC_V2: 111 | return model.model_config 112 | return model.__config__ # type: ignore 113 | 114 | 115 | def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: 116 | if PYDANTIC_V2: 117 | return model.model_fields 118 | return model.__fields__ # type: ignore 119 | 120 | 121 | def model_copy(model: _ModelT) -> _ModelT: 122 | if PYDANTIC_V2: 123 | return model.model_copy() 124 | return model.copy() # type: ignore 125 | 126 | 127 | def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: 128 | if PYDANTIC_V2: 129 | return model.model_dump_json(indent=indent) 130 | return model.json(indent=indent) # type: ignore 131 | 132 | 133 | def model_dump( 134 | model: pydantic.BaseModel, 135 | *, 136 | exclude_unset: bool = False, 137 | exclude_defaults: bool = False, 138 | ) -> dict[str, Any]: 139 | if PYDANTIC_V2: 140 | return model.model_dump( 141 | exclude_unset=exclude_unset, 142 | exclude_defaults=exclude_defaults, 143 | ) 144 | return cast( 145 | "dict[str, Any]", 146 | model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] 147 | exclude_unset=exclude_unset, 148 | exclude_defaults=exclude_defaults, 149 | ), 150 | ) 151 | 152 | 153 | def model_parse(model: type[_ModelT], data: Any) -> _ModelT: 154 | if PYDANTIC_V2: 155 | return model.model_validate(data) 156 | return model.parse_obj(data) # pyright: ignore[reportDeprecated] 157 | 158 | 159 | # generic models 160 | if TYPE_CHECKING: 161 | 162 | class GenericModel(pydantic.BaseModel): 163 | ... 164 | 165 | else: 166 | if PYDANTIC_V2: 167 | # there no longer needs to be a distinction in v2 but 168 | # we still have to create our own subclass to avoid 169 | # inconsistent MRO ordering errors 170 | class GenericModel(pydantic.BaseModel): 171 | ... 172 | 173 | else: 174 | import pydantic.generics 175 | 176 | class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): 177 | ... 178 | 179 | 180 | # cached properties 181 | if TYPE_CHECKING: 182 | cached_property = property 183 | 184 | # we define a separate type (copied from typeshed) 185 | # that represents that `cached_property` is `set`able 186 | # at runtime, which differs from `@property`. 187 | # 188 | # this is a separate type as editors likely special case 189 | # `@property` and we don't want to cause issues just to have 190 | # more helpful internal types. 191 | 192 | class typed_cached_property(Generic[_T]): 193 | func: Callable[[Any], _T] 194 | attrname: str | None 195 | 196 | def __init__(self, func: Callable[[Any], _T]) -> None: 197 | ... 198 | 199 | @overload 200 | def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: 201 | ... 202 | 203 | @overload 204 | def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: 205 | ... 206 | 207 | def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: 208 | raise NotImplementedError() 209 | 210 | def __set_name__(self, owner: type[Any], name: str) -> None: 211 | ... 212 | 213 | # __set__ is not defined at runtime, but @cached_property is designed to be settable 214 | def __set__(self, instance: object, value: _T) -> None: 215 | ... 216 | else: 217 | try: 218 | from functools import cached_property as cached_property 219 | except ImportError: 220 | from cached_property import cached_property as cached_property 221 | 222 | typed_cached_property = cached_property 223 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_base_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | from typing import TYPE_CHECKING, Any, Type, Generic, TypeVar, cast, Dict 5 | from datetime import date, datetime 6 | from typing_extensions import ( 7 | Literal, 8 | ClassVar, 9 | Protocol, 10 | override, 11 | runtime_checkable, 12 | ) 13 | 14 | import pydantic 15 | import pydantic.generics 16 | from pydantic.fields import FieldInfo 17 | from pydantic import root_validator 18 | from ._base_type import ( 19 | IncEx, 20 | ModelT, 21 | ) 22 | from ._utils import is_list, is_mapping, parse_date, parse_datetime 23 | from ._base_compat import ( 24 | PYDANTIC_V2, 25 | ConfigDict, 26 | GenericModel as BaseGenericModel, 27 | get_args, 28 | is_union, 29 | parse_obj, 30 | get_origin, 31 | is_literal_type, 32 | get_model_config, 33 | get_model_fields, 34 | field_get_default, 35 | ) 36 | 37 | __all__ = ["BaseModel", "GenericModel"] 38 | 39 | _T = TypeVar("_T") 40 | 41 | @runtime_checkable 42 | class _ConfigProtocol(Protocol): 43 | allow_population_by_field_name: bool 44 | 45 | 46 | class BaseModel(pydantic.BaseModel): 47 | 48 | if PYDANTIC_V2: 49 | model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow") 50 | else: 51 | 52 | @property 53 | @override 54 | def model_fields_set(self) -> set[str]: 55 | # a forwards-compat shim for pydantic v2 56 | return self.__fields_set__ # type: ignore 57 | 58 | class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] 59 | extra: Any = pydantic.Extra.allow # type: ignore 60 | 61 | @override 62 | def __str__(self) -> str: 63 | # mypy complains about an invalid self arg 64 | return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc] 65 | 66 | # Override the 'construct' method in a way that supports recursive parsing without validation. 67 | # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. 68 | @classmethod 69 | @override 70 | def construct( 71 | cls: Type[ModelT], 72 | _fields_set: set[str] | None = None, 73 | **values: object, 74 | ) -> ModelT: 75 | m = cls.__new__(cls) 76 | fields_values: dict[str, object] = {} 77 | 78 | config = get_model_config(cls) 79 | populate_by_name = ( 80 | config.allow_population_by_field_name 81 | if isinstance(config, _ConfigProtocol) 82 | else config.get("populate_by_name") 83 | ) 84 | 85 | if _fields_set is None: 86 | _fields_set = set() 87 | 88 | model_fields = get_model_fields(cls) 89 | for name, field in model_fields.items(): 90 | key = field.alias 91 | if key is None or (key not in values and populate_by_name): 92 | key = name 93 | 94 | if key in values: 95 | fields_values[name] = _construct_field(value=values[key], field=field, key=key) 96 | _fields_set.add(name) 97 | else: 98 | fields_values[name] = field_get_default(field) 99 | 100 | _extra = {} 101 | for key, value in values.items(): 102 | if key not in model_fields: 103 | if PYDANTIC_V2: 104 | _extra[key] = value 105 | else: 106 | _fields_set.add(key) 107 | fields_values[key] = value 108 | 109 | object.__setattr__(m, "__dict__", fields_values) 110 | 111 | if PYDANTIC_V2: 112 | # these properties are copied from Pydantic's `model_construct()` method 113 | object.__setattr__(m, "__pydantic_private__", None) 114 | object.__setattr__(m, "__pydantic_extra__", _extra) 115 | object.__setattr__(m, "__pydantic_fields_set__", _fields_set) 116 | else: 117 | # init_private_attributes() does not exist in v2 118 | m._init_private_attributes() # type: ignore 119 | 120 | # copied from Pydantic v1's `construct()` method 121 | object.__setattr__(m, "__fields_set__", _fields_set) 122 | 123 | return m 124 | 125 | if not TYPE_CHECKING: 126 | # type checkers incorrectly complain about this assignment 127 | # because the type signatures are technically different 128 | # although not in practice 129 | model_construct = construct 130 | 131 | if not PYDANTIC_V2: 132 | # we define aliases for some of the new pydantic v2 methods so 133 | # that we can just document these methods without having to specify 134 | # a specific pydantic version as some users may not know which 135 | # pydantic version they are currently using 136 | 137 | @override 138 | def model_dump( 139 | self, 140 | *, 141 | mode: Literal["json", "python"] | str = "python", 142 | include: IncEx = None, 143 | exclude: IncEx = None, 144 | by_alias: bool = False, 145 | exclude_unset: bool = False, 146 | exclude_defaults: bool = False, 147 | exclude_none: bool = False, 148 | round_trip: bool = False, 149 | warnings: bool = True, 150 | ) -> dict[str, Any]: 151 | """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump 152 | 153 | Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. 154 | 155 | Args: 156 | mode: The mode in which `to_python` should run. 157 | If mode is 'json', the dictionary will only contain JSON serializable types. 158 | If mode is 'python', the dictionary may contain any Python objects. 159 | include: A list of fields to include in the output. 160 | exclude: A list of fields to exclude from the output. 161 | by_alias: Whether to use the field's alias in the dictionary key if defined. 162 | exclude_unset: Whether to exclude fields that are unset or None from the output. 163 | exclude_defaults: Whether to exclude fields that are set to their default value from the output. 164 | exclude_none: Whether to exclude fields that have a value of `None` from the output. 165 | round_trip: Whether to enable serialization and deserialization round-trip support. 166 | warnings: Whether to log warnings when invalid fields are encountered. 167 | 168 | Returns: 169 | A dictionary representation of the model. 170 | """ 171 | if mode != "python": 172 | raise ValueError("mode is only supported in Pydantic v2") 173 | if round_trip != False: 174 | raise ValueError("round_trip is only supported in Pydantic v2") 175 | if warnings != True: 176 | raise ValueError("warnings is only supported in Pydantic v2") 177 | return super().dict( # pyright: ignore[reportDeprecated] 178 | include=include, 179 | exclude=exclude, 180 | by_alias=by_alias, 181 | exclude_unset=exclude_unset, 182 | exclude_defaults=exclude_defaults, 183 | exclude_none=exclude_none, 184 | ) 185 | 186 | @override 187 | def model_dump_json( 188 | self, 189 | *, 190 | indent: int | None = None, 191 | include: IncEx = None, 192 | exclude: IncEx = None, 193 | by_alias: bool = False, 194 | exclude_unset: bool = False, 195 | exclude_defaults: bool = False, 196 | exclude_none: bool = False, 197 | round_trip: bool = False, 198 | warnings: bool = True, 199 | ) -> str: 200 | """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json 201 | 202 | Generates a JSON representation of the model using Pydantic's `to_json` method. 203 | 204 | Args: 205 | indent: Indentation to use in the JSON output. If None is passed, the output will be compact. 206 | include: Field(s) to include in the JSON output. Can take either a string or set of strings. 207 | exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. 208 | by_alias: Whether to serialize using field aliases. 209 | exclude_unset: Whether to exclude fields that have not been explicitly set. 210 | exclude_defaults: Whether to exclude fields that have the default value. 211 | exclude_none: Whether to exclude fields that have a value of `None`. 212 | round_trip: Whether to use serialization/deserialization between JSON and class instance. 213 | warnings: Whether to show any warnings that occurred during serialization. 214 | 215 | Returns: 216 | A JSON string representation of the model. 217 | """ 218 | if round_trip != False: 219 | raise ValueError("round_trip is only supported in Pydantic v2") 220 | if warnings != True: 221 | raise ValueError("warnings is only supported in Pydantic v2") 222 | return super().json( # type: ignore[reportDeprecated] 223 | indent=indent, 224 | include=include, 225 | exclude=exclude, 226 | by_alias=by_alias, 227 | exclude_unset=exclude_unset, 228 | exclude_defaults=exclude_defaults, 229 | exclude_none=exclude_none, 230 | ) 231 | 232 | 233 | def _construct_field(value: object, field: FieldInfo, key: str) -> object: 234 | if value is None: 235 | return field_get_default(field) 236 | 237 | if PYDANTIC_V2: 238 | type_ = field.annotation 239 | else: 240 | type_ = cast(type, field.outer_type_) # type: ignore 241 | 242 | if type_ is None: 243 | raise RuntimeError(f"Unexpected field type is None for {key}") 244 | 245 | return construct_type(value=value, type_=type_) 246 | 247 | 248 | def is_basemodel(type_: type) -> bool: 249 | """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" 250 | origin = get_origin(type_) or type_ 251 | if is_union(type_): 252 | for variant in get_args(type_): 253 | if is_basemodel(variant): 254 | return True 255 | 256 | return False 257 | 258 | return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) 259 | 260 | 261 | def construct_type(*, value: object, type_: type) -> object: 262 | """Loose coercion to the expected type with construction of nested values. 263 | 264 | If the given value does not match the expected type then it is returned as-is. 265 | """ 266 | 267 | # we need to use the origin class for any types that are subscripted generics 268 | # e.g. Dict[str, object] 269 | origin = get_origin(type_) or type_ 270 | args = get_args(type_) 271 | 272 | if is_union(origin): 273 | try: 274 | return validate_type(type_=cast("type[object]", type_), value=value) 275 | except Exception: 276 | pass 277 | 278 | # if the data is not valid, use the first variant that doesn't fail while deserializing 279 | for variant in args: 280 | try: 281 | return construct_type(value=value, type_=variant) 282 | except Exception: 283 | continue 284 | 285 | raise RuntimeError(f"Could not convert data into a valid instance of {type_}") 286 | 287 | if origin == dict: 288 | if not is_mapping(value): 289 | return value 290 | 291 | _, items_type = get_args(type_) # Dict[_, items_type] 292 | return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} 293 | 294 | if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): 295 | if is_list(value): 296 | return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] 297 | 298 | if is_mapping(value): 299 | if issubclass(type_, BaseModel): 300 | return type_.construct(**value) # type: ignore[arg-type] 301 | 302 | return cast(Any, type_).construct(**value) 303 | 304 | if origin == list: 305 | if not is_list(value): 306 | return value 307 | 308 | inner_type = args[0] # List[inner_type] 309 | return [construct_type(value=entry, type_=inner_type) for entry in value] 310 | 311 | if origin == float: 312 | if isinstance(value, int): 313 | coerced = float(value) 314 | if coerced != value: 315 | return value 316 | return coerced 317 | 318 | return value 319 | 320 | if type_ == datetime: 321 | try: 322 | return parse_datetime(value) # type: ignore 323 | except Exception: 324 | return value 325 | 326 | if type_ == date: 327 | try: 328 | return parse_date(value) # type: ignore 329 | except Exception: 330 | return value 331 | 332 | return value 333 | 334 | 335 | def validate_type(*, type_: type[_T], value: object) -> _T: 336 | """Strict validation that the given value matches the expected type""" 337 | if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): 338 | return cast(_T, parse_obj(type_, value)) 339 | 340 | return cast(_T, _validate_non_model_type(type_=type_, value=value)) 341 | 342 | 343 | # our use of subclasssing here causes weirdness for type checkers, 344 | # so we just pretend that we don't subclass 345 | if TYPE_CHECKING: 346 | GenericModel = BaseModel 347 | else: 348 | 349 | class GenericModel(BaseGenericModel, BaseModel): 350 | pass 351 | 352 | 353 | if PYDANTIC_V2: 354 | from pydantic import TypeAdapter 355 | 356 | 357 | def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: 358 | return TypeAdapter(type_).validate_python(value) 359 | 360 | elif not TYPE_CHECKING: 361 | 362 | class TypeAdapter(Generic[_T]): 363 | """Used as a placeholder to easily convert runtime types to a Pydantic format 364 | to provide validation. 365 | 366 | For example: 367 | ```py 368 | validated = RootModel[int](__root__="5").__root__ 369 | # validated: 5 370 | ``` 371 | """ 372 | def __init__(self, type_: Type[_T]): 373 | self.type_ = type_ 374 | 375 | def validate_python(self, value: Any) -> _T: 376 | if not isinstance(value, self.type_): 377 | raise ValueError(f"Invalid type: {value} is not of type {self.type_}") 378 | return value 379 | 380 | def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: 381 | return TypeAdapter(type_).validate_python(value) 382 | 383 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_base_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from os import PathLike 4 | from typing import ( 5 | IO, 6 | TYPE_CHECKING, 7 | Any, 8 | Dict, 9 | List, 10 | Type, 11 | Tuple, 12 | Union, 13 | Mapping, 14 | TypeVar, 15 | Callable, 16 | Optional, 17 | Sequence, 18 | ) 19 | import pydantic 20 | from httpx import Response 21 | from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable 22 | Query = Mapping[str, object] 23 | Body = object 24 | AnyMapping = Mapping[str, object] 25 | PrimitiveData = Union[str, int, float, bool, None] 26 | Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] 27 | ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) 28 | _T = TypeVar("_T") 29 | 30 | if TYPE_CHECKING: 31 | NoneType: Type[None] 32 | else: 33 | NoneType = type(None) 34 | 35 | 36 | # Sentinel class used until PEP 0661 is accepted 37 | class NotGiven: 38 | """ 39 | A sentinel singleton class used to distinguish omitted keyword arguments 40 | from those passed in with the value None (which may have different behavior). 41 | 42 | For example: 43 | 44 | ```py 45 | def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... 46 | 47 | get(timeout=1) # 1s timeout 48 | get(timeout=None) # No timeout 49 | get() # Default timeout behavior, which may not be statically known at the method definition. 50 | ``` 51 | """ 52 | 53 | def __bool__(self) -> Literal[False]: 54 | return False 55 | 56 | @override 57 | def __repr__(self) -> str: 58 | return "NOT_GIVEN" 59 | 60 | 61 | NotGivenOr = Union[_T, NotGiven] 62 | NOT_GIVEN = NotGiven() 63 | 64 | 65 | class Omit: 66 | """In certain situations you need to be able to represent a case where a default value has 67 | to be explicitly removed and `None` is not an appropriate substitute, for example: 68 | 69 | ```py 70 | # as the default `Content-Type` header is `application/json` that will be sent 71 | client.post('/upload/files', files={'file': b'my raw file content'}) 72 | 73 | # you can't explicitly override the header as it has to be dynamically generated 74 | # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' 75 | client.post(..., headers={'Content-Type': 'multipart/form-data'}) 76 | 77 | # instead you can remove the default `application/json` header by passing Omit 78 | client.post(..., headers={'Content-Type': Omit()}) 79 | ``` 80 | """ 81 | 82 | def __bool__(self) -> Literal[False]: 83 | return False 84 | 85 | 86 | @runtime_checkable 87 | class ModelBuilderProtocol(Protocol): 88 | @classmethod 89 | def build( 90 | cls: type[_T], 91 | *, 92 | response: Response, 93 | data: object, 94 | ) -> _T: 95 | ... 96 | 97 | 98 | Headers = Mapping[str, Union[str, Omit]] 99 | 100 | 101 | class HeadersLikeProtocol(Protocol): 102 | def get(self, __key: str) -> str | None: 103 | ... 104 | 105 | 106 | HeadersLike = Union[Headers, HeadersLikeProtocol] 107 | 108 | ResponseT = TypeVar( 109 | "ResponseT", 110 | bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", 111 | ) 112 | 113 | StrBytesIntFloat = Union[str, bytes, int, float] 114 | 115 | # Note: copied from Pydantic 116 | # https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 117 | IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" 118 | 119 | PostParser = Callable[[Any], Any] 120 | 121 | 122 | @runtime_checkable 123 | class InheritsGeneric(Protocol): 124 | """Represents a type that has inherited from `Generic` 125 | 126 | The `__orig_bases__` property can be used to determine the resolved 127 | type variable for a given base class. 128 | """ 129 | 130 | __orig_bases__: tuple[_GenericAlias] 131 | 132 | 133 | class _GenericAlias(Protocol): 134 | __origin__: type[object] 135 | 136 | 137 | class HttpxSendArgs(TypedDict, total=False): 138 | auth: httpx.Auth 139 | 140 | # for user input files 141 | if TYPE_CHECKING: 142 | FileContent = Union[IO[bytes], bytes, PathLike[str]] 143 | else: 144 | FileContent = Union[IO[bytes], bytes, PathLike] 145 | 146 | FileTypes = Union[ 147 | FileContent, # file content 148 | Tuple[str, FileContent], # (filename, file) 149 | Tuple[str, FileContent, str], # (filename, file , content_type) 150 | Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) 151 | ] 152 | 153 | RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] 154 | 155 | # for httpx client supported files 156 | 157 | HttpxFileContent = Union[bytes, IO[bytes]] 158 | HttpxFileTypes = Union[ 159 | FileContent, # file content 160 | Tuple[str, HttpxFileContent], # (filename, file) 161 | Tuple[str, HttpxFileContent, str], # (filename, file , content_type) 162 | Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) 163 | ] 164 | 165 | HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] 166 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_errors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import httpx 4 | 5 | __all__ = [ 6 | "ZhipuAIError", 7 | "APIStatusError", 8 | "APIRequestFailedError", 9 | "APIAuthenticationError", 10 | "APIReachLimitError", 11 | "APIInternalError", 12 | "APIServerFlowExceedError", 13 | "APIResponseError", 14 | "APIResponseValidationError", 15 | "APITimeoutError", 16 | ] 17 | 18 | 19 | class ZhipuAIError(Exception): 20 | def __init__(self, message: str, ) -> None: 21 | super().__init__(message) 22 | 23 | 24 | class APIStatusError(ZhipuAIError): 25 | response: httpx.Response 26 | status_code: int 27 | 28 | def __init__(self, message: str, *, response: httpx.Response) -> None: 29 | super().__init__(message) 30 | self.response = response 31 | self.status_code = response.status_code 32 | 33 | 34 | class APIRequestFailedError(APIStatusError): 35 | ... 36 | 37 | 38 | class APIAuthenticationError(APIStatusError): 39 | ... 40 | 41 | 42 | class APIReachLimitError(APIStatusError): 43 | ... 44 | 45 | 46 | class APIInternalError(APIStatusError): 47 | ... 48 | 49 | 50 | class APIServerFlowExceedError(APIStatusError): 51 | ... 52 | 53 | 54 | class APIResponseError(ZhipuAIError): 55 | message: str 56 | request: httpx.Request 57 | json_data: object 58 | 59 | def __init__(self, message: str, request: httpx.Request, json_data: object): 60 | self.message = message 61 | self.request = request 62 | self.json_data = json_data 63 | super().__init__(message) 64 | 65 | 66 | class APIResponseValidationError(APIResponseError): 67 | status_code: int 68 | response: httpx.Response 69 | 70 | def __init__( 71 | self, 72 | response: httpx.Response, 73 | json_data: object | None, *, 74 | message: str | None = None 75 | ) -> None: 76 | super().__init__( 77 | message=message or "Data returned by API invalid for expected schema.", 78 | request=response.request, 79 | json_data=json_data 80 | ) 81 | self.response = response 82 | self.status_code = response.status_code 83 | 84 | 85 | class APIConnectionError(APIResponseError): 86 | def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: 87 | super().__init__(message, request, json_data=None) 88 | 89 | 90 | class APITimeoutError(APIConnectionError): 91 | def __init__(self, request: httpx.Request) -> None: 92 | super().__init__(message="Request timed out.", request=request) 93 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_files.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | import os 5 | from pathlib import Path 6 | from typing import Mapping, Sequence 7 | 8 | from ._base_type import ( 9 | FileTypes, 10 | HttpxFileTypes, 11 | HttpxRequestFiles, 12 | RequestFiles, 13 | ) 14 | 15 | 16 | def is_file_content(obj: object) -> bool: 17 | return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) 18 | 19 | 20 | def _transform_file(file: FileTypes) -> HttpxFileTypes: 21 | if is_file_content(file): 22 | if isinstance(file, os.PathLike): 23 | path = Path(file) 24 | return path.name, path.read_bytes() 25 | else: 26 | return file 27 | if isinstance(file, tuple): 28 | if isinstance(file[1], os.PathLike): 29 | return (file[0], Path(file[1]).read_bytes(), *file[2:]) 30 | else: 31 | return (file[0], file[1], *file[2:]) 32 | else: 33 | raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") 34 | 35 | 36 | def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: 37 | if files is None: 38 | return None 39 | 40 | if isinstance(files, Mapping): 41 | files = {key: _transform_file(file) for key, file in files.items()} 42 | elif isinstance(files, Sequence): 43 | files = [(key, _transform_file(file)) for key, file in files] 44 | else: 45 | raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") 46 | return files 47 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_jwt_token.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import time 3 | 4 | import cachetools.func 5 | import jwt 6 | 7 | # 缓存时间 3分钟 8 | CACHE_TTL_SECONDS = 3 * 60 9 | 10 | # token 有效期比缓存时间 多30秒 11 | API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30 12 | 13 | 14 | @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) 15 | def generate_token(apikey: str): 16 | try: 17 | api_key, secret = apikey.split(".") 18 | except Exception as e: 19 | raise Exception("invalid api_key", e) 20 | 21 | payload = { 22 | "api_key": api_key, 23 | "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, 24 | "timestamp": int(round(time.time() * 1000)), 25 | } 26 | ret = jwt.encode( 27 | payload, 28 | secret, 29 | algorithm="HS256", 30 | headers={"alg": "HS256", "sign_type": "SIGN"}, 31 | ) 32 | return ret 33 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_request_opt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, Any, cast, TYPE_CHECKING 4 | 5 | from ._base_compat import ConfigDict, PYDANTIC_V2 6 | import pydantic.generics 7 | from httpx import Timeout 8 | from typing_extensions import ( 9 | final, Unpack, ClassVar, TypedDict 10 | 11 | ) 12 | 13 | from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query, AnyMapping 14 | from ._utils import remove_notgiven_indict, strip_not_given 15 | 16 | 17 | class UserRequestInput(TypedDict, total=False): 18 | headers: Headers 19 | max_retries: int 20 | timeout: float | Timeout | None 21 | params: Query 22 | extra_json: AnyMapping 23 | 24 | class FinalRequestOptionsInput(TypedDict, total=False): 25 | method: Required[str] 26 | url: Required[str] 27 | params: Query 28 | headers: Headers 29 | max_retries: int 30 | timeout: float | Timeout | None 31 | files: HttpxRequestFiles | None 32 | json_data: Body 33 | extra_json: AnyMapping 34 | 35 | @final 36 | class FinalRequestOptions(pydantic.BaseModel): 37 | method: str 38 | url: str 39 | params: Query = {} 40 | headers: Union[Headers, NotGiven] = NotGiven() 41 | max_retries: Union[int, NotGiven] = NotGiven() 42 | timeout: Union[float, Timeout, None, NotGiven] = NotGiven() 43 | files: Union[HttpxRequestFiles, None] = None 44 | idempotency_key: Union[str, None] = None 45 | post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() 46 | 47 | # It should be noted that we cannot use `json` here as that would override 48 | # a BaseModel method in an incompatible fashion. 49 | json_data: Union[Body, None] = None 50 | extra_json: Union[AnyMapping, None] = None 51 | 52 | 53 | if PYDANTIC_V2: 54 | model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) 55 | else: 56 | 57 | class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] 58 | arbitrary_types_allowed: bool = True 59 | 60 | def get_max_retries(self, max_retries: int) -> int: 61 | if isinstance(self.max_retries, NotGiven): 62 | return max_retries 63 | return self.max_retries 64 | 65 | def _strip_raw_response_header(self) -> None: 66 | if not is_given(self.headers): 67 | return 68 | 69 | if self.headers.get(RAW_RESPONSE_HEADER): 70 | self.headers = {**self.headers} 71 | self.headers.pop(RAW_RESPONSE_HEADER) 72 | 73 | # override the `construct` method so that we can run custom transformations. 74 | # this is necessary as we don't want to do any actual runtime type checking 75 | # (which means we can't use validators) but we do want to ensure that `NotGiven` 76 | # values are not present 77 | # 78 | # type ignore required because we're adding explicit types to `**values` 79 | @classmethod 80 | def construct( # type: ignore 81 | cls, 82 | _fields_set: set[str] | None = None, 83 | **values: Unpack[UserRequestInput], 84 | ) -> FinalRequestOptions: 85 | kwargs: dict[str, Any] = { 86 | # we unconditionally call `strip_not_given` on any value 87 | # as it will just ignore any non-mapping types 88 | key: strip_not_given(value) 89 | for key, value in values.items() 90 | } 91 | if PYDANTIC_V2: 92 | return super().model_construct(_fields_set, **kwargs) 93 | return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] 94 | 95 | if not TYPE_CHECKING: 96 | # type checkers incorrectly complain about this assignment 97 | model_construct = construct 98 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_response.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING 5 | 6 | import httpx 7 | import logging 8 | from typing_extensions import ParamSpec, get_origin, get_args 9 | 10 | from ._base_type import NoneType 11 | from ._sse_client import StreamResponse 12 | from ._base_models import is_basemodel 13 | 14 | if TYPE_CHECKING: 15 | from ._http_client import HttpClient 16 | 17 | P = ParamSpec("P") 18 | R = TypeVar("R") 19 | log: logging.Logger = logging.getLogger(__name__) 20 | 21 | class HttpResponse(Generic[R]): 22 | _cast_type: type[R] 23 | _client: "HttpClient" 24 | _parsed: R | None 25 | _stream: bool 26 | _stream_cls: type[StreamResponse[Any]] 27 | http_response: httpx.Response 28 | 29 | def __init__( 30 | self, 31 | *, 32 | raw_response: httpx.Response, 33 | cast_type: type[R], 34 | client: "HttpClient", 35 | stream: bool = False, 36 | stream_cls: type[StreamResponse[Any]] | None = None, 37 | ) -> None: 38 | self._cast_type = cast_type 39 | self._client = client 40 | self._parsed = None 41 | self._stream_cls = stream_cls 42 | self._stream = stream 43 | self.http_response = raw_response 44 | 45 | def parse(self) -> R: 46 | self._parsed = self._parse() 47 | return self._parsed 48 | 49 | def _parse(self) -> R: 50 | if self._stream: 51 | self._parsed = cast( 52 | R, 53 | self._stream_cls( 54 | cast_type=cast(type, get_args(self._stream_cls)[0]), 55 | response=self.http_response, 56 | client=self._client 57 | ) 58 | ) 59 | return self._parsed 60 | cast_type = self._cast_type 61 | if cast_type is NoneType: 62 | return cast(R, None) 63 | http_response = self.http_response 64 | if cast_type == str: 65 | return cast(R, http_response.text) 66 | 67 | content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") 68 | origin = get_origin(cast_type) or cast_type 69 | if content_type != "application/json": 70 | if is_basemodel(cast_type): 71 | try: 72 | data = http_response.json() 73 | except Exception as exc: 74 | log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) 75 | else: 76 | return self._client._process_response_data( 77 | data=data, 78 | cast_type=cast_type, # type: ignore 79 | response=http_response, 80 | ) 81 | 82 | return http_response.text 83 | 84 | data = http_response.json() 85 | 86 | return self._client._process_response_data( 87 | data=data, 88 | cast_type=cast_type, # type: ignore 89 | response=http_response, 90 | ) 91 | 92 | @property 93 | def headers(self) -> httpx.Headers: 94 | return self.http_response.headers 95 | 96 | @property 97 | def http_request(self) -> httpx.Request: 98 | return self.http_response.request 99 | 100 | @property 101 | def status_code(self) -> int: 102 | return self.http_response.status_code 103 | 104 | @property 105 | def url(self) -> httpx.URL: 106 | return self.http_response.url 107 | 108 | @property 109 | def method(self) -> str: 110 | return self.http_request.method 111 | 112 | @property 113 | def content(self) -> bytes: 114 | return self.http_response.content 115 | 116 | @property 117 | def text(self) -> str: 118 | return self.http_response.text 119 | 120 | @property 121 | def http_version(self) -> str: 122 | return self.http_response.http_version 123 | 124 | @property 125 | def elapsed(self) -> datetime.timedelta: 126 | return self.http_response.elapsed 127 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_sse_client.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import json 5 | from typing import Generic, Iterator, TYPE_CHECKING, Mapping 6 | 7 | import httpx 8 | 9 | from ._base_type import ResponseT 10 | from ._errors import APIResponseError 11 | 12 | _FIELD_SEPARATOR = ":" 13 | 14 | if TYPE_CHECKING: 15 | from ._http_client import HttpClient 16 | 17 | 18 | class StreamResponse(Generic[ResponseT]): 19 | 20 | response: httpx.Response 21 | _cast_type: type[ResponseT] 22 | 23 | def __init__( 24 | self, 25 | *, 26 | cast_type: type[ResponseT], 27 | response: httpx.Response, 28 | client: HttpClient, 29 | ) -> None: 30 | self.response = response 31 | self._cast_type = cast_type 32 | self._data_process_func = client._process_response_data 33 | self._stream_chunks = self.__stream__() 34 | 35 | def __next__(self) -> ResponseT: 36 | return self._stream_chunks.__next__() 37 | 38 | def __iter__(self) -> Iterator[ResponseT]: 39 | for item in self._stream_chunks: 40 | yield item 41 | 42 | def __stream__(self) -> Iterator[ResponseT]: 43 | 44 | sse_line_parser = SSELineParser() 45 | iterator = sse_line_parser.iter_lines(self.response.iter_lines()) 46 | 47 | for sse in iterator: 48 | if sse.data.startswith("[DONE]"): 49 | break 50 | 51 | if sse.event is None: 52 | data = sse.json_data() 53 | if isinstance(data, Mapping) and data.get("error"): 54 | raise APIResponseError( 55 | message="An error occurred during streaming", 56 | request=self.response.request, 57 | json_data=data["error"], 58 | ) 59 | 60 | yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) 61 | for sse in iterator: 62 | pass 63 | 64 | 65 | class Event(object): 66 | def __init__( 67 | self, 68 | event: str | None = None, 69 | data: str | None = None, 70 | id: str | None = None, 71 | retry: int | None = None 72 | ): 73 | self._event = event 74 | self._data = data 75 | self._id = id 76 | self._retry = retry 77 | 78 | def __repr__(self): 79 | data_len = len(self._data) if self._data else 0 80 | return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" 81 | 82 | @property 83 | def event(self): return self._event 84 | 85 | @property 86 | def data(self): return self._data 87 | 88 | def json_data(self): return json.loads(self._data) 89 | 90 | @property 91 | def id(self): return self._id 92 | 93 | @property 94 | def retry(self): return self._retry 95 | 96 | 97 | class SSELineParser: 98 | _data: list[str] 99 | _event: str | None 100 | _retry: int | None 101 | _id: str | None 102 | 103 | def __init__(self): 104 | self._event = None 105 | self._data = [] 106 | self._id = None 107 | self._retry = None 108 | 109 | def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: 110 | for line in lines: 111 | line = line.rstrip('\n') 112 | if not line: 113 | if self._event is None and \ 114 | not self._data and \ 115 | self._id is None and \ 116 | self._retry is None: 117 | continue 118 | sse_event = Event( 119 | event=self._event, 120 | data='\n'.join(self._data), 121 | id=self._id, 122 | retry=self._retry 123 | ) 124 | self._event = None 125 | self._data = [] 126 | self._id = None 127 | self._retry = None 128 | 129 | yield sse_event 130 | self.decode_line(line) 131 | 132 | def decode_line(self, line: str): 133 | if line.startswith(":") or not line: 134 | return 135 | 136 | field, _p, value = line.partition(":") 137 | 138 | if value.startswith(' '): 139 | value = value[1:] 140 | if field == "data": 141 | self._data.append(value) 142 | elif field == "event": 143 | self._event = value 144 | elif field == "retry": 145 | try: 146 | self._retry = int(value) 147 | except (TypeError, ValueError): 148 | pass 149 | return 150 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._utils import ( 3 | remove_notgiven_indict as remove_notgiven_indict, 4 | flatten as flatten, 5 | is_dict as is_dict, 6 | is_list as is_list, 7 | is_given as is_given, 8 | is_tuple as is_tuple, 9 | is_mapping as is_mapping, 10 | is_tuple_t as is_tuple_t, 11 | parse_date as parse_date, 12 | is_iterable as is_iterable, 13 | is_sequence as is_sequence, 14 | coerce_float as coerce_float, 15 | is_mapping_t as is_mapping_t, 16 | removeprefix as removeprefix, 17 | removesuffix as removesuffix, 18 | extract_files as extract_files, 19 | is_sequence_t as is_sequence_t, 20 | required_args as required_args, 21 | coerce_boolean as coerce_boolean, 22 | coerce_integer as coerce_integer, 23 | file_from_path as file_from_path, 24 | parse_datetime as parse_datetime, 25 | strip_not_given as strip_not_given, 26 | deepcopy_minimal as deepcopy_minimal, 27 | get_async_library as get_async_library, 28 | maybe_coerce_float as maybe_coerce_float, 29 | get_required_header as get_required_header, 30 | maybe_coerce_boolean as maybe_coerce_boolean, 31 | maybe_coerce_integer as maybe_coerce_integer, 32 | ) 33 | 34 | 35 | from ._typing import ( 36 | is_list_type as is_list_type, 37 | is_union_type as is_union_type, 38 | extract_type_arg as extract_type_arg, 39 | is_iterable_type as is_iterable_type, 40 | is_required_type as is_required_type, 41 | is_annotated_type as is_annotated_type, 42 | strip_annotated_type as strip_annotated_type, 43 | extract_type_var_from_base as extract_type_var_from_base, 44 | ) -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_utils/_typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, TypeVar, Iterable, cast 4 | from collections import abc as _c_abc 5 | from typing_extensions import Required, Annotated, get_args, get_origin 6 | 7 | from .._base_type import InheritsGeneric 8 | from zhipuai.core._base_compat import is_union as _is_union 9 | 10 | 11 | def is_annotated_type(typ: type) -> bool: 12 | return get_origin(typ) == Annotated 13 | 14 | 15 | def is_list_type(typ: type) -> bool: 16 | return (get_origin(typ) or typ) == list 17 | 18 | 19 | def is_iterable_type(typ: type) -> bool: 20 | """If the given type is `typing.Iterable[T]`""" 21 | origin = get_origin(typ) or typ 22 | return origin == Iterable or origin == _c_abc.Iterable 23 | 24 | 25 | def is_union_type(typ: type) -> bool: 26 | return _is_union(get_origin(typ)) 27 | 28 | 29 | def is_required_type(typ: type) -> bool: 30 | return get_origin(typ) == Required 31 | 32 | 33 | def is_typevar(typ: type) -> bool: 34 | # type ignore is required because type checkers 35 | # think this expression will always return False 36 | return type(typ) == TypeVar # type: ignore 37 | 38 | 39 | # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] 40 | def strip_annotated_type(typ: type) -> type: 41 | if is_required_type(typ) or is_annotated_type(typ): 42 | return strip_annotated_type(cast(type, get_args(typ)[0])) 43 | 44 | return typ 45 | 46 | 47 | def extract_type_arg(typ: type, index: int) -> type: 48 | args = get_args(typ) 49 | try: 50 | return cast(type, args[index]) 51 | except IndexError as err: 52 | raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err 53 | 54 | 55 | def extract_type_var_from_base( 56 | typ: type, 57 | *, 58 | generic_bases: tuple[type, ...], 59 | index: int, 60 | failure_message: str | None = None, 61 | ) -> type: 62 | """Given a type like `Foo[T]`, returns the generic type variable `T`. 63 | 64 | This also handles the case where a concrete subclass is given, e.g. 65 | ```py 66 | class MyResponse(Foo[bytes]): 67 | ... 68 | 69 | extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes 70 | ``` 71 | 72 | And where a generic subclass is given: 73 | ```py 74 | _T = TypeVar('_T') 75 | class MyResponse(Foo[_T]): 76 | ... 77 | 78 | extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes 79 | ``` 80 | """ 81 | cls = cast(object, get_origin(typ) or typ) 82 | if cls in generic_bases: 83 | # we're given the class directly 84 | return extract_type_arg(typ, index) 85 | 86 | # if a subclass is given 87 | # --- 88 | # this is needed as __orig_bases__ is not present in the typeshed stubs 89 | # because it is intended to be for internal use only, however there does 90 | # not seem to be a way to resolve generic TypeVars for inherited subclasses 91 | # without using it. 92 | if isinstance(cls, InheritsGeneric): 93 | target_base_class: Any | None = None 94 | for base in cls.__orig_bases__: 95 | if base.__origin__ in generic_bases: 96 | target_base_class = base 97 | break 98 | 99 | if target_base_class is None: 100 | raise RuntimeError( 101 | "Could not find the generic base class;\n" 102 | "This should never happen;\n" 103 | f"Does {cls} inherit from one of {generic_bases} ?" 104 | ) 105 | 106 | extracted = extract_type_arg(target_base_class, index) 107 | if is_typevar(extracted): 108 | # If the extracted type argument is itself a type variable 109 | # then that means the subclass itself is generic, so we have 110 | # to resolve the type argument from the class itself, not 111 | # the base class. 112 | # 113 | # Note: if there is more than 1 type argument, the subclass could 114 | # change the ordering of the type arguments, this is not currently 115 | # supported. 116 | return extract_type_arg(typ, index) 117 | 118 | return extracted 119 | 120 | raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") 121 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/_utils/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import re 5 | import inspect 6 | import functools 7 | from typing import ( 8 | Any, 9 | Tuple, 10 | Mapping, 11 | TypeVar, 12 | Callable, 13 | Iterable, 14 | Sequence, 15 | cast, 16 | overload, 17 | ) 18 | from pathlib import Path 19 | from typing_extensions import TypeGuard 20 | 21 | import sniffio 22 | 23 | from .._base_type import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike 24 | from .._base_compat import parse_date as parse_date, parse_datetime as parse_datetime 25 | 26 | 27 | def remove_notgiven_indict(obj): 28 | if obj is None or (not isinstance(obj, Mapping)): 29 | return obj 30 | return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} 31 | 32 | 33 | _T = TypeVar("_T") 34 | _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) 35 | _MappingT = TypeVar("_MappingT", bound=Mapping[str, object]) 36 | _SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) 37 | CallableT = TypeVar("CallableT", bound=Callable[..., Any]) 38 | 39 | 40 | def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: 41 | return [item for sublist in t for item in sublist] 42 | 43 | 44 | def extract_files( 45 | # TODO: this needs to take Dict but variance issues..... 46 | # create protocol type ? 47 | query: Mapping[str, object], 48 | *, 49 | paths: Sequence[Sequence[str]], 50 | ) -> list[tuple[str, FileTypes]]: 51 | """Recursively extract files from the given dictionary based on specified paths. 52 | 53 | A path may look like this ['foo', 'files', '', 'data']. 54 | 55 | Note: this mutates the given dictionary. 56 | """ 57 | files: list[tuple[str, FileTypes]] = [] 58 | for path in paths: 59 | files.extend(_extract_items(query, path, index=0, flattened_key=None)) 60 | return files 61 | 62 | 63 | def _extract_items( 64 | obj: object, 65 | path: Sequence[str], 66 | *, 67 | index: int, 68 | flattened_key: str | None, 69 | ) -> list[tuple[str, FileTypes]]: 70 | try: 71 | key = path[index] 72 | except IndexError: 73 | if isinstance(obj, NotGiven): 74 | # no value was provided - we can safely ignore 75 | return [] 76 | 77 | # cyclical import 78 | from .._files import assert_is_file_content 79 | 80 | # We have exhausted the path, return the entry we found. 81 | assert_is_file_content(obj, key=flattened_key) 82 | assert flattened_key is not None 83 | return [(flattened_key, cast(FileTypes, obj))] 84 | 85 | index += 1 86 | if is_dict(obj): 87 | try: 88 | # We are at the last entry in the path so we must remove the field 89 | if (len(path)) == index: 90 | item = obj.pop(key) 91 | else: 92 | item = obj[key] 93 | except KeyError: 94 | # Key was not present in the dictionary, this is not indicative of an error 95 | # as the given path may not point to a required field. We also do not want 96 | # to enforce required fields as the API may differ from the spec in some cases. 97 | return [] 98 | if flattened_key is None: 99 | flattened_key = key 100 | else: 101 | flattened_key += f"[{key}]" 102 | return _extract_items( 103 | item, 104 | path, 105 | index=index, 106 | flattened_key=flattened_key, 107 | ) 108 | elif is_list(obj): 109 | if key != "": 110 | return [] 111 | 112 | return flatten( 113 | [ 114 | _extract_items( 115 | item, 116 | path, 117 | index=index, 118 | flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", 119 | ) 120 | for item in obj 121 | ] 122 | ) 123 | 124 | # Something unexpected was passed, just ignore it. 125 | return [] 126 | 127 | 128 | def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: 129 | return not isinstance(obj, NotGiven) 130 | 131 | 132 | # Type safe methods for narrowing types with TypeVars. 133 | # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], 134 | # however this cause Pyright to rightfully report errors. As we know we don't 135 | # care about the contained types we can safely use `object` in it's place. 136 | # 137 | # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. 138 | # `is_*` is for when you're dealing with an unknown input 139 | # `is_*_t` is for when you're narrowing a known union type to a specific subset 140 | 141 | 142 | def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]: 143 | return isinstance(obj, tuple) 144 | 145 | 146 | def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]: 147 | return isinstance(obj, tuple) 148 | 149 | 150 | def is_sequence(obj: object) -> TypeGuard[Sequence[object]]: 151 | return isinstance(obj, Sequence) 152 | 153 | 154 | def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]: 155 | return isinstance(obj, Sequence) 156 | 157 | 158 | def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]: 159 | return isinstance(obj, Mapping) 160 | 161 | 162 | def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]: 163 | return isinstance(obj, Mapping) 164 | 165 | 166 | def is_dict(obj: object) -> TypeGuard[dict[object, object]]: 167 | return isinstance(obj, dict) 168 | 169 | 170 | def is_list(obj: object) -> TypeGuard[list[object]]: 171 | return isinstance(obj, list) 172 | 173 | 174 | def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: 175 | return isinstance(obj, Iterable) 176 | 177 | 178 | def deepcopy_minimal(item: _T) -> _T: 179 | """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: 180 | 181 | - mappings, e.g. `dict` 182 | - list 183 | 184 | This is done for performance reasons. 185 | """ 186 | if is_mapping(item): 187 | return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) 188 | if is_list(item): 189 | return cast(_T, [deepcopy_minimal(entry) for entry in item]) 190 | return item 191 | 192 | 193 | # copied from https://github.com/Rapptz/RoboDanny 194 | def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: 195 | size = len(seq) 196 | if size == 0: 197 | return "" 198 | 199 | if size == 1: 200 | return seq[0] 201 | 202 | if size == 2: 203 | return f"{seq[0]} {final} {seq[1]}" 204 | 205 | return delim.join(seq[:-1]) + f" {final} {seq[-1]}" 206 | 207 | 208 | def quote(string: str) -> str: 209 | """Add single quotation marks around the given string. Does *not* do any escaping.""" 210 | return f"'{string}'" 211 | 212 | 213 | def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: 214 | """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. 215 | 216 | Useful for enforcing runtime validation of overloaded functions. 217 | 218 | Example usage: 219 | ```py 220 | @overload 221 | def foo(*, a: str) -> str: 222 | ... 223 | 224 | 225 | @overload 226 | def foo(*, b: bool) -> str: 227 | ... 228 | 229 | 230 | # This enforces the same constraints that a static type checker would 231 | # i.e. that either a or b must be passed to the function 232 | @required_args(["a"], ["b"]) 233 | def foo(*, a: str | None = None, b: bool | None = None) -> str: 234 | ... 235 | ``` 236 | """ 237 | 238 | def inner(func: CallableT) -> CallableT: 239 | params = inspect.signature(func).parameters 240 | positional = [ 241 | name 242 | for name, param in params.items() 243 | if param.kind 244 | in { 245 | param.POSITIONAL_ONLY, 246 | param.POSITIONAL_OR_KEYWORD, 247 | } 248 | ] 249 | 250 | @functools.wraps(func) 251 | def wrapper(*args: object, **kwargs: object) -> object: 252 | given_params: set[str] = set() 253 | for i, _ in enumerate(args): 254 | try: 255 | given_params.add(positional[i]) 256 | except IndexError: 257 | raise TypeError( 258 | f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" 259 | ) from None 260 | 261 | for key in kwargs.keys(): 262 | given_params.add(key) 263 | 264 | for variant in variants: 265 | matches = all((param in given_params for param in variant)) 266 | if matches: 267 | break 268 | else: # no break 269 | if len(variants) > 1: 270 | variations = human_join( 271 | ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] 272 | ) 273 | msg = f"Missing required arguments; Expected either {variations} arguments to be given" 274 | else: 275 | # TODO: this error message is not deterministic 276 | missing = list(set(variants[0]) - given_params) 277 | if len(missing) > 1: 278 | msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" 279 | else: 280 | msg = f"Missing required argument: {quote(missing[0])}" 281 | raise TypeError(msg) 282 | return func(*args, **kwargs) 283 | 284 | return wrapper # type: ignore 285 | 286 | return inner 287 | 288 | 289 | _K = TypeVar("_K") 290 | _V = TypeVar("_V") 291 | 292 | 293 | @overload 294 | def strip_not_given(obj: None) -> None: 295 | ... 296 | 297 | 298 | @overload 299 | def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: 300 | ... 301 | 302 | 303 | @overload 304 | def strip_not_given(obj: object) -> object: 305 | ... 306 | 307 | 308 | def strip_not_given(obj: object | None) -> object: 309 | """Remove all top-level keys where their values are instances of `NotGiven`""" 310 | if obj is None: 311 | return None 312 | 313 | if not is_mapping(obj): 314 | return obj 315 | 316 | return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} 317 | 318 | 319 | def coerce_integer(val: str) -> int: 320 | return int(val, base=10) 321 | 322 | 323 | def coerce_float(val: str) -> float: 324 | return float(val) 325 | 326 | 327 | def coerce_boolean(val: str) -> bool: 328 | return val == "true" or val == "1" or val == "on" 329 | 330 | 331 | def maybe_coerce_integer(val: str | None) -> int | None: 332 | if val is None: 333 | return None 334 | return coerce_integer(val) 335 | 336 | 337 | def maybe_coerce_float(val: str | None) -> float | None: 338 | if val is None: 339 | return None 340 | return coerce_float(val) 341 | 342 | 343 | def maybe_coerce_boolean(val: str | None) -> bool | None: 344 | if val is None: 345 | return None 346 | return coerce_boolean(val) 347 | 348 | 349 | def removeprefix(string: str, prefix: str) -> str: 350 | """Remove a prefix from a string. 351 | 352 | Backport of `str.removeprefix` for Python < 3.9 353 | """ 354 | if string.startswith(prefix): 355 | return string[len(prefix):] 356 | return string 357 | 358 | 359 | def removesuffix(string: str, suffix: str) -> str: 360 | """Remove a suffix from a string. 361 | 362 | Backport of `str.removesuffix` for Python < 3.9 363 | """ 364 | if string.endswith(suffix): 365 | return string[: -len(suffix)] 366 | return string 367 | 368 | 369 | def file_from_path(path: str) -> FileTypes: 370 | contents = Path(path).read_bytes() 371 | file_name = os.path.basename(path) 372 | return (file_name, contents) 373 | 374 | 375 | def get_required_header(headers: HeadersLike, header: str) -> str: 376 | lower_header = header.lower() 377 | if isinstance(headers, Mapping): 378 | headers = cast(Headers, headers) 379 | for k, v in headers.items(): 380 | if k.lower() == lower_header and isinstance(v, str): 381 | return v 382 | 383 | """ to deal with the case where the header looks like Stainless-Event-Id """ 384 | intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) 385 | 386 | for normalized_header in [header, lower_header, header.upper(), intercaps_header]: 387 | value = headers.get(normalized_header) 388 | if value: 389 | return value 390 | 391 | raise ValueError(f"Could not find {header} header") 392 | 393 | 394 | def get_async_library() -> str: 395 | try: 396 | return sniffio.current_async_library() 397 | except Exception: 398 | return "false" 399 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/core/logs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class LoggerNameFilter(logging.Filter): 9 | def filter(self, record): 10 | # return record.name.startswith("loom_core") or record.name in "ERROR" or ( 11 | # record.name.startswith("uvicorn.error") 12 | # and record.getMessage().startswith("Uvicorn running on") 13 | # ) 14 | return True 15 | 16 | 17 | def get_log_file(log_path: str, sub_dir: str): 18 | """ 19 | sub_dir should contain a timestamp. 20 | """ 21 | log_dir = os.path.join(log_path, sub_dir) 22 | # Here should be creating a new directory each time, so `exist_ok=False` 23 | os.makedirs(log_dir, exist_ok=False) 24 | return os.path.join(log_dir, "zhipuai.log") 25 | 26 | 27 | def get_config_dict( 28 | log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int 29 | ) -> dict: 30 | # for windows, the path should be a raw string. 31 | log_file_path = ( 32 | log_file_path.encode("unicode-escape").decode() 33 | if os.name == "nt" 34 | else log_file_path 35 | ) 36 | log_level = log_level.upper() 37 | config_dict = { 38 | "version": 1, 39 | "disable_existing_loggers": False, 40 | "formatters": { 41 | "formatter": { 42 | "format": ( 43 | "%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s" 44 | ) 45 | }, 46 | }, 47 | "filters": { 48 | "logger_name_filter": { 49 | "()": __name__ + ".LoggerNameFilter", 50 | }, 51 | }, 52 | "handlers": { 53 | "stream_handler": { 54 | "class": "logging.StreamHandler", 55 | "formatter": "formatter", 56 | "level": log_level, 57 | # "stream": "ext://sys.stdout", 58 | # "filters": ["logger_name_filter"], 59 | }, 60 | "file_handler": { 61 | "class": "logging.handlers.RotatingFileHandler", 62 | "formatter": "formatter", 63 | "level": log_level, 64 | "filename": log_file_path, 65 | "mode": "a", 66 | "maxBytes": log_max_bytes, 67 | "backupCount": log_backup_count, 68 | "encoding": "utf8", 69 | }, 70 | }, 71 | "loggers": { 72 | "loom_core": { 73 | "handlers": ["stream_handler", "file_handler"], 74 | "level": log_level, 75 | "propagate": False, 76 | } 77 | }, 78 | "root": { 79 | "level": log_level, 80 | "handlers": ["stream_handler", "file_handler"], 81 | }, 82 | } 83 | return config_dict 84 | 85 | 86 | def get_timestamp_ms(): 87 | t = time.time() 88 | return int(round(t * 1000)) 89 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/zhipuai_sdk/zhipuai/types/__init__.py -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/chat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxygoodjob/Hit-leetcode/b073d0855c1efacae59710ab9dcf6fb40325283a/zhipuai_sdk/zhipuai/types/chat/__init__.py -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/chat/async_chat_completion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from ...core import BaseModel 4 | from .chat_completion import CompletionChoice, CompletionUsage 5 | 6 | __all__ = ["AsyncTaskStatus", "AsyncCompletion"] 7 | 8 | 9 | class AsyncTaskStatus(BaseModel): 10 | id: Optional[str] = None 11 | request_id: Optional[str] = None 12 | model: Optional[str] = None 13 | task_status: Optional[str] = None 14 | 15 | 16 | class AsyncCompletion(BaseModel): 17 | id: Optional[str] = None 18 | request_id: Optional[str] = None 19 | model: Optional[str] = None 20 | task_status: str 21 | choices: List[CompletionChoice] 22 | usage: CompletionUsage 23 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/chat/chat_completion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from ...core import BaseModel 4 | 5 | __all__ = ["Completion", "CompletionUsage"] 6 | 7 | 8 | class Function(BaseModel): 9 | arguments: str 10 | name: str 11 | 12 | 13 | class CompletionMessageToolCall(BaseModel): 14 | id: str 15 | function: Function 16 | type: str 17 | 18 | 19 | class CompletionMessage(BaseModel): 20 | content: Optional[str] = None 21 | role: str 22 | tool_calls: Optional[List[CompletionMessageToolCall]] = None 23 | 24 | 25 | class CompletionUsage(BaseModel): 26 | prompt_tokens: int 27 | completion_tokens: int 28 | total_tokens: int 29 | 30 | 31 | class CompletionChoice(BaseModel): 32 | index: int 33 | finish_reason: str 34 | message: CompletionMessage 35 | 36 | 37 | class Completion(BaseModel): 38 | model: Optional[str] = None 39 | created: Optional[int] = None 40 | choices: List[CompletionChoice] 41 | request_id: Optional[str] = None 42 | id: Optional[str] = None 43 | usage: CompletionUsage 44 | 45 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/chat/chat_completion_chunk.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict, Any 2 | 3 | from ...core import BaseModel 4 | 5 | __all__ = [ 6 | "ChatCompletionChunk", 7 | "Choice", 8 | "ChoiceDelta", 9 | "ChoiceDeltaFunctionCall", 10 | "ChoiceDeltaToolCall", 11 | "ChoiceDeltaToolCallFunction", 12 | ] 13 | 14 | 15 | class ChoiceDeltaFunctionCall(BaseModel): 16 | arguments: Optional[str] = None 17 | name: Optional[str] = None 18 | 19 | 20 | class ChoiceDeltaToolCallFunction(BaseModel): 21 | arguments: Optional[str] = None 22 | name: Optional[str] = None 23 | 24 | 25 | class ChoiceDeltaToolCall(BaseModel): 26 | index: int 27 | id: Optional[str] = None 28 | function: Optional[ChoiceDeltaToolCallFunction] = None 29 | type: Optional[str] = None 30 | 31 | 32 | class ChoiceDelta(BaseModel): 33 | content: Optional[str] = None 34 | role: Optional[str] = None 35 | tool_calls: Optional[List[ChoiceDeltaToolCall]] = None 36 | 37 | 38 | class Choice(BaseModel): 39 | delta: ChoiceDelta 40 | finish_reason: Optional[str] = None 41 | index: int 42 | 43 | 44 | class CompletionUsage(BaseModel): 45 | prompt_tokens: int 46 | completion_tokens: int 47 | total_tokens: int 48 | 49 | 50 | class ChatCompletionChunk(BaseModel): 51 | id: Optional[str] = None 52 | choices: List[Choice] 53 | created: Optional[int] = None 54 | model: Optional[str] = None 55 | usage: Optional[CompletionUsage] = None 56 | extra_json: Dict[str, Any] 57 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/chat/chat_completions_create_param.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from typing_extensions import TypedDict 4 | 5 | 6 | class Reference(TypedDict, total=False): 7 | enable: Optional[bool] 8 | search_query: Optional[str] 9 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, List 4 | 5 | from ..core import BaseModel 6 | from .chat.chat_completion import CompletionUsage 7 | __all__ = ["Embedding", "EmbeddingsResponded"] 8 | 9 | 10 | class Embedding(BaseModel): 11 | object: str 12 | index: Optional[int] = None 13 | embedding: List[float] 14 | 15 | 16 | class EmbeddingsResponded(BaseModel): 17 | object: str 18 | data: List[Embedding] 19 | model: str 20 | usage: CompletionUsage 21 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/file_object.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from ..core import BaseModel 4 | 5 | __all__ = ["FileObject", "ListOfFileObject"] 6 | 7 | 8 | class FileObject(BaseModel): 9 | 10 | id: Optional[str] = None 11 | bytes: Optional[int] = None 12 | created_at: Optional[int] = None 13 | filename: Optional[str] = None 14 | object: Optional[str] = None 15 | purpose: Optional[str] = None 16 | status: Optional[str] = None 17 | status_details: Optional[str] = None 18 | 19 | 20 | class ListOfFileObject(BaseModel): 21 | 22 | object: Optional[str] = None 23 | data: List[FileObject] 24 | has_more: Optional[bool] = None 25 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/fine_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .fine_tuning_job import FineTuningJob as FineTuningJob 4 | from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob 5 | from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent 6 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/fine_tuning/fine_tuning_job.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional, Dict, Any 2 | 3 | from ...core import BaseModel 4 | 5 | __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] 6 | 7 | 8 | class Error(BaseModel): 9 | code: str 10 | message: str 11 | param: Optional[str] = None 12 | 13 | 14 | class Hyperparameters(BaseModel): 15 | n_epochs: Union[str, int, None] = None 16 | 17 | 18 | class FineTuningJob(BaseModel): 19 | id: Optional[str] = None 20 | 21 | request_id: Optional[str] = None 22 | 23 | created_at: Optional[int] = None 24 | 25 | error: Optional[Error] = None 26 | 27 | fine_tuned_model: Optional[str] = None 28 | 29 | finished_at: Optional[int] = None 30 | 31 | hyperparameters: Optional[Hyperparameters] = None 32 | 33 | model: Optional[str] = None 34 | 35 | object: Optional[str] = None 36 | 37 | result_files: List[str] 38 | 39 | status: str 40 | 41 | trained_tokens: Optional[int] = None 42 | 43 | training_file: str 44 | 45 | validation_file: Optional[str] = None 46 | 47 | 48 | class ListOfFineTuningJob(BaseModel): 49 | object: Optional[str] = None 50 | data: List[FineTuningJob] 51 | has_more: Optional[bool] = None 52 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/fine_tuning/fine_tuning_job_event.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional 2 | 3 | from ...core import BaseModel 4 | 5 | __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] 6 | 7 | 8 | class Metric(BaseModel): 9 | epoch: Optional[Union[str, int, float]] = None 10 | current_steps: Optional[int] = None 11 | total_steps: Optional[int] = None 12 | elapsed_time: Optional[str] = None 13 | remaining_time: Optional[str] = None 14 | trained_tokens: Optional[int] = None 15 | loss: Optional[Union[str, int, float]] = None 16 | eval_loss: Optional[Union[str, int, float]] = None 17 | acc: Optional[Union[str, int, float]] = None 18 | eval_acc: Optional[Union[str, int, float]] = None 19 | learning_rate: Optional[Union[str, int, float]] = None 20 | 21 | 22 | class JobEvent(BaseModel): 23 | object: Optional[str] = None 24 | id: Optional[str] = None 25 | type: Optional[str] = None 26 | created_at: Optional[int] = None 27 | level: Optional[str] = None 28 | message: Optional[str] = None 29 | data: Optional[Metric] = None 30 | 31 | 32 | class FineTuningJobEvent(BaseModel): 33 | object: Optional[str] = None 34 | data: List[JobEvent] 35 | has_more: Optional[bool] = None 36 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/fine_tuning/job_create_params.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union 4 | 5 | from typing_extensions import Literal, TypedDict 6 | 7 | __all__ = ["Hyperparameters"] 8 | 9 | 10 | class Hyperparameters(TypedDict, total=False): 11 | batch_size: Union[Literal["auto"], int] 12 | 13 | learning_rate_multiplier: Union[Literal["auto"], float] 14 | 15 | n_epochs: Union[Literal["auto"], int] 16 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/fine_tuning/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .fine_tuned_models import FineTunedModelsStatus -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/fine_tuning/models/fine_tuned_models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional, ClassVar 2 | 3 | from ....core import BaseModel, PYDANTIC_V2, ConfigDict 4 | 5 | __all__ = ["FineTunedModelsStatus"] 6 | 7 | 8 | class FineTunedModelsStatus(BaseModel): 9 | if PYDANTIC_V2: 10 | model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", protected_namespaces=()) 11 | request_id: str #请求id 12 | model_name: str #模型名称 13 | delete_status: str #删除状态 deleting(删除中), deleted (已删除) 14 | -------------------------------------------------------------------------------- /zhipuai_sdk/zhipuai/types/image.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, List 4 | 5 | from ..core import BaseModel 6 | 7 | __all__ = ["GeneratedImage", "ImagesResponded"] 8 | 9 | 10 | class GeneratedImage(BaseModel): 11 | b64_json: Optional[str] = None 12 | url: Optional[str] = None 13 | revised_prompt: Optional[str] = None 14 | 15 | 16 | class ImagesResponded(BaseModel): 17 | created: int 18 | data: List[GeneratedImage] 19 | --------------------------------------------------------------------------------