├── LICENSE ├── README.md ├── assets ├── main_results.jpg ├── pg_gen.jpg └── pgr.jpg ├── data └── CRUD_RAG │ ├── QAGeneration_gpt-4-1106-preview_1doc.json │ ├── QAGeneration_gpt-4-1106-preview_2docs.json │ └── QAGeneration_gpt-4-1106-preview_3docs.json ├── pgrag ├── configs │ ├── __pycache__ │ │ └── real_config.cpython-39.pyc │ └── real_config.py ├── data │ └── eval │ │ └── eval_data_with_qe_and_qdse_example.json ├── llms │ ├── __pycache__ │ │ ├── base.cpython-39.pyc │ │ ├── local.cpython-39.pyc │ │ ├── qwen13b.cpython-39.pyc │ │ ├── qwen14b.cpython-39.pyc │ │ ├── remote.cpython-311.pyc │ │ ├── remote.cpython-38.pyc │ │ └── remote.cpython-39.pyc │ ├── base.py │ └── remote.py ├── mindmap_generator.py ├── prompts │ ├── extract_fact_verification_items.txt │ ├── gen_mindmap.txt │ ├── query_deconstruction.txt │ └── topic_extract.txt ├── pseudo_graph_constructor.py ├── seed_context_recall.py └── sub_pseudo_graph_retriever.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Waiting for updates 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | PG-RAG: Empowering Large Language Models to Set up a Knowledge Retrieval Indexer via Self-Learning 3 |

4 | 5 |

6 | 7 | ## Introduction 8 | 9 | PG-RAG proposes a pre-retrieval augmented generation method that introduces a _refinement_ step before the _indexing-retrieval-generation_ process, ensuring the accuracy of retrieved content from the outset. We leverage the self-learning capabilities of LLMs to transform documents into easily understandable and retrievable hierarchical indexes. This process naturally filters out noise and enhances information readability. By establishing connections between similar or complementary pieces of knowledge, we enable the retriever to function across multiple documents. During the knowledge retrieval phase, we use _pseudo-answers_ to assist the retriever in locating relevant information and perform walking in the matrices, thereby achieving accurate and rapid knowledge localization. Finally, we assemble the retrieved fact paths into a structured context, providing rich background information for LLMs to generate knowledge-grounded responses. 10 | 11 |

