├── README.md
├── README_zh.md
├── data
├── dataset.json
├── law_library.jsonl
└── samples
│ ├── generated_responses.jsonl
│ ├── retrieval_Qwen2-1.5B.jsonl
│ └── rewrite_question.jsonl
└── src
├── config
├── config.py
└── template
│ └── prompt.txt
├── eval
├── evaluator.py
├── llm_as_judge
│ ├── make_prompt.py
│ ├── run_judge.py
│ └── use_template.py
├── metrics.py
└── retrieval_metrics.py
├── generate
├── data_processor.py
├── generator.py
└── prompt_builder.py
├── pipeline.py
├── process
├── processor.py
└── rewriter.py
├── retrieval
├── dense_retriever.py
├── lexical_matching.py
└── run_retrieval.py
└── utils
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | LexRAG: Benchmarking Retrieval-Augmented Generation in Multi-Turn Legal Consultation Conversation
6 |
7 | :book:中文 |
8 | English
9 |
10 | Welcome to LexiT, the dedicated toolkit for RAG in the legal domain.
11 |
12 | ## :link: Introduction
13 | To advance RAG system research in the legal domain, we’ve proposed LexiT, a modular and scalable RAG toolkit for legal researchers. Although there are some general-domain RAG toolkits available, they do not support multi-turn conversations and evaluations tailored to the legal domain. LexiT consists of three components: **Data**, **Pipeline**, and **Evaluation**. It integrates all elements of the RAG process into a unified framework and supports standalone applications. This modular design enhances flexibility and allows for high customizability in evaluating different legal scenarios.
14 |
15 |

16 |
17 |
18 | ## :books: Data
19 | * The data component consists of two key elements: input conversations and corpora.
20 | * The conversation format can be either single-turn or multi-turn. Multi-turn conversations provide previous dialogue history as context.
21 | The dataset ```./data/dataset.json``` contains 1,013 multi-turn conversations, each with 5 rounds of questions and responses.
22 | * For the corpora, we collect raw data from three different sources. In addition to Legal Articles, which serve as the candidate corpus in this paper, Legal Books and Legal Cases are also included in the toolkit for researchers’ convenience. Specifically, Legal Articles contains 17,228 provisions from various Chinese statutory laws.
23 | The corpus is stored in ```./data/law_library.jsonl```.
24 |
25 |

26 |
27 |
28 | ## :rocket: Pipeline
29 | ### :bookmark_tabs: Processor
30 | The processor is responsible for converting the conversation into queries used by the retriever. There are several strategies for constructing the query, including using the last question, the entire conversation context, or the entire query history. Run ```./src/pipeline.py``` :
31 | ```
32 | pipeline = ProcessorPipeline()
33 | pipeline.run_processor(
34 | process_type="process_type",
35 | original_data_path="data/dataset.json",
36 | output_path="data/current_question.jsonl"
37 | )
38 | ```
39 | ```--process_type```: ```current_question``` ```prefix_question``` ```prefix_question_answer``` ```suffix_question``` the strategy for constructing the query
40 | ```--original_data_path```: the path to the conversation data you want to process
41 | ```--output_path```: the path for output
42 |
43 | Moreover, we also predefined a query rewrite strategy, which employs an LLM to integrate all necessary context into a clear, standalone question.
44 | ```
45 | pipeline = ProcessorPipeline(model_type="")
46 | pipeline.run_processor(
47 | process_type="rewrite_question",
48 | original_data_path="data/dataset.json",
49 | output_path="data/rewrite_question.jsonl",
50 | max_retries=5,
51 | max_parallel=32,
52 | batch_size=20
53 | )
54 | ```
55 | ```--model_type```: We provide some default models that are stored in ```./src/config/config.py```, which can use by changing the configuration information. If you want to use other models, you can replace ```model_type=""``` with ```config=``` and customise the configuration information.
56 | ```--max_retries``` ```--max_parallel```: parallel processing parameter
57 | ```--batch_size```: batch size
58 |
59 | You can check the results in ```output_path```. A sample processed data is ```./data/samples/rewrite_question.jsonl``` which you can see processed query in ```"question"```.
60 |
61 | ### :bookmark_tabs: Retriever
62 | #### Dense Retrieval
63 | We support advanced models such as BGE and GTE. Users can encode vectors using locally loaded models or API calls. We employ the ```Faiss``` for index construction which can support three mainstream faiss types: ```FlatIP```, ```HNSW``` and ```IVF```.
64 | * For API calls
65 | Run ```./src/pipeline.py``` :
66 | ```
67 | openai_config = {
68 | "api_key": "your_api_key",
69 | "base_url": "your_base_url"
70 | }
71 | pipeline = RetrieverPipeline(config=openai_config)
72 | pipeline.run_retriever(
73 | model_type="openai",
74 | model_name="model_name",
75 | faiss_type="faiss_type",
76 | question_file_path="data/rewrite_question.jsonl",
77 | law_path="data/law_library.jsonl"
78 | )
79 | ```
80 | ```--model_name```: the model for embedding
81 | ```--question_file_path```: the path for processed queries (by *Processor*)
82 | ```--law_path```: the path for corpus
83 | * For loaded models
84 | ```
85 | pipeline = RetrieverPipeline()
86 | pipeline.run_retriever(
87 | model_type="BGE-base-zh",
88 | faiss_type="faiss_type",
89 | question_file_path="data/rewrite_question.jsonl",
90 | law_path="data/law_library.jsonl"
91 | )
92 | ```
93 | ```--model_type```: the model for embedding
94 |
95 | #### Sparse Retrieval
96 | For lexical matching, we use the ```Pyserini``` library to implement ```BM25``` and ```QLD``` while supporting ```bm25s``` to implement ```BM25```.
97 | * For BM25:
98 | ```
99 | pipeline = RetrieverPipeline()
100 | pipeline.run_retriever(
101 | model_type="bm25",
102 | bm25_backend="bm25_backend",
103 | question_file_path="data/rewrite_question.jsonl",
104 | law_path="data/law_library.jsonl"
105 | )
106 | ```
107 | ```--bm25_backend```: ```bm25s``` ```pyserini``` the method for building bm25 index
108 | ```--question_file_path```: the path for processed queries (by *Processor*)
109 | ```--law_path```: the path for corpus
110 |
111 | * For QLQ:
112 | ```
113 | pipeline = RetrieverPipeline()
114 | pipeline.run_retriever(
115 | model_type="qld",
116 | question_file_path="data/rewrite_question.jsonl",
117 | law_path="data/law_library.jsonl"
118 | )
119 | ```
120 |
121 | > For Dense Retrieval, You can check the index in ```./data/retrieval/law_index_{model_type}.faiss``` and the results in ```./data/retrieval/res/retrieval_{model_type}.jsonl```. A sample retrieve data is ```./data/samples/retrieval_Qwen2-1.5B.jsonl``` which you can see retrieve recall in ```"recall"```.
122 |
123 | > For Sparse Retrieval, You can check the index in ```./data/retrieval/pyserini_index```(pyserini) and the results in ```./data/retrieval/res/retrieval_{model_type}_{bm25_backend}.jsonl```.
124 |
125 | ### :bookmark_tabs: Generator
126 | We support mainstream LLMs for generating responses. Run ```./src/pipeline.py``` :
127 | ```
128 | pipeline = GeneratorPipeline(model_type="")
129 | pipeline.run_generator(
130 | raw_data_path="data/dataset.json",
131 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
132 | max_retries=5,
133 | max_parallel=32,
134 | top_n=5,
135 | batch_size=20
136 | )
137 | ```
138 | ```--model_type```: Support common LLMs, just enter the model name in ```model_type``` (for common models, you need to modify the corresponding configuration information in ```./src/config/config.py```)
139 | ```--raw_data_path```: the path for conversation data which includes queries
140 | ```--retrieval_data_path```: the path for retrieve data
141 | ```--max_retries``` ```--max_parallel```: parallel processing parameter
142 | ```--batch_size```: batch size
143 | ```--top_n```: use the top_n of the retrieved return laws as references
144 |
145 |
146 | We also support for response generation using ```vllm```, ```huggingface``` and local models.
147 | * For ```vllm```
148 | ```
149 | custom_config = {
150 | "model_type": "vllm",
151 | "model_path": "vllm_model_path",
152 | "gpu_num": 2
153 | }
154 | pipeline = GeneratorPipeline(model_type="vllm", config=custom_config)
155 | pipeline.run_generator(
156 | raw_data_path="data/dataset.json",
157 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
158 | max_retries=5,
159 | max_parallel=32,
160 | top_n=5,
161 | batch_size=20
162 | )
163 | ```
164 | * For ```huggingface```
165 | ```
166 | hf_config = {
167 | "model_type": "huggingface",
168 | "model_path": "hf_model_path"
169 | }
170 | pipeline = GeneratorPipeline(
171 | model_type="huggingface",
172 | config=hf_config,
173 | )
174 | pipeline.run_generator(
175 | raw_data_path="data/dataset.json",
176 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
177 | max_retries=5,
178 | max_parallel=32,
179 | top_n=5,
180 | batch_size=20
181 | )
182 | ```
183 | * For local model
184 | ```
185 | local_config = {
186 | "model_type": "local",
187 | "model_path": "local_model_path"
188 | }
189 | pipeline = GeneratorPipeline(
190 | model_type="local",
191 | config=local_config,
192 | )
193 | pipeline.run_generator(
194 | raw_data_path="data/dataset.json",
195 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
196 | max_retries=5,
197 | max_parallel=32,
198 | top_n=5,
199 | batch_size=20
200 | )
201 | ```
202 |
203 |
204 | We supports flexible customisation of the input prompt. By default we use our defined ```LegalPromptBuilder```, you can choose to use ```CustomSystemPromptBuilder``` to customise the system content, or ```FullCustomPromptBuilder``` for full prompt customisation.
205 | * For ```CustomSystemPromptBuilder```:
206 | ```
207 | from generate.prompt_builder import LegalPromptBuilder, CustomSystemPromptBuilder, FullCustomPromptBuilder
208 | custom_prompt = CustomSystemPromptBuilder("请用一句话回答法律问题:")
209 | pipeline = GeneratorPipeline(
210 | model_type="",
211 | prompt_builder=custom_prompt
212 | )
213 | pipeline.run_generator(
214 | raw_data_path="data/dataset.json",
215 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
216 | max_retries=5,
217 | max_parallel=32,
218 | top_n=5,
219 | batch_size=20
220 | )
221 | ```
222 | ```--prompt_builder```: you can use ```CustomSystemPromptBuilder(" ")``` customising the system base used by LLM
223 |
224 | * For ```FullCustomPromptBuilder```:
225 | ```
226 | def full_custom_builder(history, question, articles):
227 | return [
228 | {"role": "user", "content": f"请以“回答如下:”为开头回答\n问题:{question}(相关法条:{','.join(articles)})"}
229 | ]
230 |
231 | pipeline = GeneratorPipeline(
232 | model_type="",
233 | prompt_builder=FullCustomPromptBuilder(full_custom_builder)
234 | )
235 | pipeline.run_generator(
236 | raw_data_path="data/dataset.json",
237 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
238 | max_retries=5,
239 | max_parallel=32,
240 | top_n=5,
241 | batch_size=20
242 | )
243 | ```
244 | ```--prompt_builder```: ```FullCustomPromptBuilder``` supports ```history```, ```question```, ```articles``` as input, you can customize the prompt strategy first, and then used it as prompt_builder.
245 | > ```history```: conversation history
246 | > ```question```: current query for LLM
247 | > ```articles```: the retrieved return ```--top_n``` articles as references
248 |
249 | You can check the results in ```./data/generated_responses.jsonl```. A sample processed data is ```./data/samples/generated_responses.jsonl```. The ```.jsonl``` file format for each line is as follows:
250 | ```
251 | {"id": "xxx", "question": "...", "response": "..."}
252 | ```
253 |
254 | ## :pencil: Evaluation
255 | ### Generation Evaluator
256 | The generation evaluator measures the consistency between generated responses and reference answers, supporting automated metrics like ROUGE, BLEU, METEOR, and BERTScore. Run ```./src/pipeline.py```:
257 | ```
258 | pipeline = EvaluatorPipeline()
259 | pipeline.run_evaluator(
260 | eval_type="generation",
261 | metrics=["bleu", "rouge", "bert-score", "keyword_accuracy", "char_scores", "meteor"],
262 | data_path="data/dataset.json",
263 | response_file="response_file_path"
264 | )
265 | ```
266 | ```--data_path```: the path to original query dataset
267 | ```--response_file```: the path to LLM's generated responses
268 |
269 | ### Retrieval Evaluator
270 | The retrieval evaluator assesses the relevance and accuracy of retrieved documents, supporting the calculation of mainstream automated metrics such as NDCG, Recall, MRR, Precision, and F1.
271 | ```
272 | pipeline = EvaluatorPipeline()
273 | pipeline.run_evaluator(
274 | eval_type="retrieval",
275 | results_path="retrieval_results_path",
276 | metrics=["recall", "precision", "f1", "ndcg", "mrr"],
277 | k_values=[1, 3, 5]
278 | )
279 | ```
280 | ```--results_path```: the path for retrieval results
281 | ```--k_values```: consider the highest k scores in the ranking
282 | > You can check the results in ```./data/retrieval/report.jsonl```.
283 |
284 | ### LLM-as-a-Judge
285 | LLM judge evaluates response quality through multidimensional chain of thought reasoning. The prompt we used for LLM-as-a-Judge is ```./src/config/template/prompt.txt```
286 | ```
287 | pipeline = EvaluatorPipeline("model_type")
288 | pipeline.run_evaluator(
289 | eval_type="llm_judge",
290 | data_path="data/dataset.json",
291 | gen_path="generated_responses_path"
292 | )
293 | ```
294 | ```--model_type```: the model as LLM Judge
295 | ```--data_path```: the path to original query dataset
296 | ```--gen_path```: the path to LLM's generated responses
297 | > You can check the results in ```./data/results/turn{turn}/judge_results.jsonl```.
298 |
299 |
300 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | LexRAG: 多轮法律咨询对话中的检索-增强生成基准测试
6 |
7 | :book:中文 |
8 | English
9 |
10 | 欢迎来到LexiT,一个法律领域的检索增强生成(RAG)专用工具包。
11 |
12 | ## :link: Introduction
13 | 为了推进法律领域的 RAG 系统研究,我们为法律研究人员提出了模块化、可扩展的 RAG 工具包 LexiT。虽然目前已有一些通用领域的 RAG 工具包,但它们并不支持多轮对话和针对法律领域的评估。LexiT 由三个部分组成: ```Data``` ```Pipeline``` ```Evaluation```。LexiT将 RAG 流程的所有要素整合到一个统一的框架中,并支持独立应用。这种模块化设计提高了灵活性,在评估不同的法律情况时具有高度的可定制性。
14 |
15 |

