├── .gitignore ├── LICENSE ├── README.md ├── examples ├── ceval │ ├── README.md │ ├── ceval-with-answer.jsonl │ ├── chat-eval.yaml │ └── prompt-eval.yaml ├── llm_judge │ ├── README.md │ ├── ceval-llm-judge.jsonl │ └── ceval-llm-judge.yaml ├── rag-eval │ ├── .gitignore │ ├── rag-eval-zh.ipynb │ ├── rag-eval.yaml │ ├── rag.py │ ├── retrieval-eval.yaml │ └── retrieval.py └── sql-eval │ ├── dusql_sample.jsonl │ ├── sql-eval-qianfan-limiter.yaml │ ├── sql-eval-qianfan.yaml │ ├── sql-eval.yaml │ └── 洗衣机.sqlite ├── pyproject.toml ├── src └── langeval │ ├── __about__.py │ ├── __init__.py │ ├── __main__.py │ ├── cli │ ├── __init__.py │ ├── application.py │ ├── constant.py │ ├── rerun │ │ └── __init__.py │ ├── run │ │ ├── __init__.py │ │ ├── display.py │ │ └── run.py │ ├── show │ │ └── __init__.py │ └── terminal.py │ ├── config │ └── __init__.py │ ├── evaluators │ ├── __init__.py │ ├── evaluator.py │ ├── exception.py │ ├── nlp │ │ └── __init__.py │ ├── rag │ │ ├── __init__.py │ │ └── utils.py │ ├── run.py │ └── sql │ │ ├── __init__.py │ │ └── sqleval.py │ ├── models │ ├── __init__.py │ ├── embeddings.py │ ├── exception.py │ ├── llms.py │ ├── openai.py │ ├── qianfan.py │ └── types.py │ ├── providers │ ├── __init__.py │ ├── exception.py │ ├── output_parser.py │ ├── provider.py │ └── run.py │ └── tasks │ ├── __init__.py │ ├── ratelimiter.py │ ├── runner.py │ └── task.py └── tests └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | output/ 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | .ruff_cache/ 165 | .DS_Store 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tao Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LangEval 2 | 3 | Evaluation for AI apps and agent 4 | 5 | [![PyPI - Version](https://img.shields.io/pypi/v/langeval-cli.svg)](https://pypi.org/project/langeval-cli) 6 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/langeval-cli.svg)](https://pypi.org/project/langeval-cli) 7 | 8 | ```txt 9 | ▄▄▌ ▄▄▄· ▐ ▄ ▄▄ • ▄▄▄ . ▌ ▐· ▄▄▄· ▄▄▌ 10 | ██• ▐█ ▀█ •█▌▐█▐█ ▀ ▪▀▄.▀·▪█·█▌▐█ ▀█ ██• 11 | ██▪ ▄█▀▀█ ▐█▐▐▌▄█ ▀█▄▐▀▀▪▄▐█▐█•▄█▀▀█ ██▪ 12 | ▐█▌▐▌▐█ ▪▐▌██▐█▌▐█▄▪▐█▐█▄▄▌ ███ ▐█ ▪▐▌▐█▌▐▌ 13 | .▀▀▀ ▀ ▀ ▀▀ █▪·▀▀▀▀ ▀▀▀ . ▀ ▀ ▀ .▀▀▀ 14 | ``` 15 | 16 | ----- 17 | 18 | ## Table of Contents 19 | 20 | - [Installation](#installation) 21 | - [Documents](#documents) 22 | - [How to use](#how-to-use) 23 | - [Development](#development) 24 | - [License](#license) 25 | 26 | ## Installation 27 | 28 | ```console 29 | pip install langeval-cli 30 | ``` 31 | 32 | ## Documents 33 | 34 | TODOs: 35 | 36 | - Refactor RAG Eval 37 | - Support custom output parser 38 | - Support more provider. 39 | 40 | ## How to use 41 | 42 | see `./examples` for more details. 43 | 44 | ## Development 45 | 46 | ```bash 47 | # Create virtual environment 48 | hatch env create 49 | # Activate virtual environment 50 | hatch shell 51 | # Run test 52 | hatch run test 53 | # Run lint 54 | hatch run lint:style 55 | 56 | # Version dump 57 | hatch version patch/minor/major 58 | # Build 59 | hatch build 60 | # Upload to pypi 61 | hatch publish 62 | ``` 63 | 64 | ## License 65 | 66 | `LangEval` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license. 67 | -------------------------------------------------------------------------------- /examples/ceval/README.md: -------------------------------------------------------------------------------- 1 | # `ceval-with-answer` 任务,使用精确匹配 2 | 3 | ```bash 4 | export OPENAI_API_KEY="sk-xxxxx" 5 | # 单个问题 6 | langeval run ceval-with-answer.yaml --sample 1 7 | 8 | # 全量测试 9 | langeval run ceval-with-answer.yaml 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/ceval/ceval-with-answer.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 0, "question": "下列设备属于资源子网的是____。", "A": "计算机软件", "B": "网桥", "C": "交换机", "D": "路由器", "answer": "A", "explanation": "1. 首先,资源子网是指提供共享资源的网络,如打印机、文件服务器等。\r\n2. 其次,我们需要了解选项中设备的功能。网桥、交换机和路由器的主要功能是实现不同网络之间的通信和数据传输,是通信子网设备。而计算机软件可以提供共享资源的功能。"} 2 | {"id": 1, "question": "滑动窗口的作用是____。", "A": "流量控制", "B": "拥塞控制", "C": "路由控制", "D": "差错控制", "answer": "A", "explanation": "1. 滑动窗口是一种流量控制机制,用于控制发送方向接收方发送数据的速率,以避免接收方无法处理过多的数据而导致数据丢失或拥塞。"} 3 | {"id": 2, "question": "在OSI参考模型中,直接为会话层提供服务的是____。", "A": "应用层", "B": "表示层", "C": "传输层", "D": "网络层", "answer": "C", "explanation": "1. 直接为会话层提供服务的是会话层的下一层,即传输层。"} 4 | {"id": 3, "question": "协议是指在____之间进行通信的规则或约定。", "A": "同一结点的上下层", "B": "不同结点", "C": "相邻实体", "D": "不同结点对等实体", "answer": "D", "explanation": "1. 协议是指在不同结点对等实体之间进行通信的规则或约定。"} 5 | {"id": 4, "question": "主机甲与主机乙之间使用后退N帧协议(GBN)传输数据,甲的发送窗口尺寸为1000,数据帧长为1000字节,信道带宽为100Mbps,乙每收到一个数据帧立即利用一个短帧(忽略其传输延迟)进行确认,若甲、乙之间的单向传播延迟是50ms,则甲可以达到的最大平均数据传输速率约为____。", "A": "10Mbps", "B": "20Mbps", "C": "80Mbps", "D": "100Mbps", "answer": "C", "explanation": "1. 主机甲、乙之间采用后退N帧协议,那么因为甲、乙主机之间采用后退N帧协议传输数据,要考虑发送一个数据到接收到它的确认之前,最多能发送多少数据,所以甲的最大传输速率是这个值和信道带宽中小的那一个。\r\n2. 甲的发送窗口的尺寸为1000,即收到第一个数据的确认之前,最多能发送1000个数据帧,也就是发送1000*1000B=1MB的内容,而从发送第一个帧到接收到它的确认的时间是一个往返时延,也就是50+50=100ms=0.1s,即在100ms中,最多能传输1MB的数据,因此,此时的最大传输速率为lMB/0.1s=10MB/s=80Mbps。\r\n3. 信道带宽为100Mbps,所以答案为min{80Mbps, 100Mbps}=80Mbps,选C。"} 6 | -------------------------------------------------------------------------------- /examples/ceval/chat-eval.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "chat_completion" 3 | input_variables: ["question", "A", "B", "C", "D"] 4 | output_parser: 5 | name: match 6 | kwargs: 7 | match_key: "choice" 8 | match_re: | 9 | [ABCD](?!.*[ABCD]) 10 | settings: 11 | llm: 12 | provider: "openai" 13 | model: "gpt-3.5-turbo" 14 | kwargs: 15 | temperature: 0.1 16 | messages: 17 | - {"role": "system", "content": "你是专家,请仔细阅读并一步步的思考【问题】,在【候选答案】中选择唯一正确的答案,并按照【输出格式】输出。\n## 【输出格式】\n输出两行:\n\nReasoning: <此处填写思考过程,请注意特殊符号的转义。>\nChoice: <必填字段,此处仅填写最终选项,为 A、B、C、D 中的一个>\n\n"} 18 | - {"role": "user", "content": "【问题】:`公开发行公司债券,证监会同意注册的决定自作出之日起一定期限内有效,发行人应当该期限内发行公司债券。该期限是____。`\n【候选答案】:\nA: `6个月` \nB: `1年`\nC: `2年`\nD: `3年`\n\n"} 19 | - {"role": "assistant", "content": "Reasoning: 中国证监会同意注册的决定自作出之日起2年内有效,发行人应当在注册决定有效期内发行公司债券,并自主选择发行时点。所以正确答案是 C: 2年。\nChoice: C\n\n"} 20 | - {"role": "user", "content": "【问题】:`{{question}}` \n【候选答案】:\nA: `{{A}}` \nB: `{{B}}`\nC: `{{C}}`\nD: `{{D}}`\n\n"} 21 | input_dataset_name: "ceval-with-answer.jsonl" 22 | evaluators: 23 | - name: exact_match 24 | type: NLP 25 | settings: 26 | prediction_key: "choice" 27 | reference_key: "answer" 28 | nlp_metrics: ["exact_match"] 29 | run_config: 30 | parallelism: 6 # 并发度 31 | timeout: 120 # 超时时间 32 | rounds: 1 # 评测轮数 33 | batch_size: 2 # 批量运行/评测 34 | -------------------------------------------------------------------------------- /examples/ceval/prompt-eval.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "completion" 3 | input_variables: ["question", "A", "B", "C", "D"] 4 | output_parser: 5 | name: match 6 | kwargs: 7 | match_key: "choice" 8 | match_re: | 9 | [ABCD](?!.*[ABCD]) 10 | settings: 11 | llm: 12 | provider: "openai" 13 | model: "gpt-3.5-turbo-instruct" 14 | kwargs: 15 | temperature: 0.1 16 | prompt: | 17 | 你是专家,请仔细阅读并一步步的思考【问题】,在【候选答案】中选择唯一正确的答案,并按照【输出格式】输出。 18 | 19 | ## 【输出格式】 20 | 输出两行: 21 | 22 | - 推理过程:<此处填写思考过程。> 23 | - 最终选项:<必填字段,此处仅填写最终选项,为 A、B、C、D 中的一个。> 24 | 25 | ## 【示例】 26 | 27 | 【问题】:"公开发行公司债券,证监会同意注册的决定自作出之日起一定期限内有效,发行人应当该期限内发行公司债券。该期限是____。" 28 | 【候选答案】: 29 | A: "6个月" 30 | B: "1年" 31 | C: "2年" 32 | D: "3年" 33 | 34 | 【回答】: 35 | - 推理过程:中国证监会同意注册的决定自作出之日起2年内有效,发行人应当在注册决定有效期内发行公司债券,并自主选择发行时点。所以正确答案是 C: 2年。 36 | - 最终选项:C 37 | 38 | ## 【输入输出】 39 | 40 | 【问题】:"{{question}}" 41 | 【候选答案】: 42 | A: "{{A}}" 43 | B: "{{B}}" 44 | C: "{{C}}" 45 | D: "{{D}}" 46 | 47 | 【回答】: 48 | input_dataset_name: "ceval-with-answer.jsonl" 49 | evaluators: 50 | - name: exact_match 51 | type: NLP 52 | settings: 53 | prediction_key: "choice" 54 | reference_key: "answer" 55 | nlp_metrics: ["exact_match"] 56 | run_config: 57 | parallelism: 5 # 并发度 58 | timeout: 120 # 超时时间 59 | rounds: 1 # 评测轮数 60 | batch_size: 2 # 批量运行/评测 61 | -------------------------------------------------------------------------------- /examples/llm_judge/README.md: -------------------------------------------------------------------------------- 1 | # `ceval-llm-judge` 任务,使用 GPT-4 进行判卷。 2 | 3 | ```bash 4 | export OPENAI_API_KEY="sk-xxxxx" 5 | # 单个问题 6 | langeval run ceval-llm-judge.yaml --sample 1 7 | 8 | # 全量测试 9 | langeval run ceval-llm-judge.yaml 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/llm_judge/ceval-llm-judge.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 0, "question": "计算机网络的资源主要是指____。", "A": "服务器、路由器、通信线路与用户计算机", "B": "计算机操作系统、数据库与应用软件", "C": "计算机硬件、软件与数据", "D": "Web服务器、数据库服务器与文件服务器", "answer": "", "explanation": ""} 2 | {"id": 1, "question": "____是传输层数据交换的基本单位。", "A": "位", "B": "分组", "C": "帧", "D": "报文段", "answer": "", "explanation": ""} 3 | {"id": 2, "question": "TCP使用____次握手协议建立连接。", "A": "一", "B": "二", "C": "三", "D": "四", "answer": "", "explanation": ""} 4 | {"id": 3, "question": "TCP使用慢启动算法是为了____", "A": "减小拥堵", "B": "高速传输", "C": "快速探测网络承载力", "D": "适应接收窗口的大小", "answer": "", "explanation": ""} 5 | {"id": 4, "question": "差分曼彻斯特码的原理是:每一位中间都有一个跳变,位中间跳变表示____", "A": "时钟", "B": "同步", "C": "数据", "D": "定界", "answer": "", "explanation": ""} 6 | -------------------------------------------------------------------------------- /examples/llm_judge/ceval-llm-judge.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "completion" 3 | input_variables: ["question", "A", "B", "C", "D"] 4 | output_parser: 5 | name: json 6 | settings: 7 | llm: 8 | provider: "openai" 9 | model: "gpt-3.5-turbo" 10 | kwargs: 11 | temperature: 0.1 12 | prompt: | 13 | ## Instruction 14 | 15 | 仔细阅读以下问题,并选择正确答案。 16 | 17 | 问题:"{{question}}" 18 | 19 | 候选答案: 20 | A: "{{A}}" 21 | B: "{{B}}" 22 | C: "{{C}}" 23 | D: "{{D}}" 24 | 25 | ## Output format 26 | 27 | 输出为 Markdown JSON 格式,输出参考: 28 | 29 | ```json 30 | { 31 | "reasoning": "<此处填写思考过程>", 32 | "choice": "<此处仅填写最终选项,为 A、B、C、D 中的一个>" 33 | } 34 | ``` 35 | 36 | ## Output 37 | 输出(Markdown JSON 格式): 38 | input_dataset_name: "ceval-llm-judge.jsonl" 39 | evaluators: 40 | - name: "gpt-4-teacher" 41 | type: LLM_GRADE 42 | settings: 43 | eval_keys: ["reasoning", "right_answer", "score"] 44 | llm: 45 | provider: "openai" 46 | model: "gpt-4" 47 | kwargs: 48 | temperature: 0.5 49 | prompt: | 50 | 你是一名教师,接下来你需要根据【问题】和【选项】,来评价【学生答案】的评分。 51 | 52 | # 评分过程 53 | 54 | 1. 首先你应该先根据【问题】和【选项】得到你自己的答案。 55 | 2. 给出对【学生答案】的评分。(0代表错误,1代表正确)。 56 | 57 | # 输入 58 | 59 | 问题:"{{ question }}" 60 | 选项: 61 | A: "{{A}}" 62 | B: "{{B}}" 63 | C: "{{C}}" 64 | D: "{{D}}" 65 | 学生答案:"{{ choice }}" 66 | 67 | # 输出 68 | 69 | 输出为 Markdown JSON 格式的字符串,示例: 70 | 71 | ```json 72 | { 73 | "reasoning": "<此处填写你对题目的思考过程>", 74 | "right_answer": "<此处填写你的答案,为 A、B、C、D 中的一个>" 75 | "score": <此处是你对学生答案的打分,0为错误,1为正确> 76 | } 77 | ``` 78 | 79 | 输出(Markdown JSON 格式): 80 | run_config: 81 | parallelism: 5 # 并发度 82 | timeout: 30 # 超时时间 83 | rounds: 1 # 评测轮数 84 | -------------------------------------------------------------------------------- /examples/rag-eval/.gitignore: -------------------------------------------------------------------------------- 1 | cmrc-eval-zh.faiss 2 | cmrc-eval-zh.jsonl -------------------------------------------------------------------------------- /examples/rag-eval/rag-eval.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "execute" 3 | input_variables: ["question"] 4 | output_parser: 5 | name: json 6 | settings: 7 | command: "python3 rag.py" 8 | kwargs: 9 | timeout: 50 10 | input_dataset_name: "cmrc-eval-zh.jsonl" 11 | evaluators: 12 | - name: rag 13 | type: RAG 14 | settings: 15 | rag_metrics: ["retrieval_recall", "answer_correctness"] 16 | rag_llm: 17 | provider: "openai" 18 | model: "gpt-4" 19 | kwargs: 20 | temperature: 0.2 21 | run_config: 22 | parallelism: 10 # 并发度 23 | timeout: 60 # 超时时间 24 | rounds: 1 # 评测轮数 -------------------------------------------------------------------------------- /examples/rag-eval/rag.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from operator import itemgetter 4 | from typing import List, Tuple 5 | 6 | import tiktoken 7 | from langchain.chat_models import ChatOpenAI 8 | from langchain.embeddings import OpenAIEmbeddings 9 | from langchain.prompts import ChatPromptTemplate 10 | from langchain.schema import Document 11 | from langchain.schema.runnable import RunnablePassthrough 12 | from langchain.vectorstores.faiss import FAISS 13 | 14 | # embed = HuggingFaceEmbeddings(model_name="infgrad/stella-base-zh") 15 | embed = OpenAIEmbeddings() 16 | vectorstore = FAISS.load_local("cmrc-eval-zh.faiss", embed, 17 | distance_strategy="COSINE") 18 | def _retrieve_with_scores(query: str) -> List[Tuple[Document, float]]: 19 | # 自定义的原因是 VectorStoreRetriever 不返回 score 20 | docs = vectorstore.similarity_search_with_relevance_scores( 21 | query, k=5, score_threshold=0.1) 22 | tokens = 0 23 | token_limit = 4000 24 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 25 | for i, doc in enumerate(docs): 26 | tokens += len(enc.encode(doc[0].page_content)) 27 | if tokens > token_limit: 28 | return docs[:i] 29 | return docs 30 | 31 | template = """仅使用如下上下文回答问题: 32 | ``` 33 | {context} 34 | ``` 35 | 36 | 问题:{question} 37 | 回答: 38 | """ 39 | prompt = ChatPromptTemplate.from_template(template) 40 | 41 | model = ChatOpenAI() 42 | 43 | def _combine_documents(docs_with_scores: List[Tuple[Document, float]]): 44 | return "\n\n".join([i[0].page_content for i in docs_with_scores]) 45 | 46 | _inputs = RunnablePassthrough() 47 | 48 | retrieved_documents = { 49 | "docs_with_score": lambda x: _retrieve_with_scores(x["question"]), 50 | "question": itemgetter("question"), 51 | } 52 | # Now we construct the inputs for the final prompt 53 | final_inputs = { 54 | "context": lambda x: _combine_documents(x["docs_with_score"]), 55 | "question": itemgetter("question"), 56 | } 57 | # And finally, we do the part that returns the answers 58 | answer = { 59 | "answer": final_inputs | prompt | model, 60 | "docs_with_score": itemgetter("docs_with_score"), 61 | } 62 | 63 | final_chain = _inputs | retrieved_documents | answer 64 | 65 | final_results = [] 66 | for i in json.loads(sys.stdin.read()): 67 | inputs = { 68 | "question": i["question"] 69 | } 70 | result = final_chain.invoke(inputs) 71 | final_results.append(json.dumps({ 72 | "answer": result["answer"].content, 73 | "contexts": [i[0].page_content for i in result["docs_with_score"]], 74 | "contexts_scores": [i[1] for i in result["docs_with_score"]], 75 | }, ensure_ascii=False)) 76 | 77 | print(json.dumps(final_results, ensure_ascii=False, indent=2)) 78 | -------------------------------------------------------------------------------- /examples/rag-eval/retrieval-eval.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "execute" 3 | input_variables: ["question"] 4 | output_parser: 5 | name: json 6 | settings: 7 | command: "python3 retrieval.py" 8 | kwargs: 9 | timeout: 50 10 | input_dataset_name: "cmrc-eval-zh.jsonl" 11 | evaluators: 12 | - name: rag 13 | type: RAG 14 | settings: 15 | rag_metrics: ["retrieval_recall"] 16 | run_config: 17 | parallelism: 10 # 并发度 18 | timeout: 60 # 超时时间 19 | rounds: 1 # 评测轮数 -------------------------------------------------------------------------------- /examples/rag-eval/retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import List, Tuple 4 | 5 | import tiktoken 6 | from langchain.embeddings import OpenAIEmbeddings 7 | from langchain.schema import Document 8 | from langchain.vectorstores.faiss import FAISS 9 | 10 | # embed = HuggingFaceEmbeddings(model_name="infgrad/stella-base-zh") 11 | embed = OpenAIEmbeddings() 12 | vectorstore = FAISS.load_local("cmrc-eval-zh.faiss", embed, 13 | distance_strategy="COSINE") 14 | def _retrieve_with_scores(query: str) -> List[Tuple[Document, float]]: 15 | # 自定义的原因是 VectorStoreRetriever 不返回 score 16 | docs = vectorstore.similarity_search_with_relevance_scores( 17 | query, k=5, score_threshold=0.1) 18 | tokens = 0 19 | token_limit = 4000 20 | enc = tiktoken.encoding_for_model("gpt-3.5-turbo") 21 | for i, doc in enumerate(docs): 22 | tokens += len(enc.encode(doc[0].page_content)) 23 | if tokens > token_limit: 24 | return docs[:i] 25 | return docs 26 | 27 | 28 | inputs = json.loads(sys.stdin.read()) 29 | results = [] 30 | for row in inputs: 31 | docs_with_score = _retrieve_with_scores(row["question"]) 32 | result_formatted = { 33 | "contexts": [i[0].page_content for i in docs_with_score], 34 | "contexts_scores": [i[1] for i in docs_with_score], 35 | } 36 | results.append(json.dumps(result_formatted, ensure_ascii=False)) 37 | 38 | print(json.dumps(results, ensure_ascii=False, indent=2)) 39 | -------------------------------------------------------------------------------- /examples/sql-eval/dusql_sample.jsonl: -------------------------------------------------------------------------------- 1 | {"origin_sql":"select 城市 from 洗衣机品牌门店 group by 城市 order by avg ( 售卖量 * 平均售价 ) desc limit 1","db_id":"洗衣机","query":"哪些城市对所有洗衣机品牌平均售卖总额最高","id":"qid000926","result_len":1,"golden_sql":"SELECT `城市`\nFROM `洗衣机品牌门店`\nGROUP BY `城市`\nORDER BY AVG (`售卖量` * `平均售价`) DESC\nLIMIT 1","create_tables":["CREATE TABLE `洗衣机品牌` (\n`词条id` number,\n`名称` text,\n`所属公司` text,\n`成立时间` time,\n`世界500强排名` number,\n`市场份额` number,\n`2018年营业额` number,\n`2018年利润` number,\nPRIMARY KEY (`词条id`));","CREATE TABLE `洗衣机型号` (\n`词条id` number,\n`名称` text,\n`产品类别` text,\n`洗涤容量` number,\n`能效等级` number,\n`自动化程度` text,\n`售价` number,\n`品牌id` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌门店` (\n`品牌id` number,\n`城市` text,\n`门店数量` number,\n`售卖量` number,\n`平均售价` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌平台评分` (\n`品牌id` number,\n`平台` text,\n`总评分` number,\n`性价比得分` number,\n`功能得分` number,\n`做工得分` number,\n`外观得分` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));"],"join_pairs":[["`洗衣机品牌门店`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机品牌平台评分`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机型号`.`品牌id`","`洗衣机品牌`.`词条id`"]]} 2 | {"origin_sql":"select T2.名称 , min ( T1.门店数量 ) , T2.所属公司 from 洗衣机品牌门店 as T1 join 洗衣机品牌 as T2 on 洗衣机品牌门店.品牌id == 洗衣机品牌.词条id where T2.市场份额 <= 0.102 group by T1.品牌id","db_id":"洗衣机","query":"给出市场份额不超过10.2%的洗衣机品牌及其所属公司,并找出对应的洗衣机品牌的最少门店数","id":"qid000245","result_len":3,"golden_sql":"SELECT T2.`名称`,\n MIN (T1.`门店数量`) , T2.`所属公司`\nFROM `洗衣机品牌门店` AS T1\nJOIN `洗衣机品牌` AS T2 ON T1.`品牌id` == T2.`词条id`\nWHERE T2.`市场份额` <= 0.102\nGROUP BY T1.`品牌id`","create_tables":["CREATE TABLE `洗衣机品牌` (\n`词条id` number,\n`名称` text,\n`所属公司` text,\n`成立时间` time,\n`世界500强排名` number,\n`市场份额` number,\n`2018年营业额` number,\n`2018年利润` number,\nPRIMARY KEY (`词条id`));","CREATE TABLE `洗衣机型号` (\n`词条id` number,\n`名称` text,\n`产品类别` text,\n`洗涤容量` number,\n`能效等级` number,\n`自动化程度` text,\n`售价` number,\n`品牌id` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌门店` (\n`品牌id` number,\n`城市` text,\n`门店数量` number,\n`售卖量` number,\n`平均售价` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌平台评分` (\n`品牌id` number,\n`平台` text,\n`总评分` number,\n`性价比得分` number,\n`功能得分` number,\n`做工得分` number,\n`外观得分` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));"],"join_pairs":[["`洗衣机品牌门店`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机品牌平台评分`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机型号`.`品牌id`","`洗衣机品牌`.`词条id`"]]} 3 | {"origin_sql":"select T2.名称 , sum ( T1.总评分 ) , T2.所属公司 from 洗衣机品牌平台评分 as T1 join 洗衣机品牌 as T2 on 洗衣机品牌平台评分.品牌id == 洗衣机品牌.词条id where T2.市场份额 <= 0.102 group by T1.品牌id","db_id":"洗衣机","query":"给出市场份额不超过10.2%的洗衣机品牌及所属公司,并给出对应的洗衣机品牌的总评分","id":"qid000259","result_len":3,"golden_sql":"SELECT T2.`名称`,\n SUM (T1.`总评分`) , T2.`所属公司`\nFROM `洗衣机品牌平台评分` AS T1\nJOIN `洗衣机品牌` AS T2 ON T1.`品牌id` == T2.`词条id`\nWHERE T2.`市场份额` <= 0.102\nGROUP BY T1.`品牌id`","create_tables":["CREATE TABLE `洗衣机品牌` (\n`词条id` number,\n`名称` text,\n`所属公司` text,\n`成立时间` time,\n`世界500强排名` number,\n`市场份额` number,\n`2018年营业额` number,\n`2018年利润` number,\nPRIMARY KEY (`词条id`));","CREATE TABLE `洗衣机型号` (\n`词条id` number,\n`名称` text,\n`产品类别` text,\n`洗涤容量` number,\n`能效等级` number,\n`自动化程度` text,\n`售价` number,\n`品牌id` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌门店` (\n`品牌id` number,\n`城市` text,\n`门店数量` number,\n`售卖量` number,\n`平均售价` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌平台评分` (\n`品牌id` number,\n`平台` text,\n`总评分` number,\n`性价比得分` number,\n`功能得分` number,\n`做工得分` number,\n`外观得分` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));"],"join_pairs":[["`洗衣机品牌门店`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机品牌平台评分`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机型号`.`品牌id`","`洗衣机品牌`.`词条id`"]]} 4 | {"origin_sql":"select T2.名称 , T2.所属公司 , T1.产品类别 from 洗衣机型号 as T1 join 洗衣机品牌 as T2 on 洗衣机型号.品牌id == 洗衣机品牌.词条id where T1.售价 >= 3000 and T2.市场份额 >= 0.102","db_id":"洗衣机","query":"给出售价不低于三千块,且市场份额不低于10.2%的洗衣机品牌,以及属于哪个所属公司,产品类别是什么","id":"qid000118","result_len":1,"golden_sql":"SELECT T2.`名称`,\n T2.`所属公司`,\n T1.`产品类别`\nFROM `洗衣机型号` AS T1\nJOIN `洗衣机品牌` AS T2 ON T1.`品牌id` == T2.`词条id`\nWHERE T1.`售价` >= 3000\n AND T2.`市场份额` >= 0.102","create_tables":["CREATE TABLE `洗衣机品牌` (\n`词条id` number,\n`名称` text,\n`所属公司` text,\n`成立时间` time,\n`世界500强排名` number,\n`市场份额` number,\n`2018年营业额` number,\n`2018年利润` number,\nPRIMARY KEY (`词条id`));","CREATE TABLE `洗衣机型号` (\n`词条id` number,\n`名称` text,\n`产品类别` text,\n`洗涤容量` number,\n`能效等级` number,\n`自动化程度` text,\n`售价` number,\n`品牌id` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌门店` (\n`品牌id` number,\n`城市` text,\n`门店数量` number,\n`售卖量` number,\n`平均售价` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌平台评分` (\n`品牌id` number,\n`平台` text,\n`总评分` number,\n`性价比得分` number,\n`功能得分` number,\n`做工得分` number,\n`外观得分` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));"],"join_pairs":[["`洗衣机品牌门店`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机品牌平台评分`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机型号`.`品牌id`","`洗衣机品牌`.`词条id`"]]} 5 | {"origin_sql":"select 名称 from 洗衣机品牌 where TIME_NOW - 成立时间 != 12 and 市场份额 > 0.02","db_id":"洗衣机","query":"成立年数不等于12年市场份额大于2%的洗衣机品牌有哪些?","id":"qid000878","result_len":3,"golden_sql":"SELECT `名称`\nFROM `洗衣机品牌`\nWHERE datetime('now') - `成立时间` != 12\n AND `市场份额` > 0.02","create_tables":["CREATE TABLE `洗衣机品牌` (\n`词条id` number,\n`名称` text,\n`所属公司` text,\n`成立时间` time,\n`世界500强排名` number,\n`市场份额` number,\n`2018年营业额` number,\n`2018年利润` number,\nPRIMARY KEY (`词条id`));","CREATE TABLE `洗衣机型号` (\n`词条id` number,\n`名称` text,\n`产品类别` text,\n`洗涤容量` number,\n`能效等级` number,\n`自动化程度` text,\n`售价` number,\n`品牌id` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌门店` (\n`品牌id` number,\n`城市` text,\n`门店数量` number,\n`售卖量` number,\n`平均售价` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));","CREATE TABLE `洗衣机品牌平台评分` (\n`品牌id` number,\n`平台` text,\n`总评分` number,\n`性价比得分` number,\n`功能得分` number,\n`做工得分` number,\n`外观得分` number,\nFOREIGN KEY (`品牌id`) REFERENCES `洗衣机品牌` (`词条id`));"],"join_pairs":[["`洗衣机品牌门店`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机品牌平台评分`.`品牌id`","`洗衣机品牌`.`词条id`"],["`洗衣机型号`.`品牌id`","`洗衣机品牌`.`词条id`"]]} 6 | -------------------------------------------------------------------------------- /examples/sql-eval/sql-eval-qianfan-limiter.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "completion" 3 | input_variables: ["query", "create_tables", "join_pairs", "db_id"] 4 | output_parser: 5 | name: sql 6 | settings: 7 | llm: 8 | provider: "qianfan" 9 | model: "SQLCoder-7B" 10 | chat: false 11 | kwargs: 12 | temperature: 0.7 13 | prompt: | 14 | ## 任务 15 | 你是一名数据库专家,请生成 SQL 语句回答如下提问: 16 | `{{ query }}` 17 | 18 | ### 表结构 19 | 提问依赖的表结构如下: 20 | {% for create_table in create_tables %} 21 | {{ create_table }} 22 | {% endfor %} 23 | 24 | {% for pair0, pair1 in join_pairs %}-- {{ pair0 }} 可以和 {{ pair1 }} JOIN。 25 | {% endfor %} 26 | 27 | ### 规则 28 | 29 | - 百分比均转换为浮点数。比如 10.5% 转换为 0.105。 30 | - SQL 应该满足 SQLite 格式要求,并满足 SQL 标准。 31 | 32 | ### 输出 33 | 34 | 通过给定的表结构,输出【问题】对应的【SQL 查询语句】: 35 | 36 | 【问题】:`{{ query }}` 37 | 【SQL 查询语句】: 38 | ```sql 39 | input_dataset_name: "dusql_sample.jsonl" 40 | evaluators: 41 | - name: sqleval 42 | type: SQL 43 | settings: 44 | question_key: "query" 45 | sql_key: "sql" 46 | golden_sql_key: "golden_sql" 47 | db_url: "sqlite:///{db_name}.sqlite" 48 | db_name_key: "db_id" 49 | run_config: 50 | parallelism: 5 # 并发度 51 | timeout: 120 # 超时时间 52 | rounds: 1 # 评测轮数 53 | batch_size: 1 # 批量运行/评测 54 | query_per_second: 0.5 55 | -------------------------------------------------------------------------------- /examples/sql-eval/sql-eval-qianfan.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "completion" 3 | input_variables: ["query", "create_tables", "join_pairs", "db_id"] 4 | output_parser: 5 | name: sql 6 | settings: 7 | llm: 8 | provider: "qianfan" 9 | model: "SQLCoder-7B" 10 | chat: false 11 | kwargs: 12 | temperature: 0.7 13 | prompt: | 14 | ## 任务 15 | 你是一名数据库专家,请生成 SQL 语句回答如下提问: 16 | `{{ query }}` 17 | 18 | ### 表结构 19 | 提问依赖的表结构如下: 20 | {% for create_table in create_tables %} 21 | {{ create_table }} 22 | {% endfor %} 23 | 24 | {% for pair0, pair1 in join_pairs %}-- {{ pair0 }} 可以和 {{ pair1 }} JOIN。 25 | {% endfor %} 26 | 27 | ### 规则 28 | 29 | - 百分比均转换为浮点数。比如 10.5% 转换为 0.105。 30 | - SQL 应该满足 SQLite 格式要求,并满足 SQL 标准。 31 | 32 | ### 输出 33 | 34 | 通过给定的表结构,输出【问题】对应的【SQL 查询语句】: 35 | 36 | 【问题】:`{{ query }}` 37 | 【SQL 查询语句】: 38 | ```sql 39 | input_dataset_name: "dusql_sample.jsonl" 40 | evaluators: 41 | - name: sqleval 42 | type: SQL 43 | settings: 44 | question_key: "query" 45 | sql_key: "sql" 46 | golden_sql_key: "golden_sql" 47 | db_url: "sqlite:///{db_name}.sqlite" 48 | db_name_key: "db_id" 49 | run_config: 50 | parallelism: 1 # 并发度 51 | timeout: 120 # 超时时间 52 | rounds: 1 # 评测轮数 53 | batch_size: 1 # 批量运行/评测 54 | -------------------------------------------------------------------------------- /examples/sql-eval/sql-eval.yaml: -------------------------------------------------------------------------------- 1 | provider: 2 | type: "completion" 3 | input_variables: ["query", "create_tables", "join_pairs", "db_id"] 4 | output_parser: 5 | name: sql 6 | settings: 7 | llm: 8 | provider: "openai" 9 | model: "gpt-3.5-turbo" 10 | kwargs: 11 | temperature: 0.1 12 | prompt: | 13 | ## 任务 14 | 你是一名数据库专家,请生成 SQL 语句回答如下提问: 15 | `{{ query }}` 16 | 17 | ### 表结构 18 | 提问依赖的表结构如下: 19 | {% for create_table in create_tables %} 20 | {{ create_table }} 21 | {% endfor %} 22 | 23 | {% for pair0, pair1 in join_pairs %}-- {{ pair0 }} 可以和 {{ pair1 }} JOIN。 24 | {% endfor %} 25 | 26 | ### 规则 27 | 28 | - 百分比均转换为浮点数。比如 10.5% 转换为 0.105。 29 | - SQL 应该满足 SQLite 格式要求,并满足 SQL 标准。 30 | 31 | ### 输出 32 | 33 | 通过给定的表结构,输出【问题】对应的【SQL 查询语句】: 34 | 35 | 【问题】:`{{ query }}` 36 | 【SQL 查询语句】: 37 | ```sql 38 | input_dataset_name: "dusql_sample.jsonl" 39 | evaluators: 40 | - name: sqleval 41 | type: SQL 42 | settings: 43 | question_key: "query" 44 | sql_key: "sql" 45 | golden_sql_key: "golden_sql" 46 | db_url: "sqlite:///{db_name}.sqlite" 47 | db_name_key: "db_id" 48 | run_config: 49 | parallelism: 5 # 并发度 50 | timeout: 120 # 超时时间 51 | rounds: 1 # 评测轮数 52 | batch_size: 1 # 批量运行/评测 53 | -------------------------------------------------------------------------------- /examples/sql-eval/洗衣机.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninehills/langeval/b19123c2eeea375b06ef2b2b47f862cefba6aa66/examples/sql-eval/洗衣机.sqlite -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "langeval-cli" 7 | dynamic = ["version"] 8 | description = 'Evaluation for AI apps and agent.' 9 | readme = "README.md" 10 | requires-python = ">=3.7" 11 | license = "MIT" 12 | keywords = [] 13 | authors = [ 14 | { name = "Tao Yang", email = "swulling@gmail.com" }, 15 | ] 16 | classifiers = [ 17 | "Development Status :: 4 - Beta", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: Implementation :: CPython", 22 | "Programming Language :: Python :: Implementation :: PyPy", 23 | ] 24 | dependencies = [ 25 | "click ~= 8.0", 26 | "rich ~= 13.0", 27 | "numpy ~= 1.0", 28 | "pydantic", 29 | "openai > 1.0", 30 | "Jinja2", 31 | "pandas", 32 | "PyYAML", 33 | "scikit-learn", 34 | "rouge_chinese", 35 | "jieba", 36 | "nltk", 37 | "func_timeout", 38 | "python-dotenv", 39 | "sqlalchemy" 40 | ] 41 | 42 | [project.urls] 43 | Documentation = "https://github.com/ninehills/langeval#readme" 44 | Issues = "https://github.com/ninehills/langeval/issues" 45 | Source = "https://github.com/ninehills/langeval" 46 | 47 | [project.scripts] 48 | langeval = "langeval.cli:main" 49 | 50 | [tool.hatch.build.targets.wheel] 51 | packages = ["src/langeval"] 52 | 53 | [tool.hatch.version] 54 | path = "src/langeval/__about__.py" 55 | 56 | [tool.hatch.envs.default] 57 | dependencies = [ 58 | "coverage[toml]>=6.5", 59 | "pytest", 60 | ] 61 | [tool.hatch.envs.default.scripts] 62 | test = "pytest {args:tests}" 63 | test-cov = "coverage run -m pytest {args:tests}" 64 | cov-report = [ 65 | "- coverage combine", 66 | "coverage report", 67 | ] 68 | cov = [ 69 | "test-cov", 70 | "cov-report", 71 | ] 72 | 73 | [[tool.hatch.envs.all.matrix]] 74 | python = ["3.10", "3.11"] 75 | 76 | [tool.hatch.envs.lint] 77 | detached = true 78 | dependencies = [ 79 | "black>=23.1.0", 80 | "mypy>=1.0.0", 81 | "ruff>=0.0.243", 82 | ] 83 | [tool.hatch.envs.lint.scripts] 84 | typing = "mypy --install-types --non-interactive {args:src/langeval tests}" 85 | style = [ 86 | "ruff {args:.}", 87 | "black --check --diff {args:.}", 88 | ] 89 | fmt = [ 90 | "black {args:.}", 91 | "ruff --fix {args:.}", 92 | "style", 93 | ] 94 | all = [ 95 | "style", 96 | "typing", 97 | ] 98 | 99 | [tool.black] 100 | target-version = ["py37"] 101 | line-length = 120 102 | skip-string-normalization = true 103 | 104 | [tool.ruff] 105 | target-version = "py37" 106 | line-length = 120 107 | select = [ 108 | "A", 109 | "ARG", 110 | "B", 111 | "C", 112 | "DTZ", 113 | "E", 114 | "EM", 115 | "F", 116 | "FBT", 117 | "I", 118 | "ICN", 119 | "ISC", 120 | "N", 121 | "PLC", 122 | "PLE", 123 | "PLR", 124 | "PLW", 125 | "Q", 126 | "RUF", 127 | "S", 128 | "T", 129 | "TID", 130 | "UP", 131 | "W", 132 | "YTT", 133 | ] 134 | ignore = [ 135 | # Allow non-abstract empty methods in abstract base classes 136 | "B027", 137 | # Allow boolean positional values in function calls, like `dict.get(... True)` 138 | "FBT003", 139 | # Ignore checks for possible passwords 140 | "S105", "S106", "S107", 141 | # Ignore complexity 142 | "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", 143 | # Ignore: Exception must not use an f-string literal, assign to variable first 144 | "EM102", "EM101", 145 | # Ignore: Use `X | Y` for type annotations 146 | "UP007", 147 | # Ignore: The use of `datetime.datetime.now()` without `tz` argument is not allowed 148 | "DTZ005", "DTZ003", 149 | # Ignore: Boolean-typed positional argument in function definition 150 | "FBT001", 151 | # Ignore: String contains ambiguous `,` (FULLWIDTH COMMA). Did you mean `,` (COMMA) 152 | "RUF001", "RUF003", 153 | # https://docs.astral.sh/ruff/rules/boolean-default-value-positional-argument/ 154 | "FBT002", 155 | ] 156 | unfixable = [ 157 | # Don't touch unused imports 158 | "F401", 159 | ] 160 | 161 | [tool.ruff.isort] 162 | known-first-party = ["langeval"] 163 | 164 | [tool.ruff.flake8-tidy-imports] 165 | ban-relative-imports = "all" 166 | 167 | [tool.ruff.per-file-ignores] 168 | # Tests can use magic values, assertions, and relative imports 169 | "tests/**/*" = ["PLR2004", "S101", "TID252"] 170 | 171 | [tool.coverage.run] 172 | source_pkgs = ["langeval", "tests"] 173 | branch = true 174 | parallel = true 175 | omit = [ 176 | "src/langeval/__about__.py", 177 | ] 178 | 179 | [tool.coverage.paths] 180 | langeval = ["src/langeval", "*/langeval/src/langeval"] 181 | tests = ["tests", "*/langeval/tests"] 182 | 183 | [tool.coverage.report] 184 | exclude_lines = [ 185 | "no cov", 186 | "if __name__ == .__main__.:", 187 | "if TYPE_CHECKING:", 188 | ] 189 | -------------------------------------------------------------------------------- /src/langeval/__about__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023-present Tao Yang 2 | # 3 | # SPDX-License-Identifier: MIT 4 | __version__ = "0.4.0" 5 | -------------------------------------------------------------------------------- /src/langeval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninehills/langeval/b19123c2eeea375b06ef2b2b47f862cefba6aa66/src/langeval/__init__.py -------------------------------------------------------------------------------- /src/langeval/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if __name__ == "__main__": 4 | from langeval.cli import main 5 | 6 | sys.exit(main()) 7 | -------------------------------------------------------------------------------- /src/langeval/cli/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Optional 4 | 5 | import click 6 | from dotenv import load_dotenv 7 | 8 | from langeval.__about__ import __version__ 9 | from langeval.cli.application import Application 10 | from langeval.cli.rerun import rerun 11 | from langeval.cli.run import run 12 | from langeval.cli.show import show 13 | from langeval.config import AppEnvVars, ConfigEnvVars 14 | 15 | 16 | @click.group( 17 | context_settings={"help_option_names": ["-h", "--help"], "max_content_width": 120}, invoke_without_command=True 18 | ) 19 | @click.option( 20 | "--verbose", 21 | "-v", 22 | envvar=AppEnvVars.VERBOSE, 23 | count=True, 24 | help=f"Increase verbosity (can be used additively) [env var: `{AppEnvVars.VERBOSE}`]", 25 | ) 26 | @click.option( 27 | "--color/--no-color", 28 | default=None, 29 | help="Whether or not to display colored output (default is auto-detection) " 30 | f"[env vars: `{AppEnvVars.FORCE_COLOR}`/`{AppEnvVars.NO_COLOR}`]", 31 | ) 32 | @click.option( 33 | "--env", 34 | "env_file", 35 | envvar=ConfigEnvVars.ENV_FILE, 36 | default=".env", 37 | help=f"The path to a custom envfile to use [env var: `{ConfigEnvVars.ENV_FILE}`], default is `.env`", 38 | ) 39 | @click.version_option(version=__version__, prog_name="langeval") 40 | @click.pass_context 41 | def langeval(ctx: click.Context, verbose: int, color: Optional[bool], env_file: str): 42 | """ 43 | \b 44 | ▄▄▌ ▄▄▄· ▐ ▄ ▄▄ • ▄▄▄ . ▌ ▐· ▄▄▄· ▄▄▌ 45 | ██• ▐█ ▀█ •█▌▐█▐█ ▀ ▪▀▄.▀·▪█·█▌▐█ ▀█ ██• 46 | ██▪ ▄█▀▀█ ▐█▐▐▌▄█ ▀█▄▐▀▀▪▄▐█▐█•▄█▀▀█ ██▪ 47 | ▐█▌▐▌▐█ ▪▐▌██▐█▌▐█▄▪▐█▐█▄▄▌ ███ ▐█ ▪▐▌▐█▌▐▌ 48 | .▀▀▀ ▀ ▀ ▀▀ █▪·▀▀▀▀ ▀▀▀ . ▀ ▀ ▀ .▀▀▀ 49 | """ 50 | if color is None: 51 | if os.environ.get(AppEnvVars.NO_COLOR) == "1": 52 | color = False 53 | elif os.environ.get(AppEnvVars.FORCE_COLOR) == "1": 54 | color = True 55 | 56 | if verbose > 0: 57 | if verbose == 1: 58 | logging.basicConfig(level=logging.INFO) 59 | else: 60 | logging.basicConfig(level=logging.DEBUG) 61 | if os.path.exists(env_file): 62 | click.echo("Load env file: %s" % env_file) 63 | load_dotenv(env_file) 64 | app = Application(ctx.exit, verbose, color) 65 | 66 | if not ctx.invoked_subcommand: 67 | app.display_info(ctx.get_help()) 68 | return 69 | 70 | # Persist app data for sub-commands 71 | ctx.obj = app 72 | 73 | 74 | langeval.add_command(run) 75 | langeval.add_command(show) 76 | langeval.add_command(rerun) 77 | 78 | 79 | def main(): # no cov 80 | try: 81 | return langeval(prog_name="langeval", windows_expand_args=False) # type: ignore 82 | except Exception: 83 | from rich.console import Console 84 | 85 | console = Console() 86 | console.print_exception(suppress=[click]) 87 | return 1 88 | -------------------------------------------------------------------------------- /src/langeval/cli/application.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from langeval.cli.terminal import Terminal 4 | 5 | 6 | class Application(Terminal): 7 | def __init__(self, exit_func, verbosity: int, color: Optional[bool]): 8 | super().__init__(verbosity, color, False) 9 | self.__exit_func = exit_func 10 | 11 | self.verbosity = self.verbosity > 0 12 | 13 | def abort(self, text="", code=1, **kwargs): 14 | if text: 15 | self.display_error(text, **kwargs) 16 | self.__exit_func(code) 17 | -------------------------------------------------------------------------------- /src/langeval/cli/constant.py: -------------------------------------------------------------------------------- 1 | class TaskOutputVars: 2 | TaskMeta = "task.yaml" 3 | TaskOutput = "output.jsonl" 4 | TaskResult = "result.jsonl" 5 | TaskMerged = "merged.jsonl" 6 | TaskLog = "output.log" 7 | TaskStatus = "status.json" 8 | TaskStastics = "stastics.json" 9 | -------------------------------------------------------------------------------- /src/langeval/cli/rerun/__init__.py: -------------------------------------------------------------------------------- 1 | """Re-run the evaluation task.""" 2 | import json 3 | import os 4 | import shutil 5 | import uuid 6 | from datetime import datetime 7 | 8 | import click 9 | 10 | from langeval.cli.application import Application 11 | from langeval.cli.constant import TaskOutputVars 12 | from langeval.cli.run.run import run_task 13 | from langeval.tasks import EvalTask, Result 14 | 15 | 16 | @click.command(short_help="Re-run evaluation task.") 17 | @click.argument("task_dir", required=True, type=click.Path(exists=True)) 18 | @click.option( 19 | "--output", 20 | "-o", 21 | "output", 22 | help="Output directory for the evaluation files and results", 23 | type=click.Path(exists=False), 24 | ) 25 | @click.pass_obj 26 | def rerun(app: Application, task_dir, output): 27 | """Re-run evaluation task. 28 | 29 | TASK_DIR: The directory of the evaluation task. 30 | """ 31 | # 1. 创建 output dir 32 | task_id = f"{datetime.now().strftime('%y%m%d%H%M')}-{uuid.uuid4().hex[:4]}" 33 | if not output: 34 | output = f"output/{task_id}" 35 | app.display_info(f"Output dir: {output}") 36 | if not os.path.exists(output): 37 | os.makedirs(output) 38 | app.display_info(f"Output dir created: {output}") 39 | else: 40 | app.abort(f"Output dir {output} exists, exit.") 41 | 42 | # 2. 复制 task_dir 到新的目录 43 | shutil.copytree(task_dir, output, dirs_exist_ok=True) 44 | 45 | # 3. Load task 46 | task_file = os.path.join(output, TaskOutputVars.TaskMeta) 47 | status_file = os.path.join(output, TaskOutputVars.TaskStatus) 48 | 49 | with open(task_file) as f: 50 | task_file_content = f.read() 51 | task = EvalTask.from_yaml(task_file_content, dataset_dir=output) 52 | with open(status_file) as f: 53 | # {"uuid": "2311021530-5c69", "status": "FINISHED", "progress": "1/0/1", "finished_time": 1698910215.125846} 54 | status = json.loads(f.read()) 55 | with open(os.path.join(output, TaskOutputVars.TaskResult)) as f: 56 | results = [Result.from_json(line) for line in f.readlines()] 57 | 58 | # 4. Run task 59 | run_task(app, output, task_id, task, 60 | sample=status["sample"], sample_seed=status["sample_seed"], 61 | results=results) 62 | -------------------------------------------------------------------------------- /src/langeval/cli/run/__init__.py: -------------------------------------------------------------------------------- 1 | """Run evaluation task.""" 2 | import os 3 | import uuid 4 | from datetime import datetime 5 | 6 | import click 7 | 8 | from langeval.cli.application import Application 9 | from langeval.cli.constant import TaskOutputVars 10 | from langeval.cli.run.run import run_task 11 | from langeval.tasks import EvalTask 12 | 13 | 14 | @click.command(short_help="Run evaluation task") 15 | @click.argument("task_file", required=True, type=click.File("r")) 16 | @click.option( 17 | "--output", 18 | "-o", 19 | "output", 20 | help="Output directory for the evaluation files and results", 21 | type=click.Path(exists=False), 22 | ) 23 | @click.option("--sample", "-s", "sample", type=int, help="Sample size for the evaluation.") 24 | @click.option("--sample_seed", "-ss", "sample_seed", type=int, help="Sample seed for the evaluation. Default: 42", 25 | default=42) 26 | @click.pass_obj 27 | def run(app: Application, task_file, output, sample, sample_seed): 28 | """Run evaluation task. 29 | 30 | TASK_FILE: The evaluation task yaml file. 31 | """ 32 | # 1. Load task 33 | task_file_content = task_file.read() 34 | task = EvalTask.from_yaml(task_file_content) 35 | task_id = f"{datetime.now().strftime('%y%m%d%H%M')}-{uuid.uuid4().hex[:4]}" 36 | app.display_info(f">>> Loaded task from {task_file.name} successfully, task_id: {task_id}") 37 | 38 | # 2. Create output dir 39 | if not output: 40 | output = f"output/{task_id}" 41 | app.display_info(f"Output dir: {output}") 42 | if not os.path.exists(output): 43 | os.makedirs(output) 44 | app.display_info(f"Output dir created: {output}") 45 | else: 46 | app.abort(f"Output dir {output} exists, exit.") 47 | 48 | # 3. Copy task file & input dataset to output dir 49 | with open(os.path.join(output, TaskOutputVars.TaskMeta), "w") as f: 50 | f.write(task_file_content) 51 | if task.input_dataset_name and task.input_dataset_binary: 52 | input_dataset_filename = os.path.basename(task.input_dataset_name) 53 | with open(os.path.join(output, input_dataset_filename), "wb") as f: 54 | f.write(task.input_dataset_binary) 55 | # 4. Run task 56 | run_task(app, output, task_id, task, sample, sample_seed) 57 | -------------------------------------------------------------------------------- /src/langeval/cli/run/display.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pandas as pd 5 | 6 | from langeval.cli.application import Application 7 | from langeval.cli.constant import TaskOutputVars 8 | from langeval.tasks import Result, TaskRunner 9 | 10 | 11 | def save_task_merged_result(file: str, results: list[Result]): 12 | with open(file, "w") as f: 13 | for result in results: 14 | data = result.inputs 15 | data.update(result.run.outputs) 16 | f.write(json.dumps(data, ensure_ascii=False) + "\n") 17 | 18 | def save_task_result(file: str, running_stats, eval_stats): 19 | result = { 20 | # process NaN value. 21 | "running_stats": json.loads(running_stats.to_json(force_ascii=False)), 22 | "eval_stats": json.loads(eval_stats.to_json(force_ascii=False)), 23 | } 24 | with open(file, "w") as f: 25 | json.dump(result, f, indent=2, ensure_ascii=False) 26 | 27 | def show_task_result(app: Application, runner: TaskRunner, output_dir: str): 28 | result_file = os.path.join(output_dir, TaskOutputVars.TaskResult) 29 | # Display info 30 | app.display_header("Task Info") 31 | app.display_info(f"ID: {runner.uuid}") 32 | app.display_info(f"Status: {runner.status}") 33 | app.display_info(f"Progress: {runner.progress}") 34 | app.display_info(f"Output JSONL: {result_file}") 35 | 36 | # Display stastics 37 | app.display_header("Task Stastics") 38 | running_stats, eval_stats = runner.statistic() 39 | app.display_table( 40 | title="Run stats", 41 | columns=convert_running_stats_to_columns(running_stats), 42 | show_lines=True, 43 | force_ascii=True, 44 | ) 45 | app.display_table( 46 | title="Eval stats", 47 | columns=convert_eval_stats_to_columns(eval_stats), 48 | show_lines=True, 49 | force_ascii=True, 50 | ) 51 | 52 | def convert_running_stats_to_columns(df: pd.DataFrame) -> dict[str, dict[int, str]]: 53 | columns: dict[str, dict[int, str]] = {k: {} for k in df.columns} 54 | index = 0 55 | for _, r in df.iterrows(): 56 | for k, v in r.items(): 57 | columns[str(k)][index] = str(v) 58 | index += 1 59 | return columns 60 | 61 | 62 | def convert_eval_stats_to_columns(df: pd.DataFrame) -> dict[str, dict[int, str]]: 63 | columns: dict[str, dict[int, str]] = {"eval": {}} 64 | for k in df.columns: 65 | columns[k] = {} 66 | index = 0 67 | for i, r in df.iterrows(): 68 | columns["eval"][index] = str(i) 69 | for k, v in r.items(): 70 | columns[str(k)][index] = str(v) 71 | index += 1 72 | return columns 73 | -------------------------------------------------------------------------------- /src/langeval/cli/run/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | from langeval.cli.application import Application 5 | from langeval.cli.constant import TaskOutputVars 6 | from langeval.cli.run.display import save_task_merged_result, save_task_result, show_task_result 7 | from langeval.tasks import EvalTask, Result, TaskRunner, TaskRunnerStatus 8 | 9 | 10 | def run_task(app:Application, output: str, task_id: str, task: EvalTask, 11 | sample: int = 0, sample_seed: int = 42, results: Optional[List[Result]] = None): 12 | # 4. Run task 13 | log_file = os.path.join(output, TaskOutputVars.TaskLog) 14 | log_file_handler = open(log_file, "w") 15 | 16 | def log_callback(_, log): 17 | app.display_info(log) 18 | log_file_handler.write(log) 19 | log_file_handler.flush() 20 | 21 | status_file = os.path.join(output, TaskOutputVars.TaskStatus) 22 | output_file = os.path.join(output, TaskOutputVars.TaskOutput) 23 | output_file_handler = open(output_file, "w") 24 | 25 | def progress_callback(_, progress, results): 26 | app.display_info(f"Progress: {progress}") 27 | jsonl = "".join([r.to_jsonl() for r in results]) 28 | output_file_handler.write(jsonl) 29 | output_file_handler.flush() 30 | 31 | with open(status_file, "w") as f: 32 | f.write(runner.status_json()) 33 | 34 | def status_callback(_, status): 35 | app.display_info(f"Status: {status}") 36 | with open(status_file, "w") as f: 37 | f.write(runner.status_json()) 38 | 39 | runner = TaskRunner( 40 | task_id, 41 | task, 42 | sample=sample, 43 | sample_seed=sample_seed, 44 | status_callback=status_callback, 45 | log_callback=log_callback, 46 | progress_callback=progress_callback, 47 | ) 48 | if results: 49 | app.display(">>> Load previous results.") 50 | runner.results = results 51 | 52 | runner.start() 53 | app.display_waiting(f">>> Task {task_id} running...") 54 | runner.join() 55 | 56 | # 5. Task finish 57 | if runner.status == TaskRunnerStatus.FINISHED: 58 | app.display_success(f">>> Task {task_id} finish: {runner.status}.") 59 | else: 60 | app.display_error(f">>> Task {task_id} finish: {runner.status}.") 61 | 62 | log_file_handler.close() 63 | output_file_handler.close() 64 | 65 | with open(os.path.join(output, TaskOutputVars.TaskResult), "w") as f: 66 | for result in runner.results: 67 | f.write(result.to_jsonl()) 68 | 69 | # 6. Show result 70 | running_stats, eval_stats = runner.statistic() 71 | app.display_info(f"Save task result to {output}") 72 | save_task_result( 73 | os.path.join(output, TaskOutputVars.TaskStastics), 74 | running_stats, eval_stats) 75 | 76 | save_task_merged_result( 77 | os.path.join(output, TaskOutputVars.TaskMerged), 78 | runner.results 79 | ) 80 | 81 | show_task_result(app, runner, output) 82 | -------------------------------------------------------------------------------- /src/langeval/cli/show/__init__.py: -------------------------------------------------------------------------------- 1 | """Show evalution result in CLI or Web UI.""" 2 | import json 3 | import os 4 | 5 | import click 6 | 7 | from langeval.cli.application import Application 8 | from langeval.cli.constant import TaskOutputVars 9 | from langeval.cli.run.display import show_task_result 10 | from langeval.tasks import EvalTask, Result, TaskRunner 11 | 12 | 13 | @click.command(short_help="Show evaluation result.") 14 | @click.argument("task_dir", required=True, type=click.Path(exists=True)) 15 | @click.option("--web", "web", is_flag=True, help="Display the web UI for the evaluation.") 16 | @click.pass_obj 17 | def show(app: Application, task_dir, web): 18 | """Show evaluation result. 19 | 20 | TASK_DIR: The directory of the evaluation task. 21 | """ 22 | task_file = os.path.join(task_dir, TaskOutputVars.TaskMeta) 23 | with open(task_file) as f: 24 | task_file_content = f.read() 25 | task = EvalTask.from_yaml(task_file_content, dataset_dir=task_dir) 26 | 27 | with open(os.path.join(task_dir, TaskOutputVars.TaskStatus)) as f: 28 | # {"uuid": "2311021530-5c69", "status": "FINISHED", "progress": "1/0/1", "finished_time": 1698910215.125846} 29 | status = json.loads(f.read()) 30 | 31 | with open(os.path.join(task_dir, TaskOutputVars.TaskResult)) as f: 32 | results = [Result.from_json(line) for line in f.readlines()] 33 | 34 | runner = TaskRunner(status["uuid"], task) 35 | runner.set_status(status["status"]) 36 | runner.progress = status["progress"] 37 | runner.finished_time = status["finished_time"] 38 | runner.results = results 39 | 40 | if not runner: 41 | app.abort(f"No task found in {task_dir}") 42 | show_task_result(app, runner, task_dir) 43 | -------------------------------------------------------------------------------- /src/langeval/cli/terminal.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | https://github.com/pypa/hatch/blob/master/src/hatch/cli/terminal.py 4 | """ 5 | from __future__ import annotations 6 | 7 | import os 8 | from abc import ABC, abstractmethod 9 | from functools import cached_property 10 | from textwrap import indent as indent_text 11 | from typing import Callable, Optional 12 | 13 | import click 14 | from rich.console import Console 15 | from rich.errors import StyleSyntaxError 16 | from rich.status import Status 17 | from rich.style import Style 18 | from rich.text import Text 19 | 20 | 21 | class TerminalStatus(ABC): 22 | @abstractmethod 23 | def stop(self) -> None: 24 | ... 25 | 26 | def __enter__(self) -> TerminalStatus: 27 | return self 28 | 29 | @abstractmethod 30 | def __exit__(self, exc_type, exc_val, exc_tb): 31 | ... 32 | 33 | 34 | class NullStatus(TerminalStatus): 35 | def stop(self): 36 | pass 37 | 38 | def __exit__(self, exc_type, exc_value, traceback): 39 | pass 40 | 41 | 42 | class BorrowedStatus(TerminalStatus): 43 | def __init__( 44 | self, 45 | console: Console, 46 | *, 47 | is_interactive: bool, 48 | verbosity: int, 49 | spinner_style: str, 50 | waiting_style: Style, 51 | success_style: Style, 52 | initializer: Callable, 53 | finalizer: Callable, 54 | ): 55 | self.__console = console 56 | self.__is_interactive = is_interactive 57 | self.__verbosity = verbosity 58 | self.__spinner_style = spinner_style 59 | self.__waiting_style = waiting_style 60 | self.__success_style = success_style 61 | self.__initializer = initializer 62 | self.__finalizer = finalizer 63 | 64 | # This is the possibly active current status 65 | self.__status: Status | None = None 66 | 67 | # This is used as a stack to display the current message 68 | self.__messages: list[tuple[Text, str]] = [] 69 | 70 | def stop(self) -> None: 71 | active = self.__active() 72 | if self.__status is not None: 73 | self.__status.stop() 74 | 75 | old_message, final_text = self.__messages[-1] 76 | if self.__verbosity > 0 and active: 77 | if not final_text: 78 | final_text = old_message.plain 79 | final_text = f"Finished {final_text[:1].lower()}{final_text[1:]}" 80 | 81 | self.__output(Text(final_text, style=self.__success_style)) 82 | 83 | def __call__(self, message: str, final_text: str = "") -> BorrowedStatus: 84 | self.__messages.append((Text(message, style=self.__waiting_style), final_text)) 85 | return self 86 | 87 | def __enter__(self) -> BorrowedStatus: 88 | if not self.__messages: 89 | return self 90 | 91 | message, _ = self.__messages[-1] 92 | if not self.__is_interactive: 93 | self.__output(message) 94 | return self 95 | 96 | if self.__status is None: 97 | self.__initializer() 98 | else: 99 | self.__status.stop() 100 | 101 | self.__status = self.__console.status(message, spinner=self.__spinner_style) 102 | self.__status.start() 103 | 104 | return self 105 | 106 | def __exit__(self, exc_type, exc_val, exc_tb): 107 | old_message, final_text = self.__messages.pop() 108 | if self.__verbosity > 0 and self.__active(): 109 | if not final_text: 110 | final_text = old_message.plain 111 | final_text = f"Finished {final_text[:1].lower()}{final_text[1:]}" 112 | 113 | self.__output(Text(final_text, style=self.__success_style)) 114 | 115 | if not self.__is_interactive: 116 | return 117 | 118 | if self.__status is not None: 119 | self.__status.stop() 120 | if not self.__messages: 121 | self.__status = None 122 | self.__finalizer() 123 | else: 124 | message, _ = self.__messages[-1] 125 | self.__status = self.__console.status(message, spinner=self.__spinner_style) 126 | self.__status.start() 127 | 128 | def __active(self) -> bool: 129 | return self.__status is not None and self.__status._live.is_started 130 | 131 | def __output(self, text): 132 | self.__console.stderr = True 133 | try: 134 | self.__console.print(text, overflow="ignore", no_wrap=True, crop=False) 135 | finally: 136 | self.__console.stderr = False 137 | 138 | 139 | class Terminal: 140 | def __init__(self, verbosity, enable_color: Optional[bool], interactive: bool): 141 | self.verbosity = verbosity 142 | self.console = Console( 143 | force_terminal=enable_color, 144 | force_interactive=interactive, 145 | no_color=enable_color is False, 146 | markup=False, 147 | emoji=False, 148 | highlight=False, 149 | # Force consistent output for test assertions 150 | legacy_windows=False if "HATCH_SELF_TESTING" in os.environ else None, 151 | ) 152 | 153 | # Set defaults so we can pretty print before loading user config 154 | self._style_level_success: Style | str = "bold cyan" 155 | self._style_level_error: Style | str = "bold red" 156 | self._style_level_warning: Style | str = "bold yellow" 157 | self._style_level_waiting: Style | str = "bold magenta" 158 | # Default is simply bold rather than bold white for shells that have been configured with a white background 159 | self._style_level_info: Style | str = "bold" 160 | self._style_level_debug: Style | str = "bold" 161 | 162 | # Chosen as the default since it's compatible everywhere and looks nice 163 | self._style_spinner = "simpleDotsScrolling" 164 | 165 | @cached_property 166 | def kv_separator(self) -> Style: 167 | return self.style_warning("->") # type: ignore 168 | 169 | def style_success(self, text: str) -> Text: 170 | return Text(text, style=self._style_level_success) 171 | 172 | def style_error(self, text: str) -> Text: 173 | return Text(text, style=self._style_level_error) 174 | 175 | def style_warning(self, text: str) -> Text: 176 | return Text(text, style=self._style_level_warning) 177 | 178 | def style_waiting(self, text: str) -> Text: 179 | return Text(text, style=self._style_level_waiting) 180 | 181 | def style_info(self, text: str) -> Text: 182 | return Text(text, style=self._style_level_info) 183 | 184 | def style_debug(self, text: str) -> Text: 185 | return Text(text, style=self._style_level_debug) 186 | 187 | def initialize_styles(self, styles: dict): # no cov 188 | # Lazily display errors so that they use the correct style 189 | errors = [] 190 | 191 | for option, style in styles.items(): 192 | attribute = f"_style_level_{option}" 193 | 194 | default_level = getattr(self, attribute, None) 195 | if default_level: 196 | try: 197 | style = Style.parse(style) # noqa: PLW2901 198 | except StyleSyntaxError as e: # no cov 199 | errors.append(f"Invalid style definition for `{option}`, defaulting to `{default_level}`: {e}") 200 | style = Style.parse(default_level) # noqa: PLW2901 201 | else: 202 | attribute = f"_style_{option}" 203 | 204 | setattr(self, attribute, style) 205 | 206 | return errors 207 | 208 | def display(self, text="", **kwargs): 209 | self.console.print(text, style=self._style_level_info, overflow="ignore", no_wrap=True, crop=False, **kwargs) 210 | 211 | def display_critical(self, text="", **kwargs): 212 | self.console.stderr = True 213 | try: 214 | self.console.print( 215 | text, style=self._style_level_error, overflow="ignore", no_wrap=True, crop=False, **kwargs 216 | ) 217 | finally: 218 | self.console.stderr = False 219 | 220 | def display_error(self, text="", *, stderr=True, indent=None, link=None, **kwargs): 221 | if self.verbosity < -2: # noqa: PLR2004 222 | return 223 | 224 | self._output(text, self._style_level_error, stderr=stderr, indent=indent, link=link, **kwargs) 225 | 226 | def display_warning(self, text="", *, stderr=True, indent=None, link=None, **kwargs): 227 | if self.verbosity < -1: 228 | return 229 | 230 | self._output(text, self._style_level_warning, stderr=stderr, indent=indent, link=link, **kwargs) 231 | 232 | def display_info(self, text="", *, stderr=True, indent=None, link=None, **kwargs): 233 | if self.verbosity < 0: 234 | return 235 | 236 | self._output(text, self._style_level_info, stderr=stderr, indent=indent, link=link, **kwargs) 237 | 238 | def display_success(self, text="", *, stderr=True, indent=None, link=None, **kwargs): 239 | if self.verbosity < 0: 240 | return 241 | 242 | self._output(text, self._style_level_success, stderr=stderr, indent=indent, link=link, **kwargs) 243 | 244 | def display_waiting(self, text="", *, stderr=True, indent=None, link=None, **kwargs): 245 | if self.verbosity < 0: 246 | return 247 | 248 | self._output(text, self._style_level_waiting, stderr=stderr, indent=indent, link=link, **kwargs) 249 | 250 | def display_debug(self, text="", level=1, *, stderr=True, indent=None, link=None, **kwargs): 251 | if not 1 <= level <= 3: # noqa: PLR2004 252 | error_message = "Debug output can only have verbosity levels between 1 and 3 (inclusive)" 253 | raise ValueError(error_message) 254 | elif self.verbosity < level: 255 | return 256 | 257 | self._output(text, self._style_level_debug, stderr=stderr, indent=indent, link=link, **kwargs) 258 | 259 | def display_mini_header(self, text, *, stderr=False, indent=None, link=None): 260 | if self.verbosity < 0: 261 | return 262 | 263 | self.display_info("[", stderr=stderr, indent=indent, end="") 264 | self.display_success(text, stderr=stderr, link=link, end="") 265 | self.display_info("]", stderr=stderr) 266 | 267 | def display_header(self, title="", *, stderr=False): # noqa: ARG002 268 | self.console.rule(Text(title, self._style_level_success)) 269 | 270 | def display_markdown(self, text, **kwargs): # no cov 271 | from rich.markdown import Markdown 272 | 273 | self.output(Markdown(text), **kwargs) 274 | 275 | def display_pair(self, key, value): 276 | self.output(self.style_success(key), self.kv_separator, value) 277 | 278 | def display_table(self, title, columns, *, show_lines=False, column_options=None, force_ascii=False, num_rows=0): 279 | from rich.table import Table 280 | 281 | if column_options is None: 282 | column_options = {} 283 | 284 | table_options = {} 285 | if force_ascii: 286 | from rich.box import ASCII_DOUBLE_HEAD 287 | 288 | table_options["box"] = ASCII_DOUBLE_HEAD 289 | table_options["safe_box"] = True 290 | 291 | table = Table(title=title, show_lines=show_lines, title_style="", **table_options) 292 | columns = dict(columns) 293 | 294 | for title, indices in list(columns.items()): 295 | if indices: 296 | table.add_column(title, style="bold", **column_options.get(title, {})) 297 | else: 298 | columns.pop(title) 299 | 300 | if not columns: 301 | return 302 | 303 | for i in range(num_rows or max(map(max, columns.values())) + 1): 304 | row = [] 305 | for indices in columns.values(): 306 | row.append(indices.get(i, "")) 307 | 308 | if any(row): 309 | table.add_row(*row) 310 | 311 | self.output(table) 312 | 313 | @cached_property 314 | def status(self) -> BorrowedStatus: 315 | return BorrowedStatus( 316 | self.console, 317 | is_interactive=self.console.is_interactive, 318 | verbosity=self.verbosity, 319 | spinner_style=self._style_spinner, 320 | waiting_style=self._style_level_waiting, # type: ignore 321 | success_style=self._style_level_success, # type: ignore 322 | initializer=lambda: setattr(self.platform, "displaying_status", True), # type: ignore[attr-defined] 323 | finalizer=lambda: setattr(self.platform, "displaying_status", False), # type: ignore[attr-defined] 324 | ) 325 | 326 | def status_if(self, *args, condition: bool, **kwargs) -> TerminalStatus: 327 | return self.status(*args, **kwargs) if condition else NullStatus() 328 | 329 | def _output(self, text="", style=None, *, stderr=False, indent=None, link=None, **kwargs): 330 | if indent: 331 | text = indent_text(text, indent) 332 | 333 | if link: 334 | style = style.update_link(self.platform.format_file_uri(link)) # type: ignore 335 | 336 | self.output(text, stderr=stderr, style=style, **kwargs) 337 | 338 | def output(self, *args, stderr=False, **kwargs): 339 | kwargs.setdefault("overflow", "ignore") 340 | kwargs.setdefault("no_wrap", True) 341 | kwargs.setdefault("crop", False) 342 | 343 | if not stderr: 344 | self.console.print(*args, **kwargs) 345 | else: 346 | self.console.stderr = True 347 | try: 348 | self.console.print(*args, **kwargs) 349 | finally: 350 | self.console.stderr = False 351 | 352 | @staticmethod 353 | def prompt(text, **kwargs): 354 | return click.prompt(text, **kwargs) 355 | 356 | @staticmethod 357 | def confirm(text, **kwargs): 358 | return click.confirm(text, **kwargs) 359 | -------------------------------------------------------------------------------- /src/langeval/config/__init__.py: -------------------------------------------------------------------------------- 1 | class AppEnvVars: 2 | """Environment variables used by the application.""" 3 | 4 | VERBOSE = "LANGEVAL_VERBOSE" 5 | # https://no-color.org 6 | NO_COLOR = "NO_COLOR" 7 | FORCE_COLOR = "FORCE_COLOR" 8 | 9 | 10 | class ConfigEnvVars: 11 | DATA_DIR = "LANGEVAL_DATA_DIR" 12 | ENV_FILE = "LANGEVAL_ENV_FILE" 13 | -------------------------------------------------------------------------------- /src/langeval/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import Evaluator, EvaluatorType, EvaluatorSettings # noqa 2 | from .exception import EvalRunError # noqa 3 | -------------------------------------------------------------------------------- /src/langeval/evaluators/evaluator.py: -------------------------------------------------------------------------------- 1 | """Evaluator Types""" 2 | import copy 3 | import enum 4 | import logging 5 | from typing import Any, Optional, Union 6 | 7 | import yaml 8 | 9 | try: 10 | import pydantic.v1 as pc 11 | except ImportError: 12 | import pydantic as pc 13 | 14 | from langeval.evaluators.exception import EvalRunError 15 | from langeval.evaluators.nlp import NLP 16 | from langeval.evaluators.rag import Rag 17 | from langeval.evaluators.sql import SQLEvaluator 18 | from langeval.models import LLM, Embedding 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | def represent_multiline_text(dumper, data): 23 | if "\n" in data: 24 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") 25 | else: 26 | return dumper.represent_scalar("tag:yaml.org,2002:str", data) 27 | 28 | 29 | yaml.add_representer(str, represent_multiline_text) 30 | 31 | 32 | class EvaluatorType(str, enum.Enum): 33 | """Evaluator Type""" 34 | 35 | LLM_GRADE = "LLM_GRADE" 36 | EMBEDDING_COS_SIM = "EMBEDDING_COS_SIM" 37 | PYTHON_CODE = "PYTHON_CODE" 38 | RAG = "RAG" 39 | # Some NLP metrics 40 | NLP = "NLP" 41 | # sql evaluator 42 | SQL = "SQL" 43 | 44 | 45 | class LLMGrade(pc.BaseModel): 46 | prompt: str 47 | eval_keys: list[str] 48 | llm: Optional[LLM] = None 49 | 50 | 51 | class EmbeddingCosSim(pc.BaseModel): 52 | pairs_keys: tuple[str, str] 53 | embedding: Embedding 54 | cos_sim_threshold: float 55 | 56 | 57 | class PythonCode(pc.BaseModel): 58 | code: str 59 | 60 | 61 | EvaluatorSettings = { 62 | EvaluatorType.LLM_GRADE: LLMGrade, 63 | EvaluatorType.EMBEDDING_COS_SIM: EmbeddingCosSim, 64 | EvaluatorType.PYTHON_CODE: PythonCode, 65 | EvaluatorType.RAG: Rag, 66 | EvaluatorType.NLP: NLP, 67 | EvaluatorType.SQL: SQLEvaluator, 68 | } 69 | 70 | 71 | class Evaluator(pc.BaseModel): 72 | """Evaluator""" 73 | # Name 74 | name: str 75 | # Type 76 | type: EvaluatorType # noqa: A003 77 | # Detail config 78 | settings: Optional[Union[Rag, LLMGrade, EmbeddingCosSim, PythonCode, NLP, SQLEvaluator]] = None 79 | # Rate limit 80 | query_per_second: float = pc.Field(default=0, ge=0.1, le=100) 81 | 82 | def to_yaml(self) -> str: 83 | return yaml.dump(self.dict(exclude_unset=True), encoding="utf-8", allow_unicode=True).decode("utf-8") 84 | 85 | @pc.validator("type") 86 | def type_must_be_valid(cls, v): # noqa: N805 87 | if EvaluatorType(v) not in EvaluatorSettings.keys(): 88 | raise ValueError(f"Invalid type: {v}") 89 | return v 90 | 91 | def batch_call(self, batch_inputs: list[dict[str, Any]], batch_outputs: list[dict[str, Any]], timeout=10, 92 | default_llm=None) -> list[dict[str, Any]]: 93 | """Do batch eval""" 94 | from langeval.evaluators.rag import eval_rag 95 | from langeval.evaluators.run import eval_embedding_cos_sim, eval_llm_grade, eval_python_code 96 | kwargs_list = [] 97 | for i, inputs in enumerate(batch_inputs): 98 | kwargs = copy.copy(inputs) 99 | kwargs.update(batch_outputs[i]) 100 | kwargs_list.append(kwargs) 101 | 102 | results = [] 103 | try: 104 | if self.type == EvaluatorType.LLM_GRADE: 105 | for kwargs in kwargs_list: 106 | results.append(eval_llm_grade(self, kwargs, timeout, default_llm)) 107 | return results 108 | elif self.type == EvaluatorType.EMBEDDING_COS_SIM: 109 | for kwargs in kwargs_list: 110 | results.append(eval_embedding_cos_sim(self, kwargs, timeout)) 111 | return results 112 | elif self.type == EvaluatorType.PYTHON_CODE: 113 | return eval_python_code(self, kwargs_list, timeout) 114 | elif self.type == EvaluatorType.RAG: 115 | if self.settings is None or not isinstance(self.settings, Rag): 116 | raise EvalRunError(f"RAG settings is not specified: {self.settings}") 117 | for kwargs in kwargs_list: 118 | results.append(eval_rag(self.settings, kwargs, timeout, default_llm)) 119 | return results 120 | elif self.type == EvaluatorType.NLP: 121 | if self.settings is None or not isinstance(self.settings, NLP): 122 | raise EvalRunError(f"NLP settings is not specified: {self.settings}") 123 | for kwargs in kwargs_list: 124 | results.append(self.settings.call(kwargs)) 125 | return results 126 | elif self.type == EvaluatorType.SQL: 127 | if self.settings is None or type(self.settings) != SQLEvaluator: 128 | raise EvalRunError(f"SQL settings is not specified: {self.settings}") 129 | for kwargs in kwargs_list: 130 | results.append(self.settings.call(kwargs, timeout)) 131 | return results 132 | 133 | except Exception as e: 134 | logger.exception(f"eval failed: {e}") 135 | logger.debug(f"evaluator {self} eval failed: {e}", exc_info=True) 136 | raise EvalRunError(f"eval failed: {e}") from e 137 | 138 | raise EvalRunError(f"eval type not supported: {self.type}") 139 | -------------------------------------------------------------------------------- /src/langeval/evaluators/exception.py: -------------------------------------------------------------------------------- 1 | class EvalRunError(RuntimeError): 2 | """Evaluator Run Exception""" 3 | 4 | pass 5 | -------------------------------------------------------------------------------- /src/langeval/evaluators/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import logging 3 | from typing import Any 4 | 5 | import jieba 6 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 7 | 8 | try: 9 | import pydantic.v1 as pc 10 | except ImportError: 11 | import pydantic as pc 12 | 13 | from rouge_chinese import Rouge 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class NLPMetric(str, enum.Enum): 18 | # Exact Match 19 | ExactMatch = "exact_match" 20 | # Chinese BLEU 21 | BLEUChinese = "bleu_chinese" 22 | # Chinese Rouge 23 | RougeChinese = "rouge_chinese" 24 | 25 | 26 | def exact_match(prediction: str, reference: str, 27 | stripped: bool = True, ignore_case: bool = False) -> dict[str, float]: 28 | """Exact Match 29 | 30 | Args: 31 | prediction (str): prediction 32 | reference (str): reference 33 | stripped (bool): remove leading and trailing whitespaces, default True 34 | ignore_case (bool): ignore case, default False 35 | 36 | Returns: 37 | dict[str, float]: {"exact_match": 0.0/1.0} 38 | """ 39 | if stripped: 40 | prediction = prediction.strip() 41 | reference = reference.strip() 42 | if ignore_case: 43 | prediction = prediction.lower() 44 | reference = reference.lower() 45 | return { 46 | "exact_match": float(prediction == reference) 47 | } 48 | 49 | def jieba_tokenizer(text: str) -> list[str]: 50 | """Jieba Tokenizer""" 51 | return list(jieba.cut(text)) 52 | 53 | def bleu_chinese(prediction, reference, 54 | segment: bool = True): 55 | """Chinese BLEU 56 | 57 | Args: 58 | prediction (str): prediction 59 | reference (str): reference 60 | segment (bool): jieba word segment, default True 61 | 62 | Returns: 63 | dict[str, float]: {"bleu-4": 0.0-1.0} 64 | """ 65 | if segment: 66 | prediction = jieba_tokenizer(prediction) 67 | reference = jieba_tokenizer(reference) 68 | else: 69 | prediction = list[prediction] 70 | reference = list[reference] 71 | return { 72 | "bleu-4": sentence_bleu([reference], prediction, smoothing_function=SmoothingFunction().method3) 73 | } 74 | 75 | def rouge_chinese(prediction, reference): 76 | """Chinese Rouge 77 | 78 | Args: 79 | prediction (str): prediction 80 | reference (str): reference 81 | 82 | Returns: 83 | dict[str, float]: {"rouge-1": 0.0-1.0, "rouge-2": 0.0-1.0, "rouge-l": 0.0-1.0} 84 | """ 85 | prediction = jieba_tokenizer(prediction) 86 | reference = jieba_tokenizer(reference) 87 | if len(" ".join(prediction).split()) == 0 or len(" ".join(reference).split()) == 0: 88 | return {"rouge-1": 0.0, "rouge-2": 0.0, "rouge-l": 0.0} 89 | rouge = Rouge() 90 | scores = rouge.get_scores(" ".join(prediction), " ".join(reference))[0] 91 | return { 92 | "rouge-1": scores["rouge-1"]["f"], # type: ignore 93 | "rouge-2": scores["rouge-2"]["f"], # type: ignore 94 | "rouge-l": scores["rouge-l"]["f"], # type: ignore 95 | } 96 | 97 | 98 | metrics_eval_funcs = { 99 | NLPMetric.ExactMatch: exact_match, 100 | NLPMetric.BLEUChinese: bleu_chinese, 101 | NLPMetric.RougeChinese: rouge_chinese, 102 | } 103 | 104 | 105 | class NLP(pc.BaseModel): 106 | prediction_key: str 107 | reference_key: str 108 | nlp_metrics: list[NLPMetric] 109 | nlp_metrics_kwargs: dict[NLPMetric, Any] = pc.Field(default_factory=dict) 110 | 111 | def call(self, kwargs: dict[str, Any]) -> dict[str, Any]: 112 | """Evaluate call""" 113 | prediction = kwargs[self.prediction_key] 114 | reference = kwargs[self.reference_key] 115 | results = {} 116 | for metric in self.nlp_metrics: 117 | result = metrics_eval_funcs[metric](prediction, reference, 118 | **self.nlp_metrics_kwargs.get(metric, {})) 119 | results.update(result) 120 | return results 121 | -------------------------------------------------------------------------------- /src/langeval/evaluators/rag/__init__.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import logging 3 | from typing import Any, Optional 4 | 5 | import jinja2 6 | 7 | try: 8 | import pydantic.v1 as pc 9 | except ImportError: 10 | import pydantic as pc 11 | 12 | from sklearn.metrics import ndcg_score 13 | 14 | from langeval.evaluators.exception import EvalRunError 15 | from langeval.evaluators.rag.utils import overlap_coefficient_contain 16 | from langeval.models import LLM 17 | from langeval.providers.output_parser import SimpleJsonOutputParser 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | DEFAULT_ANSWER_SIMILARITY_THRESHOLD = 0.8 22 | DEFAULT_ANSWER_CORRECTNESS_PROMPT = """ 23 | 你是一名教师,接下来你需要根据【问题】和【参考答案】,来评价【学生答案】。 24 | 25 | # 评分过程 26 | 27 | 1. 首先思考【问题】下,【参考答案】和【学生答案】的一致性。 28 | 2. 根据一致性,给出对【学生答案】的评价理由。 29 | 3. 根据评价理由,给出对【学生答案】的评分。(0代表错误,0.5代表部分正确,1代表完全正确)。 30 | 31 | # 输入 32 | 33 | 问题:"{{ question }}" 34 | 参考答案:"{{ reference_answer }}" 35 | 学生答案:"{{ answer }}" 36 | 37 | # 输出 38 | 39 | 输出为 Markdown JSON 格式的字符串,示例: 40 | 41 | ```json 42 | { 43 | "answer_correctness_reasoning": "<评价理由>", 44 | "answer_correctness": 0 45 | } 46 | ``` 47 | 48 | 注意分数只能为 0、0.5、1,0 代表错误,0.5 代表部分正确,2 代表完全正确。 49 | 50 | 输出(Markdown JSON 格式): 51 | """ 52 | 53 | class RagMetric(str, enum.Enum): 54 | # Compare contexts and reference_context 55 | RetrievalRecall = "retrieval_recall" 56 | # Compare answer and reference_answer with LLM-JUDGE correctness 57 | AnswerCorrectness = "answer_correctness" 58 | 59 | 60 | class Rag(pc.BaseModel): 61 | rag_metrics: list[RagMetric] 62 | rag_llm: Optional[LLM] = None 63 | answer_correctness_prompt: Optional[str] = None 64 | 65 | 66 | def retrieval_recall(rag: Rag, kwargs: dict[str, Any], timeout, default_llm): # noqa: ARG001 67 | contexts = kwargs["contexts"] 68 | reference_context = kwargs["reference_context"] 69 | if len(contexts) == 0: 70 | return { 71 | #"retrieval_recall_ndgc_10": 0.0, 72 | "retrieval_recall_hit_rate": 0.0, 73 | "retrieval_recall_mrr": 0.0 74 | } 75 | # 通过判断 reference_context 是否在 contexts 中,以及所在的位置来计算。 76 | # 这里使用 ndcg 算法:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ndcg_score.html 77 | # 假设 reference_context 和 contexts 是同一批语料的不同切分形式,所以不需要语义相似度,而是比对文本相似度。 78 | # 使用修改后的 Overlap Coefficient 算法计算相似度。取值范围是 0-1 79 | # 这种算法,当 contexts 包含 reference_context 时,得分是 1。 80 | #true_relevance: list[float] = [] 81 | # 如下是假设 reference_context 和 contexts 是相同语料。 82 | hit_rate = 0.0 # 代表 reference_context 是否在 contexts 中精确匹配。 83 | mrr = 0.0 # 计算 reference_context 的精确匹配位置,首位为 1.0,次位为 0.5,以此类推。 84 | for i, context in enumerate(contexts): 85 | if reference_context == context: 86 | hit_rate = 1.0 87 | if mrr == 0: 88 | # if 是避免 contexts 中有重复内容。 89 | mrr = 1.0 / (i + 1) 90 | #true_relevance.append(overlap_coefficient_contain( 91 | # context, reference_context 92 | #)) 93 | # scores 的绝对值没有意义,只要是倒序排列就行。 94 | # scores = list(range(len(contexts), 0, -1)) 95 | # k = 10 代表只统计前 10 个结果 96 | # ndgc_10 = ndcg_score([true_relevance], [scores], k=10) # type: ignore 97 | 98 | return { 99 | "retrieval_recall_hit_rate": hit_rate, 100 | "retrieval_recall_mrr": mrr, 101 | #"retrieval_recall_ndgc_10": ndgc_10, 102 | } 103 | 104 | def answer_correctness(rag: Rag, kwargs: dict[str, Any], timeout, default_llm): 105 | if rag.rag_llm: 106 | llm = rag.rag_llm 107 | else: 108 | llm = default_llm 109 | if llm is None: 110 | raise EvalRunError("llm is None, can not eval answer_correctness") 111 | 112 | prompt_tpl = DEFAULT_ANSWER_CORRECTNESS_PROMPT \ 113 | if rag.answer_correctness_prompt is None else rag.answer_correctness_prompt 114 | prompt = jinja2.Template(prompt_tpl).render(**kwargs) 115 | text = llm.completion(prompt, timeout=timeout) 116 | eval_result = {} 117 | eval_result = SimpleJsonOutputParser().parse(text) 118 | if "answer_correctness" not in eval_result: 119 | raise EvalRunError("eval completion result missing key: answer_correctness") 120 | return eval_result 121 | 122 | 123 | metrics_eval_funcs = { 124 | RagMetric.RetrievalRecall: retrieval_recall, 125 | RagMetric.AnswerCorrectness: answer_correctness, 126 | } 127 | 128 | def eval_rag(rag: Rag, kwargs: dict[str, Any], timeout, default_llm) -> dict[str, Any]: 129 | """Rag eval""" 130 | eval_result = {} 131 | for metric in rag.rag_metrics: 132 | if metric not in metrics_eval_funcs: 133 | raise EvalRunError(f"eval rag metric not supported: {metric}") 134 | eval_func = metrics_eval_funcs[metric] 135 | r = eval_func(rag, kwargs, timeout, default_llm) 136 | eval_result.update(r) 137 | return eval_result 138 | -------------------------------------------------------------------------------- /src/langeval/evaluators/rag/utils.py: -------------------------------------------------------------------------------- 1 | def str_inx(word_, string_): 2 | return [i for i in range(len(string_)) if string_[i] == word_] 3 | 4 | 5 | def ab_max_inx(s_a, s_b): 6 | i, len_a, len_b = 0, len(s_a), len(s_b) 7 | while len_a > i and len_b > i and s_a[i] == s_b[i]: 8 | i += 1 9 | return i 10 | 11 | 12 | def lcs(s_a, s_b): 13 | """计算两个字符串的所有不重复的公共字串""" 14 | res = [] 15 | if s_a: 16 | a0_inx_in_b = str_inx(s_a[0], s_b) 17 | if a0_inx_in_b: 18 | b_end_inx, a_end_inx = -1, 0 19 | for inx in a0_inx_in_b: 20 | if b_end_inx > inx: 21 | continue 22 | this_inx = ab_max_inx(s_a, s_b[inx:]) 23 | a_end_inx = max(a_end_inx, this_inx) 24 | res.append(s_a[:this_inx]) 25 | b_end_inx = this_inx + inx 26 | res += lcs(s_a[a_end_inx:], s_b) 27 | else: 28 | res += lcs(s_a[1:], s_b) 29 | return res 30 | 31 | def overlap_coefficient_contain(s1: str, s2: str) -> float: 32 | """Compute overlap coefficient between two strings. 33 | 34 | Need find the longest common substring between the two strings. 35 | 36 | when s1 contains s2, overlap coefficient is 1. 37 | when s2 contains s1, overlap coefficient not 1. 38 | """ 39 | s1_len = len(s1) 40 | s2_len = len(s2) 41 | if s1_len == 0 or s2_len == 0: 42 | return 0 43 | # Find the all common substrings between the two strings 44 | cs_list = lcs(s1, s2) 45 | # Calculate the weight length of all common substrings 46 | overlap_coefficient = 0 47 | for cs in cs_list: 48 | # 取平方可以增加长串的权重,降低短串的权重。 49 | # 同时过滤掉短字符 50 | if len(cs) >= min(5, s2_len): 51 | overlap_coefficient += (len(cs) / s2_len) ** 2 52 | return overlap_coefficient 53 | 54 | 55 | if __name__ == "__main__": 56 | print(overlap_coefficient_contain("你好aaa", "你好")) 57 | print(overlap_coefficient_contain("你好", "你好aaaa")) 58 | print(overlap_coefficient_contain("你a好zzzzzzzzzzzzzzzzzzz", "你好xxxxxxxxxxxxxxxxxxxxxxx")) 59 | -------------------------------------------------------------------------------- /src/langeval/evaluators/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import jinja2 5 | 6 | from langeval.evaluators.evaluator import EmbeddingCosSim, Evaluator, LLMGrade, PythonCode 7 | from langeval.evaluators.exception import EvalRunError 8 | from langeval.models import Embedding 9 | from langeval.providers.output_parser import SimpleJsonOutputParser 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def eval_python_code(evaluator: Evaluator, kwargs_list: list[dict[str, Any]], timeout) -> list[dict[str, Any]]: 15 | """Do python code eval""" 16 | if evaluator.settings is None or not isinstance(evaluator.settings, PythonCode): 17 | msg = "PYTHON_CODE not specified" 18 | raise EvalRunError(msg) 19 | 20 | # Write code to temp file, then run it 21 | import json 22 | import subprocess 23 | import tempfile 24 | 25 | with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=True) as f: 26 | f.write(evaluator.settings.code) 27 | f.flush() 28 | result = subprocess.run( 29 | ["python3", f.name], # noqa: S603, S607 30 | input=json.dumps(kwargs_list), 31 | stdout=subprocess.PIPE, 32 | timeout=timeout, 33 | check=True, 34 | encoding="utf-8", 35 | ) 36 | return json.loads(result.stdout) 37 | 38 | def eval_embedding_cos_sim(evaluator: Evaluator, kwargs: dict[str, Any], timeout) -> dict[str, Any]: 39 | """Embedding Cosine Similarity""" 40 | if evaluator.settings is None or not isinstance(evaluator.settings, EmbeddingCosSim): 41 | raise EvalRunError("EMBEDDING_COS_SIM not specified") 42 | if len(evaluator.settings.pairs_keys) != 2: # noqa: PLR2004 43 | raise EvalRunError("EMBEDDING_COS_SIM input/output keys not specified") 44 | 45 | model = evaluator.settings.embedding 46 | key1, key2 = evaluator.settings.pairs_keys 47 | embeddings = model.embedding([kwargs[key1], kwargs[key2]], timeout=timeout) 48 | cos_sim = Embedding.cosine_similarity(embeddings[0], embeddings[1]) 49 | is_similar = cos_sim >= evaluator.settings.cos_sim_threshold 50 | return { 51 | "cos_sim": cos_sim, 52 | "is_similar": int(is_similar), 53 | } 54 | 55 | def eval_llm_grade(evaluator: Evaluator, kwargs: dict[str, Any], timeout, default_llm) -> dict[str, Any]: 56 | """Use LLM as Judge for Grade""" 57 | if evaluator.settings is None or not isinstance(evaluator.settings, LLMGrade): 58 | raise EvalRunError("LLM_GRADE not specified") 59 | if evaluator.settings.llm: 60 | llm = evaluator.settings.llm 61 | else: 62 | llm = default_llm 63 | if not llm: 64 | raise EvalRunError("LLM not specified") 65 | prompt = jinja2.Template(evaluator.settings.prompt).render(**kwargs) 66 | text = llm.completion(prompt, timeout=timeout) 67 | eval_result = {} 68 | result = SimpleJsonOutputParser().parse(text) 69 | for k in evaluator.settings.eval_keys: 70 | if k not in result: 71 | raise EvalRunError(f"eval completion result missing key: {k}") 72 | eval_result[k] = result[k] 73 | return eval_result 74 | -------------------------------------------------------------------------------- /src/langeval/evaluators/sql/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, List 3 | from string import Formatter 4 | 5 | try: 6 | import pydantic.v1 as pc 7 | except ImportError: 8 | import pydantic as pc 9 | 10 | from .sqleval import compare_query_results # noqa: TID252 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class SQLEvaluator(pc.BaseModel): 16 | """SQL Evaluator 17 | 18 | from: 19 | 20 | Output format: 21 | { 22 | "exact_match": 0/1, 23 | "correct": 0/1, 24 | "error_msg": "QUERY EXECUTION ERROR: ..." 25 | } 26 | """ 27 | question_key: str 28 | sql_key: str 29 | # golden sql, 支持用 {a.c, a.b} 表示兼容 a.c 或 a.b 30 | golden_sql_key: str 31 | # sqlalchemy url, eg: 32 | # - sqlite:///tmp/test.db 33 | # - mysql://user:pass@localhost:port/dbname 34 | # - postgresql://user:pass@localhost:port/dbname 35 | db_url: str 36 | # 支持动态生成 db_url,需要在db_url 中有 { db_name } 字段 37 | db_name_key: str = "db_id" 38 | 39 | def call(self, kwargs: dict[str, Any], timeout: int = 30) -> dict[str, Any]: 40 | """Evaluate call""" 41 | question = kwargs[self.question_key] 42 | sql = kwargs[self.sql_key] 43 | golden_sql = kwargs[self.golden_sql_key] 44 | 45 | ret = { 46 | "exact_match": 0, 47 | "correct": 0, 48 | "error_msg": "", 49 | } 50 | try: 51 | # 判断 db_url 是否需要动态生成 52 | vars = get_template_variables(self.db_url) 53 | if "db_name" in vars: 54 | db_name = kwargs[self.db_name_key] 55 | db_url = self.db_url.format(db_name=db_name) 56 | else: 57 | db_url = self.db_url 58 | exact_match, correct = compare_query_results( 59 | query_gold=golden_sql, 60 | query_gen=sql, 61 | db_url=db_url, 62 | question=question, 63 | timeout=timeout, 64 | ) 65 | ret["exact_match"] = int(exact_match) 66 | ret["correct"] = int(correct) 67 | ret["error_msg"] = "" 68 | except Exception as e: 69 | ret["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 70 | return ret 71 | 72 | 73 | def get_template_variables(template: str) -> List[str]: 74 | """Get the variables from the template. 75 | 76 | Args: 77 | template: The template string. 78 | template_format: The template format. Should be one of "f-string" or "jinja2". 79 | 80 | Returns: 81 | The variables from the template. 82 | 83 | Raises: 84 | ValueError: If the template format is not supported. 85 | """ 86 | input_variables = { 87 | v for _, v, _, _ in Formatter().parse(template) if v is not None 88 | } 89 | 90 | 91 | return sorted(input_variables) 92 | -------------------------------------------------------------------------------- /src/langeval/evaluators/sql/sqleval.py: -------------------------------------------------------------------------------- 1 | # this file contains all of the helper functions used for evaluations 2 | # from: https://github.com/defog-ai/sql-eval/blob/main/eval/eval.py 3 | # Licensed under the Apache-2.0 License 4 | 5 | import itertools 6 | import re 7 | from venv import logger 8 | 9 | import pandas as pd 10 | from func_timeout import func_timeout 11 | from pandas.testing import assert_frame_equal, assert_series_equal 12 | from sqlalchemy import create_engine 13 | 14 | # like_pattern = r"LIKE\s+'[^']*'" 15 | LIKE_PATTERN = r"LIKE[\s\S]*'" 16 | 17 | 18 | def normalize_table( 19 | df: pd.DataFrame, query_category: str, question: str 20 | ) -> pd.DataFrame: 21 | """ 22 | Normalizes a dataframe by: 23 | 1. removing all duplicate rows 24 | 2. sorting columns in alphabetical order 25 | 3. sorting rows using values from first column to last (if query_category is not 'order_by' and 26 | question does not ask for ordering) 27 | 4. resetting index 28 | """ 29 | # remove duplicate rows, if any 30 | df = df.drop_duplicates() 31 | 32 | # sort columns in alphabetical order 33 | sorted_df = df.reindex(sorted(df.columns), axis=1) 34 | 35 | # check if query_category is 'order_by' and if question asks for ordering 36 | has_order_by = False 37 | pattern = re.compile(r"(order|sort|arrange)", re.IGNORECASE) 38 | in_question = re.search(pattern, question.lower()) # true if contains 39 | if query_category == "order_by" or in_question: 40 | has_order_by = True 41 | if not has_order_by: 42 | # sort rows using values from first column to last 43 | sorted_df = sorted_df.sort_values(by=list(sorted_df.columns)) 44 | # reset index 45 | sorted_df = sorted_df.reset_index(drop=True) 46 | return sorted_df 47 | 48 | 49 | # for escaping percent signs in regex matches 50 | def escape_percent(match): 51 | # Extract the matched group 52 | group = match.group(0) 53 | # Replace '%' with '%%' within the matched group 54 | escaped_group = group.replace("%", "%%") 55 | # Return the escaped group 56 | return escaped_group 57 | 58 | 59 | # find start and end index of { } in a string. return (start, end) if found, else return (-1, -1) 60 | def find_bracket_indices(s: str, start_index: int = 0) -> "tuple[int, int]": 61 | start = s.find("{", start_index) 62 | end = s.find("}", start + 1) 63 | if start == -1 or end == -1: 64 | return (-1, -1) 65 | return (start, end) 66 | 67 | 68 | # extrapolate all possible queries from a query with { } in it 69 | def get_all_minimal_queries(query: str) -> "list[str]": 70 | """ 71 | extrapolate all possible queries 72 | - split by semicolon. this is to accommodate queries where joins to other tables are also acceptable. 73 | - expand all column permutations if there are braces { } in it. eg: 74 | ```sql 75 | SELECT {user.id, user.name} FROM user; 76 | ``` 77 | Would be expanded to: 78 | ```sql 79 | SELECT user.id FROM user; 80 | SELECT user.name FROM user; 81 | SELECT user.id, user.name FROM user; 82 | ``` 83 | """ 84 | queries = query.split(";") 85 | result_queries = [] 86 | for q in queries: 87 | query = q.strip() 88 | if query == "": 89 | continue 90 | start, end = find_bracket_indices(query, 0) 91 | if (start, end) == (-1, -1): 92 | result_queries.append(query) 93 | continue 94 | else: 95 | # get all possible column subsets 96 | column_options = query[start + 1 : end].split(",") 97 | column_combinations = list( 98 | itertools.chain.from_iterable( 99 | itertools.combinations(column_options, r) 100 | for r in range(1, len(column_options) + 1) 101 | ) 102 | ) 103 | for column_tuple in column_combinations: 104 | left = query[:start] 105 | column_str = ", ".join(column_tuple) 106 | right = query[end + 1 :] 107 | # change group by size dynamically if necessary 108 | if right.find("GROUP BY {}"): 109 | right = right.replace("GROUP BY {}", f"GROUP BY {column_str}") 110 | result_queries.append(left + column_str + right) 111 | return result_queries 112 | 113 | 114 | def query_db( 115 | query: str, db_url: str, timeout: float = 10.0 116 | ) -> pd.DataFrame: 117 | """ 118 | Runs query on postgres db and returns results as a dataframe. 119 | This assumes that you have the evaluation database running locally. 120 | If you don't, you can following the instructions in the README (Restoring to Postgres) to set it up. 121 | 122 | timeout: time in seconds to wait for query to finish before timing out 123 | """ 124 | try: 125 | engine = create_engine(db_url) 126 | escaped_query = re.sub( 127 | LIKE_PATTERN, escape_percent, query, flags=re.IGNORECASE 128 | ) # ignore case of LIKE 129 | results_df = func_timeout( 130 | timeout, pd.read_sql_query, args=(escaped_query, engine) 131 | ) 132 | engine.dispose() # type: ignore 133 | return results_df # type: ignore 134 | except Exception as e: 135 | if engine: # type: ignore 136 | engine.dispose() # type: ignore 137 | raise e 138 | 139 | 140 | def compare_df( 141 | df1: pd.DataFrame, df2: pd.DataFrame, query_category: str, question: str 142 | ) -> bool: 143 | """ 144 | Compares two dataframes and returns True if they are the same, else False. 145 | """ 146 | # drop duplicates to ensure equivalence 147 | if df1.shape == df2.shape and (df1.values == df2.values).all(): 148 | return True 149 | 150 | df1 = normalize_table(df1, query_category, question) 151 | df2 = normalize_table(df2, query_category, question) 152 | 153 | if df1.shape == df2.shape and (df1.values == df2.values).all(): 154 | return True 155 | else: 156 | return False 157 | 158 | 159 | def subset_df( 160 | df_sub: pd.DataFrame, 161 | df_super: pd.DataFrame, 162 | query_category: str, 163 | question: str, 164 | verbose: bool = False, 165 | ) -> bool: 166 | """ 167 | Checks if df_sub is a subset of df_super 168 | """ 169 | if df_sub.empty: 170 | return False # handle cases for empty dataframes 171 | 172 | # make a copy of df_super so we don't modify the original while keeping track of matches 173 | df_super_temp = df_super.copy(deep=True) 174 | matched_columns = [] 175 | for col_sub_name in df_sub.columns: 176 | col_match = False 177 | for col_super_name in df_super_temp.columns: 178 | col_sub = df_sub[col_sub_name].sort_values().reset_index(drop=True) 179 | col_super = ( 180 | df_super_temp[col_super_name].sort_values().reset_index(drop=True) 181 | ) 182 | try: 183 | assert_series_equal( 184 | col_sub, col_super, check_dtype=False, check_names=False 185 | ) 186 | col_match = True 187 | matched_columns.append(col_super_name) 188 | # remove col_super_name to prevent us from matching it again 189 | df_super_temp = df_super_temp.drop(columns=[col_super_name]) 190 | break 191 | except AssertionError: 192 | continue 193 | if col_match is False: 194 | if verbose: 195 | logger.warning(f"no match for {col_sub_name}") 196 | return False 197 | df_sub_normalized = normalize_table(df_sub, query_category, question) 198 | 199 | # get matched columns from df_super, and rename them with columns from df_sub, then normalize 200 | df_super_matched = df_super[matched_columns].rename( 201 | columns=dict(zip(matched_columns, df_sub.columns)) 202 | ) 203 | df_super_matched = normalize_table(df_super_matched, query_category, question) 204 | 205 | try: 206 | assert_frame_equal(df_sub_normalized, df_super_matched, check_dtype=False) 207 | return True 208 | except AssertionError: 209 | return False 210 | 211 | 212 | def compare_query_results( 213 | query_gold: str, 214 | query_gen: str, 215 | db_url: str, 216 | question: str, 217 | timeout: float = 10.0, 218 | ) -> "tuple[bool, bool]": 219 | """ 220 | Compares the results of two queries and returns a tuple of booleans, where the first element is 221 | whether the queries produce exactly the same result, and the second element is whether the 222 | result of the gold query is a subset of the result of the generated query (still correct). 223 | We bubble up exceptions (mostly from query_postgres_db) to be handled in the runner. 224 | """ 225 | # check if query contains "order by" 226 | query_category = "order_by" if "order by" in query_gold.lower() else "" 227 | queries_gold = get_all_minimal_queries(query_gold) 228 | results_gen = query_db(query_gen, db_url, timeout) 229 | correct = False 230 | for q in queries_gold: 231 | results_gold = query_db(q, db_url, timeout) 232 | if compare_df(results_gold, results_gen, query_category, question): 233 | return (True, True) 234 | elif subset_df(results_gold, results_gen, query_category, question): 235 | correct = True 236 | return (False, correct) 237 | -------------------------------------------------------------------------------- /src/langeval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .types import Message # noqa 2 | from .llms import LLM # noqa 3 | from .embeddings import Embedding # noqa 4 | from .exception import ModelRunError # noqa 5 | -------------------------------------------------------------------------------- /src/langeval/models/embeddings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from numpy.linalg import norm 5 | 6 | try: 7 | import pydantic.v1 as pc 8 | except ImportError: 9 | import pydantic as pc 10 | 11 | from langeval.models.exception import ModelRunError 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class Embedding(pc.BaseModel): 17 | provider: str 18 | model: str 19 | # Model parameters, e.g. Qianfan has ak, sk 20 | kwargs: dict = {} 21 | 22 | @pc.validator("provider") 23 | def provider_must_be_valid(cls, v): # noqa: N805 24 | if v not in ["qianfan", "openai"]: 25 | raise ValueError(f"Invalid provider: {v}") 26 | return v 27 | 28 | def embedding(self, texts: list[str], timeout: int = 10) -> list[list[float]]: 29 | """Generate embeddings for texts""" 30 | if self.provider == "qianfan": 31 | # Cut for qianfan 384 tokens limit 32 | texts = [text[:384] for text in texts] 33 | import qianfan 34 | import qianfan.errors 35 | 36 | try: 37 | client = qianfan.Embedding(**self.kwargs) 38 | res = client.do(texts, request_timeout=float(timeout)) 39 | logger.debug(f"qianfan embedding: {texts}") 40 | if res.code != 200: # type: ignore # noqa: PLR2004 41 | raise ModelRunError(f"qianfan embedding failed: {res}") 42 | result = res.body # type: ignore 43 | if not result: 44 | raise ModelRunError(f"qianfan embedding failed: {res}") 45 | # type: ignore 46 | return [i.get("embedding", []) for i in result["data"]] 47 | except qianfan.errors.QianfanError as e: 48 | raise ModelRunError(f"qianfan embedding failed: {e.__class__.__name__}({e})") from e 49 | except Exception as e: 50 | logger.error(f"qianfan embedding failed: {e}", exc_info=True) 51 | raise ModelRunError(f"qianfan embedding failed: {e}") from e 52 | elif self.provider == "openai": 53 | try: 54 | import openai 55 | except ImportError as e: 56 | raise ValueError( 57 | "Could not import openai python package. Please install it with `pip install openai`." 58 | ) from e 59 | try: 60 | response = openai.embeddings.create( 61 | model=self.model, input=texts, encoding_format="float", timeout=timeout, **self.kwargs 62 | ) 63 | logger.debug(f"openai embedding: {texts}") 64 | return [i.get("embedding", []) for i in response["data"]] # type: ignore 65 | except Exception as e: 66 | raise ModelRunError(f"openai call failed: {e.__class__.__name__}({e})") from e 67 | else: 68 | raise NotImplementedError() 69 | 70 | @staticmethod 71 | def cosine_similarity(vector1: list[float], vector2: list[float]) -> float: 72 | """Compute cosine similarity between two vectors""" 73 | return np.dot(vector1, vector2) / (norm(vector1) * norm(vector2)) 74 | -------------------------------------------------------------------------------- /src/langeval/models/exception.py: -------------------------------------------------------------------------------- 1 | class ModelRunError(RuntimeError): 2 | """LLM Call Exception""" 3 | 4 | pass 5 | -------------------------------------------------------------------------------- /src/langeval/models/llms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | try: 5 | import pydantic.v1 as pc 6 | except ImportError: 7 | import pydantic as pc 8 | 9 | from langeval.models.exception import ModelRunError 10 | from langeval.models.openai import OpenAI 11 | from langeval.models.qianfan import Qianfan 12 | from langeval.models.types import Message 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class LLM(pc.BaseModel): 18 | provider: str 19 | model: str 20 | chat: bool = True # Default to chat model 21 | # Model parameters, e.g. Qianfan has ak, sk, etc. 22 | kwargs: dict = {} 23 | instance: Any = None 24 | 25 | @pc.validator("provider") 26 | def provider_must_be_valid(cls, v): # noqa: N805 27 | if v not in ["qianfan", "openai", "langchain"]: 28 | raise ValueError(f"Invalid provider: {v}") 29 | return v 30 | 31 | def completion(self, prompt: str, timeout: int = 10) -> str: 32 | """Generate completion for prompt""" 33 | if self.provider == "qianfan": 34 | if self.instance is None: 35 | self.instance = Qianfan(self.model, self.chat) 36 | return self.instance.call(prompt, [], timeout, **self.kwargs) 37 | elif self.provider == "openai": 38 | if self.instance is None: 39 | self.instance = OpenAI(self.model, self.chat) 40 | return self.instance.call(prompt, [], timeout, **self.kwargs) 41 | elif self.provider == "langchain": 42 | try: 43 | from langchain.llms.loading import load_llm_from_config 44 | 45 | llm = load_llm_from_config( 46 | dict( 47 | _type=self.provider, 48 | model_name=self.model, 49 | **self.kwargs, 50 | ) 51 | ) 52 | except ImportError as e: 53 | raise ValueError( 54 | "Could not import langchain python package or llm package." 55 | "Please install it with `pip install langchain`." 56 | ) from e 57 | try: 58 | response = llm.predict(prompt, request_timeout=float(timeout)) 59 | logger.debug(f"langchain completion: {prompt} -> {response}") 60 | return response 61 | except Exception as e: 62 | raise ModelRunError( 63 | f"langchain call failed: {e.__class__.__name__}({e})") from e 64 | else: 65 | raise ValueError(f"Invalid provider: {self.provider}") 66 | 67 | def chat_completion(self, messages: list[Message], timeout: int = 10) -> str: 68 | """Generate chat completion for messages""" 69 | if self.provider == "qianfan": 70 | if self.instance is None: 71 | self.instance = Qianfan(self.model, self.chat) 72 | return self.instance.call("", messages, timeout, **self.kwargs) 73 | elif self.provider == "openai": 74 | if self.instance is None: 75 | self.instance = OpenAI(self.model, self.chat) 76 | return self.instance.call("", messages, timeout, **self.kwargs) 77 | elif self.provider == "langchain": 78 | raise ValueError("langchain does not support chat_model load yet") 79 | else: 80 | raise ValueError(f"Invalid provider: {self.provider}") 81 | -------------------------------------------------------------------------------- /src/langeval/models/openai.py: -------------------------------------------------------------------------------- 1 | try: 2 | import openai 3 | except ImportError as e: 4 | raise ValueError("Could not import openai python package. Please install it with `pip install openai`.") from e 5 | 6 | import logging 7 | from typing import Any, List 8 | 9 | from langeval.models.exception import ModelRunError 10 | from langeval.models.types import Message 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class OpenAI: 16 | def __init__(self, model: str, chat: bool = True): 17 | self.model = model 18 | self.client = openai.Client() 19 | self.chat = chat 20 | 21 | def call(self, prompt: str, messages: List[Message], timeout: int, **kwargs: Any) -> str: 22 | try: 23 | if self.chat and prompt: 24 | # When chat model use prompt, then convert it to messages 25 | messages = [Message(role="user", content=prompt)] 26 | if messages: 27 | response = self.client.chat.completions.create( 28 | model=self.model, 29 | messages=[{"role": message.role, "content": message.content} for message in messages], # type: ignore 30 | timeout=float(timeout), 31 | **kwargs, 32 | ) 33 | logger.debug(f"openai completion: {messages} -> {response}") 34 | return response.choices[0].message.content 35 | else: 36 | kwargs = kwargs.copy() 37 | if "max_tokens" not in kwargs: 38 | # Default to 1024 tokens 39 | kwargs["max_tokens"] = 1024 40 | response = self.client.completions.create( 41 | model=self.model, 42 | prompt=prompt, 43 | timeout=float(timeout), 44 | **kwargs, 45 | ) 46 | logger.debug(f"openai completion: {prompt} -> {response}") 47 | return response.choices[0].text 48 | except openai.OpenAIError as e: 49 | raise ModelRunError(f"openai call failed: {e.__class__.__name__}({e})") from e 50 | -------------------------------------------------------------------------------- /src/langeval/models/qianfan.py: -------------------------------------------------------------------------------- 1 | try: 2 | import qianfan 3 | import qianfan.errors 4 | except ImportError as e: 5 | raise ValueError( 6 | "Could not import qianfan python package. Please install it with `pip install qianfan`." 7 | ) from e 8 | 9 | from typing import Any, List 10 | 11 | from langeval.models.exception import ModelRunError 12 | from langeval.models.types import Message 13 | 14 | 15 | class Qianfan: 16 | def __init__(self, model: str, chat: bool = True): 17 | self.model = model 18 | self.chat = chat 19 | if chat: 20 | if model.startswith("endpoint/"): 21 | self.client = qianfan.ChatCompletion(endpoint=model.split("/")[-1]) 22 | else: 23 | self.client = qianfan.ChatCompletion(model=model) 24 | else: 25 | if model.startswith("endpoint/"): 26 | self.client = qianfan.Completion(endpoint=model.split("/")[-1]) 27 | else: 28 | self.client = qianfan.Completion(model=model) 29 | 30 | def call(self, prompt: str, messages: List[Message], timeout: int, **kwargs: Any) -> str: 31 | try: 32 | if prompt: 33 | messages_converted = [{"role": "user", "content": prompt}] 34 | else: 35 | system = "" 36 | messages_converted = [] 37 | for message in messages: 38 | if message.role == "system": 39 | system = message.content 40 | continue 41 | messages_converted.append({"role": message.role, "content": message.content}) 42 | if system: 43 | kwargs["system"] = system 44 | if isinstance(self.client, qianfan.ChatCompletion): 45 | res = self.client.do( 46 | messages_converted, request_timeout=float(timeout), **kwargs) 47 | else: 48 | res = self.client.do( 49 | prompt, request_timeout=float(timeout), **kwargs) 50 | if res.code != 200: # type: ignore # noqa: PLR2004 51 | raise ModelRunError(f"qianfan call failed: {res}") 52 | result = res.body.get("result", None) # type: ignore 53 | if not result: 54 | raise ModelRunError(f"qianfan call failed: {res}") 55 | return result 56 | except qianfan.errors.QianfanError as e: 57 | raise ModelRunError(f"qianfan call failed: {e.__class__.__name__}({e})") from e 58 | -------------------------------------------------------------------------------- /src/langeval/models/types.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pydantic.v1 as pc 3 | except ImportError: 4 | import pydantic as pc 5 | 6 | class Message(pc.BaseModel): 7 | """ChatCompletion message""" 8 | 9 | role: str 10 | content: str 11 | -------------------------------------------------------------------------------- /src/langeval/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from .provider import Provider # noqa 2 | from .exception import ProviderRunError # noqa 3 | -------------------------------------------------------------------------------- /src/langeval/providers/exception.py: -------------------------------------------------------------------------------- 1 | class ProviderRunError(RuntimeError): 2 | """Provider run error""" 3 | 4 | pass 5 | -------------------------------------------------------------------------------- /src/langeval/providers/output_parser.py: -------------------------------------------------------------------------------- 1 | """From: https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/output_parsers/json.py 2 | """ 3 | import json 4 | import re 5 | from typing import Any, Callable, Dict 6 | 7 | 8 | class OutputParserError(ValueError): 9 | """输出解析异常""" 10 | 11 | pass 12 | 13 | 14 | # Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py 15 | # MIT License 16 | def parse_partial_json(s: str, *, strict: bool = False) -> Any: 17 | """Parse a JSON string that may be missing closing braces. 18 | 19 | Args: 20 | s: The JSON string to parse. 21 | strict: Whether to use strict parsing. Defaults to False. 22 | 23 | Returns: 24 | The parsed JSON object as a Python dictionary. 25 | """ 26 | # Attempt to parse the string as-is. 27 | try: 28 | return json.loads(s, strict=strict) 29 | except json.JSONDecodeError: 30 | pass 31 | 32 | # Initialize variables. 33 | new_s = "" 34 | stack = [] 35 | is_inside_string = False 36 | escaped = False 37 | 38 | # Process each character in the string one at a time. 39 | for char in s: 40 | if is_inside_string: 41 | if char == '"' and not escaped: 42 | is_inside_string = False 43 | elif char == "\n" and not escaped: 44 | char = "\\n" # Replace the newline character with the escape sequence. # noqa: PLW2901 45 | elif char == "\\": 46 | escaped = not escaped 47 | else: 48 | escaped = False 49 | elif char == '"': 50 | is_inside_string = True 51 | escaped = False 52 | elif char == "{": 53 | stack.append("}") 54 | elif char == "[": 55 | stack.append("]") 56 | elif char in ("}", "]"): 57 | if stack and stack[-1] == char: 58 | stack.pop() 59 | else: 60 | # Mismatched closing character; the input is malformed. 61 | return None 62 | 63 | # Append the processed character to the new string. 64 | new_s += char 65 | 66 | # If we're still inside a string at the end of processing, 67 | # we need to close the string. 68 | if is_inside_string: 69 | new_s += '"' 70 | 71 | # Close any remaining open structures in the reverse order that they were opened. 72 | for closing_char in reversed(stack): 73 | new_s += closing_char 74 | 75 | # Attempt to parse the modified string as JSON. 76 | try: 77 | return json.loads(new_s, strict=strict) 78 | except json.JSONDecodeError: 79 | # If we still can't parse the string as JSON, return None to indicate failure. 80 | return None 81 | 82 | 83 | def _replace_new_line(match: re.Match[str]) -> str: 84 | value = match.group(2) 85 | value = re.sub(r"\n", r"\\n", value) 86 | value = re.sub(r"\r", r"\\r", value) 87 | value = re.sub(r"\t", r"\\t", value) 88 | value = re.sub(r'(? str: 94 | """ 95 | The LLM response for `action_input` may be a multiline 96 | string containing unescaped newlines, tabs or quotes. This function 97 | replaces those characters with their escaped counterparts. 98 | (newlines in JSON must be double-escaped: `\\n`) 99 | """ 100 | if isinstance(multiline_string, (bytes, bytearray)): 101 | multiline_string = multiline_string.decode() 102 | 103 | multiline_string = re.sub( 104 | r'("action_input"\:\s*")(.*)(")', 105 | _replace_new_line, 106 | multiline_string, 107 | flags=re.DOTALL, 108 | ) 109 | 110 | return multiline_string 111 | 112 | 113 | def parse_json_markdown(json_string: str, *, parser: Callable[[str], Any] = json.loads) -> Any: 114 | """ 115 | Parse a JSON string from a Markdown string. 116 | 117 | Args: 118 | json_string: The Markdown string. 119 | 120 | Returns: 121 | The parsed JSON object as a Python dictionary. 122 | """ 123 | # Try to find JSON string within triple backticks 124 | match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL) 125 | 126 | # If no match found, assume the entire string is a JSON string 127 | if match is None: 128 | json_str = json_string 129 | else: 130 | # If match found, use the content within the backticks 131 | json_str = match.group(2) 132 | 133 | # Strip whitespace and newlines from the start and end 134 | json_str = json_str.strip() 135 | 136 | # handle newlines and other special characters inside the returned value 137 | json_str = _custom_parser(json_str) 138 | 139 | # Parse the JSON string into a Python dictionary 140 | parsed = parser(json_str) 141 | 142 | return parsed 143 | 144 | 145 | class SimpleJsonOutputParser: 146 | """Parse the output of an LLM call to a JSON object.""" 147 | 148 | def parse(self, text: str) -> Any: 149 | text = text.strip() 150 | try: 151 | return parse_json_markdown(text.strip(), parser=parse_partial_json) 152 | except json.JSONDecodeError as e: 153 | raise OutputParserError(f"Invalid json output: {text}") from e 154 | 155 | @property 156 | def _type(self) -> str: 157 | return "simple_json_output_parser" 158 | 159 | 160 | class JsonListParser: 161 | """Parse the output of an LLM call to a list of JSON objects.""" 162 | 163 | def parse(self, text: str) -> Any: 164 | text = text.strip() 165 | try: 166 | ret = parse_json_markdown(text.strip(), parser=parse_partial_json) 167 | except json.JSONDecodeError as e: 168 | raise OutputParserError(f"Invalid json listoutput: {text}") from e 169 | 170 | if not isinstance(ret, list): 171 | raise OutputParserError(f"Invalid json list output: {text}") 172 | 173 | return ret 174 | 175 | @property 176 | def _type(self) -> str: 177 | return "json_list_parser" 178 | 179 | class SQLParser: 180 | """Parse the sql output of an LLM call. 181 | 182 | sql output is expected to be in the format: 183 | --- 184 | xxx 185 | ```sql 186 | 187 | ``` 188 | xxx 189 | --- 190 | 191 | Sometime the output is in the format: 192 | --- 193 | 194 | ``` 195 | xxxxxx 196 | --- 197 | """ 198 | 199 | @property 200 | def _type(self) -> str: 201 | """Return the type key.""" 202 | return "sql_parser" 203 | 204 | def parse(self, text: str) -> Dict[str, str]: 205 | """Parse the output of an LLM call. 206 | 207 | Args: 208 | text (str): output of an LLM call 209 | 210 | Returns: 211 | Dict[str, str]: a dict of parsed output 212 | - sql: the sql query 213 | - text: the original output text 214 | 215 | """ 216 | origin_text = text 217 | sql_split = text.split('```sql') 218 | if len(sql_split) > 1: 219 | text = sql_split[1] 220 | try: 221 | sql = text.split('```')[0] 222 | except IndexError: 223 | raise OutputParserError(f"Invalid sql output: {text}") 224 | return {"sql": sql, "text": origin_text} 225 | -------------------------------------------------------------------------------- /src/langeval/providers/provider.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import logging 3 | from typing import Any, List, Optional, Union 4 | 5 | import yaml 6 | 7 | try: 8 | import pydantic.v1 as pc 9 | except ImportError: 10 | import pydantic as pc 11 | 12 | 13 | from langeval.models import LLM, Message 14 | from langeval.providers.exception import ProviderRunError 15 | from langeval.providers.output_parser import JsonListParser, OutputParserError, SimpleJsonOutputParser, SQLParser 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class ProviderType(str, enum.Enum): 21 | """Provider Type""" 22 | 23 | Completion = "completion" 24 | ChatCompletion = "chat_completion" 25 | Execute = "execute" 26 | 27 | 28 | class ExecSettings(pc.BaseModel): 29 | """Exec settings""" 30 | 31 | command: str = pc.Field(..., min_length=1, max_length=1024) 32 | kwargs: dict = {} 33 | 34 | 35 | class CompletionSettings(pc.BaseModel): 36 | """Completion settings""" 37 | 38 | llm: LLM 39 | prompt: str 40 | 41 | 42 | class ChatCompletionSettings(pc.BaseModel): 43 | """ChatCompletion settings""" 44 | 45 | llm: LLM 46 | messages: List[Message] 47 | 48 | 49 | class OutputParser(pc.BaseModel): 50 | """Output Parser""" 51 | 52 | name: str = pc.Field(default="text") 53 | # json parser: 54 | # output_keys defines the keys to be extracted from the json output 55 | # if not defined, all keys will be extracted 56 | # match parser: 57 | # match_key: the key to be matched 58 | # match_re: the regex to be matched 59 | kwargs: dict = {} 60 | 61 | def parse(self, text: str) -> dict[str, Any]: 62 | if self.name == "text": 63 | return {"text": text} 64 | elif self.name == "sql": 65 | try: 66 | resp = SQLParser().parse(text) 67 | if not resp: 68 | raise ProviderRunError( 69 | f"output parser sql failed {text} -> {resp}") 70 | except OutputParserError as e: 71 | raise ProviderRunError( 72 | f"output parser sql failed {text}: {e}") from e 73 | return resp 74 | elif self.name == "json": 75 | try: 76 | resp = SimpleJsonOutputParser().parse(text) 77 | if not resp: 78 | raise ProviderRunError( 79 | f"output parser failed {text} -> {resp}") 80 | except OutputParserError as e: 81 | raise ProviderRunError( 82 | f"output parser failed {text}: {e}") from e 83 | keys = self.kwargs.get("output_keys", None) 84 | final_resp = {} 85 | if keys is not None: 86 | for key in keys: 87 | if key not in resp: 88 | raise ProviderRunError( 89 | f"output parser failed lack keys: {text} -> {resp}") 90 | final_resp[key] = resp[key] 91 | else: 92 | final_resp = resp.copy() 93 | final_resp["_text"] = text 94 | return final_resp 95 | elif self.name == "json_list": 96 | try: 97 | resp = JsonListParser().parse(text) 98 | except OutputParserError as e: 99 | raise ProviderRunError( 100 | f"output parser failed {text}: {e}") from e 101 | final_resp = { 102 | "list": resp, 103 | "_text": text, 104 | } 105 | return final_resp 106 | elif self.name == "match": 107 | match_key = self.kwargs.get("match_key", None) 108 | match_re = self.kwargs.get("match_re", None) 109 | if not match_key or not match_re: 110 | raise ProviderRunError( 111 | f"Invalid output parser: {self.name} kwargs: {self.kwargs}") 112 | import re 113 | matchs = re.findall(re.compile(match_re.strip()), text) 114 | if not matchs: 115 | raise ProviderRunError( 116 | f"output parser failed: {text} match '{match_re.strip()}' failed") 117 | # only match last element 118 | logger.debug(f"match_re: {match_re}, matchs: {matchs}") 119 | return {match_key: matchs[-1], "_text": text} 120 | else: 121 | raise ProviderRunError(f"Invalid output parser: {self.name}") 122 | 123 | 124 | class Provider(pc.BaseModel): 125 | """Provider Config""" 126 | 127 | # provider types, completion, chat_completion, execute 128 | type: ProviderType # noqa: A003 129 | input_variables: list[str] 130 | settings: Union[ExecSettings, CompletionSettings, ChatCompletionSettings] 131 | output_parser: OutputParser 132 | 133 | class Config: 134 | validate_assignment = True 135 | 136 | @pc.validator("type") 137 | def type_must_be_valid(cls, v): # noqa: N805 138 | if v not in [ProviderType.Completion, ProviderType.ChatCompletion, ProviderType.Execute]: 139 | raise ValueError( 140 | "type must be one of completion, chat_completion, execute") 141 | return v 142 | 143 | @classmethod 144 | def from_yaml(cls, yaml_str: str) -> Optional["Provider"]: 145 | if yaml_str == "": 146 | return None 147 | try: 148 | return cls(**yaml.safe_load(yaml_str)) 149 | except Exception as e: 150 | raise ValueError(f"Invalid yaml: {e}") from e 151 | 152 | def batch_call(self, inputs_list: list[dict[str, Any]], timeout: int): 153 | from langeval.providers.run import batch_call_exec, call_chat_completion, call_completion 154 | 155 | for key in self.input_variables: 156 | for inputs in inputs_list: 157 | if key not in inputs: 158 | raise ProviderRunError(f"Missing input variable: {key}") 159 | if self.type == "completion": 160 | return [call_completion(self, inputs, timeout) for inputs in inputs_list] 161 | elif self.type == "chat_completion": 162 | return [call_chat_completion(self, inputs, timeout) for inputs in inputs_list] 163 | elif self.type == "execute": 164 | return batch_call_exec(self, inputs_list, timeout) 165 | else: 166 | raise ProviderRunError(f"Invalid type: {self.type}") 167 | -------------------------------------------------------------------------------- /src/langeval/providers/run.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import subprocess 5 | from typing import Any 6 | 7 | import jinja2 8 | 9 | from langeval.models import ModelRunError 10 | from langeval.providers.exception import ProviderRunError 11 | from langeval.providers.provider import ChatCompletionSettings, CompletionSettings, ExecSettings, Provider 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def call_completion(conf: Provider, inputs: dict[str, Any], timeout: int) -> dict[str, str]: 17 | if conf.settings is None or not isinstance(conf.settings, CompletionSettings): 18 | raise ProviderRunError( 19 | f"call_completion invalid provider config: {conf}") 20 | 21 | prompt = jinja2.Template(conf.settings.prompt).render(**inputs) 22 | logger.info(f"call_completion: prompt={prompt}") 23 | 24 | try: 25 | text = conf.settings.llm.completion(prompt, timeout=timeout) 26 | except ModelRunError as e: 27 | raise ProviderRunError(f"call_completion failed: {e}") from e 28 | except Exception as e: 29 | raise ProviderRunError(f"call_completion failed: {e}") from e 30 | 31 | return conf.output_parser.parse(text) 32 | 33 | 34 | def call_chat_completion(conf: Provider, inputs: dict[str, Any], timeout: int): 35 | if conf.settings is None or not isinstance(conf.settings, ChatCompletionSettings): 36 | raise ProviderRunError( 37 | f"call_completion invalid provider config: {conf}") 38 | for message in conf.settings.messages: 39 | message.content = jinja2.Template(message.content).render(**inputs) 40 | logger.info(f"call_chat_completion: messages={conf.settings.messages}") 41 | 42 | try: 43 | text = conf.settings.llm.chat_completion(conf.settings.messages, timeout=timeout) 44 | except ModelRunError as e: 45 | raise ProviderRunError(f"call_chat_completion failed: {e}") from e 46 | except Exception as e: 47 | raise ProviderRunError(f"call_chat_completion failed: {e}") from e 48 | 49 | return conf.output_parser.parse(text) 50 | 51 | 52 | def batch_call_exec(conf: Provider, inputs_list: list[dict[str, Any]], timeout: int) -> list[dict[str, Any]]: 53 | if conf.settings is None or not isinstance(conf.settings, ExecSettings): 54 | raise ProviderRunError(f"call_exec invalid provider config: {conf}") 55 | command = conf.settings.command 56 | kwargs = conf.settings.kwargs or {} 57 | # Copy progress env. 58 | env = os.environ.copy() 59 | if kwargs.get("env"): 60 | env.update(kwargs["env"]) 61 | cwd = kwargs.get("cwd") or None 62 | exec_timeout = int(kwargs.get("timeout", 300)) # type: ignore 63 | timeout = min(timeout, exec_timeout) 64 | input_data = json.dumps(inputs_list, ensure_ascii=False) 65 | 66 | try: 67 | cp = subprocess.run( 68 | command, 69 | shell=True, # noqa: S602 70 | check=True, 71 | encoding="utf-8", 72 | env=env, 73 | cwd=cwd, 74 | input=input_data, 75 | stdout=subprocess.PIPE, 76 | ) 77 | except subprocess.CalledProcessError as e: 78 | raise ProviderRunError(f"call_exec failed: {e}") from e 79 | except Exception as e: 80 | raise ProviderRunError(f"call_exec failed: {e}") from e 81 | 82 | try: 83 | # list for string 84 | texts = json.loads(cp.stdout) 85 | except json.JSONDecodeError as e: 86 | raise ProviderRunError(f"call_exec output parser failed: {e}") from e 87 | 88 | return [conf.output_parser.parse(text) for text in texts] 89 | -------------------------------------------------------------------------------- /src/langeval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import EvalTask, Result # noqa 2 | from .runner import TaskRunner, TaskRunnerStatus # noqa 3 | -------------------------------------------------------------------------------- /src/langeval/tasks/ratelimiter.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | from types import TracebackType 4 | from typing import Optional, Type 5 | 6 | 7 | class ThreadingRateLimiter: 8 | def __init__(self, query_per_second: float): 9 | """Rate limiter for threading 10 | 11 | Args: 12 | query_per_second (float): query per second 13 | """ 14 | self._query_per_second = query_per_second 15 | if query_per_second < 0: 16 | raise ValueError("query_per_second must be positive") 17 | if query_per_second == 0: 18 | # No rate limit 19 | return 20 | if query_per_second > 1: 21 | self._query_per_period = query_per_second 22 | else: 23 | self._query_per_period = 1 24 | self._token_count = self._query_per_period 25 | self._last_leak_timestamp = time.perf_counter() 26 | self._sync_lock = threading.Lock() 27 | 28 | def _leak(self) -> None: 29 | timestamp = time.perf_counter() 30 | delta = timestamp - self._last_leak_timestamp 31 | self._last_leak_timestamp = timestamp 32 | self._token_count = min( 33 | self._query_per_period, 34 | self._token_count + delta * self._query_per_second, 35 | ) 36 | 37 | def __enter__(self) -> None: 38 | if self._query_per_second == 0: 39 | return 40 | with self._sync_lock: 41 | while True: 42 | self._leak() 43 | if self._token_count >= 1: 44 | self._token_count -= 1 45 | return 46 | time.sleep((1 - self._token_count) / self._query_per_second) 47 | 48 | def __exit__( 49 | self, 50 | exc_type: Optional[Type[BaseException]], 51 | exc_val: Optional[BaseException], 52 | exc_tb: Optional[TracebackType], 53 | ) -> None: 54 | """ 55 | exit 56 | """ 57 | return 58 | -------------------------------------------------------------------------------- /src/langeval/tasks/runner.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import json 3 | import logging 4 | import random 5 | import threading 6 | import time 7 | from datetime import datetime 8 | from typing import Optional, Tuple 9 | 10 | import pandas as pd 11 | 12 | try: 13 | import pydantic.v1 as pc 14 | except ImportError: 15 | import pydantic as pc 16 | 17 | from langeval.models.llms import LLM 18 | from langeval.tasks.task import EvalTask, Result 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TaskRunnerStatus(str, enum.Enum): 24 | PENDING = "PENDING" 25 | RUNNING = "RUNNING" 26 | FINISHED = "FINISHED" 27 | CANCELLED = "CANCELLED" 28 | FAILED = "FAILED" 29 | 30 | class Progress(pc.BaseModel): 31 | total: int = 0 32 | finished: int = 0 33 | failed: int = 0 34 | 35 | class TaskProgress(pc.BaseModel): 36 | run: Progress 37 | evals: dict[str, Progress] 38 | 39 | 40 | class TaskRunner: 41 | """ 42 | Run task in background thread 43 | """ 44 | 45 | def __init__( 46 | self, 47 | uuid: str, 48 | task: EvalTask, 49 | sample: int = 0, 50 | sample_seed: int = 42, 51 | default_eval_llm: Optional[LLM] = None, 52 | status_callback = None, 53 | log_callback = None, 54 | progress_callback = None, 55 | ) -> None: 56 | self.task = task 57 | self.uuid = uuid 58 | self.sample = sample 59 | self.sample_seed = sample_seed 60 | self.default_eval_llm = default_eval_llm 61 | 62 | # callback for status, log, progress updated. 63 | self.status_callback = status_callback 64 | self.log_callback = log_callback 65 | self.progress_callback = progress_callback 66 | 67 | # status with lock 68 | self._status_lock = threading.Lock() 69 | self._status = TaskRunnerStatus.PENDING 70 | 71 | self.thread = None 72 | self.finished_time = None 73 | self.cancel_event = threading.Event() 74 | 75 | self.progress: TaskProgress = TaskProgress(run=Progress(), evals={}) 76 | self.results: list[Result] = [] 77 | 78 | @property 79 | def status(self) -> TaskRunnerStatus: 80 | with self._status_lock: 81 | return self._status 82 | 83 | def status_json(self): 84 | return json.dumps( 85 | { 86 | "uuid": self.uuid, 87 | "status": self.status.value, 88 | "progress": self.progress.dict(), 89 | "finished_time": self.finished_time, 90 | "sample": self.sample, 91 | "sample_seed": self.sample_seed, 92 | } 93 | ) 94 | 95 | def set_status(self, status: TaskRunnerStatus): 96 | with self._status_lock: 97 | self._status = status 98 | if self.status_callback: 99 | self.status_callback(self.uuid, status) 100 | 101 | def update_task_log(self, log: str): 102 | logger.info(f"[task-{self.uuid}]{log}") 103 | log_entry = f"[{datetime.utcnow().isoformat()}][task-{self.uuid}]{log}\n" 104 | if self.log_callback: 105 | self.log_callback(self.uuid, log_entry) 106 | 107 | def update_task_progress(self, progress: TaskProgress, results: list[Result]): 108 | self.progress = progress 109 | if self.progress_callback: 110 | self.progress_callback(self.uuid, progress, results) 111 | 112 | def start(self): 113 | """ 114 | 启动任务运行器 115 | """ 116 | self.update_task_log("[runner.start] start task runner") 117 | self.set_status(TaskRunnerStatus.RUNNING) 118 | try: 119 | self.thread = threading.Thread(target=self._run, daemon=True) 120 | self.thread.start() 121 | except Exception as e: 122 | logger.error(f"[task-{self.uuid}][runner.start] failed to start task runner: {e}") 123 | self.update_task_log(f"[runner.start] failed to start task runner: {e}") 124 | self.set_status(TaskRunnerStatus.FAILED) 125 | 126 | def join(self, timeout=None): 127 | """ 128 | 等待任务运行结束 129 | """ 130 | if self.thread: 131 | self.thread.join(timeout=timeout) 132 | 133 | def cancel(self): 134 | """ 135 | 取消任务 136 | """ 137 | self.update_task_log("[runner.cancel] cancel task runner") 138 | self.set_status(TaskRunnerStatus.CANCELLED) 139 | self.cancel_event.set() 140 | 141 | def _run(self) -> None: 142 | """ 143 | 运行任务 144 | """ 145 | try: 146 | self.update_task_log(f"[runner._run] start task run : {self.task.run_config}") 147 | 148 | data = self.task.split_dataset() 149 | data_lists = [] 150 | for d in data: 151 | data_lists.extend([d] * self.task.run_config.rounds) 152 | 153 | self.update_task_log( 154 | f"[runner._run] task run: {len(data)} * {self.task.run_config.rounds} = {len(data_lists)}" 155 | ) 156 | if self.sample and len(data_lists) > 0: 157 | self.update_task_log("[runner._run] task sample to 1 data.") 158 | data_lists = random.Random(self.sample_seed).sample(data_lists, self.sample) 159 | total = len(data_lists) 160 | if self.results: 161 | self.update_task_log(f"[runner._run] task resume from {len(self.results)} results.") 162 | else: 163 | self.results = [Result(inputs=d) for d in data_lists] 164 | if self.task.provider is not None: 165 | self.update_task_log("[runner._run] provider start run.") 166 | # get finished result 167 | new_results = [] 168 | need_run = [] 169 | for result in self.results: 170 | if result.run.error or not result.run.outputs: 171 | need_run.append(result) 172 | else: 173 | new_results.append(result) 174 | self.update_task_log(f"[runner._run] provider resume from {len(new_results)} results.") 175 | progress = TaskProgress(run=Progress(total=total, finished=len(new_results)), evals={}) 176 | 177 | for result in self.task.run_provider(need_run, self.cancel_event): 178 | if result.run.error: 179 | progress.run.failed += 1 180 | else: 181 | progress.run.finished += 1 182 | self.update_task_log(f"[runner._run] task progress {progress}, result: {result}") 183 | self.update_task_progress(progress, [result]) 184 | new_results.append(result) 185 | 186 | # Check if task be cancelled 187 | if self.cancel_event.is_set(): 188 | logger.warn(f"[task-{self.uuid}][runner._run] task be cancelled") 189 | self.set_status(TaskRunnerStatus.CANCELLED) 190 | self.update_task_log("[runner._run] task cancelled") 191 | return 192 | self.results = new_results 193 | logger.info(f"[task-{self.uuid}][runner._run] end task run") 194 | if progress.run.failed == progress.run.total: 195 | self.finished_time = time.time() 196 | self.set_status(TaskRunnerStatus.FAILED) 197 | self.update_task_log("[runner._run] task failed because all run failed") 198 | return 199 | 200 | progress = self.progress 201 | for evaluator in self.task.evaluators: 202 | self.update_task_log(f"[runner._run] evaluator {evaluator.name} start run.") 203 | # get finished result 204 | new_results = [] 205 | need_run = [] 206 | for result in self.results: 207 | eval_result = result.evals.get(evaluator.name) 208 | if not eval_result or eval_result.error or not eval_result.outputs: 209 | need_run.append(result) 210 | else: 211 | new_results.append(result) 212 | self.update_task_log(f"[runner._run] evaluator {evaluator.name} resume from {len(new_results)} results, need run {len(need_run)} results.") 213 | progress.evals[evaluator.name] = Progress(total=total, finished=len(new_results)) 214 | for result in self.task.run_eval(evaluator, need_run, self.cancel_event, 215 | default_eval_llm=self.default_eval_llm): 216 | if result.evals[evaluator.name].error: 217 | progress.evals[evaluator.name].failed += 1 218 | else: 219 | progress.evals[evaluator.name].finished += 1 220 | self.update_task_log( 221 | f"[runner._run] task eval {evaluator.name} progress " 222 | f"{progress.evals[evaluator.name]}, result: {result}") 223 | self.update_task_progress(progress, [result]) 224 | new_results.append(result) 225 | 226 | # Check if task be cancelled 227 | if self.cancel_event.is_set(): 228 | logger.warn(f"[task-{self.uuid}][runner._run] task be cancelled") 229 | self.set_status(TaskRunnerStatus.CANCELLED) 230 | self.update_task_log("[runner._run] task cancelled") 231 | return 232 | logger.info(f"[task-{self.uuid}][runner._run] end task eval") 233 | self.results = new_results 234 | 235 | except Exception as e: 236 | logger.error(f"[task-{self.uuid}][runner._run] failed to run task : {e}", exc_info=True) 237 | self.set_status(TaskRunnerStatus.FAILED) 238 | self.update_task_log(f"[runner._run] failed to run task : {e}") 239 | 240 | def statistic(self) -> Tuple[pd.DataFrame, pd.DataFrame]: 241 | df = pd.DataFrame([i.dict() for i in self.results]) 242 | total_count = len(df) 243 | run_success_rate = 0 244 | run_success_count = 0 245 | run_average_time = 0 246 | eval_success_rate = 0 247 | eval_average_time = 0 248 | if "run" in df.columns: 249 | run_success_rate = df["run"].apply(lambda x: x["error"] == "").mean() 250 | run_success_count = df["run"].apply(lambda x: x["error"] == "").sum() 251 | run_average_time = df["run"].apply(lambda x: x["elapsed_secs"]).mean() 252 | if "evals" in df.columns: 253 | eval_success_rate = df["evals"].apply(lambda x: all(e["error"] == "" for e in x.values())).mean() 254 | eval_average_time = df["evals"].apply(lambda x: sum(e["elapsed_secs"] for e in x.values())).mean() 255 | running_stats = pd.DataFrame( 256 | [ 257 | { 258 | "Total count": f"{total_count}", 259 | "Run success rate": f"{run_success_rate:.2%}", 260 | "Run success count": f"{run_success_count}", 261 | "Run average secs": f"{run_average_time:.2f}", 262 | "Eval success rate": f"{eval_success_rate:.2%}", 263 | "Eval average secs": f"{eval_average_time:.2f}", 264 | } 265 | ] 266 | ) 267 | eval_stats = pd.DataFrame() 268 | 269 | # "evals": {"exact_match": { 270 | # "error": "", "outputs": {"exact_match": 1.0}, "elapsed_secs": 5.728999894927256e-06}} 271 | def flatten_outputs(data_row): 272 | """ 273 | 从嵌套字典中提取并展平 'outputs' 键下的内容。 274 | :param data_row: 包含嵌套字典的数据行。 275 | :return: 展平后的 'outputs' 字典。 276 | """ 277 | flattened = {} 278 | for key, value in data_row.items(): 279 | if "outputs" in value: 280 | for output_key, output_value in value["outputs"].items(): 281 | flattened_key = f"{key}.outputs.{output_key}" 282 | flattened[flattened_key] = output_value 283 | return flattened 284 | if "evals" in df.columns: 285 | flattened_evals = df["evals"].apply(flatten_outputs) 286 | flattened_df = pd.DataFrame(flattened_evals.tolist()) 287 | flattened_df.fillna(0.0, inplace=True) 288 | 289 | if not flattened_df.empty: 290 | eval_stats = pd.DataFrame(flattened_df).describe().T 291 | return running_stats, eval_stats 292 | -------------------------------------------------------------------------------- /src/langeval/tasks/task.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import io 3 | import json 4 | import logging 5 | import os 6 | import threading 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | from time import perf_counter 9 | from typing import Any, List, Optional 10 | 11 | import pandas as pd 12 | import yaml 13 | 14 | try: 15 | import pydantic.v1 as pc 16 | except ImportError: 17 | import pydantic as pc 18 | 19 | from langeval.evaluators import Evaluator 20 | from langeval.models import LLM 21 | from langeval.providers import Provider 22 | from langeval.tasks.ratelimiter import ThreadingRateLimiter 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class RunResult(pc.BaseModel): 27 | """Result for one data run""" 28 | error: str = "" 29 | outputs: dict[str, Any] = {} 30 | elapsed_secs: float = 0 31 | 32 | class EvalResult(pc.BaseModel): 33 | """Result for one data eval""" 34 | error: str = "" 35 | outputs: dict[str, Any] = {} 36 | elapsed_secs: float = 0 37 | 38 | 39 | class Result(pc.BaseModel): 40 | # uuid: str 41 | inputs: dict[str, Any] 42 | run: RunResult = RunResult() 43 | evals: dict[str, EvalResult] = {} 44 | 45 | def to_jsonl(self) -> str: 46 | return json.dumps(self.dict(), ensure_ascii=False) + "\n" 47 | 48 | @classmethod 49 | def from_json(cls, json_str: str) -> "Result": 50 | return cls(**json.loads(json_str)) 51 | 52 | 53 | class TaskRunConfig(pc.BaseModel): 54 | parallelism: int = pc.Field(default=1, ge=1, le=30) 55 | timeout: int = pc.Field(default=30, ge=1, le=600) 56 | rounds: int = pc.Field(default=1, ge=1, le=10) 57 | batch_size: int = pc.Field(default=1, ge=1, le=10000) 58 | query_per_second: float = pc.Field(default=0, ge=0.1, le=100) 59 | 60 | def to_yaml(self) -> str: 61 | return yaml.safe_dump(self.dict(exclude_unset=True), encoding="utf-8", allow_unicode=True).decode("utf-8") 62 | 63 | @staticmethod 64 | def from_yaml(yaml_str: str) -> "TaskRunConfig": 65 | try: 66 | return TaskRunConfig(**yaml.safe_load(yaml_str)) 67 | except Exception as e: 68 | raise ValueError(f"Invalid yaml: {e}") from e 69 | 70 | 71 | class EvalTask(pc.BaseModel): 72 | # config 73 | provider: Optional[Provider] = None 74 | input_dataset_binary: Optional[bytes] = None 75 | # Only jsonl, csv supported 76 | input_dataset_name: str = pc.Field(..., min_length=1, max_length=255) 77 | # evaluator 78 | evaluators: List[Evaluator] 79 | # Run config 80 | run_config: TaskRunConfig 81 | 82 | class Config: 83 | validate_assignment = True 84 | 85 | @pc.validator("input_dataset_name") 86 | def input_dataset_name_must_be_valid(cls, v): # noqa: N805 87 | if not v.endswith(".csv") and not v.endswith(".jsonl"): 88 | raise ValueError(f"Invalid input_dataset_name: {v}") 89 | return v 90 | 91 | @classmethod 92 | def from_yaml(cls, yaml_str: str, dataset_dir: Optional[str] = None) -> "EvalTask": 93 | obj = yaml.safe_load(yaml_str) 94 | input_dataset_name = obj.get("input_dataset_name") 95 | if input_dataset_name: 96 | if dataset_dir: 97 | input_dataset_base_name = os.path.basename(input_dataset_name) 98 | path = os.path.join(dataset_dir, input_dataset_base_name) 99 | else: 100 | path = input_dataset_name 101 | with open(path, "rb") as f: 102 | obj["input_dataset_binary"] = f.read() 103 | logger.debug(f"EvalTask.from_yaml obj: {obj}") 104 | task = cls(**obj) 105 | logger.debug(f"EvalTask.from_yaml task: {task}") 106 | return task 107 | 108 | def run_provider(self, data_list: list[Result], stop_event: threading.Event): 109 | """Run data list with batch""" 110 | batch_size = self.run_config.batch_size 111 | batch_data_list = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)] 112 | limiter = ThreadingRateLimiter(self.run_config.query_per_second) 113 | 114 | with ThreadPoolExecutor(max_workers=self.run_config.parallelism) as executor: 115 | # Submit tasks for execution 116 | futures = [ 117 | executor.submit(self.batch_run, batch_data=batch_data, limiter=limiter) for batch_data in batch_data_list 118 | ] 119 | 120 | # Collect results from completed tasks 121 | for future in as_completed(futures): 122 | if stop_event.is_set(): 123 | return 124 | batch_result = future.result() 125 | yield from batch_result 126 | 127 | 128 | def run_eval(self, evaluator: Evaluator, data_list: list[Result], stop_event: threading.Event, default_eval_llm: Optional[LLM] = None): 129 | """Eval data list with batch""" 130 | # TODO seperate eval run config 131 | batch_size = self.run_config.batch_size 132 | batch_data_list = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)] 133 | limiter = ThreadingRateLimiter(evaluator.query_per_second) 134 | with ThreadPoolExecutor(max_workers=self.run_config.parallelism) as executor: 135 | # Submit tasks for execution 136 | futures = [ 137 | executor.submit( 138 | self.batch_eval, 139 | evaluator=evaluator, 140 | batch_data=batch_data, 141 | limiter=limiter, 142 | default_eval_llm=default_eval_llm, 143 | ) for batch_data in batch_data_list 144 | ] 145 | 146 | # Collect results from completed tasks 147 | for future in as_completed(futures): 148 | if stop_event.is_set(): 149 | return 150 | batch_result = future.result() 151 | yield from batch_result 152 | 153 | def batch_run(self, batch_data: list[Result], limiter: ThreadingRateLimiter) -> list[Result]: 154 | """Batch run data""" 155 | start = perf_counter() 156 | run_error = "" 157 | if self.provider is not None: 158 | with limiter: 159 | inputs = [data.inputs for data in batch_data] 160 | try: 161 | # 1. 首先调用 LLM 162 | run_outputs = self.provider.batch_call( 163 | inputs, timeout=self.run_config.timeout) 164 | logger.debug(f"provider call: {inputs} -> {run_outputs}") 165 | for i, data in enumerate(batch_data): 166 | data.run.outputs = run_outputs[i] 167 | except Exception as e: 168 | logger.error(f"provider call failed: {e}", exc_info=True) 169 | run_error = str(e) 170 | 171 | for data in batch_data: 172 | data.run.error = run_error 173 | data.run.elapsed_secs = perf_counter() - start 174 | return batch_data 175 | 176 | def batch_eval(self, 177 | evaluator: Evaluator, 178 | batch_data: list[Result], 179 | limiter: ThreadingRateLimiter, 180 | default_eval_llm: Optional[LLM] = None) -> list[Result]: 181 | start = perf_counter() 182 | run_error = "" 183 | with limiter: 184 | for data in batch_data: 185 | data.evals[evaluator.name] = EvalResult() 186 | try: 187 | eval_outputs = evaluator.batch_call( 188 | batch_inputs=[data.inputs for data in batch_data], 189 | batch_outputs=[data.run.outputs for data in batch_data], 190 | timeout=self.run_config.timeout, 191 | default_llm=default_eval_llm, 192 | ) 193 | for i, data in enumerate(batch_data): 194 | data.evals[evaluator.name].outputs = eval_outputs[i] 195 | except Exception as e: 196 | logger.warning(f"evaluator call failed: {e}", exc_info=True) 197 | run_error = str(e) 198 | 199 | for data in batch_data: 200 | data.evals[evaluator.name].error = run_error 201 | data.evals[evaluator.name].elapsed_secs = perf_counter() - start 202 | return batch_data 203 | 204 | def split_dataset(self) -> list[dict[str, Any]]: 205 | if self.input_dataset_name.endswith(".csv"): 206 | return self.split_csv_dataset() 207 | elif self.input_dataset_name.endswith(".jsonl"): 208 | return self.split_jsonl_dataset() 209 | else: 210 | raise ValueError( 211 | f"Invalid input_dataset_name: {self.input_dataset_name}") 212 | 213 | def split_csv_dataset(self) -> list[dict[str, Any]]: 214 | if not self.input_dataset_binary: 215 | return [] 216 | data_list = [] 217 | with io.StringIO(self.input_dataset_binary.decode("utf-8")) as csvfile: 218 | reader = csv.DictReader(csvfile) 219 | for row in reader: 220 | data_list.append(row) 221 | return data_list 222 | 223 | def split_jsonl_dataset(self) -> list[dict[str, Any]]: 224 | if not self.input_dataset_binary: 225 | return [] 226 | data_list = [] 227 | for line in self.input_dataset_binary.decode("utf-8").split("\n"): 228 | if line.strip(): 229 | data_list.append(json.loads(line.strip())) 230 | return data_list 231 | 232 | def input_dataset_pd(self, limit: int = 5) -> pd.DataFrame: 233 | data = self.split_dataset() 234 | if len(data) > limit: 235 | data = data[:limit] 236 | return pd.DataFrame(data) 237 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninehills/langeval/b19123c2eeea375b06ef2b2b47f862cefba6aa66/tests/__init__.py --------------------------------------------------------------------------------