12 | 13 |
Supported models of PG-RAG framework 14 | 15 | | Model Type | Loading Method | Example Models | References | 16 | |------------|------------------------------|------------------------------------|------------------------------------------------------------------------------------------------------------------------| 17 | | `api` | `requests` | OpenAI models | [OpenAI API](https://platform.openai.com/docs/introduction)| 18 | | `local` | HuggingFace `Sentence transformers` | `BAAI/bge-large-zh-v1.5` | [HuggingFace `Sentence transformers`](https://sbert.net/) | 19 | | `remote` | `requests` | Internal use only | Not available to the public | 20 | 21 |
22 | 23 |
A specific case of main topic and fact-checking items. 24 | 25 | ```json 26 | { 27 | "Main Topic": "Announcement of the results of national medical device supervision sampling", 28 | "Fact-Checking Items": { 29 | "Date and Time": "2023-07-28", 30 | "Issuing Organization": "National Medical Products Administration", 31 | "Types of Products Sampled": "Dental low-voltage electric motors, various medical patches (far infrared therapy patches, magnetic therapy patches, acupoint magnetic therapy patches), among other five types", 32 | "Sampling Results": "A total of 12 batches (units) of products did not meet the standard requirements", 33 | "Specific Non-compliant Products and Issues": { 34 | "Dental Low-Voltage Electric Motor (1 unit)": "Produced by Guangdong Jingmei Medical Technology Co., Ltd., with issues related to leakage current and patient auxiliary current (under working temperature), and no-load speed not meeting the standard requirements.", 35 | "Vertical Pressure Steam Sterilizer (1 unit)": "Produced by Hefei Huatai Medical Equipment Co., Ltd., involving 'permissible limit values of accessible parts under normal conditions' and limit values under single fault condition (ground fault) not meeting the standard requirements.", 36 | "Electric Suction Device (1 unit)": "Produced by Suzhou Bein Technology Co., Ltd., involving 'network-powered, movable high negative pressure/high flow equipment' not meeting the standard requirements.", 37 | "Patch-type Medical Devices (far infrared therapy patch, magnetic therapy patch, acupoint magnetic therapy patch) 6 batches": "Produced by Jiujiang Gaoke Pharmaceutical Technology Co., Ltd., Zhengzhou Zhongyuan Fuli Industrial & Trade Co., Ltd., Ulanqab Qiao's Weiye Medical Device Co., Ltd., Hunan Dexi Medical Technology Co., Ltd., and Chongqing Zhengren Medical Device Co., Ltd., with issues involving detection of 'pharmaceutical ingredients that should not be detected according to supplementary testing methods.'", 38 | "Human Blood and Blood Component Plastic Bag Containers (blood bags) 3 batches": "Produced by Nanjing Sailjin Biomedical Co., Ltd., with issues involving non-compliance of the blood bag transfusion ports with the standards." 39 | } 40 | } 41 | } 42 | ``` 43 | 44 |
45 | 46 |
A specific case of the mind map generated through main topic and fact-checking items. 47 | 48 | ```json 49 | { 50 | "Announcement of the results of national medical device supervision sampling": { 51 | "Announcement Date and Time": "July 28, 2023", 52 | "Sampled Items": [ 53 | "Dental Low-Voltage Electric Motors", 54 | "Patch-Type Medical Devices (including Far Infrared Therapy Patches, Magnetic Therapy Patches, Acupoint Magnetic Therapy Patches)" 55 | ], 56 | "Total Number of Sampled Products": 12, 57 | "Number of Product Types Not Meeting Standard Requirements": 5, 58 | "Non-Compliant Medical Devices and Manufacturer Information": { 59 | "Dental Low-Voltage Electric Motor": { 60 | "Quantity": 1, 61 | "Manufacturer": "Guangdong Jingmei Medical Technology Co., Ltd.", 62 | "Issue Description": "Involving leakage current and patient auxiliary current, no-load speed not meeting the standard requirements" 63 | }, 64 | "Vertical Pressure Steam Sterilizer": { 65 | "Quantity": 1, 66 | "Manufacturer": "Hefei Huatai Medical Equipment Co., Ltd.", 67 | "Issue Description": "Involving the allowable limit values of accessible parts under normal conditions and limit values under single fault condition (ground fault) not meeting the standard requirements" 68 | }, 69 | "Electric Suction Device": { 70 | "Quantity": 1, 71 | "Manufacturer": "Suzhou Bein Technology Co., Ltd.", 72 | "Issue Description": "Involving network-powered, movable high negative pressure/high flow equipment not meeting the standard requirements" 73 | }, 74 | "Patch-Type Medical Devices": { 75 | "Sub-Types": [ 76 | "Far Infrared Therapy Patch", 77 | "Magnetic Therapy Patch", 78 | "Acupoint Magnetic Therapy Patch" 79 | ], 80 | "Batch Quantity": 6, 81 | "List of Manufacturers": [ 82 | "Jiujiang Gaoke Pharmaceutical Technology Co., Ltd.", 83 | "Zhengzhou Zhongyuan Fuli Industrial & Trade Co., Ltd.", 84 | "Ulangab Qiao's Weiye Medical Device Co., Ltd.", 85 | "Hunan Dexi Medical Technology Co., Ltd.", 86 | "Chongqing Zhengren Medical Device Co., Ltd." 87 | ], 88 | "Issue Description": "Involving detection of pharmaceutical ingredients that should not be detected according to supplementary testing methods" 89 | }, 90 | "Human Blood and Blood Component Plastic Bag Containers (Blood Bags)": { 91 | "Quantity": 3, 92 | "Manufacturer": "Nanjing Sailjin Biomedical Co., Ltd.", 93 | "Issue Description": "Involving blood bag transfusion ports not meeting the standard requirements" 94 | } 95 | } 96 | } 97 | } 98 | 99 | ``` 100 | 101 |
102 | 103 |
A specific case of fact path in PG. 104 | 105 | "Announcement of the results of national medical device supervision sampling"> "Non-Compliant Medical Devices and Manufacturer Information"> "Dental Low-Voltage Electric Motor"> "Manufacturer"> "Guangdong Jingmei Medical Technology Co., Ltd." 106 | 107 |
108 | 109 |
Project structure 110 | 111 | ```bash 112 | . 113 | ├── .github 114 | ├── .gitignore 115 | ├── LICENSE 116 | ├── README.md 117 | ├── assets # Static files like images used in documentation 118 | ├── data # Datasets (e.g., 1-Document QA) 119 | ├── output # Stores the contexts of querys 120 | ├── requirements.txt 121 | └── pgrag # Source code for the project 122 | ├── configs # Scripts for initializing model loading parameters 123 | ├── data # Intermediate result data 124 | ├── eval ── eval_data_with_qe_and_qdse_example.json # Evaluate dataset samples 125 | ├── raw_news # The original documents required for the retrieval library building 126 | ├── pg_gen # Data during the pseudo-graph construction process 127 | └── context_recall # Data during the pseudo-graph retrieval process 128 | ├── mindmap_generator.py # Scripts for generating mind maps 129 | ├── pseudo_graph_constructor.py # Scripts for pseudo-graph construction 130 | ├── seed_context_recall.py # Scripts for seed contexts recall 131 | ├── sub_pseudo_graph_retriever.py # Scripts for structured contexts recall 132 | ├── llm # Calling LLM methods 133 | └── prompts # Prompt Engineering 134 | ``` 135 | 136 |
137 | 138 | ## Installation 139 | 140 | Before using PG-RAG: 141 | 142 | 1. Ensure you have Python 3.9.0+ 143 | 2. Install the required packages: 144 | 145 | ```bash 146 | pip install -r requirements.txt 147 | ``` 148 | 149 | 3. Prerequisites 150 | 151 | - **JDK 17**: Ensure you have JDK 17 installed on your machine. 152 | - **Neo4j**: Download and install Neo4j 5.17.0. Start the Neo4j console using the following command: 153 | 154 | ```bash 155 | neo4j console 156 | ``` 157 | 158 | ## PG-RAG Pipeline 159 | 160 | This repository contains the implementation of the PG-RAG (Pseudo-Graph Retrieval Augmented Generation) pipeline. The pipeline consists of four main stages: mind map generation, pseudo-graph construction, seed context recall, and final context extension and generation. 161 | 162 | ### Common Setup 163 | 164 | Before running the individual scripts, ensure the following configuration: 165 | 166 | ```python 167 | graph_uri = "bolt://localhost:7687" 168 | graph_auth = ("neo4j", "password") 169 | emb_model_name = "/path/to/your/local/bge-base-zh-1.5" 170 | num_threads = 20 171 | topK = 8 172 | 173 | # Parameters for Pseudo-Graph Construction 174 | model_name = 'gpt-3.5-turbo' # Model name can be: gpt-3.5-turbo, gpt-4-0613, gpt-4-1106-preview 175 | raw_news_files_dir = 'data/raw_news/batch0' 176 | title_files_dir = 'data/pg_gen/batch0/title' 177 | fcis_files_dir = "data/pg_gen/batch0/textToVerificationText/" 178 | mindmaps_str_files_dir = "data/pg_gen/batch0/mindmap_str/" 179 | mindmaps_json_files_dir = "data/pg_gen/batch0/mindmap_json/" 180 | 181 | # Parameters for Knowledge Recall via Pseudo-Graph Retrieval 182 | eval_data_with_qe_and_qdse_file = 'data/eval/eval_data_with_qe_and_qdse.json' 183 | seed_topic_file = 'data/context_recall/pgrag/seed_topics.json' 184 | candidate_topic_file = 'data/context_recall/pgrag/candidate_topics.json' 185 | matrix_templates_file = 'data/context_recall/pgrag/matrix_templates.json' 186 | matrix_templates_with_sim_file = 'data/context_recall/pgrag/matrix_templates_with_sim.json' 187 | contexts_ids_file = 'data/context_recall/pgrag/contexts_ids.json' 188 | final_contexts_file = 'data/context_recall/pgrag/final_contexts.json' 189 | 190 | recall_top_m = 3 191 | walk_top_m = 6 192 | ``` 193 | 194 | Ensure Neo4j database is running and accessible at the specified `graph_uri`. Adjust file paths and model names as per your environment setup. Modify the parameters as needed to fine-tune the performance and results. 195 | 196 | ## Steps to Execute 197 | 198 | 1. **Generate Mind Maps**: Generate mind maps for original texts using `pgrag/mindmap_generator.py`. 199 | 2. **Build Pseudo-Graph**: Construct a pseudo-graph through `pgrag/pseudo_graph_constructor.py`. 200 | 3. **Seed Context Recall**: Recall seed topics using `pgrag/seed_context_recall.py`. 201 | 4. **Final Context Extensions**: Perform final context extensions and context generation using `pgrag/sub_pseudo_graph_retriever.py`. 202 | 203 | ### 1. Generate Mind Maps 204 | 205 | To generate mind maps from the original texts: 206 | 207 | ```python 208 | # pgrag/mindmap_generator.py 209 | 210 | mindmap_generation = MindmapGeneration(model_name, num_threads, raw_news_files_dir, title_files_dir, fcis_files_dir, mindmaps_str_files_dir, mindmaps_json_files_dir) 211 | mindmap_generation.execute() 212 | ``` 213 | 214 | ### 2. Build Pseudo-Graph 215 | 216 | To build the pseudo-graph: 217 | 218 | ```python 219 | # pgrag/pseudo_graph_constructor.py 220 | 221 | inserter = Neo4jDataInserter(graph_uri, graph_auth, emb_model_name, num_threads) 222 | inserter.execute(raw_news_files_dir, mindmaps_json_files_dir) 223 | ``` 224 | 225 | To ensure efficient querying and retrieval of topic embeddings and fact embeddings, you need to manually create the following vector indexes in Neo4j: 226 | 227 | ```cypher 228 | CREATE VECTOR INDEX topic-embeddings IF NOT EXISTS 229 | FOR (m:Topic) 230 | ON m.主题嵌入 231 | OPTIONS {indexConfig: { 232 | `vector.dimensions`: 1024, 233 | `vector.similarity_function`: 'cosine' 234 | }} 235 | 236 | CREATE VECTOR INDEX fact-embeddings IF NOT EXISTS 237 | FOR (m:Content) 238 | ON m.路径嵌入 239 | OPTIONS {indexConfig: { 240 | `vector.dimensions`: 1024, 241 | `vector.similarity_function`: 'cosine' 242 | }} 243 | ``` 244 | 245 | These commands will create the necessary indexes for the `Topic` and `Content` nodes, respectively, with a vector dimension of 1024 and using the cosine similarity function for the embeddings. 246 | 247 | Please run the above commands in your Neo4j database before executing the main application to ensure all functionalities work as expected. Then, 248 | 249 | ```python 250 | fusion = TopicAndContentFusion(graph_uri, graph_auth, emb_model_name) 251 | fusion.fuse_topics_and_contents() 252 | ``` 253 | 254 | ### 3. Seed Context Recall 255 | 256 | To recall seed contexts: 257 | 258 | ```python 259 | # pgrag/seed_context_recall.py 260 | 261 | seed_context_recall = SeedContextRecall(graph_uri, graph_auth, emb_model_name, eval_data_with_qe_and_qdse_file, seed_topic_file, candidate_topic_file, recall_top_m, walk_top_m, num_threads, topK) 262 | seed_context_recall.execute() 263 | ``` 264 | 265 | ### 4. Final Context Extensions 266 | 267 | To perform final context extensions and context generation: 268 | 269 | ```python 270 | # pgrag/sub_pseudo_graph_retriever.py 271 | 272 | processor = PG_RAG_Processor(graph_uri, graph_auth, candidate_topic_file, matrix_templates_file, matrix_templates_with_sim_file, topK) 273 | processor.create_matrix_templates() 274 | processor.compute_similarity_matrices() 275 | processor.process_top_k_ids(contexts_ids_file, final_contexts_file) 276 | ``` 277 | 278 | ## Results for Experiment 279 | 280 |

281 | -------------------------------------------------------------------------------- /assets/main_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/assets/main_results.jpg -------------------------------------------------------------------------------- /assets/pg_gen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/assets/pg_gen.jpg -------------------------------------------------------------------------------- /assets/pgr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/assets/pgr.jpg -------------------------------------------------------------------------------- /pgrag/configs/__pycache__/real_config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/configs/__pycache__/real_config.cpython-39.pyc -------------------------------------------------------------------------------- /pgrag/configs/real_config.py: -------------------------------------------------------------------------------- 1 | GPT_transit_token = '' # openai.api_key 2 | GPT_transit_url = '' # openai.base_url 3 | -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/local.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/local.cpython-39.pyc -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/qwen13b.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/qwen13b.cpython-39.pyc -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/qwen14b.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/qwen14b.cpython-39.pyc -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/remote.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/remote.cpython-311.pyc -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/remote.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/remote.cpython-38.pyc -------------------------------------------------------------------------------- /pgrag/llms/__pycache__/remote.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IAAR-Shanghai/PGRAG/c2bded31b00c02baaba87fda50e93d5fc0e86409/pgrag/llms/__pycache__/remote.cpython-39.pyc -------------------------------------------------------------------------------- /pgrag/llms/base.py: -------------------------------------------------------------------------------- 1 | import jieba 2 | from text2vec import Similarity 3 | import evaluate 4 | import shutil 5 | import os 6 | import copy 7 | import json 8 | from abc import ABC, abstractmethod 9 | 10 | from loguru import logger 11 | 12 | 13 | class BaseLLM(ABC): 14 | def __init__( 15 | self, 16 | model_name: str = None, 17 | temperature: float = 1.0, 18 | max_new_tokens: int = 4096, 19 | top_p: float = 0.9, 20 | top_k: int = 5, 21 | **more_params 22 | ): 23 | self.params = { 24 | 'model_name': model_name if model_name else self.__class__.__name__, 25 | 'temperature': temperature, 26 | 'max_new_tokens': max_new_tokens, 27 | 'top_p': top_p, 28 | 'top_k': top_k, 29 | **more_params 30 | } 31 | self.post_init() 32 | 33 | def post_init(self): 34 | """Post initialization method for subclasses. 35 | Normally, this method should initialize the model and tokenizer. 36 | """ 37 | ... 38 | 39 | def update_params(self, inplace: bool = True, **params): 40 | if inplace: 41 | self.params.update(params) 42 | return self 43 | else: 44 | new_obj = copy.deepcopy(self) 45 | new_obj.params.update(params) 46 | return new_obj 47 | 48 | @abstractmethod 49 | def request(self, query:str) -> str: 50 | return '' 51 | 52 | def safe_request(self, query: str) -> str: 53 | """Safely make a request to the language model, handling exceptions.""" 54 | try: 55 | response = self.request(query) 56 | except Exception as e: 57 | logger.warning(repr(e)) 58 | response = '' 59 | return response 60 | 61 | def safe_request_with_prompt(self, text, prompt_file) -> str: 62 | template = self._read_prompt_template(prompt_file) 63 | query = template.format(text=text) 64 | query = query.replace('{{', '{').replace('}}', '}') 65 | respond = self.safe_request(query) 66 | return respond 67 | 68 | def extract_fact_verification_items(self, news_body) -> str: 69 | prompt_file = 'extract_fact_verification_items.txt' 70 | respond = self.safe_request_with_prompt(news_body, prompt_file) 71 | return respond 72 | 73 | def extract_title(self, text) -> str: 74 | prompt_file = 'topic_extract.txt' 75 | respond = self.safe_request_with_prompt(text, prompt_file) 76 | return respond 77 | 78 | def gen_mindmap(self, title, text) -> str: 79 | prompt_file = 'gen_mindmap.txt' 80 | template = self._read_prompt_template(prompt_file) 81 | query = template.format(title=title, text=text) 82 | query = query.replace('{{', '{').replace('}}', '}') 83 | respond = self.safe_request(query) 84 | return respond 85 | 86 | def process_input_output_pair(self, line_no, output_data, output_dir): 87 | os.makedirs(output_dir, exist_ok=True) 88 | filename_output = os.path.join(output_dir, f"{line_no}.txt") 89 | 90 | with open(filename_output, 'w', encoding='UTF-8') as output_file: 91 | output_file.write(output_data) 92 | 93 | 94 | def process_efvi(self, file_path, gpt_instance, output_dir): 95 | line_no = os.path.basename(file_path).replace('.txt', '') 96 | with open(file_path, 'r', encoding='UTF-8') as file: 97 | news_body = file.read() 98 | output_data = gpt_instance.extract_fact_verification_items(news_body) 99 | rougeL_score = self.rougeL_score(output_data, news_body) 100 | print(f"ROUGE-L:'{rougeL_score}") 101 | # 测试 BertScore 分数计算 102 | bert_score = self.bert_score(output_data, news_body) 103 | print(f"bert_score:'{bert_score}") 104 | if rougeL_score >= 0.15 and bert_score >= 0.85: 105 | gpt_instance.process_input_output_pair(line_no, output_data, output_dir) 106 | else: 107 | new_file_path = "data/raw_news/regen/" 108 | os.makedirs(new_file_path, exist_ok=True) 109 | shutil.move(file_path, new_file_path) 110 | 111 | def process_et(self, file_path, gpt_instance, output_dir): 112 | line_no = os.path.basename(file_path).replace('.txt', '') 113 | with open(file_path, 'r', encoding='UTF-8') as file: 114 | news_body = file.read() 115 | output_data = gpt_instance.extract_title(news_body) 116 | gpt_instance.process_input_output_pair(line_no, output_data, output_dir) 117 | 118 | 119 | def process_gm(self, title_files_dir, fcis_files_dir, gpt_instance, output_dir): 120 | line_no = os.path.basename(title_files_dir).replace('.txt', '') 121 | with open(title_files_dir, 'r', encoding='UTF-8') as file: 122 | title = file.read() 123 | ttv_path = os.path.join(fcis_files_dir, f"{line_no}.txt") 124 | with open(ttv_path, 'r', encoding='UTF-8') as file: 125 | text = file.read() 126 | output_data = gpt_instance.gen_mindmap(title, text) 127 | gpt_instance.process_input_output_pair(line_no, output_data, output_dir) 128 | 129 | def mindmap_str_to_json(self, mindmap_str_file_path, mindmap_json_dir): 130 | mindmap_json_file_path = os.path.join(mindmap_json_dir, os.path.basename(mindmap_str_file_path).replace('.txt', '.json')) 131 | with open(mindmap_str_file_path, 'r', encoding='utf-8') as file: 132 | mindmap_str = file.read() 133 | try: 134 | if '```json' in mindmap_str: 135 | real_content = mindmap_str.replace('```json', '').replace('```', '').strip() 136 | mindmap = json.loads(real_content) 137 | else: 138 | # 否则直接使用原始响应字符串 139 | mindmap = json.loads(mindmap_str) 140 | with open(mindmap_json_file_path, 'w', encoding='utf-8') as f: 141 | json.dump(mindmap, f, ensure_ascii=False, indent=4) 142 | except json.JSONDecodeError as e: 143 | print(f'JSON解析错误在文件{mindmap_str_file_path}', e) 144 | 145 | 146 | def query_deconstruction(self, question): 147 | template = self._read_prompt_template('query_deconstruction.txt') 148 | query = template.format(question=question) 149 | respond = self.safe_request(query) 150 | return respond 151 | 152 | 153 | @staticmethod 154 | def _read_prompt_template(filename: str) -> str: 155 | path = os.path.join('prompts/', filename) 156 | if os.path.exists(path): 157 | with open(path, encoding='utf-8') as f: 158 | return f.read() 159 | else: 160 | logger.error(f'Prompt template not found at {path}') 161 | return '' 162 | 163 | def rougeL_score(self, 164 | continuation: str, 165 | reference: str 166 | ) -> float: 167 | f = lambda text: list(jieba.cut(text)) 168 | # rouge = evaluate.load('/path/to/local/rouge') 169 | rouge = evaluate.load('rouge') 170 | results = rouge.compute(predictions=[continuation], references=[[reference]], tokenizer=f, rouge_types=['rougeL']) 171 | score = results['rougeL'] 172 | return score 173 | 174 | def bert_score(self, 175 | continuation: str, 176 | reference: str 177 | ) -> float: 178 | from text2vec import Similarity 179 | sim = Similarity(model_name_or_path="/path/to/local/text2vec-base-chinese") 180 | score = sim.get_score(continuation, reference) 181 | return score 182 | -------------------------------------------------------------------------------- /pgrag/llms/remote.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | from llms.base import BaseLLM 4 | from configs import real_config as conf 5 | 6 | class GPT_transit(BaseLLM): 7 | def __init__(self, model_name='gpt-3.5-turbo', temperature=1.0, max_new_tokens=4096, report=False): 8 | super().__init__(model_name, temperature, max_new_tokens) 9 | self.report = report 10 | 11 | def request(self, query: str) -> str: 12 | url = conf.GPT_transit_url 13 | payload = json.dumps({ 14 | "model": self.params['model_name'], 15 | "messages": [{"role": "user", "content": query}], 16 | "temperature": self.params['temperature'], 17 | 'max_tokens': self.params['max_new_tokens'], 18 | "top_p": self.params['top_p'], 19 | }) 20 | headers = { 21 | 'Authorization': 'Bearer {}'.format(conf.GPT_transit_token), 22 | 'Content-Type': 'application/json', 23 | } 24 | res = requests.request("POST", url, headers=headers, data=payload,timeout=300) 25 | print('res:', res.text) 26 | res = res.json() 27 | real_res = res["choices"][0]["message"]["content"] 28 | return real_res 29 | -------------------------------------------------------------------------------- /pgrag/mindmap_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from llms.remote import GPT_transit 4 | import concurrent.futures 5 | from tqdm import tqdm 6 | 7 | class MindmapGeneration: 8 | def __init__(self, model_name, num_threads, raw_news_files_dir, title_files_dir, fcis_files_dir, mindmaps_str_files_dir, mindmaps_json_files_dir): 9 | self.gpt = GPT_transit(model_name=model_name, report=True) 10 | self.num_threads = num_threads 11 | self.raw_news_files_dir = raw_news_files_dir 12 | self.title_files_dir = title_files_dir 13 | self.fcis_files_dir = fcis_files_dir 14 | self.mindmaps_str_files_dir = mindmaps_str_files_dir 15 | self.mindmaps_json_files_dir = mindmaps_json_files_dir 16 | 17 | def extract_mt(self): 18 | raw_news_files_to_process = [os.path.join(self.raw_news_files_dir, file) for file in os.listdir(self.raw_news_files_dir) if file.endswith('.txt')] 19 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor: 20 | list(tqdm(executor.map(self.gpt.process_et, raw_news_files_to_process, [self.gpt] * len(raw_news_files_to_process), [self.title_files_dir] * len(raw_news_files_to_process)), total=len(raw_news_files_to_process))) 21 | 22 | def extract_fcis(self): 23 | raw_news_files_to_process = [os.path.join(self.raw_news_files_dir, file) for file in os.listdir(self.raw_news_files_dir) if file.endswith('.txt')] 24 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor: 25 | list(tqdm(executor.map(self.gpt.process_efvi, raw_news_files_to_process, [self.gpt] * len(raw_news_files_to_process), [self.fcis_files_dir] * len(raw_news_files_to_process)), total=len(raw_news_files_to_process))) 26 | 27 | def generate_mindmaps_str(self): 28 | title_files_to_process = [os.path.join(self.title_files_dir, file) for file in os.listdir(self.title_files_dir) if file.endswith('.txt')] 29 | fcis_files_to_process = [os.path.join(self.fcis_files_dir, file) for file in os.listdir(self.fcis_files_dir) if file.endswith('.txt')] 30 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor: 31 | list(tqdm(executor.map(self.gpt.process_gm, title_files_to_process, fcis_files_to_process, [self.gpt] * len(fcis_files_to_process), [self.mindmaps_str_files_dir] * len(fcis_files_to_process)), total=len(fcis_files_to_process))) 32 | 33 | def generate_mindmaps_json(self): 34 | mindmap_str_files_to_process = [os.path.join(self.mindmaps_str_files_dir, file) for file in os.listdir(self.mindmaps_str_files_dir) if file.endswith('.txt')] 35 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_threads) as executor: 36 | list(tqdm(executor.map(self.gpt.mindmap_str_to_json, mindmap_str_files_to_process, [self.mindmaps_json_files_dir] * len(mindmap_str_files_to_process)), total=len(mindmap_str_files_to_process))) 37 | 38 | def execute(self): 39 | self.extract_mt() 40 | self.extract_fcis() 41 | self.generate_mindmaps_str() 42 | self.generate_mindmaps_json() 43 | 44 | -------------------------------------------------------------------------------- /pgrag/prompts/extract_fact_verification_items.txt: -------------------------------------------------------------------------------- 1 | 你是一名优秀的自然语言处理专家,请重新整理以下文本,梳理出所有信息项和对应内容,用于文本内容的幻觉纠正。 2 | 注意,确保细节内容不被丢失 3 | 4 | 文本: 5 | ''' 6 | {text} 7 | ''' -------------------------------------------------------------------------------- /pgrag/prompts/gen_mindmap.txt: -------------------------------------------------------------------------------- 1 | 你是一名杰出的自然语言处理专家,你的任务是将下面的文本按提供的主题进行结构化,并转换成一个层次分明的思维导图。请确保你的输出严格遵循JSON格式。 2 | 3 | 在创建思维导图时,请遵循以下要求: 4 | (1)保证内容的层次性和逻辑性:确保相似或相关的信息被归纳在同一子主题下。 5 | (2)保证所有层级清晰分明,以帮助理解文本内容的结构。 6 | (3)并列内容应直接以列表形式存在于最后一层。 7 | (4)请直接输出JSON,不要输出任何其他内容。 8 | 9 | 请根据以下主题和文本信息,完成你的任务: 10 | 11 | 文本主题: 12 | ''' 13 | {title} 14 | ''' 15 | 16 | 文本内容: 17 | ''' 18 | {text} 19 | ''' 20 | 21 | 请将你的结果以如下格式输出: 22 | ```json 23 | {{ 24 | "文本主题": {{ 25 | "子主题": {{ 26 | "子主题": "内容1", 27 | "子主题": "内容2", 28 | "子主题": "", 29 | "子主题": [ 30 | "内容1", 31 | "内容2" 32 | ] 33 | }}, 34 | "子主题": {{ 35 | // 以此类推 36 | }} 37 | }} 38 | }} 39 | ``` -------------------------------------------------------------------------------- /pgrag/prompts/query_deconstruction.txt: -------------------------------------------------------------------------------- 1 | 问题: {text} 2 | 请将此问题的回答要点:提供一个精简且准确的回答回答这个问题所需的关键信息或具体信息。 3 | 4 | 注意,不要出现幻觉和重复描述。仅输出回答问题的要点(写在之间)即可。 5 | 6 | 示例: 7 | 8 | 干眼症的发生趋势,是否有明显增加或减少。 9 | 10 | 11 | 哪些环境因素可能导致孩子们干眼症症状加重,如电子屏幕使用时间、室内空气质量等。 12 | -------------------------------------------------------------------------------- /pgrag/prompts/topic_extract.txt: -------------------------------------------------------------------------------- 1 | 你是一名优秀的新闻工作者,请为以下文本生成一个准确、简洁、客观标题。仅输出提取出文本标题即可。 2 | 3 | 文本: 4 | ''' 5 | {text} 6 | ''' -------------------------------------------------------------------------------- /pgrag/pseudo_graph_constructor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from py2neo import Graph 4 | import os 5 | from sentence_transformers import SentenceTransformer 6 | from concurrent.futures import ThreadPoolExecutor 7 | from tqdm import tqdm 8 | 9 | class Neo4jDataInserter: 10 | def __init__(self, graph_uri, graph_auth, emb_model_name, max_workers=20): 11 | self.graph = Graph(graph_uri, auth=graph_auth) 12 | self.emb_model = SentenceTransformer(emb_model_name) 13 | self.max_workers = max_workers 14 | print('初始化成功!') 15 | 16 | def recursive_json_iterator(self, json_data, path='', topic_paths=None): 17 | if topic_paths is None: 18 | topic_paths = [] 19 | 20 | if isinstance(json_data, dict): 21 | for key, value in json_data.items(): 22 | current_path = f"{path} '{key}'".lstrip() 23 | self.recursive_json_iterator(value, current_path, topic_paths) 24 | elif isinstance(json_data, list): 25 | topic_path = f"{path} '{' '.join(map(str, json_data))}'".lstrip() 26 | topic_paths.append(topic_path) 27 | else: 28 | topic_path = f"{path} '{json_data}'".lstrip() 29 | topic_paths.append(topic_path) 30 | return topic_paths 31 | 32 | def load_json_files(self, mindmap_json_dir, raw_doc_dir): 33 | file_names = [file for file in os.listdir(mindmap_json_dir) if file.endswith('.json')] 34 | all_json_contents = {} 35 | all_news = [] 36 | print('总文件数:', len(file_names)) 37 | 38 | for file_name in file_names: 39 | mindmap_json_file_path = os.path.join(mindmap_json_dir, file_name) 40 | base_name, _ = os.path.splitext(file_name) 41 | raw_data_path = os.path.join(raw_doc_dir, base_name + ".txt") 42 | with open(raw_data_path, 'r', encoding='utf-8') as file: 43 | news = file.read() 44 | all_news.append(news) 45 | with open(mindmap_json_file_path, 'r', encoding='utf-8') as file: 46 | content = file.read() 47 | try: 48 | parsed_line = json.loads(content) 49 | except json.JSONDecodeError as e: 50 | print(f'JSON解析错误在文件{file_name}', e) 51 | all_json_contents[file_name] = parsed_line 52 | 53 | return all_json_contents, all_news 54 | 55 | def process_and_insert_single_data(self, raw_data, json_data): 56 | title = list(json_data.keys())[0] 57 | emb_t = self.emb_model.encode(title, normalize_embeddings=True) 58 | document_properties = { 59 | "主题": title, 60 | "主题嵌入": emb_t.tolist() 61 | } 62 | label = "Topic" 63 | primary_key = "主题" 64 | topic_paths = self.recursive_json_iterator(json_data.get(title, {})) 65 | 66 | for key, value in document_properties.items(): 67 | if value: 68 | update_query = f"MERGE (d:{label} {{{primary_key}: $primary_key_value}}) SET d.{key} = $value RETURN d" 69 | updated_node = self.graph.run(update_query, 70 | parameters={"primary_key_value": document_properties[primary_key], 71 | "value": value}).evaluate() 72 | 73 | topic_name = updated_node.get('主题') 74 | print(f'名为“{topic_name}”的主题节点插入成功!') 75 | for topic_path in topic_paths: 76 | split_parts = topic_path.strip().strip("'").split("' '") 77 | parts = [part.strip() for part in split_parts if part.strip()] 78 | print('-------------------------------') 79 | print('待插入主题路径:', parts) 80 | 81 | for j, sub_topic_type in enumerate(parts[:-1]): 82 | if j == 0: 83 | create_TST_query = "MATCH (d:Topic {主题: $topic_name}) MERGE (d)-[r:基础链接]->(st:SubTopic {路标: $sub_topic_type}) RETURN d, st" 84 | TST_result = self.graph.run(create_TST_query, 85 | parameters={"topic_name": topic_name, 86 | "sub_topic_type": sub_topic_type}).data() 87 | print('主题到子主题插入成功!结果显示:', TST_result) 88 | else: 89 | match_query_parts = [f"-[r{k}:基础链接]->(pst{k}:SubTopic {{路标: $part{k}}})" for k, part in 90 | enumerate(parts[:j], 1)] 91 | match_query = "MATCH (d:Topic {主题: $topic_name}) " + ''.join(match_query_parts) 92 | merge_query = f" WITH pst{j} MERGE (pst{j})-[r:基础链接]->(st:SubTopic {{路标: $sub_topic_type}}) RETURN st" 93 | create_STST_query = match_query + merge_query 94 | 95 | params = {"topic_name": topic_name, "sub_topic_type": sub_topic_type} 96 | for k, part in enumerate(parts[:j], 1): 97 | params[f"part{k}"] = part 98 | 99 | STST_result = self.graph.run(create_STST_query, parameters=params).data() 100 | print('插入的主题路径:', STST_result) 101 | if j == len(parts) - 2: 102 | emb_fp = self.emb_model.encode(parts[0] + ' '.join(parts[1:]), normalize_embeddings=True) 103 | fact = parts[j + 1] 104 | match_query_parts = [f"-[r{k}:基础链接]->(pst{k}:SubTopic {{路标: $part{k}}})" for k, part in 105 | enumerate(parts[:j + 1], 1)] 106 | match_query = f"MATCH (d:Topic {{主题: $topic_name}}) " + ''.join(match_query_parts) 107 | merge_query = f" MERGE (c:Content {{事实: $fact, 路径嵌入: $fp}}) WITH pst{j + 1}, c MERGE (pst{j + 1})-[r:基础链接]->(c) RETURN c" 108 | 109 | create_STC_query = match_query + merge_query 110 | 111 | params = {"topic_name": topic_name, "fact": fact, "fp": emb_fp.tolist()} 112 | 113 | for k, part in enumerate(parts[:j + 1], 1): 114 | params[f"part{k}"] = part 115 | 116 | STC_result = self.graph.run(create_STC_query, parameters=params).data() 117 | print("**完整的路径插入成功!结果显示:", STC_result) 118 | 119 | def chunked_data(self, data, size): 120 | """将列表分成指定大小的块。""" 121 | for i in range(0, len(data), size): 122 | yield data[i:i + size] 123 | 124 | def process_and_insert_data(self, raw_doc_dir, mindmap_json_dir, start_batch=0, batch_size=20): 125 | result, all_news = self.load_json_files(mindmap_json_dir, raw_doc_dir) 126 | all_batches = zip(self.chunked_data(all_news, batch_size), self.chunked_data(list(result.values()), batch_size)) 127 | for i, (batch_news, batch_json_data) in enumerate(all_batches): 128 | if i < start_batch: # 跳过已处理的批次 129 | continue 130 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 131 | list(executor.map(self.process_and_insert_single_data, batch_news, batch_json_data)) 132 | print(f"已完成第{i + 1}批处理,大小:{batch_size}") 133 | 134 | def update_single_subtopic_embedding(self, subtopic_path): 135 | """ 136 | 更新单个子主题节点的路由嵌入。 137 | """ 138 | subtopic = subtopic_path['st'] 139 | path_names = subtopic_path['path_names'] 140 | 141 | path_str = ' '.join(path_names) 142 | emb_p = self.emb_model.encode(path_str, normalize_embeddings=True) 143 | print(path_str) 144 | query_update_embedding = """ 145 | MATCH path = (t:Topic)-[:基础链接*]->(st:SubTopic {路标: $subtopic_label}) 146 | WHERE [node IN nodes(path) | CASE WHEN node:Topic THEN node.主题 WHEN node:SubTopic THEN node.路标 END] = $path_names 147 | SET st.路由嵌入 = $embedding 148 | RETURN count(st) as updated 149 | """ 150 | result = self.graph.run(query_update_embedding, subtopic_label=subtopic['路标'], path_names=path_names, 151 | embedding=emb_p.tolist()).evaluate() 152 | 153 | if result == 0: 154 | print(f"Failed to update routing embedding for path: {path_str}") 155 | 156 | def update_subtopic_embeddings(self): 157 | query_subtopics_paths = """ 158 | MATCH path = (t:Topic)-[:基础链接*]->(st:SubTopic) 159 | RETURN st, [node IN nodes(path) | CASE WHEN node:Topic THEN node.主题 WHEN node:SubTopic THEN node.路标 END] AS path_names 160 | """ 161 | subtopics_paths = self.graph.run(query_subtopics_paths).data() 162 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 163 | executor.map(self.update_single_subtopic_embedding, subtopics_paths) 164 | 165 | def execute(self, raw_doc_dir, mindmap_json_dir, start_batch=0, batch_size=20): 166 | self.process_and_insert_data(raw_doc_dir, mindmap_json_dir, start_batch, batch_size) 167 | self.update_subtopic_embeddings() 168 | 169 | class TopicAndContentFusion: 170 | def __init__(self, graph_uri, graph_auth, emb_model_name, topic_threshold=0.92, content_threshold=0.98): 171 | self.graph = Graph(graph_uri, auth=graph_auth) 172 | self.emb_model = SentenceTransformer(emb_model_name) 173 | self.topic_threshold = topic_threshold 174 | self.content_threshold = content_threshold 175 | self.topic_clusters = [] 176 | self.content_clusters = [] 177 | self.processed_topic_nodes = set() 178 | self.processed_content_nodes = set() 179 | 180 | def get_topic_node_ids(self): 181 | query = """ 182 | MATCH (n:Topic) 183 | RETURN ID(n) AS topicNodeID 184 | """ 185 | return self.graph.run(query).data() 186 | 187 | def get_content_node_ids(self): 188 | query = """ 189 | MATCH (n:Content) 190 | RETURN ID(n) AS contentNodeID 191 | """ 192 | return self.graph.run(query).data() 193 | 194 | def cluster_nodes(self, node_type, threshold): 195 | if node_type == 'Topic': 196 | node_id_list = self.get_topic_node_ids() 197 | processed_nodes = self.processed_topic_nodes 198 | clusters = self.topic_clusters 199 | embedding_field = 'n.主题嵌入' 200 | embedding_index = 'topic-embeddings' 201 | super_node_label = 'SuperTopic' 202 | elif node_type == 'Content': 203 | node_id_list = self.get_content_node_ids() 204 | processed_nodes = self.processed_content_nodes 205 | clusters = self.content_clusters 206 | embedding_field = 'n.路径嵌入' 207 | embedding_index = 'fact-embeddings' 208 | super_node_label = 'SuperContent' 209 | else: 210 | raise ValueError("Unsupported node type. Use 'Topic' or 'Content'.") 211 | 212 | print(f"{node_type} node count: {len(node_id_list)}") 213 | print(node_id_list) 214 | 215 | for node_dict in tqdm(node_id_list, desc=f"Clustering {node_type}s"): 216 | node_id = node_dict[f'{node_type.lower()}NodeID'] 217 | if node_id in processed_nodes: 218 | continue 219 | 220 | query = f""" 221 | MATCH (n:{node_type}) WHERE ID(n) = {node_id} 222 | WITH {embedding_field} AS embedding 223 | CALL db.index.vector.queryNodes('{embedding_index}', {min(len(node_id_list), 100)}, embedding) YIELD node, score 224 | WHERE score > {threshold} 225 | RETURN ID(node) AS similarNodeId 226 | """ 227 | similar_nodes = self.graph.run(query).data() 228 | similar_node_ids = {item['similarNodeId'] for item in similar_nodes} 229 | 230 | new_cluster = list(similar_node_ids) 231 | clusters.append(new_cluster) 232 | processed_nodes.update(similar_node_ids) 233 | processed_nodes.add(node_id) 234 | 235 | print(f"Total clusters for {node_type}: {len(clusters)}") 236 | for i, cluster in enumerate(clusters): 237 | print(f"{node_type} Cluster {i+1}: {cluster}") 238 | 239 | super_node_id = f"{super_node_label}_{i}" 240 | self.graph.run(f"MERGE (:{super_node_label} {{id: $id}})", id=super_node_id) 241 | 242 | matches = " ".join(f"MATCH (n{idx}) WHERE ID(n{idx}) = {node_id}" for idx, node_id in enumerate(cluster)) 243 | creates = " ".join(f"CREATE (n{idx})-[:相似链接]->(st)" for idx in range(len(cluster))) 244 | 245 | query = f""" 246 | {matches} 247 | MATCH (st:{super_node_label} {{id: '{super_node_id}'}}) 248 | {creates} 249 | RETURN NULL 250 | """ 251 | self.graph.run(query) 252 | 253 | def fuse_topics_and_contents(self): 254 | self.cluster_nodes('Topic', self.topic_threshold) 255 | self.cluster_nodes('Content', self.content_threshold) 256 | print("Fusion of topics and contents completed.") 257 | 258 | -------------------------------------------------------------------------------- /pgrag/seed_context_recall.py: -------------------------------------------------------------------------------- 1 | import json 2 | from concurrent.futures import ThreadPoolExecutor 3 | from tqdm import tqdm 4 | from py2neo import Graph 5 | from sentence_transformers import SentenceTransformer 6 | import numpy as np 7 | import torch 8 | def tensor(lst): 9 | return torch.tensor(lst) 10 | 11 | class SeedContextRecall: 12 | def __init__(self, graph_uri, graph_auth, emb_model_name, 13 | eval_data_with_qe_and_qdse_file, seed_topic_file, candidate_topic_file, 14 | recall_top_m, walk_top_m, num_threads, top_k): 15 | self.graph = Graph(graph_uri, auth=graph_auth) 16 | self.emb_model = SentenceTransformer(emb_model_name) 17 | self.eval_data_with_qe_and_qdse_file = eval_data_with_qe_and_qdse_file 18 | self.seed_topic_file = seed_topic_file 19 | self.candidate_topic_file = candidate_topic_file 20 | self.recall_top_m = recall_top_m 21 | self.walk_top_m = walk_top_m 22 | self.num_threads = num_threads 23 | self.top_k = top_k 24 | print('初始化成功!') 25 | 26 | def seed_topic_recall_base_tn(self, ci_embedding): 27 | query_seed_topic = """ 28 | CALL db.index.vector.queryNodes('topic-embeddings', $M, $emb) 29 | YIELD node AS similarTopic, score 30 | MATCH (similarTopic) 31 | RETURN ID(similarTopic) AS topic_id, score 32 | """ 33 | similar_topics_with_score = self.graph.run(query_seed_topic, M=self.recall_top_m, emb=ci_embedding).data() 34 | seed_topics_with_score = {item['topic_id']: item['score'] for item in similar_topics_with_score} 35 | return seed_topics_with_score 36 | 37 | def topic_walking(self, seed_topic_ids, qe): 38 | candidate_topics_ids = [] 39 | candidate_topics_names = [] 40 | qe_emb = np.array(qe) 41 | 42 | query_seed_topics = """ 43 | MATCH (t:Topic) WHERE id(t) IN $seed_topic_ids 44 | RETURN collect(id(t)) AS seed_ids, collect(t.主题) AS seed_names 45 | """ 46 | seed_topics = self.graph.run(query_seed_topics, seed_topic_ids=seed_topic_ids).data() 47 | if seed_topics: 48 | candidate_topics_ids.extend(seed_topics[0]['seed_ids']) 49 | candidate_topics_names.extend(seed_topics[0]['seed_names']) 50 | 51 | all_topic_ids = [] 52 | all_topic_embs = [] 53 | all_topic_names = [] 54 | for seed_topic_id in seed_topic_ids: 55 | query_super_topic = """ 56 | MATCH (t:Topic)-[:相似链接]->(st:SuperTopic)<-[:相似链接]-(other:Topic) 57 | WHERE id(t) = $seed_topic_id 58 | RETURN collect(id(other)) AS other_topic_ids, 59 | collect(other.主题嵌入) AS other_topic_embs, 60 | collect(other.主题) AS other_topic_names 61 | """ 62 | connected_topics = self.graph.run(query_super_topic, seed_topic_id=seed_topic_id).data() 63 | if connected_topics: 64 | other_topic_ids = connected_topics[0]['other_topic_ids'] 65 | other_topic_embs = [np.array(emb) for emb in connected_topics[0]['other_topic_embs']] 66 | other_topic_names = connected_topics[0]['other_topic_names'] 67 | all_topic_ids.extend(other_topic_ids) 68 | all_topic_embs.extend(other_topic_embs) 69 | all_topic_names.extend(other_topic_names) 70 | 71 | if all_topic_embs: 72 | all_topic_embs_matrix = np.vstack(all_topic_embs) 73 | sims = qe_emb @ all_topic_embs_matrix.T 74 | top_indices = np.argsort(-sims) 75 | 76 | added_ids = set(candidate_topics_ids) 77 | for index in top_indices: 78 | if len(added_ids) >= self.walk_top_m: 79 | break 80 | topic_id = all_topic_ids[index] 81 | if topic_id not in added_ids: 82 | added_ids.add(topic_id) 83 | candidate_topics_ids.append(topic_id) 84 | candidate_topics_names.append(all_topic_names[index]) 85 | 86 | return candidate_topics_ids, candidate_topics_names 87 | 88 | # def find_topN_paths_per_candidate_topic(self, candidate_topics, kpr_embedding): 89 | # query_content_and_path = ''' 90 | # CALL db.index.vector.queryNodes('fact-embeddings', 10000, $embedding) 91 | # YIELD node AS similarContent, score 92 | # UNWIND $topics AS topic 93 | # MATCH path = (t:Topic)-[*]->(similarContent) 94 | # WHERE t.主题 = topic 95 | # WITH path, score, t 96 | # ORDER BY score DESC 97 | # LIMIT $topN 98 | # UNWIND nodes(path) AS p 99 | # WITH p, 100 | # score, 101 | # CASE WHEN p:Topic THEN p.主题 102 | # WHEN p:SubTopic THEN p.路标 103 | # WHEN p:Content THEN p.事实 104 | # ELSE null END AS pathAttribute 105 | # WHERE pathAttribute IS NOT NULL 106 | # RETURN collect(pathAttribute) AS pathAttributes, score 107 | # ''' 108 | # results = self.graph.run(query_content_and_path, topics=candidate_topics, topN=self.top_k, embedding=kpr_embedding).data() 109 | # return results 110 | 111 | def seed_topic_recall(self, question_info): 112 | question, question_embedding, qdse = question_info 113 | seed_topic_paths = {} 114 | print('查询问题:', question) 115 | topM_paths = self.seed_topic_recall_base_tn(question_embedding) 116 | seed_topic_paths = { 117 | 'question': question, 118 | 'qe': question_embedding, 119 | 'topM_topic_paths': str(topM_paths), 120 | 'qdse': qdse 121 | } 122 | with open(self.seed_topic_file, 'a', encoding='utf-8') as file: 123 | file.write(str(seed_topic_paths) + '\n') 124 | 125 | def execute(self): 126 | with open(self.eval_data_with_qe_and_qdse_file, 'r', encoding='utf-8') as file: 127 | eval_data = json.load(file) 128 | 129 | questions = [] 130 | qes = [] 131 | qdses = [] 132 | for entry in eval_data: 133 | question_text = entry["question"] 134 | question_embedding = entry[question_text]["qe_bge-base-zh"] 135 | qdse = entry[question_text]['qdse_bge-base-zh'] 136 | questions.append(question_text) 137 | qes.append(question_embedding) 138 | qdses.append(qdse) 139 | 140 | question_infos = [(question, qe, qdse) for question, qe, qdse in zip(questions, qes, qdses)] 141 | print(len(question_infos)) 142 | 143 | with ThreadPoolExecutor(max_workers=self.num_threads) as executor: 144 | list(tqdm(executor.map(self.seed_topic_recall, question_infos), total=len(question_infos))) 145 | 146 | with open(self.seed_topic_file, 'r', encoding='utf-8') as file: 147 | topM_seed_topic_lines = file.readlines() 148 | for topM_seed_topic_line in topM_seed_topic_lines: 149 | topM_seed_topics_info = eval(topM_seed_topic_line) 150 | question = topM_seed_topics_info['question'] 151 | qe = topM_seed_topics_info['qe'] 152 | topM_seed_topic_with_sim = eval(topM_seed_topics_info['topM_topic_paths']) 153 | qdse = topM_seed_topics_info['qdse'] 154 | print(question) 155 | seed_topic_ids = list(topM_seed_topic_with_sim.keys()) 156 | candidate_topic_ids, candidate_topic_names = self.topic_walking(seed_topic_ids, qe) 157 | print("候选主题:", candidate_topic_ids, candidate_topic_names) 158 | 159 | candidate_topics = { 160 | 'question': question, 161 | 'candidate_topic_ids': candidate_topic_ids, 162 | 'candidate_topic_names': candidate_topic_names, 163 | 'qdse': qdse, 164 | 'qe': qe 165 | } 166 | with open(self.candidate_topic_file, 'a', encoding='utf-8') as file: 167 | file.write(str(candidate_topics) + '\n') 168 | 169 | 170 | -------------------------------------------------------------------------------- /pgrag/sub_pseudo_graph_retriever.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from py2neo import Graph 4 | 5 | class PG_RAG_Processor: 6 | def __init__(self, graph_uri, graph_auth, candidate_topic_file, 7 | matrix_templates_file, matrix_templates_with_sim_file, topK=14): 8 | self.graph_uri = graph_uri 9 | self.graph_auth = graph_auth 10 | self.candidate_topic_file = candidate_topic_file 11 | self.matrix_templates_file = matrix_templates_file 12 | self.matrix_templates_with_sim_file = matrix_templates_with_sim_file 13 | self.topK = topK 14 | self.graph = Graph(graph_uri, auth=graph_auth) 15 | 16 | def fetch_paths_and_embeddings(self, candidate_topics_ids): 17 | query = """ 18 | MATCH path=(t:Topic)-[:基础链接*]->(f: Content) 19 | WHERE id(t) IN $topic_ids 20 | RETURN [node in nodes(path) | id(node)] AS node_ids, 21 | [node in nodes(path) | coalesce(node.主题嵌入, node.路由嵌入, node.路径嵌入)] AS node_embs 22 | """ 23 | paths_results = self.graph.run(query, topic_ids=candidate_topics_ids).data() 24 | path_ids = [] 25 | path_embs = [] 26 | for result in paths_results: 27 | path_ids.append(result['node_ids']) 28 | path_embs.append(result['node_embs']) 29 | return path_ids, path_embs 30 | 31 | 32 | def create_matrix_templates(self): 33 | with open(self.candidate_topic_file, 'r', encoding='utf-8') as file: 34 | candidate_topic_lines = file.readlines() 35 | 36 | for candidate_topic_line in candidate_topic_lines: 37 | candidate_topic_info = eval(candidate_topic_line) 38 | question = candidate_topic_info['question'] 39 | candidate_topics_ids = candidate_topic_info['candidate_topic_ids'] 40 | qdse = candidate_topic_info['qdse'] 41 | path_ids, path_embs = self.fetch_paths_and_embeddings(candidate_topics_ids) 42 | 43 | max_len_ids = max(len(ids) for ids in path_ids) 44 | id_matrix = np.full((len(path_ids), max_len_ids), -1) 45 | emb_matrix = np.zeros((len(path_embs), max_len_ids, len(path_embs[0][0]))) 46 | 47 | for i, ids in enumerate(path_ids): 48 | id_matrix[i, :len(ids)] = ids 49 | for i, embs in enumerate(path_embs): 50 | for j, emb in enumerate(embs): 51 | emb_matrix[i, j, :] = emb 52 | 53 | matrix_templates = { 54 | 'question': question, 55 | 'ID Matrix': id_matrix.tolist(), 56 | 'EMB Matrix': emb_matrix.tolist(), 57 | 'qdse': qdse 58 | } 59 | 60 | with open(self.matrix_templates_file, 'a', encoding='utf-8') as file: 61 | file.write(str(matrix_templates) + '\n') 62 | 63 | def compute_similarity_matrices(self): 64 | with open(self.matrix_templates_file, 'r', encoding='utf-8') as file: 65 | matrix_templates_lines = file.readlines() 66 | 67 | for matrix_templates_line in tqdm(matrix_templates_lines, desc="Processing SM"): 68 | matrix_templates_info = eval(matrix_templates_line) 69 | question = matrix_templates_info['question'] 70 | matrix_id_list = matrix_templates_info['ID Matrix'] 71 | matrix_emb_list = matrix_templates_info['EMB Matrix'] 72 | kps_emb_list = matrix_templates_info['qdse'] 73 | 74 | matrix_emb = np.array(matrix_emb_list) 75 | kps_emb = np.array(kps_emb_list) 76 | 77 | num_kps = kps_emb.shape[0] 78 | if num_kps == 1024: 79 | num_kps = 1 80 | num_matrices, num_vectors_per_matrix, emb_dim = matrix_emb.shape 81 | flattened_matrix_emb = matrix_emb.reshape(-1, emb_dim) 82 | sims = kps_emb @ flattened_matrix_emb.T 83 | reshaped_sims = sims.reshape(num_kps, num_matrices, num_vectors_per_matrix) 84 | 85 | matrix_templates_with_sim = { 86 | 'question': question, 87 | 'SIM Matrix': reshaped_sims.tolist(), 88 | 'ID Matrix': matrix_id_list, 89 | 'qdse': kps_emb_list 90 | } 91 | 92 | with open(self.matrix_templates_with_sim_file, 'a', encoding='utf-8') as file: 93 | file.write(str(matrix_templates_with_sim) + '\n') 94 | 95 | def process_top_k_ids(self, contexts_ids_file, final_contexts_file): 96 | with open(self.matrix_templates_with_sim_file, 'r', encoding='utf-8') as file: 97 | matrix_templates_with_sim_lines = file.readlines() 98 | 99 | processor = MatrixProcessor() 100 | for matrix_templates_with_sim_line in tqdm(matrix_templates_with_sim_lines, desc="Processing max id"): 101 | matrix_templates_with_sim_info = eval(matrix_templates_with_sim_line) 102 | question = matrix_templates_with_sim_info['question'] 103 | matrix_sim = np.array(matrix_templates_with_sim_info['SIM Matrix']) 104 | matrix_id = np.array(matrix_templates_with_sim_info['ID Matrix']) 105 | 106 | final_matrix = np.zeros_like(matrix_id, dtype=np.float64) 107 | top_values = [] 108 | top_indices = [] 109 | 110 | for matrix in matrix_sim: 111 | last_non_zeros = [] 112 | 113 | for row_index, row in enumerate(matrix): 114 | for col_index in range(len(row) - 1, -1, -1): 115 | if row[col_index] != 0: 116 | last_non_zeros.append((row[col_index], row_index, col_index)) 117 | break 118 | 119 | last_non_zeros_sorted = sorted(last_non_zeros, key=lambda x: x[0], reverse=True)[:self.topK] 120 | top_values.append([val[0] for val in last_non_zeros_sorted]) 121 | top_indices.append([(val[1], val[2]) for val in last_non_zeros_sorted]) 122 | 123 | control_matrices, pathway_matrices = processor.create_control_and_pathway_matrices(matrix, matrix_id, top_values, top_indices) 124 | temp_result_matrix = processor.color_matrices(control_matrices, pathway_matrices) 125 | final_matrix += temp_result_matrix 126 | top_k_ids = processor.find_top_k_ids(final_matrix, matrix_id, self.topK) 127 | top_k_contexts = self.convert_paths_to_json(top_k_ids) 128 | contexts_ids = { 129 | 'question': question, 130 | 'top_k_ids': top_k_ids 131 | } 132 | with open(contexts_ids_file, 'a', encoding='utf-8') as file: 133 | file.write(str(contexts_ids) + '\n') 134 | 135 | final_contexts = { 136 | 'question': question, 137 | 'top_k_contexts': top_k_contexts 138 | } 139 | with open(final_contexts_file, 'a', encoding='utf-8') as file: 140 | file.write(str(final_contexts) + '\n') 141 | 142 | def convert_paths_to_json(self, top_k_ids): 143 | query = ''' 144 | MATCH (f) 145 | WHERE id(f) IN $id_list 146 | OPTIONAL MATCH path=(t:Topic)-[:基础链接*]->(f: Content) 147 | WITH COLLECT(nodes(path)) AS all_nodes 148 | RETURN DISTINCT all_nodes 149 | ''' 150 | result = self.graph.run(query, id_list=top_k_ids).data() 151 | all_paths = list(result[0]['all_nodes']) 152 | # print('all_paths:', all_paths) 153 | final_json = {} 154 | for path_nodes in all_paths: 155 | current_level = final_json # 初始化当前层级指向final_json 156 | for node in path_nodes: 157 | # 假设有方法从节点获取名称和类型 158 | if 'Topic' in node.labels: 159 | node_key = node['主题'].strip('\'').strip('。') 160 | node_embedding = node['主题嵌入'] # 获取节点嵌入 161 | elif 'SubTopic' in node.labels: 162 | node_key = node['路标'].strip('\'') 163 | node_embedding = node['路由嵌入'] # 获取节点嵌入 164 | elif 'Content' in node.labels: 165 | node_key = node['事实'].strip('\'') 166 | node_embedding = node['路径嵌入'] # 获取节点嵌入 167 | # 检查当前层级是否已存在节点键 168 | if node_key not in current_level: 169 | # 如果节点是事实节点,直接添加 170 | if 'Content' in node.labels: 171 | current_level[node_key] = {} 172 | else: 173 | # 对于主题和子主题节点,创建新的字典来保存子节点 174 | current_level[node_key] = {} if 'Topic' in node.labels else {} 175 | # 更新当前层级的引用,指向新添加的节点 176 | current_level = current_level[node_key] 177 | else: 178 | # 如果当前层级已存在节点键,更新当前层级的引用,除非是事实节点 179 | if 'Content' not in node.labels: 180 | current_level = current_level[node_key] 181 | return final_json 182 | 183 | class MatrixProcessor: 184 | def create_control_and_pathway_matrices(self, matrix_sim, matrix_id, top_values, top_indices, support_threshold=0.01, continue_threshold=0.02): 185 | control_matrices = [] 186 | pathway_matrices = [] 187 | 188 | for value, (row_idx, col_idx) in zip(top_values[-1], top_indices[-1]): 189 | control_matrix = np.zeros_like(matrix_sim) 190 | pathway_matrix = np.zeros_like(matrix_id) 191 | 192 | f = col_idx 193 | while f >= 0: 194 | diff = abs(matrix_sim[row_idx, f] - value) 195 | if diff <= support_threshold: 196 | control_matrix[row_idx, f] = 1 * matrix_sim[row_idx, f] 197 | elif support_threshold < diff <= continue_threshold: 198 | control_matrix[row_idx, f] = 0.5 * matrix_sim[row_idx, f] 199 | else: 200 | control_matrix[row_idx, f] = 0 201 | break 202 | f -= 1 203 | 204 | col_f = f + 1 205 | sub_root = matrix_id[row_idx, col_f] 206 | 207 | for i in range(matrix_sim.shape[0]): 208 | if i == row_idx: 209 | continue 210 | for j in range(col_f, matrix_sim.shape[1]): 211 | diff = abs(matrix_sim[i, j] - value) 212 | if diff > continue_threshold: 213 | break 214 | elif diff <= support_threshold: 215 | control_matrix[i, j] = 1 * matrix_sim[i, j] 216 | elif support_threshold < diff <= continue_threshold: 217 | control_matrix[i, j] = 0.5 * matrix_sim[i, j] 218 | 219 | control_matrices.append(control_matrix) 220 | 221 | for j in range(col_f, matrix_id.shape[1]): 222 | if matrix_id[row_idx, j] == -1: 223 | break 224 | pathway_matrix[row_idx, j] = 1 225 | 226 | for direction in [-1, 1]: 227 | i = row_idx 228 | while 0 <= i + direction < matrix_id.shape[0]: 229 | i += direction 230 | if matrix_id[i, col_f] != sub_root: 231 | break 232 | for j in range(col_f, matrix_id.shape[1]): 233 | if matrix_id[i, j] == -1: 234 | break 235 | pathway_matrix[i, j] = 1 236 | 237 | pathway_matrices.append(pathway_matrix) 238 | 239 | return control_matrices, pathway_matrices 240 | 241 | def color_matrices(self, control_matrices, pathway_matrices): 242 | result_matrix = np.zeros_like(control_matrices[0]) 243 | 244 | for control_matrix, pathway_matrix in zip(control_matrices, pathway_matrices): 245 | result_matrix += np.multiply(control_matrix, pathway_matrix) 246 | 247 | return result_matrix 248 | 249 | def find_top_k_ids(self, final_matrix, matrix_id, topK): 250 | row_sums = np.sum(final_matrix, axis=1) 251 | top_k_row_indices = np.argsort(row_sums)[-topK:][::-1] 252 | 253 | top_k_ids = [] 254 | for row_idx in top_k_row_indices: 255 | last_id = -1 256 | for col_idx in range(matrix_id.shape[1]): 257 | id_val = matrix_id[row_idx, col_idx] 258 | if id_val == -1: 259 | break 260 | last_id = id_val 261 | if last_id != -1: 262 | top_k_ids.append(int(last_id)) 263 | 264 | return top_k_ids 265 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | evaluate==0.4.1 2 | FlagEmbedding==1.2.5 3 | jieba==0.42.1 4 | loguru==0.7.2 5 | neo4j==5.17.0 6 | numpy==1.24.0 7 | openai==1.25.2 8 | py2neo==2021.2.4 9 | requests==2.31.0 10 | rouge_score==0.1.2 11 | scikit-learn==1.4.2 12 | sentence-transformers==2.4.0 13 | text2vec==1.2.9 14 | tiktoken==0.6.0 15 | tokenizers==0.15.2 16 | torch==2.2.1 17 | tqdm==4.66.2 18 | transformers==4.38.1 19 | umap==0.1.1 20 | --------------------------------------------------------------------------------