16 |
17 |
18 | ## :books: Data
19 | * 数据组件由两个关键要素组成:输入对话和语料库
20 | * 对话格式可以是单轮对话,也可以是多轮对话。多轮对话提供以前的对话历史作为背景。
21 | 我们提供一个确保准确性和专业性的数据集 ```./data/dataset.json``` ,包含 1,013 个多轮对话,每个对话有 5 轮问题和回答。
22 | * 对于语料库 ```./data/law_library.jsonl``` ,我们从三个不同来源收集原始数据。除了作为本文候选语料库的法律条文外,为方便研究人员使用,工具包中还包括法律文书和法律案例。具体来说,法律条文包含了中国各种成文法中的 17,228 个条文。
23 |
24 |

25 |
26 |
27 | ## :rocket: Pipeline
28 | ### :bookmark_tabs: Processor
29 | Processor将对话转换成Retriever使用的查询。我们支持几种构建查询的策略,包括使用最后一个问题、整个对话上下文或整个查询历史等。 运行 ```./src/pipeline.py``` :
30 | ```
31 | pipeline = ProcessorPipeline()
32 | pipeline.run_processor(
33 | process_type="process_type",
34 | original_data_path="data/dataset.json",
35 | output_path="data/current_question.jsonl"
36 | )
37 | ```
38 | ```--process_type```: ```current_question``` ```prefix_question``` ```prefix_question_answer``` ```suffix_question``` 构建查询的策略
39 | ```--original_data_path```: 需要进行处理的对话数据路径
40 | ```--output_path```: 输出路径
41 |
42 | 此外,我们还预定义了一种查询重写策略,利用LLM将所有必要的上下文整合为一个清晰、独立的问题。
43 | ```
44 | pipeline = ProcessorPipeline(model_type="")
45 | pipeline.run_processor(
46 | process_type="rewrite_question",
47 | original_data_path="data/dataset.json",
48 | output_path="data/rewrite_question.jsonl",
49 | max_retries=5,
50 | max_parallel=32,
51 | batch_size=20
52 | )
53 | ```
54 | ```--model_type```: 我们在 ```./src/config/config.py``` 中提供了一些默认模型,可以通过更改模型的配置信息进行使用。如果您想使用其他模型,可以将 ```model_type=""``` 替换为 ```config=``` ,并自定义配置信息。
55 | ```--max_retries``` ```--max_parallel```: 并行处理参数
56 | ```--batch_size```: 批次大小
57 |
58 | 您可以在 ```output_path```查看输出结果。 以重写查询策略输出结果 ```./data/samples/rewrite_question.jsonl``` 为例,您可以在 ```"question"``` 中查看处理后查询。
59 |
60 | ### :bookmark_tabs: Retriever
61 | #### Dense Retrieval
62 | 我们支持 BGE 和 GTE 等高级模型,您可以使用本地加载的模型或 API 调用对向量进行编码。使用```Faiss```进行索引构建,支持三种faiss类型:```FlatIP```、```HNSW```和```IVF```。
63 | * 对于API调用
64 |
65 | 运行 ```./src/pipeline.py``` :
66 | ```
67 | openai_config = {
68 | "api_key": "your_api_key",
69 | "base_url": "your_base_url"
70 | }
71 | pipeline = RetrieverPipeline(config=openai_config)
72 | pipeline.run_retriever(
73 | model_type="openai",
74 | model_name="model_name",
75 | faiss_type="faiss_type",
76 | question_file_path="data/rewrite_question.jsonl",
77 | law_path="data/law_library.jsonl"
78 | )
79 | ```
80 | ```--model_name```: embedding模型名称
81 | ```--question_file_path```: 处理后文件路径 (by *Processor*)
82 | ```--law_path```: 语料库路径
83 |
84 | * 对于加载的模型
85 | ```
86 | pipeline = RetrieverPipeline()
87 | pipeline.run_retriever(
88 | model_type="BGE-base-zh",
89 | faiss_type="faiss_type",
90 | question_file_path="data/rewrite_question.jsonl",
91 | law_path="data/law_library.jsonl"
92 | )
93 | ```
94 | ```--model_type```: embedding模型名称
95 |
96 | #### Sparse Retrieval
97 | 对于词法匹配,我们使用 ```Pyserini``` 库实现 ```BM25``` 和 ```QLD``` ,同时支持使用 ```bm25s``` 实现 ```BM25``` 。
98 | * 对于BM25:
99 | ```
100 | pipeline = RetrieverPipeline()
101 | pipeline.run_retriever(
102 | model_type="bm25",
103 | bm25_backend="bm25_backend",
104 | question_file_path="data/rewrite_question.jsonl",
105 | law_path="data/law_library.jsonl"
106 | )
107 | ```
108 | ```--bm25_backend```: ```bm25s``` ```pyserini``` 构建bm25索引的方法
109 | ```--question_file_path```: 处理后文件路径 (by *Processor*)
110 | ```--law_path```: 语料库路径
111 |
112 | * 对于QLQ:
113 | ```
114 | pipeline = RetrieverPipeline()
115 | pipeline.run_retriever(
116 | model_type="qld",
117 | question_file_path="data/rewrite_question.jsonl",
118 | law_path="data/law_library.jsonl"
119 | )
120 | ```
121 |
122 | > 使用Dense Retrieval, 您可以在 ```./data/retrieval/law_index_{model_type}.faiss``` 查看索引并在 ```./data/retrieval/res/retrieval_{model_type}.jsonl``` 查看输出检索结果。 以GTE_Qwen2-1.5B模型输出结果 ```./data/samples/retrieval_Qwen2-1.5B.jsonl``` 为例,您可以在 ```"recall"``` 查看检索召回结果。
123 |
124 | > 使用Sparse Retrieval, 您可以在 ```./data/retrieval/pyserini_index```(pyserini) 查看索引并在 ```./data/retrieval/res/retrieval_{model_type}_{bm25_backend}.jsonl``` 查看输出检索结果。
125 |
126 | ### :bookmark_tabs: Generator
127 | 我们支持主流的LLMs进行回答生成。 运行 ```./src/pipeline.py``` :
128 | ```
129 | pipeline = GeneratorPipeline(model_type="")
130 | pipeline.run_generator(
131 | raw_data_path="data/dataset.json",
132 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
133 | max_retries=5,
134 | max_parallel=32,
135 | top_n=5,
136 | batch_size=20
137 | )
138 | ```
139 | ```--model_type```: 支持常用的LLMs,只需要在 ```model_type``` 输入模型名称(常用模型已在 ```./src/config/config.py``` 中设置,您修改相关配置信息即可使用)
140 | ```--raw_data_path```: 包含问题的对话数据路径
141 | ```--retrieval_data_path```: 检索得到的数据路径
142 | ```--max_retries``` ```--max_parallel```: 并行处理相关参数
143 | ```--batch_size```: 批次大小
144 | ```--top_n```: 使用检索结果中的top_n法条作为回答参考
145 |
146 |
147 | 我们还支持使用 ```vllm``` ```huggingface``` 和本地模型进行回答生成。
148 | * 对于 ```vllm```
149 | ```
150 | custom_config = {
151 | "model_type": "vllm",
152 | "model_path": "vllm_model_path",
153 | "gpu_num": 2
154 | }
155 | pipeline = GeneratorPipeline(model_type="vllm", config=custom_config)
156 | pipeline.run_generator(
157 | raw_data_path="data/dataset.json",
158 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
159 | max_retries=5,
160 | max_parallel=32,
161 | top_n=5,
162 | batch_size=20
163 | )
164 | ```
165 | * 对于 ```huggingface```
166 | ```
167 | hf_config = {
168 | "model_type": "huggingface",
169 | "model_path": "hf_model_path"
170 | }
171 | pipeline = GeneratorPipeline(
172 | model_type="huggingface",
173 | config=hf_config,
174 | )
175 | pipeline.run_generator(
176 | raw_data_path="data/dataset.json",
177 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
178 | max_retries=5,
179 | max_parallel=32,
180 | top_n=5,
181 | batch_size=20
182 | )
183 | ```
184 | * 对于本地模型
185 | ```
186 | local_config = {
187 | "model_type": "local",
188 | "model_path": "local_model_path"
189 | }
190 | pipeline = GeneratorPipeline(
191 | model_type="local",
192 | config=local_config,
193 | )
194 | pipeline.run_generator(
195 | raw_data_path="data/dataset.json",
196 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
197 | max_retries=5,
198 | max_parallel=32,
199 | top_n=5,
200 | batch_size=20
201 | )
202 | ```
203 |
204 |
205 | 我们支持使用灵活自定义的prompt构造。 默认情况使用我们预定义的 ```LegalPromptBuilder``` 进行prompt构造,您还可以选择使用 ```CustomSystemPromptBuilder``` 自定义prompt的系统角色部分,或者选择 ```FullCustomPromptBuilder``` 完全自定义prompt。
206 | * 对于 ```CustomSystemPromptBuilder```:
207 | ```
208 | from generate.prompt_builder import LegalPromptBuilder, CustomSystemPromptBuilder, FullCustomPromptBuilder
209 | custom_prompt = CustomSystemPromptBuilder("请用一句话回答法律问题:")
210 | pipeline = GeneratorPipeline(
211 | model_type="",
212 | prompt_builder=custom_prompt
213 | )
214 | pipeline.run_generator(
215 | raw_data_path="data/dataset.json",
216 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
217 | max_retries=5,
218 | max_parallel=32,
219 | top_n=5,
220 | batch_size=20
221 | )
222 | ```
223 | ```--prompt_builder```: 您可以使用 ```CustomSystemPromptBuilder(" ")``` 自定义系统角色
224 |
225 | * 对于 ```FullCustomPromptBuilder```:
226 | ```
227 | def full_custom_builder(history, question, articles):
228 | return [
229 | {"role": "user", "content": f"请以“回答如下:”为开头回答\n问题:{question}(相关法条:{','.join(articles)})"}
230 | ]
231 |
232 | pipeline = GeneratorPipeline(
233 | model_type="",
234 | prompt_builder=FullCustomPromptBuilder(full_custom_builder)
235 | )
236 | pipeline.run_generator(
237 | raw_data_path="data/dataset.json",
238 | retrieval_data_path="data/samples/retrieval_Qwen2-1.5B.jsonl",
239 | max_retries=5,
240 | max_parallel=32,
241 | top_n=5,
242 | batch_size=20
243 | )
244 | ```
245 | ```--prompt_builder```: ```FullCustomPromptBuilder``` 支持输入 ```history```, ```question```, ```articles```, 您可以先自定义prompt策略,再作为prompt_builder输入。
246 | > ```history```: 对话历史
247 | > ```question```: 多轮对话的当前问题
248 | > ```articles```: 检索返回结果的 ```--top_n``` 参考法条
249 |
250 | 您可以在 ```./data/generated_responses.jsonl``` 查看输出。 示例输出 ```./data/samples/generated_responses.jsonl```。 ```.jsonl``` 文件格式如下:
251 | ```
252 | {"id": "xxx", "question": "...", "response": "..."}
253 | ```
254 |
255 | ## :pencil: Evaluation
256 | ### Generation Evaluator
257 | 生成评估器衡量生成回答与参考答案之间的一致性,支持指标如ROUGE, BLEU, METEOR, BERTScore。 运行 ```./src/pipeline.py```:
258 | ```
259 | pipeline = EvaluatorPipeline()
260 | pipeline.run_evaluator(
261 | eval_type="generation",
262 | metrics=["bleu", "rouge", "bert-score", "keyword_accuracy", "char_scores", "meteor"],
263 | data_path="data/dataset.json",
264 | response_file="response_file_path"
265 | )
266 | ```
267 | ```--data_path```: 原始包含查询问题的数据集路径
268 | ```--response_file```: LLM生成回答的路径
269 |
270 | ### Retrieval Evaluator
271 | 检索评估器评估检索文件的相关性和准确性,支持主流指标如NDCG, Recall, MRR, Precision, F1的计算。
272 | ```
273 | pipeline = EvaluatorPipeline()
274 | pipeline.run_evaluator(
275 | eval_type="retrieval",
276 | results_path="retrieval_results_path",
277 | metrics=["recall", "precision", "f1", "ndcg", "mrr"],
278 | k_values=[1, 3, 5]
279 | )
280 | ```
281 | ```--results_path```: 检索结果路径
282 | ```--k_values```: 考虑分数最高的k个结果
283 | > 您可以在 ```./data/retrieval/report.jsonl``` 查看结果。
284 |
285 | ### LLM-as-a-Judge
286 | LLM通过多维思维链推理来评估回答质量。 LLM-as-a-Judge 使用的prompt为 ```./src/config/template/prompt.txt```
287 | ```
288 | pipeline = EvaluatorPipeline("model_type")
289 | pipeline.run_evaluator(
290 | eval_type="llm_judge",
291 | data_path="data/dataset.json",
292 | gen_path="generated_responses_path"
293 | )
294 | ```
295 | ```--model_type```: 评估使用的模型名称
296 | ```--data_path```: 原始包含查询问题的数据集路径
297 | ```--gen_path```: LLM生成回答的路径
298 | > 您可以在 ```./data/results/turn{turn}/judge_results.jsonl``` 查看结果。
299 |
--------------------------------------------------------------------------------
/src/config/config.py:
--------------------------------------------------------------------------------
1 | class Config:
2 | _default_configs = {
3 | "openai": {
4 | "model_type": "openai",
5 | "model_name": "gpt-3.5-turbo",
6 | "api_base": "",
7 | "api_key": "",
8 | "max_retries": 10,
9 | "max_parallel": 32
10 | },
11 | "zhipu": {
12 | "model_type": "zhipu",
13 | "model_name": "glm-4-flash",
14 | "api_key": "",
15 | "max_retries": 10,
16 | "max_parallel": 32
17 | },
18 | "llama": {
19 | "model_type": "llama",
20 | "model_name": "llama-3.3-70b-instruct",
21 | "api_base": "",
22 | "api_key": "",
23 | "max_retries": 10,
24 | "max_parallel": 32
25 | },
26 | "qwen": {
27 | "model_type": "qwen",
28 | "model_name": "qwen2.5-72b-instruct",
29 | "api_base": "",
30 | "api_key": "",
31 | "max_retries": 10,
32 | "max_parallel": 32
33 | }
34 | }
35 |
36 | def __init__(self, model_type=None, my_config=None):
37 | if my_config:
38 | self.config = my_config
39 | elif model_type:
40 | self.config = self._default_configs.get(model_type)
41 | if not self.config:
42 | raise ValueError(f"Invalid model_type: {model_type}")
43 | else:
44 | raise ValueError("Must provide either model_type or my_config")
45 |
46 | def get(self, key, default=None):
47 | return self.config.get(key, default)
48 |
49 | @property
50 | def model_type(self):
51 | return self.config["model_type"]
52 |
--------------------------------------------------------------------------------
/src/config/template/prompt.txt:
--------------------------------------------------------------------------------
1 | 你是一位经验丰富的法律专家,专门负责评估法律咨询回复的质量。请以公正、严谨的评判者身份,对AI助手撰写的回复进行客观评估。评估时,你需要基于以下五个关键维度进行评分:
2 |
3 | 1. 准确性: 提供的信息是否准确无误,是否基于可信的法条。
4 | 2. 满足用户需求:是否满足了用户提出问题的目的和需求,是否对问题进行了全面而恰当的回应。
5 | 3. 清晰度: 是否表达清晰易懂,是否使用了简洁的语言和结构,以便用户可以轻松理解。
6 | 4. 逻辑连贯性: 是否在整体上保持一致,是否在多轮对话之间保持逻辑连贯性,避免了自相矛盾。
7 | 5. 完备性: 回答是否提供了足够的信息和细节,以满足用户的需求,是否遗漏了重要的方面。
8 |
9 | 注意:回答不是越长越好,简短并且满足上述要求的回答是最理想的。判断准确性时,需要与参考答案进行对比,只有参考答案引用的法条是正确的。
10 |
11 | 我们会给您提供用户的多轮会话,一个8分左右的参考答案,和需要你评估的AI助手对最后一轮问题的答案。当你开始你的评估时,你需要按照遵守以下的流程
12 |
13 | 在评估过程中,请遵循以下步骤:
14 | 1. 将AI助手的答案与参考答案进行比较,指出AI助手的答案有哪些不足,并进一步解释。
15 | 2. 根据上述评分标准,对每个维度进行严格评分。所有维度都应严格遵守参考答案的高标准,避免给出过高的分数。
16 | 3. 最后,综合每个维度的评估,对AI助手的回答给出一个1~10的综合分数,综合分数应反映出答案在各维度的整体表现,不会因为某一维度的优势而过度提高最终分数。
17 | 4. 你的打分需要尽可能严格,并且要遵守下面的评分规则:总的来说,模型回答的质量越高,则分数越高。
18 |
19 | 准确性和满足用户需求这两个维度的评分对最终得分至关重要,因此这些维度的评分必须严格,尤其是在法条引用和法律解释方面。与参考答案的差异将直接影响得分。
20 |
21 | 评分标准
22 | 当模型回答严重错误,提供的法律条文和法律解释完全不符或直接错误。完全未回答问题,答案与问题无关,未做任何有价值的法律分析。语言混乱、冗长或无法理解,结构极其复杂,给用户带来困扰。完全没有逻辑连贯性,推理混乱,自相矛盾,不能提供有效答案。回答严重缺失,漏掉所有关键信息,未提供任何有效细节。得分应为1到2分。
23 | 当模型回答存在重大事实错误,法条引用明显不正确,无法提供准确法律支持。回答没有解决核心问题,未提及必要的法条或关键分析。表达不清,含糊其辞,逻辑混乱,用户难以理解。存在明显逻辑错误,推理不连贯,多个部分互相矛盾。遗漏了大部分关键信息或条文,回答不完整。得分应为3到4分。
24 | 当模型回答法条引用大体准确,细节上有些微误差,仍可理解,但并不完全符合法律解释。基本满足了用户需求,提供了大部分需要的分析和法条,但仍有一些遗漏。语言清晰,结构合理,表达简洁,但部分地方可能需要改进。大部分逻辑清晰,推理大致合理,少数部分可能存在小的连贯性问题。信息较为完备,涵盖大部分内容,但仍然有少许遗漏。得分应为5到6分。
25 | 当模型回答准确,所有法条引用和法律解释与参考答案一致,极少的细节误差可接受。非常接近完美地回应了用户需求,答案非常全面,只有少量细节未完全涵盖。清晰易懂,语言简洁,结构明确,一目了然。推理严密,逻辑结构清晰,几乎无任何矛盾或不一致。答案较为完备,涵盖了绝大部分需要的信息,得分应为7到8分。
26 | 只有当模型回答质量显著超过参考答案,充分地解决了用户问题和所有需求,并且在所有维度上都接近满分的情况下,仅在此情况下可获得9-10分。
27 |
28 | 作为示例,参考答案在准确度、满足用户需求、清晰度、逻辑连贯性、完备性上可以得到8分。参考答案综合得分可以得到8分。
29 |
30 | 请在评分时,提供详细的评估说明。每个维度评分后,务必给出相应分数,所有评分均需为整数。最终,请按照以下字典格式返回评分结果:
31 | ```
32 | {{'准确性': 分数, '满足用户需求': 分数, '清晰度': 分数, '逻辑连贯性': 分数, '完备性': 分数, '综合得分': 总分}}
33 | ```
34 |
35 | [多轮法律咨询会话开始]
36 | {会话过程}
37 | [多轮法律咨询会话结束]
38 |
39 | [参考答案开始]
40 | {参考答案}
41 | [参考答案结束]
42 |
43 | [AI助手撰写的答案开始]
44 | {AI助手撰写的答案}
45 | [AI助手撰写答案结束]
46 |
47 | 请开始评估:
48 |
--------------------------------------------------------------------------------
/src/eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | from typing import Dict, List
5 | from .metrics import UnifiedEvaluator
6 | from generate.data_processor import DataProcessor
7 | from .retrieval_metrics import RetrievalMetrics
8 |
9 | logging.basicConfig(
10 | level=logging.INFO,
11 | format="%(asctime)s - %(levelname)s - %(message)s",
12 | handlers=[
13 | logging.StreamHandler(),
14 | logging.FileHandler("log", mode='w', encoding='utf-8')
15 | ]
16 | )
17 |
18 | class BaseEvaluator:
19 | def __init__(self, config):
20 | self.config = config
21 | self.logger = logging.getLogger(self.__class__.__name__)
22 |
23 | class GenerationEvaluator(BaseEvaluator):
24 | def __init__(self, config):
25 | super().__init__(config)
26 | self.metric_calculator = UnifiedEvaluator()
27 |
28 | def evaluate(self, data_path, response_file, metrics):
29 | results = {}
30 | id_to_response = self._load_responses(response_file)
31 |
32 | # Segregated multi-stage data
33 | processor = DataProcessor()
34 | processed_data = processor.process_conversation_turns(data_path)
35 | output_dir = "data/generated_samples"
36 | os.makedirs(output_dir, exist_ok=True)
37 | for turn_num, samples in processed_data.items():
38 | output_path = f"{output_dir}/{turn_num}.json"
39 | with open(output_path, "w", encoding="utf-8") as f:
40 | json.dump(samples, f, indent=2, ensure_ascii=False)
41 |
42 | for turn_file in sorted(os.listdir(output_dir)):
43 | turn_path = os.path.join(output_dir, turn_file)
44 | with open(turn_path, "r", encoding="utf-8") as f:
45 | data = json.load(f)
46 |
47 | preds, refs, keywords = self._prepare_data(data, id_to_response)
48 |
49 | metrics_result = {}
50 | if "rouge" in metrics:
51 | metrics_result.update(self.metric_calculator._get_rouge(preds, refs))
52 | if "bert-score" in metrics:
53 | metrics_result.update(self.metric_calculator._get_bert_score(preds, refs))
54 | if "bleu" in metrics:
55 | metrics_result.update(self.metric_calculator._get_bleu(preds, refs))
56 | if "keyword_accuracy" in metrics:
57 | metrics_result.update(self.metric_calculator._get_keyword_accuracy(keywords, preds))
58 | if "char-scores" in metrics:
59 | metrics_result.update(self.metric_calculator._get_char_f1(preds, refs))
60 | if "meteor" in metrics:
61 | metrics_result.update(self.metric_calculator._get_meteor(preds, refs))
62 |
63 | turn_num = os.path.splitext(turn_file)[0].split('_')[-1]
64 | results[turn_num] = {k:v for k,v in metrics_result.items() if k in metrics}
65 | logging.info(f"\nTurn{turn_num.upper()} Metrics:")
66 | for k, v in metrics_result.items():
67 | logging.info(f"{k.ljust(15)}: {v:.4f}")
68 |
69 | return results
70 |
71 | def _load_responses(self, response_file):
72 | id_to_response = {}
73 | with open(response_file, "r", encoding="utf-8") as f:
74 | for line in f:
75 | record = json.loads(line)
76 | id_to_response[record["id"]] = record["response"]
77 | return id_to_response
78 |
79 | def _prepare_data(self, data, id_to_response):
80 | preds, refs, keywords = [], [], []
81 | for sample in data:
82 | pred_response = id_to_response.get(sample["id"], "").strip()
83 | if pred_response:
84 | preds.append(pred_response)
85 | refs.append(sample["reference"])
86 | keywords.append(sample["keywords"])
87 | return preds, refs, keywords
88 |
89 | class LLMJudge(BaseEvaluator):
90 | def __init__(self, config):
91 | super().__init__(config)
92 | from openai import OpenAI
93 | from zhipuai import ZhipuAI
94 | import httpx
95 | from eval.llm_as_judge.run_judge import Judge
96 | if self.config["model_type"] == "openai":
97 | self.client = OpenAI(
98 | base_url=self.config["api_base"],
99 | api_key=self.config["api_key"],
100 | http_client=httpx.Client(
101 | base_url=self.config["api_base"],
102 | follow_redirects=True,
103 | ),
104 | )
105 | elif self.config["model_type"] == "zhipu":
106 | self.client = ZhipuAI(api_key=self.config["api_key"])
107 | elif self.config["model_type"] in ["qwen","llama"]:
108 | self.client = OpenAI(
109 | base_url=self.config["api_base"],
110 | api_key=self.config["key"]
111 | )
112 | self.model_name = self.config["model_name"]
113 | Judge(self.config)
114 |
115 | def evaluate(self, data_path, gen_path):
116 | from eval.llm_as_judge.make_prompt import process_model
117 | from eval.llm_as_judge.run_judge import process_turn
118 |
119 | process_model(
120 | data_path,
121 | gen_path
122 | )
123 |
124 | for turn in range(1, 6):
125 | process_turn(
126 | self.config,
127 | turn=turn
128 | )
129 |
130 | class RetrievalEvaluator(BaseEvaluator):
131 | def evaluate(self, results_path: str, metrics: List[str], k_values: List[int], report_path="data/retrieval/report.jsonl") -> Dict:
132 | with open(results_path, "r", encoding="utf-8") as f:
133 | res_data = [json.loads(line) for line in f]
134 |
135 | res_list, res_score_list, label_list = [], [], []
136 | for data in res_data:
137 | for conv in data["conversation"]:
138 | # Sort
139 | sorted_recall = sorted(conv["question"]["recall"],
140 | key=lambda x: x["score"],
141 | reverse=True)
142 |
143 | res = [law["article"]["name"] for law in sorted_recall]
144 | scores = [law["score"] for law in sorted_recall]
145 |
146 | label = conv["article"]
147 | res_list.append(res)
148 | res_score_list.append(scores)
149 | label_list.append(label)
150 |
151 | report = {"results_path": results_path}
152 | metric_functions = {
153 | "recall": RetrievalMetrics.recall,
154 | "precision": RetrievalMetrics.precision,
155 | "f1": RetrievalMetrics.f1_score,
156 | "mrr": RetrievalMetrics.mrr,
157 | "ndcg": RetrievalMetrics.ndcg
158 | }
159 |
160 | for metric in metrics:
161 | if metric not in metric_functions:
162 | continue
163 | for k in k_values:
164 | if metric == "ndcg":
165 | score = metric_functions[metric](res_list, res_score_list, label_list, k)
166 | else:
167 | score = metric_functions[metric](res_list, label_list, k)
168 | report[f"{metric}@{k}"] = score
169 | self.logger.info(f"{metric.upper()}@{k}: {score:.4f}")
170 | report_dir = os.path.dirname(report_path)
171 | os.makedirs(report_dir, exist_ok=True)
172 | with open(report_path, "a", encoding="utf-8") as f:
173 | f.write(json.dumps(report, ensure_ascii=False) + "\n")
174 | return report
175 |
--------------------------------------------------------------------------------
/src/eval/llm_as_judge/make_prompt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from eval.llm_as_judge.use_template import use_judge_template
4 |
5 | #data_path: original data path
6 | #gen_path: (llm model)generated response path
7 | def process_model(data_path, gen_path):
8 | with open(data_path, 'r', encoding='utf-8') as f:
9 | original_data = json.load(f)
10 |
11 | with open(gen_path, 'r', encoding='utf-8') as f:
12 | generated_data = [json.loads(line) for line in f]
13 |
14 | for turn in range(5):
15 | #Path to the output prompts
16 | output_dir = f"data/prompt/turn{turn+1}"
17 | os.makedirs(output_dir, exist_ok=True)
18 | output_path = os.path.join(output_dir, "judge_prompt.jsonl")
19 |
20 | with open(output_path, 'w', encoding='utf-8') as out_file:
21 | for case in original_data:
22 | conv = case['conversation']
23 | if len(conv) <= turn:
24 | continue
25 |
26 | gen_id = f"{case['id']}_turn{turn+1}"
27 | generated = next((g for g in generated_data if g['id'] == gen_id), None)
28 | if not generated:
29 | continue
30 |
31 | prompt = use_judge_template(
32 | conversation=conv,
33 | reference_answer=conv[turn]['assistant'],
34 | generated_answer=generated['response'],
35 | current_turn=turn
36 | )
37 |
38 | out_file.write(json.dumps({
39 | "id": gen_id,
40 | "prompt": prompt
41 | }, ensure_ascii=False) + '\n')
42 |
--------------------------------------------------------------------------------
/src/eval/llm_as_judge/run_judge.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import httpx
4 | from tqdm import tqdm
5 | import concurrent.futures
6 | from openai import OpenAI
7 | from zhipuai import ZhipuAI
8 |
9 | class Judge:
10 | def __init__(self, config):
11 | self.config = config
12 | if self.config.get("model_type") == "openai":
13 | self.client = OpenAI(
14 | base_url=self.config.get("api_base"),
15 | api_key=self.config.get("api_key"),
16 | http_client=httpx.Client(
17 | base_url=self.config.get("api_base"),
18 | follow_redirects=True,
19 | ),
20 | )
21 | elif self.config.get("model_type") == "zhipu":
22 | self.client = ZhipuAI(api_key=self.config["api_key"])
23 | elif self.config.get("model_type") in ["qwen","llama"]:
24 | self.client = OpenAI(
25 | base_url=self.config.get("api_base"),
26 | api_key=self.config.get("api_key")
27 | )
28 | self.model = self.config.get("model_name")
29 | self.max_retries = self.config.get("max_retries")
30 | self.max_workers = self.config.get("max_parallel")
31 |
32 | def evaluate(self, prompt):
33 | for _ in range(self.max_retries):
34 | try:
35 | response = self.client.chat.completions.create(
36 | model=self.model,
37 | messages=[{"role": "user", "content": prompt}],
38 | temperature=0.0
39 | )
40 | return response.choices[0].message.content
41 | except Exception as e:
42 | print(f"API Error: {str(e)}")
43 | return "Evaluation Fail"
44 |
45 | def process_turn(config, turn):
46 | evaluator = Judge(config)
47 | #input(prompts made by make_prompt.py) and output path
48 | input_dir = f"data/prompt/turn{turn}"
49 | output_dir = f"data/results/turn{turn}"
50 | os.makedirs(output_dir, exist_ok=True)
51 |
52 | input_file = os.path.join(input_dir, "judge_prompt.jsonl")
53 | output_file = os.path.join(output_dir, "judge_results.jsonl")
54 |
55 | processed = set()
56 | if os.path.exists(output_file):
57 | with open(output_file, 'r', encoding="utf-8") as f:
58 | for line in f:
59 | try:
60 | processed.add(json.loads(line)['id'])
61 | except:
62 | continue
63 |
64 | with open(input_file, 'r', encoding="utf-8") as f:
65 | tasks = [json.loads(line) for line in f if json.loads(line)['id'] not in processed]
66 |
67 | with concurrent.futures.ThreadPoolExecutor(max_workers=evaluator.max_workers) as executor:
68 | task_ids = [task['id'] for task in tasks]
69 | task_prompts = [task['prompt'] for task in tasks]
70 | results = []
71 | try:
72 | results = list(tqdm(
73 | executor.map(evaluator.evaluate, task_prompts),
74 | total=len(task_prompts),
75 | desc=f"{evaluator.model} Turn{turn}"
76 | ))
77 | except Exception as e:
78 | print(f"Processing Error: {str(e)}")
79 |
80 | with open(output_file, 'a', encoding="utf-8") as f:
81 | for task_id, response in zip(task_ids, results):
82 | result = {
83 | "id": task_id,
84 | "response": response
85 | }
86 | f.write(json.dumps(result, ensure_ascii=False) + '\n')
87 |
--------------------------------------------------------------------------------
/src/eval/llm_as_judge/use_template.py:
--------------------------------------------------------------------------------
1 | def load_template(file_path):
2 | with open(file_path, 'r', encoding='utf-8') as file:
3 | return file.read()
4 |
5 | def build_conversation_history(conversation, current_turn):
6 | history = []
7 | for i in range(current_turn + 1):
8 | history.append(f"用户:{conversation[i]['user']}")
9 | if i < current_turn:
10 | history.append(f"助手:{conversation[i]['assistant']}")
11 | return "\n".join(history)
12 |
13 | def use_judge_template(conversation, reference_answer, generated_answer, current_turn):
14 | template = load_template('src/config/template/prompt.txt')
15 |
16 | convsation_history = build_conversation_history(conversation, current_turn)
17 |
18 | return template.replace(
19 | "{会话过程}", convsation_history
20 | ).replace(
21 | "{参考答案}", reference_answer
22 | ).replace(
23 | "{AI助手撰写的答案}", generated_answer
24 | )
--------------------------------------------------------------------------------
/src/eval/metrics.py:
--------------------------------------------------------------------------------
1 | from rouge_score import rouge_scorer
2 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
3 | from bert_score import score as bert_score
4 | from nltk.translate.meteor_score import meteor_score
5 | import nltk
6 | import jieba
7 | from transformers import AutoTokenizer
8 | import numpy as np
9 | from collections import Counter
10 | import logging
11 | import re
12 |
13 | jieba.setLogLevel(logging.INFO)
14 |
15 | class UnifiedEvaluator:
16 | def __init__(self, max_seq_length=510):
17 | self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
18 | self.bleu_weights = [
19 | (1.0, 0, 0, 0),
20 | (0.5, 0.5, 0, 0),
21 | (0.33, 0.33, 0.33, 0),
22 | (0.25, 0.25, 0.25, 0.25)
23 | ]
24 |
25 | self.smoothie = SmoothingFunction().method1
26 | self.tokenizer = AutoTokenizer.from_pretrained(
27 | "hfl/chinese-bert-wwm",
28 | use_fast=True
29 | )
30 | self.max_seq_length = max_seq_length
31 | self._init_jieba()
32 |
33 | def _init_jieba(self):
34 | jieba.initialize()
35 |
36 | def calculate_all_metrics(self, preds, refs, keyword_lists):
37 |
38 | return {
39 | **self._get_rouge(preds, refs),
40 | **self._get_bert_score(preds, refs),
41 | **self._get_bleu(preds, refs),
42 | **self._get_keyword_accuracy(keyword_lists, preds),
43 | **self._get_char_f1(preds, refs),
44 | **self._get_meteor(preds, refs)
45 | }
46 |
47 | def _remove_punctuation(self, text):
48 | return re.sub(r'[\W]+', '', text)
49 |
50 | def _get_rouge(self, preds, refs):
51 | f_scores = []
52 | for p, r in zip(preds, refs):
53 | if not p.strip() or not r.strip():
54 | continue
55 |
56 | p = self._preprocess_text(p)
57 | r = self._preprocess_text(r)
58 |
59 | scores = self.scorer.score(r, p)
60 | f_scores.append(scores["rougeL"].fmeasure)
61 |
62 | return {"rouge-l": round(np.mean(f_scores).item(), 4)} if f_scores else {"rouge-l": 0.0}
63 |
64 | def _safe_truncate(self, text):
65 | if not isinstance(text, str) or not text.strip():
66 | return ""
67 |
68 | try:
69 | tokens = self.tokenizer.encode(
70 | text,
71 | truncation=True,
72 | max_length=self.max_seq_length,
73 | add_special_tokens=False
74 | )
75 | return self.tokenizer.decode(tokens, skip_special_tokens=True)
76 | except Exception as e:
77 | print(f"truncation error: {str(e)}")
78 | return text[:self.max_seq_length]
79 |
80 | def _get_bert_score(self, preds, refs):
81 | try:
82 | truncated_preds = [self._safe_truncate(p) for p in preds]
83 | truncated_refs = [self._safe_truncate(r) for r in refs]
84 |
85 | valid_pairs = [
86 | (p, r) for p, r in zip(truncated_preds, truncated_refs)
87 | if p.strip() and r.strip()
88 | ]
89 | if not valid_pairs:
90 | return {"bert-precision": 0.0, "bert-recall": 0.0, "bert-f1": 0.0}
91 |
92 | valid_preds, valid_refs = zip(*valid_pairs)
93 |
94 | P, R, F1 = bert_score(
95 | cands=valid_preds,
96 | refs=valid_refs,
97 | lang="zh",
98 | model_type="hfl/chinese-bert-wwm",
99 | num_layers=8,
100 | batch_size=32,
101 | nthreads=4,
102 | rescale_with_baseline=True,
103 | verbose=True
104 | )
105 |
106 | return {
107 | "bert-precision": round(P.mean().item(), 4),
108 | "bert-recall": round(R.mean().item(), 4),
109 | "bert-f1": round(F1.mean().item(), 4)
110 | }
111 | except Exception as e:
112 | print(f"BERTScore Error: {str(e)}")
113 | return {"bert-precision": 0.0, "bert-recall": 0.0, "bert-f1": 0.0}
114 |
115 | def _preprocess_text(self, text):
116 | '''preprocess'''
117 | text = re.sub(r'\*\*.*?\*\*', '', text)
118 | text = re.sub(r'\*\*|^\d+\.\s+|-\s+', '', text, flags=re.MULTILINE)
119 | text = re.sub(r'\n+|\s+', ' ', text)
120 | return text.strip()
121 |
122 | def _tokenize_text(self, text):
123 | tokens = list(jieba.cut(text))
124 | tokens = [t for t in tokens if not re.match(r'[^\w\s]', t)]
125 |
126 | return tokens
127 |
128 | def _get_bleu(self, preds, refs):
129 | bleu_scores = {f'bleu-{i+1}': [] for i in range(4)}
130 |
131 | for p, r in zip(preds, refs):
132 | if not p.strip() or not r.strip():
133 | continue
134 |
135 | p = self._preprocess_text(p)
136 | r = self._preprocess_text(r)
137 |
138 | p_tokens = self._tokenize_text(p)
139 | r_tokens = self._tokenize_text(r)
140 |
141 | try:
142 | for i, weights in enumerate(self.bleu_weights):
143 | score = sentence_bleu(
144 | [r_tokens],
145 | p_tokens,
146 | weights=weights,
147 | smoothing_function=self.smoothie
148 | )
149 |
150 | bleu_scores[f'bleu-{i+1}'].append(score)
151 | except Exception as e:
152 | print(f"ERROR: {str(e)}")
153 | continue
154 |
155 | return {k: round(np.mean(v).item(), 4) if v else 0.0
156 | for k, v in bleu_scores.items()}
157 |
158 | def _get_keyword_accuracy(self, keyword_lists, preds):
159 | accuracies = []
160 | for keywords, pred in zip(keyword_lists, preds):
161 | try:
162 | if not keywords or not isinstance(keywords, list):
163 | continue
164 |
165 | hits = sum(1 for kw in keywords if kw in pred)
166 |
167 | if len(keywords) > 0:
168 | accuracies.append(hits / len(keywords))
169 | except Exception as e:
170 | print(f"Error processing keywords {keywords} with prediction {pred}: {e}")
171 | accuracies.append(0.0)
172 | return {"keyword_accuracy": round(np.mean(accuracies).item(), 4) if accuracies else 0.0}
173 |
174 | def _get_char_f1(self, preds, refs):
175 | precision_scores, recall_scores, f1_scores = [], [], []
176 |
177 | for p, r in zip(preds, refs):
178 | try:
179 |
180 | p = self._preprocess_text(p)
181 | r = self._preprocess_text(r)
182 | p_tokens = self._tokenize_text(p)
183 | r_tokens = self._tokenize_text(r)
184 |
185 | p_tokens = list(p_tokens)
186 | r_tokens = list(r_tokens)
187 |
188 | common = Counter(p_tokens) & Counter(r_tokens)
189 | num_same = sum(common.values())
190 |
191 | if num_same == 0:
192 | continue
193 |
194 | precision = 1.0 * num_same / len(p_tokens) if p_tokens else 0
195 | recall = 1.0 * num_same / len(r_tokens) if r_tokens else 0
196 | f1 = self._safe_f1(precision, recall)
197 |
198 | precision_scores.append(precision)
199 | recall_scores.append(recall)
200 | f1_scores.append(f1)
201 | except Exception as e:
202 | logging.error(f"Char F1 Error: {str(e)}")
203 | precision_scores.append(0.0)
204 | recall_scores.append(0.0)
205 | f1_scores.append(0.0)
206 | return {
207 | "char_precision": round(np.mean(precision_scores).item(), 4) if precision_scores else 0.0,
208 | "char_recall": round(np.mean(recall_scores).item(), 4) if recall_scores else 0.0,
209 | "char_f1": round(np.mean(f1_scores).item(), 4) if f1_scores else 0.0
210 | }
211 |
212 | def _normalize(self, text):
213 | return text.lower().strip().replace(" ", "").replace("\n", "").replace("\t", "")
214 |
215 |
216 | def _safe_f1(self, p, r):
217 | denominator = p + r
218 | return 2 * p * r / denominator if denominator > 1e-9 else 0.0
219 |
220 | def _clean_text(self, text):
221 | text = re.sub(r'[\*\n\r]', '', text)
222 | text = re.sub(r'\s+', ' ', text)
223 | text = text.strip()
224 | return text
225 |
226 | def _preprocess_meteor(self, text):
227 | return ' '.join(jieba.cut(text))
228 |
229 | def _get_meteor(self, preds, refs):
230 | nltk.download('wordnet')
231 | scores = []
232 | for p, r in zip(preds, refs):
233 | p=self._clean_text(p)
234 | r=self._clean_text(r)
235 | p_processed = self._preprocess_meteor(p).split()
236 | r_processed = self._preprocess_meteor(r).split()
237 |
238 | try:
239 | score = meteor_score(
240 | [r_processed],
241 | p_processed
242 | )
243 | scores.append(score)
244 | except Exception as e:
245 | print(f"METEOR ERROR: {str(e)}")
246 | scores.append(0.0)
247 |
248 | return {"meteor": round(np.mean(scores).item(), 4)} if scores else {"meteor": 0.0}
--------------------------------------------------------------------------------
/src/eval/retrieval_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import ndcg_score
3 |
4 | class RetrievalMetrics:
5 | @staticmethod
6 | def recall(res_list: list[list[str]], label_list: list[list[str]], k: int) -> float:
7 | true_positives = 0
8 | false_negatives = 0
9 | for actual, predicted in zip(label_list, res_list):
10 | actual_set = set(actual)
11 | predicted_set = set(predicted[:k])
12 | true_positives += len(actual_set & predicted_set)
13 | false_negatives += len(actual_set - predicted_set)
14 | return true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
15 |
16 | @staticmethod
17 | def precision(res_list: list[list[str]], label_list: list[list[str]], k: int) -> float:
18 | true_positives = 0
19 | false_positives = 0
20 | for actual, predicted in zip(label_list, res_list):
21 | actual_set = set(actual)
22 | predicted_set = set(predicted[:k])
23 | true_positives += len(actual_set & predicted_set)
24 | false_positives += len(predicted_set - actual_set)
25 | return true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
26 |
27 | @staticmethod
28 | def f1_score(res_list: list[list[str]], label_list: list[list[str]], k: int) -> float:
29 | prec = RetrievalMetrics.precision(res_list, label_list, k)
30 | rec = RetrievalMetrics.recall(res_list, label_list, k)
31 | return 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
32 |
33 | @staticmethod
34 | def mrr(res_list: list[list[str]], label_list: list[list[str]], k: int) -> float:
35 | reciprocal_ranks = []
36 | for actual, predicted in zip(label_list, res_list):
37 | for i, item in enumerate(predicted[:k], 1):
38 | if item in actual:
39 | reciprocal_ranks.append(1 / i)
40 | break
41 | else:
42 | reciprocal_ranks.append(0)
43 | return np.mean(reciprocal_ranks)
44 |
45 | @staticmethod
46 | def ndcg(data_list: list[list[str]], score_list: list[list[float]], label_list: list[list[str]], k: int) -> float:
47 | ndcg_scores = []
48 | for data, scores, labels in zip(data_list, score_list, label_list):
49 | true_scores = [1 if item in labels else 0 for item in data]
50 | y_true = np.array([true_scores])
51 | y_score = np.array([scores])
52 | try:
53 | score = ndcg_score(y_true, y_score, k=k)
54 | except ValueError:
55 | score = 0.0
56 | ndcg_scores.append(score)
57 | return np.mean(ndcg_scores)
58 |
--------------------------------------------------------------------------------
/src/generate/data_processor.py:
--------------------------------------------------------------------------------
1 | # Data preprocessing of dataset(output generated_samples for each turn)
2 | import json
3 |
4 | class DataProcessor:
5 | @staticmethod
6 | def process_conversation_turns(raw_data_path):
7 | with open(raw_data_path, encoding="utf-8") as f:
8 | data = json.load(f)
9 |
10 | processed = {}
11 | for turn in range(1, 6):
12 | processed[f"turn_{turn}"] = []
13 |
14 | for item in data:
15 | conv = item["conversation"]
16 | for turn_num in range(1, 6):
17 | if turn_num > len(conv):
18 | continue
19 |
20 | current_turn = conv[turn_num-1]
21 |
22 | clean_history = [
23 | {
24 | "user": h["user"],
25 | "assistant": h["assistant"]
26 | }
27 | for h in conv[:turn_num-1]
28 | ]
29 |
30 | entry = {
31 | "id": f"{item['id']}_turn{turn_num}",
32 | "history": clean_history,
33 | "current_question": current_turn["user"],
34 | "reference": current_turn["assistant"],
35 | "keywords": current_turn.get("keyword", []),
36 | "article": current_turn.get("article", []),
37 | "article_context": current_turn.get("article_context", [])
38 | }
39 | processed[f"turn_{turn_num}"].append(entry)
40 |
41 | return processed
--------------------------------------------------------------------------------
/src/generate/generator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import json
4 | import json
5 | import time
6 | from typing import List
7 | from tqdm import tqdm
8 | from zhipuai import ZhipuAI
9 | from openai import OpenAI
10 | import httpx
11 | from concurrent.futures import ThreadPoolExecutor, as_completed
12 | from transformers import AutoModelForCausalLM, AutoTokenizer
13 | from generate.prompt_builder import LegalPromptBuilder, CustomSystemPromptBuilder, FullCustomPromptBuilder
14 |
15 | class BaseGenerator:
16 | """Base class for all generators"""
17 | def __init__(self, config, max_retries, max_parallel, top_n, batch_size = 20):
18 | self.config = config
19 | self.max_retries = max_retries
20 | self.max_parallel = max_parallel
21 | self.top_n = top_n
22 | self.batch_size = batch_size
23 | self.failed_ids = set()
24 | self.logger = logging.getLogger(self.__class__.__name__)
25 |
26 | def _save_results(self, result_dict):
27 | with open("data/generated_responses.jsonl", "a", encoding="utf-8") as f:
28 | for item_id in sorted(result_dict.keys(), key=lambda x: int(x.split("_")[0])):
29 | f.write(json.dumps(result_dict[item_id], ensure_ascii=False) + "\n")
30 |
31 | class OpenAIGenerator(BaseGenerator):
32 | """Generator for OpenAI API models"""
33 | def __init__(self, prompt_builder, *args, **kwargs):
34 | super().__init__(*args, **kwargs)
35 | self.prompt_builder = prompt_builder
36 | self.model = self.config.get("model_name", "gpt-3.5-turbo")
37 | if self.model and self.model.startswith("gpt"):
38 | self.client = OpenAI(
39 | base_url=self.config["api_base"],
40 | api_key=self.config["api_key"],
41 | http_client=httpx.Client(
42 | base_url=self.config["api_base"],
43 | follow_redirects=True,
44 | ),
45 | )
46 | else:
47 | self.client = OpenAI(
48 | base_url=self.config["api_base"],
49 | api_key=self.config["api_key"]
50 | )
51 |
52 | def generate(self, processed_data, retrieval_data):
53 | for turn_num, samples in processed_data.items():
54 | messages_list, id_list, questions_list, articles_list, system_prompt = self._prepare_inputs(samples, retrieval_data)
55 | self._batch_call(messages_list, id_list, questions_list, articles_list, system_prompt)
56 |
57 | def _prepare_inputs(self, data, retrieval_data):
58 | messages_list = []
59 | id_list = []
60 | questions_list = []
61 | articles_list = []
62 | system_prompt = []
63 | for sample in data:
64 | articles = self._get_top_articles(sample["id"], retrieval_data, self.top_n)
65 | messages = self.prompt_builder.build_messages(
66 | sample["history"],
67 | sample["current_question"],
68 | articles
69 | )
70 |
71 | system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), "")
72 |
73 | messages_list.append([msg for msg in messages if msg["role"] != "system"])
74 | id_list.append(sample["id"])
75 | questions_list.append(sample["current_question"])
76 | articles_list.append(articles)
77 | system_prompt.append(system_msg)
78 |
79 |
80 | return messages_list, id_list, questions_list, articles_list, system_prompt
81 |
82 |
83 | def _get_top_articles(self, sample_id, retrieval_data, top_n):
84 | try:
85 | parts = sample_id.split("_")
86 | if len(parts) != 2 or not parts[1].startswith("turn"):
87 | raise ValueError(f"Invalid sample_id format: {sample_id}")
88 |
89 | dialogue_id = int(parts[0])
90 | turn_number = int(parts[1][4:])
91 | turn_index = turn_number - 1
92 | dialogue = retrieval_data.get(dialogue_id, {}).get("conversation", [])
93 | if not dialogue or turn_index >= len(dialogue):
94 | return []
95 |
96 | recall_list = dialogue[turn_index]["question"]["recall"]
97 | sorted_recall = sorted(recall_list,
98 | key=lambda x: x["score"],
99 | reverse=True)[:top_n]
100 |
101 | return [item["article"]["name"] for item in sorted_recall]
102 |
103 | except Exception as e:
104 | logging.error(f"Error processing sample {sample_id}: {str(e)}")
105 | return []
106 |
107 | def _call_api(self, messages, item_id, question, articles, system_prompt):
108 | full_messages = []
109 | if system_prompt:
110 | full_messages.append({"role": "system", "content": system_prompt})
111 | full_messages.extend(messages)
112 |
113 | for attempt in range(self.max_retries):
114 | try:
115 | response = self.client.chat.completions.create(
116 | model=self.model,
117 | messages=full_messages,
118 | temperature=0.0
119 | )
120 | return {
121 | "id": item_id,
122 | "question": question,
123 | "response": response.choices[0].message.content
124 | }
125 | except Exception as e:
126 | print(f"API Error (Attempt {attempt+1}): {str(e)}")
127 | time.sleep(2 ** attempt)
128 |
129 | self.failed_ids.add(item_id)
130 | return {"id": item_id, "question": question, "response": ""}
131 |
132 | def _batch_call(self, messages_list, id_list, questions_list, articles_list, system_prompts):
133 | result_dict = {}
134 | with ThreadPoolExecutor(max_workers=self.max_parallel) as executor:
135 | futures = []
136 |
137 | with tqdm(total=len(messages_list), desc="Generating responses") as pbar:
138 | for messages, item_id, question, articles, system_prompt in zip(messages_list, id_list, questions_list, articles_list, system_prompts):
139 | future = executor.submit(self._call_api, messages, item_id, question, articles, system_prompt)
140 | future.add_done_callback(lambda _: pbar.update(1))
141 | futures.append(future)
142 |
143 | if len(futures) >= self.batch_size:
144 | current_batch = []
145 | for future in as_completed(futures):
146 | result = future.result()
147 | result_dict[result["id"]] = result
148 | current_batch.append(result)
149 | self._save_results({r["id"]: r for r in current_batch})
150 | futures.clear()
151 |
152 | if futures:
153 | current_batch = []
154 | for future in as_completed(futures):
155 | result = future.result()
156 | result_dict[result["id"]] = result
157 | current_batch.append(result)
158 | self._save_results({r["id"]: r for r in current_batch})
159 |
160 | return [result_dict[id] for id in sorted(id_list, key=lambda x: int(x.split("_")[0]))]
161 |
162 | class ZhipuGenerator(BaseGenerator):
163 | """Generator for ZhipuAI API models"""
164 | def __init__(self, prompt_builder, *args, **kwargs):
165 | super().__init__(*args, **kwargs)
166 | self.prompt_builder = prompt_builder
167 | self.client = ZhipuAI(api_key=self.config["api_key"])
168 | self.model = self.config.get("model_name", "glm-4-flash")
169 |
170 | def generate(self, processed_data, retrieval_data):
171 | for turn_num, samples in processed_data.items():
172 | messages_list, id_list, questions_list, articles_list, system_prompt = self._prepare_inputs(samples, retrieval_data)
173 | self._batch_call(messages_list, id_list, questions_list, articles_list, system_prompt)
174 |
175 | def _prepare_inputs(self, data, retrieval_data):
176 | messages_list = []
177 | id_list = []
178 | questions_list = []
179 | articles_list = []
180 | system_prompt = []
181 | for sample in data:
182 | articles = self._get_top_articles(sample["id"], retrieval_data, self.top_n)
183 | messages = self.prompt_builder.build_messages(
184 | sample["history"],
185 | sample["current_question"],
186 | articles
187 | )
188 |
189 | system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), "")
190 |
191 | messages_list.append([msg for msg in messages if msg["role"] != "system"])
192 | id_list.append(sample["id"])
193 | questions_list.append(sample["current_question"])
194 | articles_list.append(articles)
195 | system_prompt.append(system_msg)
196 |
197 |
198 | return messages_list, id_list, questions_list, articles_list, system_prompt
199 |
200 |
201 | def _get_top_articles(self, sample_id, retrieval_data, top_n):
202 | try:
203 | parts = sample_id.split("_")
204 | if len(parts) != 2 or not parts[1].startswith("turn"):
205 | raise ValueError(f"Invalid sample_id format: {sample_id}")
206 |
207 | dialogue_id = int(parts[0])
208 | turn_number = int(parts[1][4:])
209 | turn_index = turn_number - 1
210 | dialogue = retrieval_data.get(dialogue_id, {}).get("conversation", [])
211 | if not dialogue or turn_index >= len(dialogue):
212 | return []
213 |
214 | recall_list = dialogue[turn_index]["question"]["recall"]
215 | sorted_recall = sorted(recall_list,
216 | key=lambda x: x["score"],
217 | reverse=True)[:top_n]
218 |
219 | return [item["article"]["name"] for item in sorted_recall]
220 |
221 | except Exception as e:
222 | logging.error(f"Error processing sample {sample_id}: {str(e)}")
223 | return []
224 |
225 | def _call_api(self, messages, item_id, question, articles, system_prompt):
226 | full_messages = []
227 | if system_prompt:
228 | full_messages.append({"role": "system", "content": system_prompt})
229 | full_messages.extend(messages)
230 |
231 | for attempt in range(self.max_retries):
232 | try:
233 | response = self.client.chat.completions.create(
234 | model=self.model,
235 | messages=full_messages,
236 | temperature=0.0
237 | )
238 | return {
239 | "id": item_id,
240 | "question": question,
241 | "response": response.choices[0].message.content
242 | }
243 | except Exception as e:
244 | print(f"API Error (Attempt {attempt+1}): {str(e)}")
245 | time.sleep(2 ** attempt)
246 |
247 | self.failed_ids.add(item_id)
248 | return {"id": item_id, "question": question, "response": ""}
249 |
250 | def _batch_call(self, messages_list, id_list, questions_list, articles_list, system_prompts):
251 | result_dict = {}
252 | with ThreadPoolExecutor(max_workers=self.max_parallel) as executor:
253 | futures = []
254 |
255 | with tqdm(total=len(messages_list), desc="Generating responses") as pbar:
256 | for messages, item_id, question, articles, system_prompt in zip(messages_list, id_list, questions_list, articles_list, system_prompts):
257 | future = executor.submit(self._call_api, messages, item_id, question, articles, system_prompt)
258 | future.add_done_callback(lambda _: pbar.update(1))
259 | futures.append(future)
260 |
261 | if len(futures) >= self.batch_size:
262 | current_batch = []
263 | for future in as_completed(futures):
264 | result = future.result()
265 | result_dict[result["id"]] = result
266 | current_batch.append(result)
267 | self._save_results({r["id"]: r for r in current_batch})
268 | futures.clear()
269 |
270 | if futures:
271 | current_batch = []
272 | for future in as_completed(futures):
273 | result = future.result()
274 | result_dict[result["id"]] = result
275 | current_batch.append(result)
276 | self._save_results({r["id"]: r for r in current_batch})
277 |
278 | return [result_dict[id] for id in sorted(id_list, key=lambda x: int(x.split("_")[0]))]
279 |
280 | class VLLMGenerator(BaseGenerator):
281 | """Generator for vLLM models"""
282 | def __init__(self, prompt_builder, *args, **kwargs):
283 | self.prompt_builder = prompt_builder
284 | kwargs.pop('prompt_builder', None)
285 | super().__init__(*args, **kwargs)
286 | from vllm import LLM, SamplingParams
287 | self.llm = LLM(
288 | model=self.config["model_path"],
289 | tensor_parallel_size=self.config["gpu_num"],
290 | gpu_memory_utilization=0.85
291 | )
292 | self.sampling_params = SamplingParams(
293 | temperature=0.0,
294 | max_tokens=4096
295 | )
296 |
297 | def generate(self, processed_data, retrieval_data):
298 | for turn_num, samples in processed_data.items():
299 | messages_list, id_list, questions_list, articles_list, system_prompts = self._prepare_inputs(samples, retrieval_data)
300 | prompts = self._build_vllm_prompts(messages_list, system_prompts)
301 | responses = self.llm.generate(prompts, self.sampling_params)
302 | self._save_vllm_results(responses, id_list, questions_list)
303 |
304 | def _prepare_inputs(self, data, retrieval_data):
305 | messages_list = []
306 | id_list = []
307 | questions_list = []
308 | articles_list = []
309 | system_prompt = []
310 | for sample in data:
311 | articles = self._get_top_articles(sample["id"], retrieval_data, self.top_n)
312 | messages = self.prompt_builder.build_messages(
313 | sample["history"],
314 | sample["current_question"],
315 | articles
316 | )
317 |
318 | system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), "")
319 |
320 | messages_list.append([msg for msg in messages if msg["role"] != "system"])
321 | id_list.append(sample["id"])
322 | questions_list.append(sample["current_question"])
323 | articles_list.append(articles)
324 | system_prompt.append(system_msg)
325 |
326 |
327 | return messages_list, id_list, questions_list, articles_list, system_prompt
328 |
329 | def _build_vllm_prompts(self, messages_list, system_prompts):
330 | prompts = []
331 | for messages, system_prompt in zip(messages_list, system_prompts):
332 | full_dialog = []
333 | if system_prompt:
334 | full_dialog.append(f"System: {system_prompt}")
335 |
336 | for msg in messages:
337 | if msg["role"] == "user":
338 | full_dialog.append(f"User: {msg['content']}")
339 | elif msg["role"] == "assistant":
340 | full_dialog.append(f"Assistant: {msg['content']}")
341 |
342 | last_question = messages[-1]["content"] if messages else ""
343 | full_dialog.append(f"User: {last_question}")
344 |
345 | prompt_text = "\n".join(full_dialog) + "\nAssistant:"
346 | prompts.append(prompt_text)
347 | return prompts
348 |
349 | def _get_top_articles(self, sample_id, retrieval_data, top_n):
350 | try:
351 | parts = sample_id.split("_")
352 | if len(parts) != 2 or not parts[1].startswith("turn"):
353 | raise ValueError(f"Invalid sample_id format: {sample_id}")
354 |
355 | dialogue_id = int(parts[0])
356 | turn_number = int(parts[1][4:])
357 | turn_index = turn_number - 1
358 | dialogue = retrieval_data.get(dialogue_id, {}).get("conversation", [])
359 | if not dialogue or turn_index >= len(dialogue):
360 | return []
361 |
362 | recall_list = dialogue[turn_index]["question"]["recall"]
363 | sorted_recall = sorted(recall_list,
364 | key=lambda x: x["score"],
365 | reverse=True)[:top_n]
366 |
367 | return [item["article"]["name"] for item in sorted_recall]
368 |
369 | except Exception as e:
370 | logging.error(f"Error processing sample {sample_id}: {str(e)}")
371 | return []
372 |
373 | def _save_vllm_results(self, outputs, id_list, questions_list):
374 | results = {}
375 | for output, item_id, question in zip(outputs, id_list, questions_list):
376 | results[item_id] = {
377 | "id": item_id,
378 | "question": question,
379 | "response": output.outputs[0].text.strip()
380 | }
381 | self._save_results(results)
382 |
383 | class HuggingFaceGenerator(BaseGenerator):
384 | """Generator for HuggingFace models"""
385 | def __init__(self, config, prompt_builder, **kwargs):
386 | super().__init__(config, **kwargs)
387 | self.prompt_builder = prompt_builder
388 | from transformers import AutoModelForCausalLM, AutoTokenizer
389 |
390 | self.tokenizer = AutoTokenizer.from_pretrained(
391 | config["model_path"],
392 | trust_remote_code=True
393 | )
394 | self.model = AutoModelForCausalLM.from_pretrained(
395 | config["model_path"],
396 | device_map="auto",
397 | trust_remote_code=True
398 | )
399 |
400 | def _call_api(self, messages, item_id, question, articles, system_prompt):
401 | try:
402 | full_messages = []
403 | if system_prompt:
404 | full_messages.append({"role": "system", "content": system_prompt})
405 | full_messages.extend(messages)
406 |
407 | inputs = self.tokenizer.apply_chat_template(
408 | full_messages,
409 | tokenize=False,
410 | add_generation_prompt=True
411 | )
412 | model_inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
413 |
414 | outputs = self.model.generate(
415 | model_inputs.input_ids,
416 | do_sample=False
417 | )
418 | response = self.tokenizer.decode(
419 | outputs[0][len(model_inputs.input_ids[0]):],
420 | skip_special_tokens=True
421 | )
422 | return {
423 | "id": item_id,
424 | "question": question,
425 | "response": response
426 | }
427 | except Exception as e:
428 | print(f"Generate Error:{str(e)}")
429 | return {"id": item_id, "question": question, "response": ""}
430 |
431 | def generate(self, processed_data, retrieval_data):
432 | for turn_num, samples in processed_data.items():
433 | messages_list, id_list, questions_list, articles_list, system_prompts = self._prepare_inputs(samples, retrieval_data)
434 | self._batch_call(messages_list, id_list, questions_list, articles_list, system_prompts)
435 |
436 | def _prepare_inputs(self, data, retrieval_data):
437 | messages_list = []
438 | id_list = []
439 | questions_list = []
440 | articles_list = []
441 | system_prompt = []
442 | for sample in data:
443 | articles = self._get_top_articles(sample["id"], retrieval_data, self.top_n)
444 | messages = self.prompt_builder.build_messages(
445 | sample["history"],
446 | sample["current_question"],
447 | articles
448 | )
449 |
450 | system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), "")
451 |
452 | messages_list.append([msg for msg in messages if msg["role"] != "system"])
453 | id_list.append(sample["id"])
454 | questions_list.append(sample["current_question"])
455 | articles_list.append(articles)
456 | system_prompt.append(system_msg)
457 |
458 |
459 | return messages_list, id_list, questions_list, articles_list, system_prompt
460 |
461 |
462 | def _get_top_articles(self, sample_id, retrieval_data, top_n):
463 | try:
464 | parts = sample_id.split("_")
465 | if len(parts) != 2 or not parts[1].startswith("turn"):
466 | raise ValueError(f"Invalid sample_id format: {sample_id}")
467 |
468 | dialogue_id = int(parts[0])
469 | turn_number = int(parts[1][4:])
470 | turn_index = turn_number - 1
471 | dialogue = retrieval_data.get(dialogue_id, {}).get("conversation", [])
472 | if not dialogue or turn_index >= len(dialogue):
473 | return []
474 |
475 | recall_list = dialogue[turn_index]["question"]["recall"]
476 | sorted_recall = sorted(recall_list,
477 | key=lambda x: x["score"],
478 | reverse=True)[:top_n]
479 |
480 | return [item["article"]["name"] for item in sorted_recall]
481 |
482 | except Exception as e:
483 | logging.error(f"Error processing sample {sample_id}: {str(e)}")
484 | return []
485 |
486 | def _batch_call(self, messages_list, id_list, questions_list, articles_list, system_prompts):
487 | result_dict = {}
488 | with ThreadPoolExecutor(max_workers=self.max_parallel) as executor:
489 | futures = []
490 |
491 | with tqdm(total=len(messages_list), desc="Generating responses") as pbar:
492 | for messages, item_id, question, articles, system_prompt in zip(messages_list, id_list, questions_list, articles_list, system_prompts):
493 | future = executor.submit(self._call_api, messages, item_id, question, articles, system_prompt)
494 | future.add_done_callback(lambda _: pbar.update(1))
495 | futures.append(future)
496 |
497 | if len(futures) >= self.batch_size:
498 | current_batch = []
499 | for future in as_completed(futures):
500 | result = future.result()
501 | result_dict[result["id"]] = result
502 | current_batch.append(result)
503 | self._save_results({r["id"]: r for r in current_batch})
504 | futures.clear()
505 |
506 | if futures:
507 | current_batch = []
508 | for future in as_completed(futures):
509 | result = future.result()
510 | result_dict[result["id"]] = result
511 | current_batch.append(result)
512 | self._save_results({r["id"]: r for r in current_batch})
513 |
514 | return [result_dict[id] for id in sorted(id_list, key=lambda x: int(x.split("_")[0]))]
515 |
516 | class LocalGenerator(HuggingFaceGenerator):
517 | def __init__(self, config, prompt_builder, **kwargs):
518 | super(HuggingFaceGenerator, self).__init__(config, **kwargs)
519 | self.prompt_builder = prompt_builder
520 |
521 | from transformers import AutoModelForCausalLM, AutoTokenizer
522 | self.tokenizer = AutoTokenizer.from_pretrained(
523 | config["model_path"],
524 | trust_remote_code=True
525 | )
526 | self.model = AutoModelForCausalLM.from_pretrained(
527 | config["model_path"],
528 | device_map="cpu",
529 | torch_dtype="auto",
530 | trust_remote_code=True
531 | )
532 |
533 | def _call_api(self, messages, item_id, question, articles, system_prompt):
534 | try:
535 | full_messages = []
536 | if system_prompt:
537 | full_messages.append({"role": "system", "content": system_prompt})
538 | full_messages.extend(messages)
539 |
540 | inputs = self.tokenizer.apply_chat_template(
541 | full_messages,
542 | tokenize=False,
543 | add_generation_prompt=True
544 | )
545 | model_inputs = self.tokenizer(inputs, return_tensors="pt")
546 |
547 | outputs = self.model.generate(
548 | model_inputs.input_ids,
549 | do_sample=False
550 | )
551 | response = self.tokenizer.decode(
552 | outputs[0][len(model_inputs.input_ids[0]):],
553 | skip_special_tokens=True
554 | )
555 | return {
556 | "id": item_id,
557 | "question": question,
558 | "response": response
559 | }
560 | except Exception as e:
561 | print(f"Generate Error:{str(e)}")
562 | return {"id": item_id, "question": question, "response": ""}
563 |
564 | class LocalGenerator(BaseGenerator):
565 | """Local Models"""
566 | def __init__(self, prompt_builder, *args, **kwargs):
567 | super().__init__(*args, **kwargs)
568 | self.prompt_builder = prompt_builder
569 | self._load_local_model()
570 |
571 | def _load_local_model(self):
572 | try:
573 | self.tokenizer = AutoTokenizer.from_pretrained(
574 | self.config["model_path"],
575 | trust_remote_code=True,
576 | local_files_only=True
577 | )
578 | self.model = AutoModelForCausalLM.from_pretrained(
579 | self.config["model_path"],
580 | device_map="cuda",
581 | trust_remote_code=True,
582 | local_files_only=True
583 | )
584 | except Exception as e:
585 | raise RuntimeError(f"Load Error!Check path:{self.config['model_path']}") from e
586 |
587 | def generate(self, processed_data, retrieval_data):
588 | for turn_num, samples in processed_data.items():
589 | messages_list, id_list, questions_list, articles_list, system_prompts = self._prepare_inputs(samples, retrieval_data)
590 | self._batch_call(messages_list, id_list, questions_list, articles_list, system_prompts)
591 |
592 | def _prepare_inputs(self, data, retrieval_data):
593 | messages_list = []
594 | id_list = []
595 | questions_list = []
596 | articles_list = []
597 | system_prompts = []
598 | for sample in data:
599 | articles = self._get_top_articles(sample["id"], retrieval_data, self.top_n)
600 | messages = self.prompt_builder.build_messages(
601 | sample["history"],
602 | sample["current_question"],
603 | articles
604 | )
605 |
606 | system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), "")
607 |
608 | messages_list.append([msg for msg in messages if msg["role"] != "system"])
609 | id_list.append(sample["id"])
610 | questions_list.append(sample["current_question"])
611 | articles_list.append(articles)
612 | system_prompts.append(system_msg)
613 |
614 | return messages_list, id_list, questions_list, articles_list, system_prompts
615 |
616 | def _get_top_articles(self, sample_id, retrieval_data, top_n):
617 | try:
618 | parts = sample_id.split("_")
619 | if len(parts) != 2 or not parts[1].startswith("turn"):
620 | raise ValueError(f"Invalid sample_id format: {sample_id}")
621 |
622 | dialogue_id = int(parts[0])
623 | turn_number = int(parts[1][4:])
624 | turn_index = turn_number - 1
625 | dialogue = retrieval_data.get(dialogue_id, {}).get("conversation", [])
626 | if not dialogue or turn_index >= len(dialogue):
627 | return []
628 |
629 | recall_list = dialogue[turn_index]["question"]["recall"]
630 | sorted_recall = sorted(recall_list,
631 | key=lambda x: x["score"],
632 | reverse=True)[:top_n]
633 |
634 | return [item["article"]["name"] for item in sorted_recall]
635 |
636 | except Exception as e:
637 | self.logger.error(f"Error processing sample {sample_id}: {str(e)}")
638 | return []
639 |
640 | def _call_api(self, messages, item_id, question, articles, system_prompt):
641 | try:
642 | full_messages = []
643 | if system_prompt:
644 | full_messages.append({"role": "system", "content": system_prompt})
645 | full_messages.extend(messages)
646 |
647 | inputs = self.tokenizer.apply_chat_template(
648 | full_messages,
649 | tokenize=False,
650 | add_generation_prompt=True
651 | )
652 | model_inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
653 |
654 | outputs = self.model.generate(
655 | model_inputs.input_ids,
656 | do_sample=False
657 | )
658 | response = self.tokenizer.decode(
659 | outputs[0][len(model_inputs.input_ids[0]):],
660 | skip_special_tokens=True
661 | )
662 |
663 | return {
664 | "id": item_id,
665 | "question": question,
666 | "response": response.strip()
667 | }
668 | except Exception as e:
669 | self.logger.error(f"Generate Error ID {item_id}: {str(e)}")
670 | return {"id": item_id, "question": question, "response": ""}
671 |
672 | def _batch_call(self, messages_list, id_list, questions_list, articles_list, system_prompts):
673 | result_dict = {}
674 | with ThreadPoolExecutor(max_workers=self.max_parallel) as executor:
675 | futures = []
676 |
677 | with tqdm(total=len(messages_list), desc="Local Model Generate") as pbar:
678 | for messages, item_id, question, articles, system_prompt in zip(
679 | messages_list, id_list, questions_list, articles_list, system_prompts
680 | ):
681 | future = executor.submit(
682 | self._call_api,
683 | messages,
684 | item_id,
685 | question,
686 | articles,
687 | system_prompt
688 | )
689 | future.add_done_callback(lambda _: pbar.update(1))
690 | futures.append(future)
691 |
692 | if len(futures) >= self.batch_size:
693 | self._process_batch(futures, result_dict)
694 | futures.clear()
695 |
696 | if futures:
697 | self._process_batch(futures, result_dict)
698 |
699 | return [result_dict[id] for id in sorted(id_list, key=lambda x: int(x.split("_")[0]))]
700 |
701 | def _process_batch(self, futures, result_dict):
702 | current_batch = []
703 | for future in as_completed(futures):
704 | result = future.result()
705 | result_dict[result["id"]] = result
706 | current_batch.append(result)
707 | self._save_results({r["id"]: r for r in current_batch})
--------------------------------------------------------------------------------
/src/generate/prompt_builder.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Dict
3 |
4 | class PromptBuilder(ABC):
5 | @abstractmethod
6 | def build_messages(self,
7 | history: List[Dict],
8 | current_question: str,
9 | articles: List[str]) -> List[Dict]:
10 | pass
11 |
12 | class LegalPromptBuilder(PromptBuilder):
13 | """Default PromptBuilder"""
14 | def build_messages(self, history, current_question, articles):
15 | messages = []
16 | system_msg = {
17 | "role": "system",
18 | "content": "你是一位精通法律知识的专家,致力于为用户提供准确、专业的法律咨询。你的回复应确保严谨、高效,并在风格上与前几轮的回答保持一致(如有)。若用户的问题涉及具体法律条文,应尽可能引用相关法条,以增强回答的权威性。同时,避免提供无关信息,确保回复简明、直接且切中要害。"
19 | }
20 |
21 | if articles:
22 | system_msg["content"] += "\n\n以下是你可以参考的法条:\n" + "\n".join(
23 | [f"{i+1}. {art}" for i, art in enumerate(articles)]
24 | )
25 |
26 | messages.append(system_msg)
27 | for h in history:
28 | messages.extend([
29 | {"role": "user", "content": h["user"]},
30 | {"role": "assistant", "content": h["assistant"]}
31 | ])
32 | messages.append({"role": "user", "content": current_question})
33 | return messages
34 |
35 | class CustomSystemPromptBuilder(PromptBuilder):
36 | """Customise system"""
37 | def __init__(self, system_template: str):
38 | self.system_template = system_template
39 |
40 | def build_messages(self, history, current_question, articles):
41 | messages = []
42 | system_content = self.system_template
43 | if articles:
44 | system_content += "\n\n以下是你可以参考的法条:\n" + "\n".join(articles)
45 |
46 | messages.append({"role": "system", "content": system_content})
47 | for h in history:
48 | messages.extend([
49 | {"role": "user", "content": h["user"]},
50 | {"role": "assistant", "content": h["assistant"]}
51 | ])
52 | messages.append({"role": "user", "content": current_question})
53 | return messages
54 |
55 | class FullCustomPromptBuilder(PromptBuilder):
56 | """Fully customisable PromptBuilder"""
57 | def __init__(self, build_fn):
58 | self.build_fn = build_fn # receive custom constructor
59 |
60 | def build_messages(self, history, current_question, articles):
61 | return self.build_fn(history, current_question, articles)
--------------------------------------------------------------------------------
/src/pipeline.py:
--------------------------------------------------------------------------------
1 | from utils.utils import create_generator, create_processor, create_evaluator, create_retriever
2 | from generate.data_processor import DataProcessor
3 | from config.config import Config
4 | import json
5 | import os
6 |
7 | class GeneratorPipeline:
8 | def __init__(self, model_type=None, config=None, prompt_builder=None):
9 | self.model_type = model_type
10 | from config.config import Config
11 | if isinstance(config, Config):
12 | self.config = config.config
13 | elif config:
14 | self.config = config
15 | else:
16 | cfg = Config(model_type=model_type)
17 | self.config = cfg.config
18 |
19 | from generate.prompt_builder import LegalPromptBuilder
20 | if prompt_builder:
21 | self.prompt_builder = prompt_builder
22 | else:
23 | self.prompt_builder = LegalPromptBuilder()
24 |
25 | self.data_processor = DataProcessor()
26 |
27 | def run_generator(self,
28 | raw_data_path,
29 | retrieval_data_path,
30 | top_n,
31 | max_retries=None,
32 | max_parallel=None,
33 | batch_size=None
34 | ):
35 | processed_data = self.data_processor.process_conversation_turns(raw_data_path)
36 | retrieval_data = self._load_retrieval_data(retrieval_data_path)
37 |
38 | generator = create_generator(
39 | model_type=self.model_type,
40 | config=self.config,
41 | prompt_builder=self.prompt_builder,
42 | max_retries=max_retries,
43 | max_parallel=max_parallel,
44 | top_n=top_n,
45 | batch_size=batch_size
46 | )
47 | generator.generate(processed_data, retrieval_data)
48 |
49 | def _load_retrieval_data(self, data_path):
50 | llm_data = {}
51 | with open(data_path, "r", encoding="utf-8") as f:
52 | for line in f:
53 | entry = json.loads(line)
54 | llm_data[entry["id"]] = entry
55 | return llm_data
56 |
57 | class ProcessorPipeline:
58 |
59 | def __init__(self, model_type=None, config=None):
60 | if config:
61 | if isinstance(config, dict):
62 | self.config = Config(config_dict=config)
63 | else:
64 | self.config = config
65 | elif model_type:
66 | self.config = Config(model_type=model_type)
67 |
68 | def run_processor(self,
69 | process_type: str,
70 | original_data_path: str,
71 | output_path: str,
72 | max_retries: int = None,
73 | max_parallel: int = None,
74 | batch_size: int = None):
75 |
76 | if process_type == "rewrite_question":
77 | processor = create_processor(
78 | "rewrite_question",
79 | config=self.config,
80 | max_retries=max_retries,
81 | max_parallel=max_parallel,
82 | batch_size=batch_size
83 | )
84 | processor.batch_process(original_data_path, output_path)
85 | else:
86 | processor = create_processor(
87 | process_type=process_type
88 | )
89 | processor.run_process(original_data_path, output_path)
90 |
91 | class EvaluatorPipeline:
92 | def __init__(self, model_type: str = None):
93 | self.model_type = model_type
94 | self.config = Config._default_configs[model_type] if model_type else None
95 |
96 | def run_evaluator(self,
97 | eval_type,
98 | metrics=None,
99 | data_path: str = None,
100 | gen_path: str = None,
101 | response_file: str = None,
102 | results_path: str = None,
103 | k_values=None):
104 | evaluator = create_evaluator(eval_type, self.config)
105 |
106 | if eval_type == "generation":
107 | return evaluator.evaluate(data_path, response_file, metrics)
108 | elif eval_type == "llm_judge":
109 | return evaluator.evaluate(data_path, gen_path)
110 | elif eval_type == "retrieval":
111 | return evaluator.evaluate(results_path, metrics, k_values)
112 | raise ValueError(f"Unsupported evaluator type: {eval_type}")
113 |
114 | class RetrieverPipeline:
115 | def __init__(self, config=None):
116 | self.config = config
117 | self.retriever = create_retriever(config=self.config)
118 |
119 | def run_retriever(self, model_type, question_file_path, law_path, **kwargs):
120 | self.retriever.run_retriever(
121 | model_type=model_type,
122 | question_file_path=question_file_path,
123 | law_path=law_path,
124 | **kwargs
125 | )
126 |
--------------------------------------------------------------------------------
/src/process/processor.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from tqdm import tqdm
4 |
5 | class QuestionGenerator:
6 | def __init__(self, process_type: str):
7 | self.process_type = process_type
8 | self.process_methods = {
9 | "current_question": self._current_question,
10 | "prefix_question": self._prefix_question,
11 | "prefix_question_answer": self._prefix_question_answer,
12 | "suffix_question": self._suffix_question
13 | }
14 |
15 | def run_process(self, original_data_path: str, output_path: str):
16 | method = self.process_methods.get(self.process_type)
17 | if not method:
18 | raise ValueError(f"Unsupported process type: {self.process_type}")
19 |
20 | Path(output_path).parent.mkdir(parents=True, exist_ok=True)
21 | with open(original_data_path, "r", encoding="utf-8") as f:
22 | data_list = json.load(f)
23 |
24 | method(data_list, output_path)
25 |
26 | def _current_question(self, data_list, file_path):
27 | for data in data_list:
28 | for conversation in data["conversation"]:
29 | conversation["question"] = {
30 | "type": "current_question",
31 | "content": conversation["user"],
32 | }
33 | self._save_output(data_list, file_path)
34 |
35 | def _prefix_question(self, data_list, file_path):
36 | for data in data_list:
37 | prefix_question = ""
38 | for conversation in data["conversation"]:
39 | prefix_question += conversation["user"]
40 | conversation["question"] = {
41 | "type": "prefix_question",
42 | "content": prefix_question,
43 | }
44 | self._save_output(data_list, file_path)
45 |
46 | def _prefix_question_answer(self, data_list, file_path):
47 | for data in data_list:
48 | prefix_question_answer = ""
49 | for conversation in data["conversation"]:
50 | prefix_question_answer += f" {conversation['user']}\n\n"
51 | conversation["question"] = {
52 | "type": "prefix_question_answer",
53 | "content": prefix_question_answer,
54 | }
55 | prefix_question_answer += f"{conversation['assistant']}\n\n"
56 | self._save_output(data_list, file_path)
57 |
58 | def _suffix_question(self, data_list, file_path):
59 | for data in data_list:
60 | suffix_question = ""
61 | for idx in range(len(data["conversation"]) - 1, -1, -1):
62 | suffix_question += f"{data['conversation'][idx]['user']}\n\n"
63 | data["conversation"][idx]["question"] = {
64 | "type": "suffix_question",
65 | "content": suffix_question,
66 | }
67 | self._save_output(data_list, file_path)
68 |
69 | def _save_output(self, data_list, file_path):
70 | with open(file_path, "w", encoding="utf-8") as f:
71 | for data in data_list:
72 | f.write(json.dumps(data, ensure_ascii=False) + "\n")
73 |
--------------------------------------------------------------------------------
/src/process/rewriter.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import httpx
4 | from tqdm import tqdm
5 | from pathlib import Path
6 | from openai import OpenAI
7 | from concurrent.futures import ThreadPoolExecutor, as_completed
8 | from zhipuai import ZhipuAI
9 |
10 | class Rewriter:
11 | def __init__(self, config, max_retries, max_parallel, batch_size):
12 | self.config = config
13 | self.max_retries = max_retries
14 | self.max_parallel = max_parallel
15 | self.batch_size = batch_size
16 | self.model = self.config.get("model_name")
17 |
18 | if self.config.get("model_type") == "openai":
19 | self.client = OpenAI(
20 | base_url=self.config.get("api_base"),
21 | api_key=self.config.get("api_key"),
22 | http_client=httpx.Client(
23 | base_url=self.config.get("api_base"),
24 | follow_redirects=True,
25 | ),
26 | )
27 | elif self.config.get("model_type") == "zhipu":
28 | self.client = ZhipuAI(api_key=self.config.get("api_key"))
29 | elif self.config.get("model_type") in ["qwen","llama"]:
30 | self.client = OpenAI(
31 | base_url=self.config.get("api_base"),
32 | api_key=self.config.get("api_key")
33 | )
34 |
35 | self.results = {}
36 | self.failed_ids = set()
37 |
38 | def generate_prompt(self, history, question):
39 | if not history.strip() or history == "无历史对话":
40 | return question
41 | return f"""给定以下对话(包括历史对话和当前问题),请将用户的当前问题改写为一个独立的问题,使其无需依赖对话历史即可理解用户的意图。
42 | 在改写用户的当前问题时,请避免不必要的措辞修改或引入对话中未提及的新术语或概念。改写应尽可能接近用户当前问题的结构和含义。
43 |
44 | 历史对话:
45 | {history}
46 |
47 | 当前问题:{question}
48 |
49 | 请输出改写结果:"""
50 |
51 | def process_single(self, data_id, conv_idx, history, question):
52 | if not history.strip() or history == "无历史对话":
53 | return (data_id, conv_idx, question)
54 | unique_id = f"{data_id}_{conv_idx}"
55 | for attempt in range(self.max_retries):
56 | try:
57 | response = self.client.chat.completions.create(
58 | model=self.model,
59 | messages=[{
60 | "role": "user",
61 | "content": self.generate_prompt(history, question)
62 | }],
63 | temperature=0.0,
64 | stream=False
65 | )
66 |
67 | reworded = response.choices[0].message.content.strip()
68 |
69 | reworded = reworded.replace('"', '').replace("```", "").strip()
70 | return (data_id, conv_idx, reworded)
71 |
72 | except Exception as e:
73 | print(f"[{unique_id}] Attempt {attempt+1} failed: {str(e)}")
74 | time.sleep(2 ** attempt)
75 |
76 | self.failed_ids.add(unique_id)
77 | return (data_id, conv_idx, question)
78 |
79 | def batch_process(self, original_data, output_path):
80 | with open(original_data, "r", encoding="utf-8") as f:
81 | original_data = json.load(f)
82 |
83 | tasks = []
84 | for data in original_data:
85 | data_id = data["id"]
86 | for conv_idx, conv in enumerate(data["conversation"]):
87 | history = "\n".join(
88 | [f"用户:{c['user']}\n助理:{c['assistant']}"
89 | for c in data["conversation"][:conv_idx]]
90 | )
91 | if not history.strip():
92 | history = "无历史对话"
93 | tasks.append((
94 | data_id,
95 | conv_idx,
96 | history,
97 | conv["user"]
98 | ))
99 |
100 | completed_count = 0
101 |
102 | with ThreadPoolExecutor(max_workers=self.max_parallel) as executor:
103 | futures = []
104 |
105 | for task in tasks:
106 | future = executor.submit(self.process_single, *task)
107 | futures.append(future)
108 |
109 | with tqdm(total=len(tasks), desc="Processing") as pbar:
110 | for future in as_completed(futures):
111 | data_id, conv_idx, reworded = future.result()
112 | unique_id = f"{data_id}_{conv_idx}"
113 | self.results[unique_id] = reworded
114 | completed_count += 1
115 | pbar.update(1)
116 | if completed_count % self.batch_size == 0:
117 | self.generate_output_file(original_data, self.results, output_path)
118 |
119 | self.generate_output_file(original_data, self.results, output_path)
120 | print(f"Process completed. Failed: {len(self.failed_ids)}")
121 |
122 | def generate_output_file(self, original_data, results, output_path):
123 | output_lines = []
124 | for data in original_data:
125 | data_id = data["id"]
126 | for conv_idx, conv in enumerate(data["conversation"]):
127 | unique_id = f"{data_id}_{conv_idx}"
128 | conv["question"] = {
129 | "type": "llm_question",
130 | "content": results.get(unique_id, conv["user"])
131 | }
132 | output_lines.append(json.dumps(data, ensure_ascii=False))
133 |
134 | with open(output_path, "w", encoding="utf-8") as f:
135 | f.write("\n".join(output_lines) + "\n")
136 |
--------------------------------------------------------------------------------
/src/retrieval/dense_retriever.py:
--------------------------------------------------------------------------------
1 | from sentence_transformers import SentenceTransformer
2 | from modelscope import snapshot_download
3 | from openai import OpenAI
4 | import httpx
5 | import faiss
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | class DenseRetriever:
10 | def __init__(self, api_key=None, base_url=None):
11 | self.api_key = api_key
12 | self.base_url = base_url
13 | self.model = None
14 |
15 | def load_model(self, model_type):
16 | if model_type == "BGE-base-zh":
17 | model_dir = snapshot_download(
18 | "AI-ModelScope/bge-base-zh-v1.5", revision="master"
19 | )
20 | self.model = SentenceTransformer(model_dir, trust_remote_code=True)
21 | elif model_type == "Qwen2-1.5B": #GTE model
22 | model_dir = snapshot_download("iic/gte_Qwen2-1.5B-instruct")
23 | self.model = SentenceTransformer(model_dir, trust_remote_code=True)
24 | elif model_type == "openai":
25 | self.model = None
26 |
27 | def _BGE_embedding(self, texts: list):
28 | embeddings = self.model.encode(texts)
29 | return embeddings
30 |
31 | def _Qwen2_embedding(self, texts: list):
32 | embeddings = self.model.encode(texts)
33 | return embeddings
34 |
35 | def _openai_embedding(self, texts: list, model_name):
36 | client = OpenAI(
37 | base_url=self.base_url,
38 | api_key=self.api_key,
39 | http_client=httpx.Client(
40 | base_url=self.base_url,
41 | follow_redirects=True,
42 | ),
43 | )
44 |
45 | response = client.embeddings.create(
46 | input=texts,
47 | model=model_name,
48 | )
49 | embeddings = [data.embedding for data in response.data]
50 | return embeddings
51 |
52 | def embed(self, texts, model_type, model_name=None, batch_size=8):
53 | if model_type in ["BGE-base-zh", "Qwen2-1.5B"]:
54 | self.load_model(model_type)
55 | embeddings = []
56 | for i in tqdm(range(0, len(texts), batch_size)):
57 | batch = texts[i : i + batch_size]
58 | embeddings.extend(self.model.encode(batch))
59 | return np.array(embeddings)
60 | elif model_type == "openai":
61 | embeddings = []
62 | for i in tqdm(range(0, len(texts), batch_size)):
63 | batch = texts[i : i + batch_size]
64 | batch_embeddings = [self._openai_embedding([text], model_name) for text in batch]
65 | embeddings.extend(batch_embeddings)
66 | return np.array(embeddings)
67 |
68 | else:
69 | raise ValueError(f"Unsupported model type: {model_type}")
70 |
71 | @staticmethod
72 | def save_faiss(embeddings, faiss_type, save_path="index.faiss"):
73 | dim = embeddings.shape[1]
74 |
75 | if faiss_type == "FlatIP":
76 | index = faiss.IndexFlatIP(dim)
77 | elif faiss_type == "HNSW":
78 | index = faiss.IndexHNSWFlat(dim, 64)
79 | elif faiss_type == "IVF":
80 | nlist = min(128, int(np.sqrt(len(embeddings))))
81 | quantizer = faiss.IndexFlatIP(dim)
82 | index = faiss.IndexIVFFlat(quantizer, dim, nlist)
83 | index.train(embeddings.astype('float32'))
84 | index.nprobe = min(8, nlist//4)
85 | else:
86 | raise ValueError(f"Unsupported FAISS type: {faiss_type}")
87 |
88 | index.add(embeddings.astype('float32'))
89 | faiss.write_index(index, save_path)
90 |
--------------------------------------------------------------------------------
/src/retrieval/lexical_matching.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import jieba
4 | import langid
5 | import bm25s
6 | import numpy as np
7 | from tqdm import tqdm
8 | from pyserini.search.lucene import LuceneSearcher
9 | from bm25s.tokenization import Tokenized
10 | import math
11 | import subprocess
12 |
13 | def judge_zh(text: str) -> bool:
14 | return langid.classify(text)[0] == 'zh'
15 |
16 | class LexicalRetriever:
17 | def __init__(self, bm25_backend=None):
18 | self.bm25_backend = bm25_backend
19 | self.searcher = None
20 |
21 | def _bm25s_tokenize(
22 | self,
23 | texts,
24 | return_ids: bool = True,
25 | show_progress: bool = True,
26 | leave: bool = False,
27 | ):
28 | if isinstance(texts, str):
29 | texts = [texts]
30 |
31 | corpus_ids = []
32 | token_to_index = {}
33 |
34 | for text in tqdm(
35 | texts, desc="Split strings", leave=leave, disable=not show_progress
36 | ):
37 |
38 | splitted = jieba.lcut(text)
39 | doc_ids = []
40 |
41 | for token in splitted:
42 | if token not in token_to_index:
43 | token_to_index[token] = len(token_to_index)
44 |
45 | token_id = token_to_index[token]
46 | doc_ids.append(token_id)
47 |
48 | corpus_ids.append(doc_ids)
49 |
50 | unique_tokens = list(token_to_index.keys())
51 | vocab_dict = token_to_index
52 |
53 | if return_ids:
54 | return Tokenized(ids=corpus_ids, vocab=vocab_dict)
55 |
56 | else:
57 | reverse_dict = unique_tokens
58 | for i, token_ids in enumerate(
59 | tqdm(
60 | corpus_ids,
61 | desc="Reconstructing token strings",
62 | leave=leave,
63 | disable=not show_progress,
64 | )
65 | ):
66 | corpus_ids[i] = [reverse_dict[token_id] for token_id in token_ids]
67 |
68 | return corpus_ids
69 |
70 | def _bm25s_search(self, corpus, query_list, k=10):
71 | bm25s.tokenize = self._bm25s_tokenize
72 | coupus_token = bm25s.tokenize(corpus)
73 | retriever = bm25s.BM25()
74 | retriever.index(coupus_token)
75 |
76 | query_token_list = [bm25s.tokenize(query) for query in query_list]
77 | scores = []
78 | result_idx_list = []
79 | for query_token in query_token_list:
80 | result, score = retriever.retrieve(query_token, k=k)
81 | scores.append(score.tolist())
82 | result_idx_list.append(result.tolist())
83 | print(np.array(result_idx_list).shape, np.array(scores).shape)
84 | return result_idx_list, scores
85 |
86 | def _build_pyserini_index(self, corpus_path, folder_path, index_dir):
87 | temp_path = "data/law_library/temp.jsonl"
88 |
89 | args = [
90 | "-collection", "JsonCollection",
91 | "-input", folder_path,
92 | "-index", index_dir,
93 | "-generator", "DefaultLuceneDocumentGenerator",
94 | "-threads", "1",
95 | ]
96 |
97 | # Detecting Chinese
98 | with open(temp_path) as f:
99 | sample_text = json.loads(next(f))['contents']
100 | lang = 'zh' if judge_zh(sample_text) else 'en'
101 | if lang == 'zh':
102 | args += ["-language", "zh"]
103 |
104 | os.makedirs(index_dir, exist_ok=True)
105 | subprocess.run(["python", "-m", "pyserini.index.lucene"] + args)
106 | self.searcher = LuceneSearcher(index_dir)
107 | if lang == 'zh':
108 | self.searcher.set_language('zh')
109 |
110 | def _bm25_search(self, queries, k=10):
111 | results = []
112 | scores = []
113 | for query in tqdm(queries, desc="Pyserini BM25 Searching"):
114 | hits = self.searcher.search(query, k=k)
115 | results.append([hit.docid for hit in hits])
116 | scores.append([hit.score for hit in hits])
117 | return results, scores
118 |
119 | def _qld_search(self, queries, k=10, index_dir="data/pyserini_index"):
120 | results = []
121 | scores = []
122 | self.searcher.set_qld()
123 | for query in tqdm(queries, desc="QLD Searching"):
124 | hits = self.searcher.search(query, k=k)
125 | results.append([hit.docid for hit in hits])
126 | scores.append([hit.score for hit in hits])
127 | return results, scores
128 |
129 | def search(self, corpus, law_path, queries, k=10, method="bm25"):
130 | output_directory = "data/law_library"
131 | os.makedirs(output_directory, exist_ok=True)
132 | temp_file_path = os.path.join(output_directory, 'temp.jsonl')
133 | with open(law_path, 'r', encoding='utf-8') as file:
134 | lines = file.readlines()
135 | with open(temp_file_path, 'w', encoding='utf-8') as temp_file:
136 | for line in lines:
137 | data = json.loads(line)
138 | if 'content' in data:
139 | data['contents'] = data.pop('content')
140 | temp_file.write(json.dumps(data, ensure_ascii=False) + '\n')
141 |
142 | folder_path = output_directory
143 | if method == "bm25":
144 | if self.bm25_backend == "bm25s":
145 | return self._bm25s_search(corpus, queries, k)
146 | elif self.bm25_backend == "pyserini":
147 | self._build_pyserini_index(law_path, folder_path, "data/retrieval/pyserini_index")
148 | return self._bm25_search(queries, k)
149 | elif method == "qld":
150 | self._build_pyserini_index(law_path, folder_path, "data/retrieval/qld_index")
151 | return self._qld_search(queries, k)
152 | else:
153 | raise ValueError(f"Unsupported method: {method}")
--------------------------------------------------------------------------------
/src/retrieval/run_retrieval.py:
--------------------------------------------------------------------------------
1 | import faiss
2 | import json
3 | from tqdm import tqdm
4 | import numpy as np
5 | from pathlib import Path
6 | import os
7 | from retrieval.dense_retriever import DenseRetriever
8 | from retrieval.lexical_matching import LexicalRetriever
9 |
10 | class Pipeline:
11 | def __init__(self, config=None):
12 | self.openai_config = config or {}
13 | self.init_dir()
14 |
15 | def run_retriever(self, model_type, question_file_path, law_path,
16 | bm25_backend="bm25s", faiss_type="FlatIP", model_name=None):
17 | if model_type == "bm25":
18 | self.pipeline_bm25(question_file_path, law_path, bm25_backend)
19 | elif model_type == "qld":
20 | self.pipeline_qld(question_file_path, law_path)
21 | else:
22 | self.pipeline_law(law_path, model_type, faiss_type, model_name)
23 | self.pipeline_question(question_file_path, model_type, model_name)
24 | self.pipeline_search(question_file_path, law_path, model_type, faiss_type)
25 |
26 | def pipeline_bm25(self, question_path, law_path, backend):
27 | res_path = f"data/retrieval/res/retrieval_bm25_{backend}.jsonl"
28 |
29 | with open(question_path, "r", encoding="utf-8") as f:
30 | data = [json.loads(line) for line in f]
31 | with open(law_path, "r", encoding="utf-8") as f:
32 | laws = [json.loads(line) for line in f]
33 | corpus = [law["name"] + law["content"] for law in laws]
34 |
35 | retriever = LexicalRetriever(bm25_backend=backend)
36 | queries = [conv["question"]["content"] for d in data for conv in d["conversation"]]
37 |
38 | if backend == "bm25s":
39 | result_idx_list, scores = retriever.search(corpus, law_path, queries, k=10)
40 | idx = 0
41 | for d in data:
42 | for conv in d["conversation"]:
43 | tmp_laws = []
44 | for result_idx, score in zip(result_idx_list[idx][0], scores[idx][0]):
45 | tmp_laws.append({
46 | "article": laws[result_idx],
47 | "score": float(score)
48 | })
49 | conv["question"]["recall"] = tmp_laws
50 | idx += 1
51 |
52 | elif backend == "pyserini":
53 | results, scores = retriever.search(corpus, law_path, queries, k=10)
54 | idx = 0
55 | for d in data:
56 | for conv in d["conversation"]:
57 | tmp_laws = []
58 | for doc_id, score in zip(results[idx], scores[idx]):
59 | tmp_laws.append({
60 | "article": laws[int(doc_id)],
61 | "score": float(score)
62 | })
63 | conv["question"]["recall"] = tmp_laws
64 | idx += 1
65 |
66 | with open(res_path, "w", encoding="utf-8") as f:
67 | for d in data:
68 | f.write(json.dumps(d, ensure_ascii=False) + "\n")
69 |
70 | def pipeline_qld(self, question_path, law_path):
71 | res_path = "data/retrieval/res/retrieval_qld.jsonl"
72 |
73 | with open(question_path, "r", encoding="utf-8") as f:
74 | data = [json.loads(line) for line in f]
75 | with open(law_path, "r", encoding="utf-8") as f:
76 | laws = [json.loads(line) for line in f]
77 | corpus = [law["name"] + law["content"] for law in laws]
78 |
79 | retriever = LexicalRetriever()
80 | queries = [conv["question"]["content"] for d in data for conv in d["conversation"]]
81 |
82 | results, scores = retriever.search(corpus, law_path, queries, k=10, method="qld")
83 | idx = 0
84 | for d in data:
85 | for conv in d["conversation"]:
86 | tmp_laws = []
87 | for doc_id, score in zip(results[idx], scores[idx]):
88 | tmp_laws.append({
89 | "article": laws[int(doc_id)],
90 | "score": float(score)
91 | })
92 | conv["question"]["recall"] = tmp_laws
93 | idx += 1
94 |
95 | with open(res_path, "w", encoding="utf-8") as f:
96 | for d in data:
97 | f.write(json.dumps(d, ensure_ascii=False) + "\n")
98 |
99 | def pipeline_law(self, law_path, model_type, faiss_type, model_name):
100 | law_index_path = f"data/retrieval/law_index_{model_type}.faiss"
101 | if os.path.exists(law_index_path):
102 | return
103 |
104 | with open(law_path) as f:
105 | laws = [json.loads(line)["name"]+json.loads(line)["content"] for line in f]
106 |
107 | emb_model = DenseRetriever(**self.openai_config)
108 | embeddings = emb_model.embed(laws, model_type, model_name)
109 | emb_model.save_faiss(embeddings, faiss_type, law_index_path)
110 |
111 | def pipeline_question(self, question_path, model_type, model_name):
112 | question_emb_path = f"data/retrieval/npy/retrieval_{model_type}.npy"
113 | if os.path.exists(question_emb_path):
114 | return
115 |
116 | with open(question_path) as f:
117 | data = [json.loads(line) for line in f]
118 | questions = [q["question"]["content"] for d in data for q in d["conversation"]]
119 |
120 | emb_model = DenseRetriever(**self.openai_config)
121 | embeddings = emb_model.embed(questions, model_type, model_name)
122 | np.save(question_emb_path, embeddings)
123 |
124 | def pipeline_search(self, question_path, law_path, model_type, faiss_type):
125 | res_path = f"data/retrieval/res/retrieval_{model_type}.jsonl"
126 | law_index_path = f"data/retrieval/law_index_{model_type}.faiss"
127 | question_emb_path = f"data/retrieval/npy/retrieval_{model_type}.npy"
128 |
129 | index = faiss.read_index(law_index_path)
130 | question_embeds = np.load(question_emb_path)
131 | D, I = index.search(question_embeds.astype('float32'), 10)
132 |
133 | with open(law_path) as f:
134 | laws = [json.loads(line) for line in f]
135 |
136 | with open(question_path) as f:
137 | data = [json.loads(line) for line in f]
138 |
139 | self.incorporate_dense_results(data, laws, D, I, res_path)
140 |
141 | def incorporate_dense_results(self, data, laws, D, I, res_path):
142 | idx = 0
143 | for d in data:
144 | for conv in d["conversation"]:
145 | tmp_laws = []
146 | for i in range(len(I[idx])):
147 | tmp_laws.append({
148 | "article": laws[I[idx][i]],
149 | "score": float(D[idx][i])
150 | })
151 | conv["question"]["recall"] = tmp_laws
152 | idx += 1
153 |
154 | with open(res_path, "w") as f:
155 | for d in data:
156 | f.write(json.dumps(d, ensure_ascii=False) + "\n")
157 |
158 | def init_dir(self):
159 | os.makedirs("data/retrieval/res", exist_ok=True)
160 | os.makedirs("data/retrieval/npy", exist_ok=True)
161 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | from config.config import Config
2 | from generate.generator import OpenAIGenerator, ZhipuGenerator, VLLMGenerator
3 | from process.rewriter import Rewriter
4 | from eval.evaluator import GenerationEvaluator, LLMJudge, RetrievalEvaluator
5 | from process.rewriter import Rewriter
6 | from process.processor import QuestionGenerator
7 |
8 | def create_generator(model_type, config, **kwargs):
9 | from generate.prompt_builder import LegalPromptBuilder
10 | from generate.generator import (
11 | OpenAIGenerator,
12 | ZhipuGenerator,
13 | VLLMGenerator,
14 | HuggingFaceGenerator,
15 | LocalGenerator
16 | )
17 |
18 | prompt_builder = kwargs.pop("prompt_builder", LegalPromptBuilder())
19 | common_params = {
20 | "config": config,
21 | "prompt_builder": prompt_builder,
22 | "max_retries": kwargs.get("max_retries"),
23 | "max_parallel": kwargs.get("max_parallel"),
24 | "top_n": kwargs.get("top_n"),
25 | "batch_size": kwargs.get("batch_size")
26 | }
27 |
28 | if model_type in ["openai", "qwen", "llama"]:
29 | return OpenAIGenerator(**common_params)
30 | elif model_type == "zhipu":
31 | return ZhipuGenerator(**common_params)
32 | elif model_type == "vllm":
33 | return VLLMGenerator(**common_params)
34 | elif model_type == "huggingface":
35 | return HuggingFaceGenerator(**common_params)
36 | elif model_type == "local":
37 | return LocalGenerator(**common_params)
38 | elif "model_path" in config: # Customised model paths
39 | if "huggingface.co" in config["model_path"]:
40 | return HuggingFaceGenerator(**common_params)
41 | else:
42 | return LocalGenerator(**common_params)
43 | else:
44 | raise ValueError(f"Unsupported model type: {model_type}")
45 |
46 | def create_processor(process_type: str, **kwargs):
47 | if process_type == "rewrite_question":
48 | required_params = ['config', 'max_retries', 'max_parallel', 'batch_size']
49 | for param in required_params:
50 | if param not in kwargs:
51 | raise ValueError(f"Missing required parameter: {param}")
52 | return Rewriter(
53 | config=kwargs['config'],
54 | max_retries=kwargs['max_retries'],
55 | max_parallel=kwargs['max_parallel'],
56 | batch_size=kwargs['batch_size']
57 | )
58 | elif process_type in ["current_question", "prefix_question", "prefix_question_answer", "suffix_question"]:
59 | return QuestionGenerator(process_type, **kwargs)
60 | else:
61 | raise ValueError(f"Unsupported processor type: {process_type}")
62 |
63 | def create_evaluator(eval_type: str, config):
64 | if eval_type == "generation":
65 | return GenerationEvaluator(config)
66 | elif eval_type == "llm_judge":
67 | return LLMJudge(config)
68 | elif eval_type == "retrieval":
69 | return RetrievalEvaluator(config)
70 | raise ValueError(f"Unsupported evaluator type: {eval_type}")
71 |
72 | def create_retriever(config=None):
73 | from retrieval.run_retrieval import Pipeline as RetrievalPipeline
74 | return RetrievalPipeline(config=config)
--------------------------------------------------------------------------------