├── LICENSE
├── README.md
├── assets
├── main_results.jpg
├── pg_gen.jpg
└── pgr.jpg
├── data
└── CRUD_RAG
│ ├── QAGeneration_gpt-4-1106-preview_1doc.json
│ ├── QAGeneration_gpt-4-1106-preview_2docs.json
│ └── QAGeneration_gpt-4-1106-preview_3docs.json
├── pgrag
├── configs
│ ├── __pycache__
│ │ └── real_config.cpython-39.pyc
│ └── real_config.py
├── data
│ └── eval
│ │ └── eval_data_with_qe_and_qdse_example.json
├── llms
│ ├── __pycache__
│ │ ├── base.cpython-39.pyc
│ │ ├── local.cpython-39.pyc
│ │ ├── qwen13b.cpython-39.pyc
│ │ ├── qwen14b.cpython-39.pyc
│ │ ├── remote.cpython-311.pyc
│ │ ├── remote.cpython-38.pyc
│ │ └── remote.cpython-39.pyc
│ ├── base.py
│ └── remote.py
├── mindmap_generator.py
├── prompts
│ ├── extract_fact_verification_items.txt
│ ├── gen_mindmap.txt
│ ├── query_deconstruction.txt
│ └── topic_extract.txt
├── pseudo_graph_constructor.py
├── seed_context_recall.py
└── sub_pseudo_graph_retriever.py
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | Waiting for updates
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | PG-RAG: Empowering Large Language Models to Set up a Knowledge Retrieval Indexer via Self-Learning
3 |
4 |
5 | 
6 |
7 | ## Introduction
8 |
9 | PG-RAG proposes a pre-retrieval augmented generation method that introduces a _refinement_ step before the _indexing-retrieval-generation_ process, ensuring the accuracy of retrieved content from the outset. We leverage the self-learning capabilities of LLMs to transform documents into easily understandable and retrievable hierarchical indexes. This process naturally filters out noise and enhances information readability. By establishing connections between similar or complementary pieces of knowledge, we enable the retriever to function across multiple documents. During the knowledge retrieval phase, we use _pseudo-answers_ to assist the retriever in locating relevant information and perform walking in the matrices, thereby achieving accurate and rapid knowledge localization. Finally, we assemble the retrieved fact paths into a structured context, providing rich background information for LLMs to generate knowledge-grounded responses.
10 |
11 | 
12 |
13 | Supported models of PG-RAG framework
14 |
15 | | Model Type | Loading Method | Example Models | References |
16 | |------------|------------------------------|------------------------------------|------------------------------------------------------------------------------------------------------------------------|
17 | | `api` | `requests` | OpenAI models | [OpenAI API](https://platform.openai.com/docs/introduction)|
18 | | `local` | HuggingFace `Sentence transformers` | `BAAI/bge-large-zh-v1.5` | [HuggingFace `Sentence transformers`](https://sbert.net/) |
19 | | `remote` | `requests` | Internal use only | Not available to the public |
20 |
21 |
22 |
23 | A specific case of main topic and fact-checking items.
24 |
25 | ```json
26 | {
27 | "Main Topic": "Announcement of the results of national medical device supervision sampling",
28 | "Fact-Checking Items": {
29 | "Date and Time": "2023-07-28",
30 | "Issuing Organization": "National Medical Products Administration",
31 | "Types of Products Sampled": "Dental low-voltage electric motors, various medical patches (far infrared therapy patches, magnetic therapy patches, acupoint magnetic therapy patches), among other five types",
32 | "Sampling Results": "A total of 12 batches (units) of products did not meet the standard requirements",
33 | "Specific Non-compliant Products and Issues": {
34 | "Dental Low-Voltage Electric Motor (1 unit)": "Produced by Guangdong Jingmei Medical Technology Co., Ltd., with issues related to leakage current and patient auxiliary current (under working temperature), and no-load speed not meeting the standard requirements.",
35 | "Vertical Pressure Steam Sterilizer (1 unit)": "Produced by Hefei Huatai Medical Equipment Co., Ltd., involving 'permissible limit values of accessible parts under normal conditions' and limit values under single fault condition (ground fault) not meeting the standard requirements.",
36 | "Electric Suction Device (1 unit)": "Produced by Suzhou Bein Technology Co., Ltd., involving 'network-powered, movable high negative pressure/high flow equipment' not meeting the standard requirements.",
37 | "Patch-type Medical Devices (far infrared therapy patch, magnetic therapy patch, acupoint magnetic therapy patch) 6 batches": "Produced by Jiujiang Gaoke Pharmaceutical Technology Co., Ltd., Zhengzhou Zhongyuan Fuli Industrial & Trade Co., Ltd., Ulanqab Qiao's Weiye Medical Device Co., Ltd., Hunan Dexi Medical Technology Co., Ltd., and Chongqing Zhengren Medical Device Co., Ltd., with issues involving detection of 'pharmaceutical ingredients that should not be detected according to supplementary testing methods.'",
38 | "Human Blood and Blood Component Plastic Bag Containers (blood bags) 3 batches": "Produced by Nanjing Sailjin Biomedical Co., Ltd., with issues involving non-compliance of the blood bag transfusion ports with the standards."
39 | }
40 | }
41 | }
42 | ```
43 |
44 |
45 |
46 | A specific case of the mind map generated through main topic and fact-checking items.
47 |
48 | ```json
49 | {
50 | "Announcement of the results of national medical device supervision sampling": {
51 | "Announcement Date and Time": "July 28, 2023",
52 | "Sampled Items": [
53 | "Dental Low-Voltage Electric Motors",
54 | "Patch-Type Medical Devices (including Far Infrared Therapy Patches, Magnetic Therapy Patches, Acupoint Magnetic Therapy Patches)"
55 | ],
56 | "Total Number of Sampled Products": 12,
57 | "Number of Product Types Not Meeting Standard Requirements": 5,
58 | "Non-Compliant Medical Devices and Manufacturer Information": {
59 | "Dental Low-Voltage Electric Motor": {
60 | "Quantity": 1,
61 | "Manufacturer": "Guangdong Jingmei Medical Technology Co., Ltd.",
62 | "Issue Description": "Involving leakage current and patient auxiliary current, no-load speed not meeting the standard requirements"
63 | },
64 | "Vertical Pressure Steam Sterilizer": {
65 | "Quantity": 1,
66 | "Manufacturer": "Hefei Huatai Medical Equipment Co., Ltd.",
67 | "Issue Description": "Involving the allowable limit values of accessible parts under normal conditions and limit values under single fault condition (ground fault) not meeting the standard requirements"
68 | },
69 | "Electric Suction Device": {
70 | "Quantity": 1,
71 | "Manufacturer": "Suzhou Bein Technology Co., Ltd.",
72 | "Issue Description": "Involving network-powered, movable high negative pressure/high flow equipment not meeting the standard requirements"
73 | },
74 | "Patch-Type Medical Devices": {
75 | "Sub-Types": [
76 | "Far Infrared Therapy Patch",
77 | "Magnetic Therapy Patch",
78 | "Acupoint Magnetic Therapy Patch"
79 | ],
80 | "Batch Quantity": 6,
81 | "List of Manufacturers": [
82 | "Jiujiang Gaoke Pharmaceutical Technology Co., Ltd.",
83 | "Zhengzhou Zhongyuan Fuli Industrial & Trade Co., Ltd.",
84 | "Ulangab Qiao's Weiye Medical Device Co., Ltd.",
85 | "Hunan Dexi Medical Technology Co., Ltd.",
86 | "Chongqing Zhengren Medical Device Co., Ltd."
87 | ],
88 | "Issue Description": "Involving detection of pharmaceutical ingredients that should not be detected according to supplementary testing methods"
89 | },
90 | "Human Blood and Blood Component Plastic Bag Containers (Blood Bags)": {
91 | "Quantity": 3,
92 | "Manufacturer": "Nanjing Sailjin Biomedical Co., Ltd.",
93 | "Issue Description": "Involving blood bag transfusion ports not meeting the standard requirements"
94 | }
95 | }
96 | }
97 | }
98 |
99 | ```
100 |
101 |
102 |
103 | A specific case of fact path in PG.
104 |
105 | "Announcement of the results of national medical device supervision sampling"> "Non-Compliant Medical Devices and Manufacturer Information"> "Dental Low-Voltage Electric Motor"> "Manufacturer"> "Guangdong Jingmei Medical Technology Co., Ltd."
106 |
107 |
108 |
109 | Project structure
110 |
111 | ```bash
112 | .
113 | ├── .github
114 | ├── .gitignore
115 | ├── LICENSE
116 | ├── README.md
117 | ├── assets # Static files like images used in documentation
118 | ├── data # Datasets (e.g., 1-Document QA)
119 | ├── output # Stores the contexts of querys
120 | ├── requirements.txt
121 | └── pgrag # Source code for the project
122 | ├── configs # Scripts for initializing model loading parameters
123 | ├── data # Intermediate result data
124 | ├── eval ── eval_data_with_qe_and_qdse_example.json # Evaluate dataset samples
125 | ├── raw_news # The original documents required for the retrieval library building
126 | ├── pg_gen # Data during the pseudo-graph construction process
127 | └── context_recall # Data during the pseudo-graph retrieval process
128 | ├── mindmap_generator.py # Scripts for generating mind maps
129 | ├── pseudo_graph_constructor.py # Scripts for pseudo-graph construction
130 | ├── seed_context_recall.py # Scripts for seed contexts recall
131 | ├── sub_pseudo_graph_retriever.py # Scripts for structured contexts recall
132 | ├── llm # Calling LLM methods
133 | └── prompts # Prompt Engineering
134 | ```
135 |
136 |
137 |
138 | ## Installation
139 |
140 | Before using PG-RAG:
141 |
142 | 1. Ensure you have Python 3.9.0+
143 | 2. Install the required packages:
144 |
145 | ```bash
146 | pip install -r requirements.txt
147 | ```
148 |
149 | 3. Prerequisites
150 |
151 | - **JDK 17**: Ensure you have JDK 17 installed on your machine.
152 | - **Neo4j**: Download and install Neo4j 5.17.0. Start the Neo4j console using the following command:
153 |
154 | ```bash
155 | neo4j console
156 | ```
157 |
158 | ## PG-RAG Pipeline
159 |
160 | This repository contains the implementation of the PG-RAG (Pseudo-Graph Retrieval Augmented Generation) pipeline. The pipeline consists of four main stages: mind map generation, pseudo-graph construction, seed context recall, and final context extension and generation.
161 |
162 | ### Common Setup
163 |
164 | Before running the individual scripts, ensure the following configuration:
165 |
166 | ```python
167 | graph_uri = "bolt://localhost:7687"
168 | graph_auth = ("neo4j", "password")
169 | emb_model_name = "/path/to/your/local/bge-base-zh-1.5"
170 | num_threads = 20
171 | topK = 8
172 |
173 | # Parameters for Pseudo-Graph Construction
174 | model_name = 'gpt-3.5-turbo' # Model name can be: gpt-3.5-turbo, gpt-4-0613, gpt-4-1106-preview
175 | raw_news_files_dir = 'data/raw_news/batch0'
176 | title_files_dir = 'data/pg_gen/batch0/title'
177 | fcis_files_dir = "data/pg_gen/batch0/textToVerificationText/"
178 | mindmaps_str_files_dir = "data/pg_gen/batch0/mindmap_str/"
179 | mindmaps_json_files_dir = "data/pg_gen/batch0/mindmap_json/"
180 |
181 | # Parameters for Knowledge Recall via Pseudo-Graph Retrieval
182 | eval_data_with_qe_and_qdse_file = 'data/eval/eval_data_with_qe_and_qdse.json'
183 | seed_topic_file = 'data/context_recall/pgrag/seed_topics.json'
184 | candidate_topic_file = 'data/context_recall/pgrag/candidate_topics.json'
185 | matrix_templates_file = 'data/context_recall/pgrag/matrix_templates.json'
186 | matrix_templates_with_sim_file = 'data/context_recall/pgrag/matrix_templates_with_sim.json'
187 | contexts_ids_file = 'data/context_recall/pgrag/contexts_ids.json'
188 | final_contexts_file = 'data/context_recall/pgrag/final_contexts.json'
189 |
190 | recall_top_m = 3
191 | walk_top_m = 6
192 | ```
193 |
194 | Ensure Neo4j database is running and accessible at the specified `graph_uri`. Adjust file paths and model names as per your environment setup. Modify the parameters as needed to fine-tune the performance and results.
195 |
196 | ## Steps to Execute
197 |
198 | 1. **Generate Mind Maps**: Generate mind maps for original texts using `pgrag/mindmap_generator.py`.
199 | 2. **Build Pseudo-Graph**: Construct a pseudo-graph through `pgrag/pseudo_graph_constructor.py`.
200 | 3. **Seed Context Recall**: Recall seed topics using `pgrag/seed_context_recall.py`.
201 | 4. **Final Context Extensions**: Perform final context extensions and context generation using `pgrag/sub_pseudo_graph_retriever.py`.
202 |
203 | ### 1. Generate Mind Maps
204 |
205 | To generate mind maps from the original texts:
206 |
207 | ```python
208 | # pgrag/mindmap_generator.py
209 |
210 | mindmap_generation = MindmapGeneration(model_name, num_threads, raw_news_files_dir, title_files_dir, fcis_files_dir, mindmaps_str_files_dir, mindmaps_json_files_dir)
211 | mindmap_generation.execute()
212 | ```
213 |
214 | ### 2. Build Pseudo-Graph
215 |
216 | To build the pseudo-graph:
217 |
218 | ```python
219 | # pgrag/pseudo_graph_constructor.py
220 |
221 | inserter = Neo4jDataInserter(graph_uri, graph_auth, emb_model_name, num_threads)
222 | inserter.execute(raw_news_files_dir, mindmaps_json_files_dir)
223 | ```
224 |
225 | To ensure efficient querying and retrieval of topic embeddings and fact embeddings, you need to manually create the following vector indexes in Neo4j:
226 |
227 | ```cypher
228 | CREATE VECTOR INDEX topic-embeddings IF NOT EXISTS
229 | FOR (m:Topic)
230 | ON m.主题嵌入
231 | OPTIONS {indexConfig: {
232 | `vector.dimensions`: 1024,
233 | `vector.similarity_function`: 'cosine'
234 | }}
235 |
236 | CREATE VECTOR INDEX fact-embeddings IF NOT EXISTS
237 | FOR (m:Content)
238 | ON m.路径嵌入
239 | OPTIONS {indexConfig: {
240 | `vector.dimensions`: 1024,
241 | `vector.similarity_function`: 'cosine'
242 | }}
243 | ```
244 |
245 | These commands will create the necessary indexes for the `Topic` and `Content` nodes, respectively, with a vector dimension of 1024 and using the cosine similarity function for the embeddings.
246 |
247 | Please run the above commands in your Neo4j database before executing the main application to ensure all functionalities work as expected. Then,
248 |
249 | ```python
250 | fusion = TopicAndContentFusion(graph_uri, graph_auth, emb_model_name)
251 | fusion.fuse_topics_and_contents()
252 | ```
253 |
254 | ### 3. Seed Context Recall
255 |
256 | To recall seed contexts:
257 |
258 | ```python
259 | # pgrag/seed_context_recall.py
260 |
261 | seed_context_recall = SeedContextRecall(graph_uri, graph_auth, emb_model_name, eval_data_with_qe_and_qdse_file, seed_topic_file, candidate_topic_file, recall_top_m, walk_top_m, num_threads, topK)
262 | seed_context_recall.execute()
263 | ```
264 |
265 | ### 4. Final Context Extensions
266 |
267 | To perform final context extensions and context generation:
268 |
269 | ```python
270 | # pgrag/sub_pseudo_graph_retriever.py
271 |
272 | processor = PG_RAG_Processor(graph_uri, graph_auth, candidate_topic_file, matrix_templates_file, matrix_templates_with_sim_file, topK)
273 | processor.create_matrix_templates()
274 | processor.compute_similarity_matrices()
275 | processor.process_top_k_ids(contexts_ids_file, final_contexts_file)
276 | ```
277 |
278 | ## Results for Experiment
279 |
280 | 
281 |
--------------------------------------------------------------------------------
/assets/main_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/assets/main_results.jpg
--------------------------------------------------------------------------------
/assets/pg_gen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/assets/pg_gen.jpg
--------------------------------------------------------------------------------
/assets/pgr.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/assets/pgr.jpg
--------------------------------------------------------------------------------
/pgrag/configs/__pycache__/real_config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/configs/__pycache__/real_config.cpython-39.pyc
--------------------------------------------------------------------------------
/pgrag/configs/real_config.py:
--------------------------------------------------------------------------------
1 | GPT_transit_token = '' # openai.api_key
2 | GPT_transit_url = '' # openai.base_url
3 |
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/base.cpython-39.pyc
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/local.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/local.cpython-39.pyc
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/qwen13b.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/qwen13b.cpython-39.pyc
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/qwen14b.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/qwen14b.cpython-39.pyc
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/remote.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/remote.cpython-311.pyc
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/remote.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/remote.cpython-38.pyc
--------------------------------------------------------------------------------
/pgrag/llms/__pycache__/remote.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/remote.cpython-39.pyc
--------------------------------------------------------------------------------
/pgrag/llms/base.py:
--------------------------------------------------------------------------------
1 | import jieba
2 | from text2vec import Similarity
3 | import evaluate
4 | import shutil
5 | import os
6 | import copy
7 | import json
8 | from abc import ABC, abstractmethod
9 |
10 | from loguru import logger
11 |
12 |
13 | class BaseLLM(ABC):
14 | def __init__(
15 | self,
16 | model_name: str = None,
17 | temperature: float = 1.0,
18 | max_new_tokens: int = 4096,
19 | top_p: float = 0.9,
20 | top_k: int = 5,
21 | **more_params
22 | ):
23 | self.params = {
24 | 'model_name': model_name if model_name else self.__class__.__name__,
25 | 'temperature': temperature,
26 | 'max_new_tokens': max_new_tokens,
27 | 'top_p': top_p,
28 | 'top_k': top_k,
29 | **more_params
30 | }
31 | self.post_init()
32 |
33 | def post_init(self):
34 | """Post initialization method for subclasses.
35 | Normally, this method should initialize the model and tokenizer.
36 | """
37 | ...
38 |
39 | def update_params(self, inplace: bool = True, **params):
40 | if inplace:
41 | self.params.update(params)
42 | return self
43 | else:
44 | new_obj = copy.deepcopy(self)
45 | new_obj.params.update(params)
46 | return new_obj
47 |
48 | @abstractmethod
49 | def request(self, query:str) -> str:
50 | return ''
51 |
52 | def safe_request(self, query: str) -> str:
53 | """Safely make a request to the language model, handling exceptions."""
54 | try:
55 | response = self.request(query)
56 | except Exception as e:
57 | logger.warning(repr(e))
58 | response = ''
59 | return response
60 |
61 | def safe_request_with_prompt(self, text, prompt_file) -> str:
62 | template = self._read_prompt_template(prompt_file)
63 | query = template.format(text=text)
64 | query = query.replace('{{', '{').replace('}}', '}')
65 | respond = self.safe_request(query)
66 | return respond
67 |
68 | def extract_fact_verification_items(self, news_body) -> str:
69 | prompt_file = 'extract_fact_verification_items.txt'
70 | respond = self.safe_request_with_prompt(news_body, prompt_file)
71 | return respond
72 |
73 | def extract_title(self, text) -> str:
74 | prompt_file = 'topic_extract.txt'
75 | respond = self.safe_request_with_prompt(text, prompt_file)
76 | return respond
77 |
78 | def gen_mindmap(self, title, text) -> str:
79 | prompt_file = 'gen_mindmap.txt'
80 | template = self._read_prompt_template(prompt_file)
81 | query = template.format(title=title, text=text)
82 | query = query.replace('{{', '{').replace('}}', '}')
83 | respond = self.safe_request(query)
84 | return respond
85 |
86 | def process_input_output_pair(self, line_no, output_data, output_dir):
87 | os.makedirs(output_dir, exist_ok=True)
88 | filename_output = os.path.join(output_dir, f"{line_no}.txt")
89 |
90 | with open(filename_output, 'w', encoding='UTF-8') as output_file:
91 | output_file.write(output_data)
92 |
93 |
94 | def process_efvi(self, file_path, gpt_instance, output_dir):
95 | line_no = os.path.basename(file_path).replace('.txt', '')
96 | with open(file_path, 'r', encoding='UTF-8') as file:
97 | news_body = file.read()
98 | output_data = gpt_instance.extract_fact_verification_items(news_body)
99 | rougeL_score = self.rougeL_score(output_data, news_body)
100 | print(f"ROUGE-L:'{rougeL_score}")
101 | # 测试 BertScore 分数计算
102 | bert_score = self.bert_score(output_data, news_body)
103 | print(f"bert_score:'{bert_score}")
104 | if rougeL_score >= 0.15 and bert_score >= 0.85:
105 | gpt_instance.process_input_output_pair(line_no, output_data, output_dir)
106 | else:
107 | new_file_path = "data/raw_news/regen/"
108 | os.makedirs(new_file_path, exist_ok=True)
109 | shutil.move(file_path, new_file_path)
110 |
111 | def process_et(self, file_path, gpt_instance, output_dir):
112 | line_no = os.path.basename(file_path).replace('.txt', '')
113 | with open(file_path, 'r', encoding='UTF-8') as file:
114 | news_body = file.read()
115 | output_data = gpt_instance.extract_title(news_body)
116 | gpt_instance.process_input_output_pair(line_no, output_data, output_dir)
117 |
118 |
119 | def process_gm(self, title_files_dir, fcis_files_dir, gpt_instance, output_dir):
120 | line_no = os.path.basename(title_files_dir).replace('.txt', '')
121 | with open(title_files_dir, 'r', encoding='UTF-8') as file:
122 | title = file.read()
123 | ttv_path = os.path.join(fcis_files_dir, f"{line_no}.txt")
124 | with open(ttv_path, 'r', encoding='UTF-8') as file:
125 | text = file.read()
126 | output_data = gpt_instance.gen_mindmap(title, text)
127 | gpt_instance.process_input_output_pair(line_no, output_data, output_dir)
128 |
129 | def mindmap_str_to_json(self, mindmap_str_file_path, mindmap_json_dir):
130 | mindmap_json_file_path = os.path.join(mindmap_json_dir, os.path.basename(mindmap_str_file_path).replace('.txt', '.json'))
131 | with open(mindmap_str_file_path, 'r', encoding='utf-8') as file:
132 | mindmap_str = file.read()
133 | try:
134 | if '```json' in mindmap_str:
135 | real_content = mindmap_str.replace('```json', '').replace('```', '').strip()
136 | mindmap = json.loads(real_content)
137 | else:
138 | # 否则直接使用原始响应字符串
139 | mindmap = json.loads(mindmap_str)
140 | with open(mindmap_json_file_path, 'w', encoding='utf-8') as f:
141 | json.dump(mindmap, f, ensure_ascii=False, indent=4)
142 | except json.JSONDecodeError as e:
143 | print(f'JSON解析错误在文件{mindmap_str_file_path}', e)
144 |
145 |
146 | def query_deconstruction(self, question):
147 | template = self._read_prompt_template('query_deconstruction.txt')
148 | query = template.format(question=question)
149 | respond = self.safe_request(query)
150 | return respond
151 |
152 |
153 | @staticmethod
154 | def _read_prompt_template(filename: str) -> str:
155 | path = os.path.join('prompts/', filename)
156 | if os.path.exists(path):
157 | with open(path, encoding='utf-8') as f:
158 | return f.read()
159 | else:
160 | logger.error(f'Prompt template not found at {path}')
161 | return ''
162 |
163 | def rougeL_score(self,
164 | continuation: str,
165 | reference: str
166 | ) -> float:
167 | f = lambda text: list(jieba.cut(text))
168 | # rouge = evaluate.load('/path/to/local/rouge')
169 | rouge = evaluate.load('rouge')
170 | results = rouge.compute(predictions=[continuation], references=[[reference]], tokenizer=f, rouge_types=['rougeL'])
171 | score = results['rougeL']
172 | return score
173 |
174 | def bert_score(self,
175 | continuation: str,
176 | reference: str
177 | ) -> float:
178 | from text2vec import Similarity
179 | sim = Similarity(model_name_or_path="/path/to/local/text2vec-base-chinese")
180 | score = sim.get_score(continuation, reference)
181 | return score
182 |
--------------------------------------------------------------------------------
/pgrag/llms/remote.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | from llms.base import BaseLLM
4 | from configs import real_config as conf
5 |
6 | class GPT_transit(BaseLLM):
7 | def __init__(self, model_name='gpt-3.5-turbo', temperature=1.0, max_new_tokens=4096, report=False):
8 | super().__init__(model_name, temperature, max_new_tokens)
9 | self.report = report
10 |
11 | def request(self, query: str) -> str:
12 | url = conf.GPT_transit_url
13 | payload = json.dumps({
14 | "model": self.params['model_name'],
15 | "messages": [{"role": "user", "content": query}],
16 | "temperature": self.params['temperature'],
17 | 'max_tokens': self.params['max_new_tokens'],
18 | "top_p": self.params['top_p'],
19 | })
20 | headers = {
21 | 'Authorization': 'Bearer {}'.format(conf.GPT_transit_token),
22 | 'Content-Type': 'application/json',
23 | }
24 | res = requests.request("POST", url, headers=headers, data=payload,timeout=300)
25 | print('res:', res.text)
26 | res = res.json()
27 | real_res = res["choices"][0]["message"]["content"]
28 | return real_res
29 |
--------------------------------------------------------------------------------
/pgrag/mindmap_generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from llms.remote import GPT_transit
4 | import concurrent.futures
5 | from tqdm import tqdm
6 |
7 | class MindmapGeneration:
8 | def __init__(self, model_name, num_threads, raw_news_files_dir, title_files_dir, fcis_files_dir, mindmaps_str_files_dir, mindmaps_json_files_dir):
9 | self.gpt = GPT_transit(model_name=model_name, report=True)
10 | self.num_threads = num_threads
11 | self.raw_news_files_dir = raw_news_files_dir
12 | self.title_files_dir = title_files_dir
13 | self.fcis_files_dir = fcis_files_dir
14 | self.mindmaps_str_files_dir = mindmaps_str_files_dir
15 | self.mindmaps_json_files_dir = mindmaps_json_files_dir
16 |
17 | def extract_mt(self):
18 | raw_news_files_to_process = [os.path.join(self.raw_news_files_dir, file) for file in os.listdir(self.raw_news_files_dir) if file.endswith('.txt')]
19 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
20 | list(tqdm(executor.map(self.gpt.process_et, raw_news_files_to_process, [self.gpt] * len(raw_news_files_to_process), [self.title_files_dir] * len(raw_news_files_to_process)), total=len(raw_news_files_to_process)))
21 |
22 | def extract_fcis(self):
23 | raw_news_files_to_process = [os.path.join(self.raw_news_files_dir, file) for file in os.listdir(self.raw_news_files_dir) if file.endswith('.txt')]
24 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
25 | list(tqdm(executor.map(self.gpt.process_efvi, raw_news_files_to_process, [self.gpt] * len(raw_news_files_to_process), [self.fcis_files_dir] * len(raw_news_files_to_process)), total=len(raw_news_files_to_process)))
26 |
27 | def generate_mindmaps_str(self):
28 | title_files_to_process = [os.path.join(self.title_files_dir, file) for file in os.listdir(self.title_files_dir) if file.endswith('.txt')]
29 | fcis_files_to_process = [os.path.join(self.fcis_files_dir, file) for file in os.listdir(self.fcis_files_dir) if file.endswith('.txt')]
30 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
31 | list(tqdm(executor.map(self.gpt.process_gm, title_files_to_process, fcis_files_to_process, [self.gpt] * len(fcis_files_to_process), [self.mindmaps_str_files_dir] * len(fcis_files_to_process)), total=len(fcis_files_to_process)))
32 |
33 | def generate_mindmaps_json(self):
34 | mindmap_str_files_to_process = [os.path.join(self.mindmaps_str_files_dir, file) for file in os.listdir(self.mindmaps_str_files_dir) if file.endswith('.txt')]
35 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor:
36 | list(tqdm(executor.map(self.gpt.mindmap_str_to_json, mindmap_str_files_to_process, [self.mindmaps_json_files_dir] * len(mindmap_str_files_to_process)), total=len(mindmap_str_files_to_process)))
37 |
38 | def execute(self):
39 | self.extract_mt()
40 | self.extract_fcis()
41 | self.generate_mindmaps_str()
42 | self.generate_mindmaps_json()
43 |
44 |
--------------------------------------------------------------------------------
/pgrag/prompts/extract_fact_verification_items.txt:
--------------------------------------------------------------------------------
1 | 你是一名优秀的自然语言处理专家,请重新整理以下文本,梳理出所有信息项和对应内容,用于文本内容的幻觉纠正。
2 | 注意,确保细节内容不被丢失
3 |
4 | 文本:
5 | '''
6 | {text}
7 | '''
--------------------------------------------------------------------------------
/pgrag/prompts/gen_mindmap.txt:
--------------------------------------------------------------------------------
1 | 你是一名杰出的自然语言处理专家,你的任务是将下面的文本按提供的主题进行结构化,并转换成一个层次分明的思维导图。请确保你的输出严格遵循JSON格式。
2 |
3 | 在创建思维导图时,请遵循以下要求:
4 | (1)保证内容的层次性和逻辑性:确保相似或相关的信息被归纳在同一子主题下。
5 | (2)保证所有层级清晰分明,以帮助理解文本内容的结构。
6 | (3)并列内容应直接以列表形式存在于最后一层。
7 | (4)请直接输出JSON,不要输出任何其他内容。
8 |
9 | 请根据以下主题和文本信息,完成你的任务:
10 |
11 | 文本主题:
12 | '''
13 | {title}
14 | '''
15 |
16 | 文本内容:
17 | '''
18 | {text}
19 | '''
20 |
21 | 请将你的结果以如下格式输出:
22 | ```json
23 | {{
24 | "文本主题": {{
25 | "子主题": {{
26 | "子主题": "内容1",
27 | "子主题": "内容2",
28 | "子主题": "",
29 | "子主题": [
30 | "内容1",
31 | "内容2"
32 | ]
33 | }},
34 | "子主题": {{
35 | // 以此类推
36 | }}
37 | }}
38 | }}
39 | ```
--------------------------------------------------------------------------------
/pgrag/prompts/query_deconstruction.txt:
--------------------------------------------------------------------------------
1 | 问题: {text}
2 | 请将此问题的回答要点:提供一个精简且准确的回答回答这个问题所需的关键信息或具体信息。
3 |
4 | 注意,不要出现幻觉和重复描述。仅输出回答问题的要点(写在之间)即可。
5 |
6 | 示例:
7 |
8 | 干眼症的发生趋势,是否有明显增加或减少。
9 |
10 |
11 | 哪些环境因素可能导致孩子们干眼症症状加重,如电子屏幕使用时间、室内空气质量等。
12 |
--------------------------------------------------------------------------------
/pgrag/prompts/topic_extract.txt:
--------------------------------------------------------------------------------
1 | 你是一名优秀的新闻工作者,请为以下文本生成一个准确、简洁、客观标题。仅输出提取出文本标题即可。
2 |
3 | 文本:
4 | '''
5 | {text}
6 | '''
--------------------------------------------------------------------------------
/pgrag/pseudo_graph_constructor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | from py2neo import Graph
4 | import os
5 | from sentence_transformers import SentenceTransformer
6 | from concurrent.futures import ThreadPoolExecutor
7 | from tqdm import tqdm
8 |
9 | class Neo4jDataInserter:
10 | def __init__(self, graph_uri, graph_auth, emb_model_name, max_workers=20):
11 | self.graph = Graph(graph_uri, auth=graph_auth)
12 | self.emb_model = SentenceTransformer(emb_model_name)
13 | self.max_workers = max_workers
14 | print('初始化成功!')
15 |
16 | def recursive_json_iterator(self, json_data, path='', topic_paths=None):
17 | if topic_paths is None:
18 | topic_paths = []
19 |
20 | if isinstance(json_data, dict):
21 | for key, value in json_data.items():
22 | current_path = f"{path} '{key}'".lstrip()
23 | self.recursive_json_iterator(value, current_path, topic_paths)
24 | elif isinstance(json_data, list):
25 | topic_path = f"{path} '{' '.join(map(str, json_data))}'".lstrip()
26 | topic_paths.append(topic_path)
27 | else:
28 | topic_path = f"{path} '{json_data}'".lstrip()
29 | topic_paths.append(topic_path)
30 | return topic_paths
31 |
32 | def load_json_files(self, mindmap_json_dir, raw_doc_dir):
33 | file_names = [file for file in os.listdir(mindmap_json_dir) if file.endswith('.json')]
34 | all_json_contents = {}
35 | all_news = []
36 | print('总文件数:', len(file_names))
37 |
38 | for file_name in file_names:
39 | mindmap_json_file_path = os.path.join(mindmap_json_dir, file_name)
40 | base_name, _ = os.path.splitext(file_name)
41 | raw_data_path = os.path.join(raw_doc_dir, base_name + ".txt")
42 | with open(raw_data_path, 'r', encoding='utf-8') as file:
43 | news = file.read()
44 | all_news.append(news)
45 | with open(mindmap_json_file_path, 'r', encoding='utf-8') as file:
46 | content = file.read()
47 | try:
48 | parsed_line = json.loads(content)
49 | except json.JSONDecodeError as e:
50 | print(f'JSON解析错误在文件{file_name}', e)
51 | all_json_contents[file_name] = parsed_line
52 |
53 | return all_json_contents, all_news
54 |
55 | def process_and_insert_single_data(self, raw_data, json_data):
56 | title = list(json_data.keys())[0]
57 | emb_t = self.emb_model.encode(title, normalize_embeddings=True)
58 | document_properties = {
59 | "主题": title,
60 | "主题嵌入": emb_t.tolist()
61 | }
62 | label = "Topic"
63 | primary_key = "主题"
64 | topic_paths = self.recursive_json_iterator(json_data.get(title, {}))
65 |
66 | for key, value in document_properties.items():
67 | if value:
68 | update_query = f"MERGE (d:{label} {{{primary_key}: $primary_key_value}}) SET d.{key} = $value RETURN d"
69 | updated_node = self.graph.run(update_query,
70 | parameters={"primary_key_value": document_properties[primary_key],
71 | "value": value}).evaluate()
72 |
73 | topic_name = updated_node.get('主题')
74 | print(f'名为“{topic_name}”的主题节点插入成功!')
75 | for topic_path in topic_paths:
76 | split_parts = topic_path.strip().strip("'").split("' '")
77 | parts = [part.strip() for part in split_parts if part.strip()]
78 | print('-------------------------------')
79 | print('待插入主题路径:', parts)
80 |
81 | for j, sub_topic_type in enumerate(parts[:-1]):
82 | if j == 0:
83 | create_TST_query = "MATCH (d:Topic {主题: $topic_name}) MERGE (d)-[r:基础链接]->(st:SubTopic {路标: $sub_topic_type}) RETURN d, st"
84 | TST_result = self.graph.run(create_TST_query,
85 | parameters={"topic_name": topic_name,
86 | "sub_topic_type": sub_topic_type}).data()
87 | print('主题到子主题插入成功!结果显示:', TST_result)
88 | else:
89 | match_query_parts = [f"-[r{k}:基础链接]->(pst{k}:SubTopic {{路标: $part{k}}})" for k, part in
90 | enumerate(parts[:j], 1)]
91 | match_query = "MATCH (d:Topic {主题: $topic_name}) " + ''.join(match_query_parts)
92 | merge_query = f" WITH pst{j} MERGE (pst{j})-[r:基础链接]->(st:SubTopic {{路标: $sub_topic_type}}) RETURN st"
93 | create_STST_query = match_query + merge_query
94 |
95 | params = {"topic_name": topic_name, "sub_topic_type": sub_topic_type}
96 | for k, part in enumerate(parts[:j], 1):
97 | params[f"part{k}"] = part
98 |
99 | STST_result = self.graph.run(create_STST_query, parameters=params).data()
100 | print('插入的主题路径:', STST_result)
101 | if j == len(parts) - 2:
102 | emb_fp = self.emb_model.encode(parts[0] + ' '.join(parts[1:]), normalize_embeddings=True)
103 | fact = parts[j + 1]
104 | match_query_parts = [f"-[r{k}:基础链接]->(pst{k}:SubTopic {{路标: $part{k}}})" for k, part in
105 | enumerate(parts[:j + 1], 1)]
106 | match_query = f"MATCH (d:Topic {{主题: $topic_name}}) " + ''.join(match_query_parts)
107 | merge_query = f" MERGE (c:Content {{事实: $fact, 路径嵌入: $fp}}) WITH pst{j + 1}, c MERGE (pst{j + 1})-[r:基础链接]->(c) RETURN c"
108 |
109 | create_STC_query = match_query + merge_query
110 |
111 | params = {"topic_name": topic_name, "fact": fact, "fp": emb_fp.tolist()}
112 |
113 | for k, part in enumerate(parts[:j + 1], 1):
114 | params[f"part{k}"] = part
115 |
116 | STC_result = self.graph.run(create_STC_query, parameters=params).data()
117 | print("**完整的路径插入成功!结果显示:", STC_result)
118 |
119 | def chunked_data(self, data, size):
120 | """将列表分成指定大小的块。"""
121 | for i in range(0, len(data), size):
122 | yield data[i:i + size]
123 |
124 | def process_and_insert_data(self, raw_doc_dir, mindmap_json_dir, start_batch=0, batch_size=20):
125 | result, all_news = self.load_json_files(mindmap_json_dir, raw_doc_dir)
126 | all_batches = zip(self.chunked_data(all_news, batch_size), self.chunked_data(list(result.values()), batch_size))
127 | for i, (batch_news, batch_json_data) in enumerate(all_batches):
128 | if i < start_batch: # 跳过已处理的批次
129 | continue
130 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
131 | list(executor.map(self.process_and_insert_single_data, batch_news, batch_json_data))
132 | print(f"已完成第{i + 1}批处理,大小:{batch_size}")
133 |
134 | def update_single_subtopic_embedding(self, subtopic_path):
135 | """
136 | 更新单个子主题节点的路由嵌入。
137 | """
138 | subtopic = subtopic_path['st']
139 | path_names = subtopic_path['path_names']
140 |
141 | path_str = ' '.join(path_names)
142 | emb_p = self.emb_model.encode(path_str, normalize_embeddings=True)
143 | print(path_str)
144 | query_update_embedding = """
145 | MATCH path = (t:Topic)-[:基础链接*]->(st:SubTopic {路标: $subtopic_label})
146 | WHERE [node IN nodes(path) | CASE WHEN node:Topic THEN node.主题 WHEN node:SubTopic THEN node.路标 END] = $path_names
147 | SET st.路由嵌入 = $embedding
148 | RETURN count(st) as updated
149 | """
150 | result = self.graph.run(query_update_embedding, subtopic_label=subtopic['路标'], path_names=path_names,
151 | embedding=emb_p.tolist()).evaluate()
152 |
153 | if result == 0:
154 | print(f"Failed to update routing embedding for path: {path_str}")
155 |
156 | def update_subtopic_embeddings(self):
157 | query_subtopics_paths = """
158 | MATCH path = (t:Topic)-[:基础链接*]->(st:SubTopic)
159 | RETURN st, [node IN nodes(path) | CASE WHEN node:Topic THEN node.主题 WHEN node:SubTopic THEN node.路标 END] AS path_names
160 | """
161 | subtopics_paths = self.graph.run(query_subtopics_paths).data()
162 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
163 | executor.map(self.update_single_subtopic_embedding, subtopics_paths)
164 |
165 | def execute(self, raw_doc_dir, mindmap_json_dir, start_batch=0, batch_size=20):
166 | self.process_and_insert_data(raw_doc_dir, mindmap_json_dir, start_batch, batch_size)
167 | self.update_subtopic_embeddings()
168 |
169 | class TopicAndContentFusion:
170 | def __init__(self, graph_uri, graph_auth, emb_model_name, topic_threshold=0.92, content_threshold=0.98):
171 | self.graph = Graph(graph_uri, auth=graph_auth)
172 | self.emb_model = SentenceTransformer(emb_model_name)
173 | self.topic_threshold = topic_threshold
174 | self.content_threshold = content_threshold
175 | self.topic_clusters = []
176 | self.content_clusters = []
177 | self.processed_topic_nodes = set()
178 | self.processed_content_nodes = set()
179 |
180 | def get_topic_node_ids(self):
181 | query = """
182 | MATCH (n:Topic)
183 | RETURN ID(n) AS topicNodeID
184 | """
185 | return self.graph.run(query).data()
186 |
187 | def get_content_node_ids(self):
188 | query = """
189 | MATCH (n:Content)
190 | RETURN ID(n) AS contentNodeID
191 | """
192 | return self.graph.run(query).data()
193 |
194 | def cluster_nodes(self, node_type, threshold):
195 | if node_type == 'Topic':
196 | node_id_list = self.get_topic_node_ids()
197 | processed_nodes = self.processed_topic_nodes
198 | clusters = self.topic_clusters
199 | embedding_field = 'n.主题嵌入'
200 | embedding_index = 'topic-embeddings'
201 | super_node_label = 'SuperTopic'
202 | elif node_type == 'Content':
203 | node_id_list = self.get_content_node_ids()
204 | processed_nodes = self.processed_content_nodes
205 | clusters = self.content_clusters
206 | embedding_field = 'n.路径嵌入'
207 | embedding_index = 'fact-embeddings'
208 | super_node_label = 'SuperContent'
209 | else:
210 | raise ValueError("Unsupported node type. Use 'Topic' or 'Content'.")
211 |
212 | print(f"{node_type} node count: {len(node_id_list)}")
213 | print(node_id_list)
214 |
215 | for node_dict in tqdm(node_id_list, desc=f"Clustering {node_type}s"):
216 | node_id = node_dict[f'{node_type.lower()}NodeID']
217 | if node_id in processed_nodes:
218 | continue
219 |
220 | query = f"""
221 | MATCH (n:{node_type}) WHERE ID(n) = {node_id}
222 | WITH {embedding_field} AS embedding
223 | CALL db.index.vector.queryNodes('{embedding_index}', {min(len(node_id_list), 100)}, embedding) YIELD node, score
224 | WHERE score > {threshold}
225 | RETURN ID(node) AS similarNodeId
226 | """
227 | similar_nodes = self.graph.run(query).data()
228 | similar_node_ids = {item['similarNodeId'] for item in similar_nodes}
229 |
230 | new_cluster = list(similar_node_ids)
231 | clusters.append(new_cluster)
232 | processed_nodes.update(similar_node_ids)
233 | processed_nodes.add(node_id)
234 |
235 | print(f"Total clusters for {node_type}: {len(clusters)}")
236 | for i, cluster in enumerate(clusters):
237 | print(f"{node_type} Cluster {i+1}: {cluster}")
238 |
239 | super_node_id = f"{super_node_label}_{i}"
240 | self.graph.run(f"MERGE (:{super_node_label} {{id: $id}})", id=super_node_id)
241 |
242 | matches = " ".join(f"MATCH (n{idx}) WHERE ID(n{idx}) = {node_id}" for idx, node_id in enumerate(cluster))
243 | creates = " ".join(f"CREATE (n{idx})-[:相似链接]->(st)" for idx in range(len(cluster)))
244 |
245 | query = f"""
246 | {matches}
247 | MATCH (st:{super_node_label} {{id: '{super_node_id}'}})
248 | {creates}
249 | RETURN NULL
250 | """
251 | self.graph.run(query)
252 |
253 | def fuse_topics_and_contents(self):
254 | self.cluster_nodes('Topic', self.topic_threshold)
255 | self.cluster_nodes('Content', self.content_threshold)
256 | print("Fusion of topics and contents completed.")
257 |
258 |
--------------------------------------------------------------------------------
/pgrag/seed_context_recall.py:
--------------------------------------------------------------------------------
1 | import json
2 | from concurrent.futures import ThreadPoolExecutor
3 | from tqdm import tqdm
4 | from py2neo import Graph
5 | from sentence_transformers import SentenceTransformer
6 | import numpy as np
7 | import torch
8 | def tensor(lst):
9 | return torch.tensor(lst)
10 |
11 | class SeedContextRecall:
12 | def __init__(self, graph_uri, graph_auth, emb_model_name,
13 | eval_data_with_qe_and_qdse_file, seed_topic_file, candidate_topic_file,
14 | recall_top_m, walk_top_m, num_threads, top_k):
15 | self.graph = Graph(graph_uri, auth=graph_auth)
16 | self.emb_model = SentenceTransformer(emb_model_name)
17 | self.eval_data_with_qe_and_qdse_file = eval_data_with_qe_and_qdse_file
18 | self.seed_topic_file = seed_topic_file
19 | self.candidate_topic_file = candidate_topic_file
20 | self.recall_top_m = recall_top_m
21 | self.walk_top_m = walk_top_m
22 | self.num_threads = num_threads
23 | self.top_k = top_k
24 | print('初始化成功!')
25 |
26 | def seed_topic_recall_base_tn(self, ci_embedding):
27 | query_seed_topic = """
28 | CALL db.index.vector.queryNodes('topic-embeddings', $M, $emb)
29 | YIELD node AS similarTopic, score
30 | MATCH (similarTopic)
31 | RETURN ID(similarTopic) AS topic_id, score
32 | """
33 | similar_topics_with_score = self.graph.run(query_seed_topic, M=self.recall_top_m, emb=ci_embedding).data()
34 | seed_topics_with_score = {item['topic_id']: item['score'] for item in similar_topics_with_score}
35 | return seed_topics_with_score
36 |
37 | def topic_walking(self, seed_topic_ids, qe):
38 | candidate_topics_ids = []
39 | candidate_topics_names = []
40 | qe_emb = np.array(qe)
41 |
42 | query_seed_topics = """
43 | MATCH (t:Topic) WHERE id(t) IN $seed_topic_ids
44 | RETURN collect(id(t)) AS seed_ids, collect(t.主题) AS seed_names
45 | """
46 | seed_topics = self.graph.run(query_seed_topics, seed_topic_ids=seed_topic_ids).data()
47 | if seed_topics:
48 | candidate_topics_ids.extend(seed_topics[0]['seed_ids'])
49 | candidate_topics_names.extend(seed_topics[0]['seed_names'])
50 |
51 | all_topic_ids = []
52 | all_topic_embs = []
53 | all_topic_names = []
54 | for seed_topic_id in seed_topic_ids:
55 | query_super_topic = """
56 | MATCH (t:Topic)-[:相似链接]->(st:SuperTopic)<-[:相似链接]-(other:Topic)
57 | WHERE id(t) = $seed_topic_id
58 | RETURN collect(id(other)) AS other_topic_ids,
59 | collect(other.主题嵌入) AS other_topic_embs,
60 | collect(other.主题) AS other_topic_names
61 | """
62 | connected_topics = self.graph.run(query_super_topic, seed_topic_id=seed_topic_id).data()
63 | if connected_topics:
64 | other_topic_ids = connected_topics[0]['other_topic_ids']
65 | other_topic_embs = [np.array(emb) for emb in connected_topics[0]['other_topic_embs']]
66 | other_topic_names = connected_topics[0]['other_topic_names']
67 | all_topic_ids.extend(other_topic_ids)
68 | all_topic_embs.extend(other_topic_embs)
69 | all_topic_names.extend(other_topic_names)
70 |
71 | if all_topic_embs:
72 | all_topic_embs_matrix = np.vstack(all_topic_embs)
73 | sims = qe_emb @ all_topic_embs_matrix.T
74 | top_indices = np.argsort(-sims)
75 |
76 | added_ids = set(candidate_topics_ids)
77 | for index in top_indices:
78 | if len(added_ids) >= self.walk_top_m:
79 | break
80 | topic_id = all_topic_ids[index]
81 | if topic_id not in added_ids:
82 | added_ids.add(topic_id)
83 | candidate_topics_ids.append(topic_id)
84 | candidate_topics_names.append(all_topic_names[index])
85 |
86 | return candidate_topics_ids, candidate_topics_names
87 |
88 | # def find_topN_paths_per_candidate_topic(self, candidate_topics, kpr_embedding):
89 | # query_content_and_path = '''
90 | # CALL db.index.vector.queryNodes('fact-embeddings', 10000, $embedding)
91 | # YIELD node AS similarContent, score
92 | # UNWIND $topics AS topic
93 | # MATCH path = (t:Topic)-[*]->(similarContent)
94 | # WHERE t.主题 = topic
95 | # WITH path, score, t
96 | # ORDER BY score DESC
97 | # LIMIT $topN
98 | # UNWIND nodes(path) AS p
99 | # WITH p,
100 | # score,
101 | # CASE WHEN p:Topic THEN p.主题
102 | # WHEN p:SubTopic THEN p.路标
103 | # WHEN p:Content THEN p.事实
104 | # ELSE null END AS pathAttribute
105 | # WHERE pathAttribute IS NOT NULL
106 | # RETURN collect(pathAttribute) AS pathAttributes, score
107 | # '''
108 | # results = self.graph.run(query_content_and_path, topics=candidate_topics, topN=self.top_k, embedding=kpr_embedding).data()
109 | # return results
110 |
111 | def seed_topic_recall(self, question_info):
112 | question, question_embedding, qdse = question_info
113 | seed_topic_paths = {}
114 | print('查询问题:', question)
115 | topM_paths = self.seed_topic_recall_base_tn(question_embedding)
116 | seed_topic_paths = {
117 | 'question': question,
118 | 'qe': question_embedding,
119 | 'topM_topic_paths': str(topM_paths),
120 | 'qdse': qdse
121 | }
122 | with open(self.seed_topic_file, 'a', encoding='utf-8') as file:
123 | file.write(str(seed_topic_paths) + '\n')
124 |
125 | def execute(self):
126 | with open(self.eval_data_with_qe_and_qdse_file, 'r', encoding='utf-8') as file:
127 | eval_data = json.load(file)
128 |
129 | questions = []
130 | qes = []
131 | qdses = []
132 | for entry in eval_data:
133 | question_text = entry["question"]
134 | question_embedding = entry[question_text]["qe_bge-base-zh"]
135 | qdse = entry[question_text]['qdse_bge-base-zh']
136 | questions.append(question_text)
137 | qes.append(question_embedding)
138 | qdses.append(qdse)
139 |
140 | question_infos = [(question, qe, qdse) for question, qe, qdse in zip(questions, qes, qdses)]
141 | print(len(question_infos))
142 |
143 | with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
144 | list(tqdm(executor.map(self.seed_topic_recall, question_infos), total=len(question_infos)))
145 |
146 | with open(self.seed_topic_file, 'r', encoding='utf-8') as file:
147 | topM_seed_topic_lines = file.readlines()
148 | for topM_seed_topic_line in topM_seed_topic_lines:
149 | topM_seed_topics_info = eval(topM_seed_topic_line)
150 | question = topM_seed_topics_info['question']
151 | qe = topM_seed_topics_info['qe']
152 | topM_seed_topic_with_sim = eval(topM_seed_topics_info['topM_topic_paths'])
153 | qdse = topM_seed_topics_info['qdse']
154 | print(question)
155 | seed_topic_ids = list(topM_seed_topic_with_sim.keys())
156 | candidate_topic_ids, candidate_topic_names = self.topic_walking(seed_topic_ids, qe)
157 | print("候选主题:", candidate_topic_ids, candidate_topic_names)
158 |
159 | candidate_topics = {
160 | 'question': question,
161 | 'candidate_topic_ids': candidate_topic_ids,
162 | 'candidate_topic_names': candidate_topic_names,
163 | 'qdse': qdse,
164 | 'qe': qe
165 | }
166 | with open(self.candidate_topic_file, 'a', encoding='utf-8') as file:
167 | file.write(str(candidate_topics) + '\n')
168 |
169 |
170 |
--------------------------------------------------------------------------------
/pgrag/sub_pseudo_graph_retriever.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | from py2neo import Graph
4 |
5 | class PG_RAG_Processor:
6 | def __init__(self, graph_uri, graph_auth, candidate_topic_file,
7 | matrix_templates_file, matrix_templates_with_sim_file, topK=14):
8 | self.graph_uri = graph_uri
9 | self.graph_auth = graph_auth
10 | self.candidate_topic_file = candidate_topic_file
11 | self.matrix_templates_file = matrix_templates_file
12 | self.matrix_templates_with_sim_file = matrix_templates_with_sim_file
13 | self.topK = topK
14 | self.graph = Graph(graph_uri, auth=graph_auth)
15 |
16 | def fetch_paths_and_embeddings(self, candidate_topics_ids):
17 | query = """
18 | MATCH path=(t:Topic)-[:基础链接*]->(f: Content)
19 | WHERE id(t) IN $topic_ids
20 | RETURN [node in nodes(path) | id(node)] AS node_ids,
21 | [node in nodes(path) | coalesce(node.主题嵌入, node.路由嵌入, node.路径嵌入)] AS node_embs
22 | """
23 | paths_results = self.graph.run(query, topic_ids=candidate_topics_ids).data()
24 | path_ids = []
25 | path_embs = []
26 | for result in paths_results:
27 | path_ids.append(result['node_ids'])
28 | path_embs.append(result['node_embs'])
29 | return path_ids, path_embs
30 |
31 |
32 | def create_matrix_templates(self):
33 | with open(self.candidate_topic_file, 'r', encoding='utf-8') as file:
34 | candidate_topic_lines = file.readlines()
35 |
36 | for candidate_topic_line in candidate_topic_lines:
37 | candidate_topic_info = eval(candidate_topic_line)
38 | question = candidate_topic_info['question']
39 | candidate_topics_ids = candidate_topic_info['candidate_topic_ids']
40 | qdse = candidate_topic_info['qdse']
41 | path_ids, path_embs = self.fetch_paths_and_embeddings(candidate_topics_ids)
42 |
43 | max_len_ids = max(len(ids) for ids in path_ids)
44 | id_matrix = np.full((len(path_ids), max_len_ids), -1)
45 | emb_matrix = np.zeros((len(path_embs), max_len_ids, len(path_embs[0][0])))
46 |
47 | for i, ids in enumerate(path_ids):
48 | id_matrix[i, :len(ids)] = ids
49 | for i, embs in enumerate(path_embs):
50 | for j, emb in enumerate(embs):
51 | emb_matrix[i, j, :] = emb
52 |
53 | matrix_templates = {
54 | 'question': question,
55 | 'ID Matrix': id_matrix.tolist(),
56 | 'EMB Matrix': emb_matrix.tolist(),
57 | 'qdse': qdse
58 | }
59 |
60 | with open(self.matrix_templates_file, 'a', encoding='utf-8') as file:
61 | file.write(str(matrix_templates) + '\n')
62 |
63 | def compute_similarity_matrices(self):
64 | with open(self.matrix_templates_file, 'r', encoding='utf-8') as file:
65 | matrix_templates_lines = file.readlines()
66 |
67 | for matrix_templates_line in tqdm(matrix_templates_lines, desc="Processing SM"):
68 | matrix_templates_info = eval(matrix_templates_line)
69 | question = matrix_templates_info['question']
70 | matrix_id_list = matrix_templates_info['ID Matrix']
71 | matrix_emb_list = matrix_templates_info['EMB Matrix']
72 | kps_emb_list = matrix_templates_info['qdse']
73 |
74 | matrix_emb = np.array(matrix_emb_list)
75 | kps_emb = np.array(kps_emb_list)
76 |
77 | num_kps = kps_emb.shape[0]
78 | if num_kps == 1024:
79 | num_kps = 1
80 | num_matrices, num_vectors_per_matrix, emb_dim = matrix_emb.shape
81 | flattened_matrix_emb = matrix_emb.reshape(-1, emb_dim)
82 | sims = kps_emb @ flattened_matrix_emb.T
83 | reshaped_sims = sims.reshape(num_kps, num_matrices, num_vectors_per_matrix)
84 |
85 | matrix_templates_with_sim = {
86 | 'question': question,
87 | 'SIM Matrix': reshaped_sims.tolist(),
88 | 'ID Matrix': matrix_id_list,
89 | 'qdse': kps_emb_list
90 | }
91 |
92 | with open(self.matrix_templates_with_sim_file, 'a', encoding='utf-8') as file:
93 | file.write(str(matrix_templates_with_sim) + '\n')
94 |
95 | def process_top_k_ids(self, contexts_ids_file, final_contexts_file):
96 | with open(self.matrix_templates_with_sim_file, 'r', encoding='utf-8') as file:
97 | matrix_templates_with_sim_lines = file.readlines()
98 |
99 | processor = MatrixProcessor()
100 | for matrix_templates_with_sim_line in tqdm(matrix_templates_with_sim_lines, desc="Processing max id"):
101 | matrix_templates_with_sim_info = eval(matrix_templates_with_sim_line)
102 | question = matrix_templates_with_sim_info['question']
103 | matrix_sim = np.array(matrix_templates_with_sim_info['SIM Matrix'])
104 | matrix_id = np.array(matrix_templates_with_sim_info['ID Matrix'])
105 |
106 | final_matrix = np.zeros_like(matrix_id, dtype=np.float64)
107 | top_values = []
108 | top_indices = []
109 |
110 | for matrix in matrix_sim:
111 | last_non_zeros = []
112 |
113 | for row_index, row in enumerate(matrix):
114 | for col_index in range(len(row) - 1, -1, -1):
115 | if row[col_index] != 0:
116 | last_non_zeros.append((row[col_index], row_index, col_index))
117 | break
118 |
119 | last_non_zeros_sorted = sorted(last_non_zeros, key=lambda x: x[0], reverse=True)[:self.topK]
120 | top_values.append([val[0] for val in last_non_zeros_sorted])
121 | top_indices.append([(val[1], val[2]) for val in last_non_zeros_sorted])
122 |
123 | control_matrices, pathway_matrices = processor.create_control_and_pathway_matrices(matrix, matrix_id, top_values, top_indices)
124 | temp_result_matrix = processor.color_matrices(control_matrices, pathway_matrices)
125 | final_matrix += temp_result_matrix
126 | top_k_ids = processor.find_top_k_ids(final_matrix, matrix_id, self.topK)
127 | top_k_contexts = self.convert_paths_to_json(top_k_ids)
128 | contexts_ids = {
129 | 'question': question,
130 | 'top_k_ids': top_k_ids
131 | }
132 | with open(contexts_ids_file, 'a', encoding='utf-8') as file:
133 | file.write(str(contexts_ids) + '\n')
134 |
135 | final_contexts = {
136 | 'question': question,
137 | 'top_k_contexts': top_k_contexts
138 | }
139 | with open(final_contexts_file, 'a', encoding='utf-8') as file:
140 | file.write(str(final_contexts) + '\n')
141 |
142 | def convert_paths_to_json(self, top_k_ids):
143 | query = '''
144 | MATCH (f)
145 | WHERE id(f) IN $id_list
146 | OPTIONAL MATCH path=(t:Topic)-[:基础链接*]->(f: Content)
147 | WITH COLLECT(nodes(path)) AS all_nodes
148 | RETURN DISTINCT all_nodes
149 | '''
150 | result = self.graph.run(query, id_list=top_k_ids).data()
151 | all_paths = list(result[0]['all_nodes'])
152 | # print('all_paths:', all_paths)
153 | final_json = {}
154 | for path_nodes in all_paths:
155 | current_level = final_json # 初始化当前层级指向final_json
156 | for node in path_nodes:
157 | # 假设有方法从节点获取名称和类型
158 | if 'Topic' in node.labels:
159 | node_key = node['主题'].strip('\'').strip('。')
160 | node_embedding = node['主题嵌入'] # 获取节点嵌入
161 | elif 'SubTopic' in node.labels:
162 | node_key = node['路标'].strip('\'')
163 | node_embedding = node['路由嵌入'] # 获取节点嵌入
164 | elif 'Content' in node.labels:
165 | node_key = node['事实'].strip('\'')
166 | node_embedding = node['路径嵌入'] # 获取节点嵌入
167 | # 检查当前层级是否已存在节点键
168 | if node_key not in current_level:
169 | # 如果节点是事实节点,直接添加
170 | if 'Content' in node.labels:
171 | current_level[node_key] = {}
172 | else:
173 | # 对于主题和子主题节点,创建新的字典来保存子节点
174 | current_level[node_key] = {} if 'Topic' in node.labels else {}
175 | # 更新当前层级的引用,指向新添加的节点
176 | current_level = current_level[node_key]
177 | else:
178 | # 如果当前层级已存在节点键,更新当前层级的引用,除非是事实节点
179 | if 'Content' not in node.labels:
180 | current_level = current_level[node_key]
181 | return final_json
182 |
183 | class MatrixProcessor:
184 | def create_control_and_pathway_matrices(self, matrix_sim, matrix_id, top_values, top_indices, support_threshold=0.01, continue_threshold=0.02):
185 | control_matrices = []
186 | pathway_matrices = []
187 |
188 | for value, (row_idx, col_idx) in zip(top_values[-1], top_indices[-1]):
189 | control_matrix = np.zeros_like(matrix_sim)
190 | pathway_matrix = np.zeros_like(matrix_id)
191 |
192 | f = col_idx
193 | while f >= 0:
194 | diff = abs(matrix_sim[row_idx, f] - value)
195 | if diff <= support_threshold:
196 | control_matrix[row_idx, f] = 1 * matrix_sim[row_idx, f]
197 | elif support_threshold < diff <= continue_threshold:
198 | control_matrix[row_idx, f] = 0.5 * matrix_sim[row_idx, f]
199 | else:
200 | control_matrix[row_idx, f] = 0
201 | break
202 | f -= 1
203 |
204 | col_f = f + 1
205 | sub_root = matrix_id[row_idx, col_f]
206 |
207 | for i in range(matrix_sim.shape[0]):
208 | if i == row_idx:
209 | continue
210 | for j in range(col_f, matrix_sim.shape[1]):
211 | diff = abs(matrix_sim[i, j] - value)
212 | if diff > continue_threshold:
213 | break
214 | elif diff <= support_threshold:
215 | control_matrix[i, j] = 1 * matrix_sim[i, j]
216 | elif support_threshold < diff <= continue_threshold:
217 | control_matrix[i, j] = 0.5 * matrix_sim[i, j]
218 |
219 | control_matrices.append(control_matrix)
220 |
221 | for j in range(col_f, matrix_id.shape[1]):
222 | if matrix_id[row_idx, j] == -1:
223 | break
224 | pathway_matrix[row_idx, j] = 1
225 |
226 | for direction in [-1, 1]:
227 | i = row_idx
228 | while 0 <= i + direction < matrix_id.shape[0]:
229 | i += direction
230 | if matrix_id[i, col_f] != sub_root:
231 | break
232 | for j in range(col_f, matrix_id.shape[1]):
233 | if matrix_id[i, j] == -1:
234 | break
235 | pathway_matrix[i, j] = 1
236 |
237 | pathway_matrices.append(pathway_matrix)
238 |
239 | return control_matrices, pathway_matrices
240 |
241 | def color_matrices(self, control_matrices, pathway_matrices):
242 | result_matrix = np.zeros_like(control_matrices[0])
243 |
244 | for control_matrix, pathway_matrix in zip(control_matrices, pathway_matrices):
245 | result_matrix += np.multiply(control_matrix, pathway_matrix)
246 |
247 | return result_matrix
248 |
249 | def find_top_k_ids(self, final_matrix, matrix_id, topK):
250 | row_sums = np.sum(final_matrix, axis=1)
251 | top_k_row_indices = np.argsort(row_sums)[-topK:][::-1]
252 |
253 | top_k_ids = []
254 | for row_idx in top_k_row_indices:
255 | last_id = -1
256 | for col_idx in range(matrix_id.shape[1]):
257 | id_val = matrix_id[row_idx, col_idx]
258 | if id_val == -1:
259 | break
260 | last_id = id_val
261 | if last_id != -1:
262 | top_k_ids.append(int(last_id))
263 |
264 | return top_k_ids
265 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | evaluate==0.4.1
2 | FlagEmbedding==1.2.5
3 | jieba==0.42.1
4 | loguru==0.7.2
5 | neo4j==5.17.0
6 | numpy==1.24.0
7 | openai==1.25.2
8 | py2neo==2021.2.4
9 | requests==2.31.0
10 | rouge_score==0.1.2
11 | scikit-learn==1.4.2
12 | sentence-transformers==2.4.0
13 | text2vec==1.2.9
14 | tiktoken==0.6.0
15 | tokenizers==0.15.2
16 | torch==2.2.1
17 | tqdm==4.66.2
18 | transformers==4.38.1
19 | umap==0.1.1
20 |
--------------------------------------------------------------------------------