├── .gitignore ├── CATS.pdf ├── CATS.png ├── README.md ├── build_close_path.py ├── build_explain_instructions.py ├── build_instructions.py ├── build_path_together.py ├── data_manager.py ├── explain.py ├── path_count_bins_histogram.pdf ├── prediction.py ├── prompt_templates.py └── statistics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | data/ 165 | datasets/ 166 | instructions/ 167 | instructions* 168 | logs/ 169 | logs* -------------------------------------------------------------------------------- /CATS.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/CATS/cca49e33ed9f918ec9ae19871edf3786a51e47fe/CATS.pdf -------------------------------------------------------------------------------- /CATS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/CATS/cca49e33ed9f918ec9ae19871edf3786a51e47fe/CATS.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI2025] CATS: Context-aware Inductive Knowledge Graph Completion with Latent Type Constraints and Subgraph Reasoning 2 | 3 | This repository provides the official implementation of the paper *"Context-aware Inductive Knowledge Graph Completion with Latent Type Constraints and Subgraph Reasoning"*. 4 | 5 | ![CATS](CATS.png) 6 | 7 | ## Experiment Environment Setup 8 | Create a python environment and install the required packages. We suggest you use Python 3.10 with PyTorch 2.2.2. 9 | For detailed Python package versions, you may refer to the suggested settings listed in `requirements.txt` from [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/requirements.txt). (VLLM is not required) 10 | 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | Additionally, install `sentence_transformers`: 16 | 17 | ```bash 18 | pip install sentence_transformers 19 | ``` 20 | 21 | ## Dataset 22 | 23 | 1. Download the full dataset and LLM instructions from the following link: 24 | 25 | - [Dataset & Instructions](https://drive.google.com/drive/folders/17C3BsllCWy_TK3B5WwCjxPQo2heuLJPz?usp=drive_link) 26 | 27 | 2. Copy the two subfolders "datasets" and "instructions" into the project directory. 28 | 29 | Alternatively, you can construct the LLM instruction prompts by executing `python build_instruction.py`. 30 | 31 | ## LLM Setup 32 | 33 | You may download LLM checkpoints from the following links: 34 | Our experimental results can be reproduced with the Qwen2-7B-Instruct LLM. 35 | 36 | - [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) 37 | - [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) 38 | 39 | Please update the default value of `LLM_PATH` in script `data_manager.py` with your local model path. 40 | 41 | ## Intruction-tuning 42 | 43 | Please refer to the official document of [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/) to conduct LLM supervised fine-tuning with the provided prompts. You may need to specify the prompt path and other training settings in a configuration file. Detailed hyper-parameters are provided in our paper. 44 | 45 | ## Inference 46 | 47 | The following command evaluates the model performance. You may alter the parameters below to test the model in different (transductive, inductive, and few-shot) scenarios. 48 | 49 | ```bash 50 | python3 prediction.py --dataset FB15k-237-subset --setting inductive --training_size full --model_name {model_path_after_sft} --prompt_type CATS --subgraph_type together --path_type degree 51 | ``` 52 | 53 | ## Citation 54 | 55 | If you find this code useful, please consider citing the following paper. 56 | ``` 57 | @article{Li_Yang_Xu_Song_Jiang_Guo_Leung_King_2025, 58 | title={Context-aware Inductive Knowledge Graph Completion with Latent Type Constraints and Subgraph Reasoning}, 59 | volume={39}, 60 | url={https://ojs.aaai.org/index.php/AAAI/article/view/33318}, 61 | DOI={10.1609/aaai.v39i11.33318}, 62 | number={11}, 63 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 64 | author={Li, Muzhi and Yang, Cehao and Xu, Chengjin and Song, Zixing and Jiang, Xuhui and Guo, Jian and Leung, Ho-fung and King, Irwin}, 65 | year={2025}, 66 | month={Apr.}, 67 | pages={12102-12111} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /build_close_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict, deque 4 | from data_manager import DataManager 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | def process_dataset(dataset, setting, train_size, max_path_hops): 9 | data_manager = DataManager(dataset=dataset, setting=setting, train_size=train_size) 10 | data_manager.max_path_hops = max_path_hops 11 | 12 | paths_dir = f"{data_manager.dataset_path}/paths_{max_path_hops}hop" 13 | os.makedirs(paths_dir, exist_ok=True) 14 | 15 | close_path_dict = {} 16 | close_paths_text = [] 17 | 18 | for triple in tqdm(data_manager.test_set, desc=f"Processing {dataset} - setting: {setting} - Train_size: {train_size}"): 19 | head, relation, tail = triple 20 | paths = list(data_manager.bfs_paths(head, tail)) 21 | close_path_dict[f"{head}-{tail}"] = paths 22 | 23 | for path_pair in paths: 24 | path_texts = [] 25 | for path in path_pair: 26 | path_text = [data_manager.entity2text[path[0]], path[1], data_manager.entity2text[path[2]]] 27 | path_texts.append(' -> '.join(path_text)) 28 | close_paths_text.append(f"{data_manager.entity2text[head]}, {data_manager.entity2text[tail]}: {'; '.join(path_texts)}") 29 | 30 | if setting == "inductive": 31 | close_path_dict_path = f"{paths_dir}/close_path.json" 32 | close_path_text_path = f"{paths_dir}/close_path_text.txt" 33 | else: 34 | close_path_dict_path = f"{paths_dir}/close_path_train_size_{train_size}.json" 35 | close_path_text_path = f"{paths_dir}/close_path_text_train_size_{train_size}.txt" 36 | 37 | with open(close_path_dict_path, "w", encoding="utf-8") as f: 38 | json.dump(close_path_dict, f, ensure_ascii=False, indent=4) 39 | 40 | with open(close_path_text_path, "w", encoding="utf-8") as file: 41 | for line in close_paths_text: 42 | file.write(line + "\n") 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--dataset", type=str, choices=["FB15k-237-subset", "NELL-995-subset", "WN18RR-subset"], default="FB15k-237-subset") 47 | parser.add_argument("--setting", type=str, choices=["inductive", "transductive"], default="inductive", help="Inductive or Transductive setting") 48 | parser.add_argument("--train_size", type=str, choices=["full", "1000", "2000"], default="full", help="Size of the training data") 49 | parser.add_argument("--max_path_hops", type=int, default=3, help="Maximum number of hops in the path") 50 | 51 | args = parser.parse_args() 52 | process_dataset(args.dataset, args.setting, args.train_size, args.max_path_hops) 53 | 54 | if __name__ == "__main__": 55 | main() -------------------------------------------------------------------------------- /build_explain_instructions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from data_manager import DataManager 5 | from tqdm import tqdm 6 | import argparse 7 | from prompt_templates import EXPLAINING_PROMPT 8 | 9 | def close_path_finder(data_manager:DataManager, triple): 10 | head, relation, tail = triple 11 | close_paths = list(data_manager.bfs_paths(head, tail)) 12 | 13 | if close_paths: 14 | path_degrees = [] 15 | for path in close_paths: 16 | degree_sum = sum(data_manager.relation_degree_dict[rel] for _, rel, _ in path) 17 | path_degrees.append((degree_sum, path)) 18 | path_degrees.sort(key=lambda x: x[0]) 19 | 20 | top_paths = [path for _, path in path_degrees[:data_manager.max_reason_paths]] 21 | top_paths.reverse() 22 | return top_paths 23 | 24 | return [] 25 | 26 | def build_instructions(dataset, train_size, neg_num): 27 | setting = "transductive" 28 | 29 | data_manager = DataManager(dataset=dataset, setting=setting, train_size=train_size) 30 | 31 | paths_dir = f"instructions_explain/{dataset}" 32 | os.makedirs(paths_dir, exist_ok=True) 33 | 34 | sft_instructions = [] 35 | 36 | 37 | for pos_triple in tqdm(data_manager.path_set[:2000], desc=f"Processing {dataset} - setting: {setting} - Train_size: {train_size}"): 38 | pos_head, relation, pos_tail = pos_triple 39 | 40 | removed_from_head = (relation, pos_tail, 1) 41 | removed_from_tail = (relation, pos_head, -1) 42 | data_manager.entity2relationtail_dict[pos_head].remove(removed_from_head) 43 | data_manager.entity2relationtail_dict[pos_tail].remove(removed_from_tail) 44 | 45 | pos_neighbor_triples = data_manager.neighbor_triple_finder(pos_triple) 46 | pos_close_paths = close_path_finder(data_manager, pos_triple) 47 | pos_reasoning_paths = "\n".join( 48 | " -> ".join(data_manager.triple_to_sentence(triple) for triple in path) 49 | for path in pos_close_paths 50 | ) 51 | pos_explain_prompt = EXPLAINING_PROMPT.format(neighbor_triples="\n".join(pos_neighbor_triples), reasoning_paths=pos_reasoning_paths, test_triple=data_manager.triple_to_sentence(pos_triple)) 52 | pos_explain_output = "" 53 | sft_instructions.append({"instruction": pos_explain_prompt, "input": "", "output": pos_explain_prompt}) 54 | 55 | neg_samples = data_manager.neg_sampling(pos_triple, neg_num) 56 | for neg_triple in neg_samples: 57 | 58 | neg_neighbor_triples = data_manager.neighbor_triple_finder(neg_triple) 59 | neg_close_paths = close_path_finder(data_manager, neg_triple) 60 | neg_reasoning_paths = "\n".join( 61 | " -> ".join(data_manager.triple_to_sentence(triple) for triple in path) 62 | for path in neg_close_paths 63 | ) 64 | neg_explain_prompt = EXPLAINING_PROMPT.format(neighbor_triples="\n".join(neg_neighbor_triples), reasoning_paths=neg_reasoning_paths, test_triple=data_manager.triple_to_sentence(neg_triple)) 65 | neg_explain_output = "" 66 | sft_instructions.append({"instruction": neg_explain_prompt, "input": "", "output": neg_explain_output}) 67 | 68 | data_manager.entity2relationtail_dict[pos_head].append(removed_from_head) 69 | data_manager.entity2relationtail_dict[pos_tail].append(removed_from_tail) 70 | 71 | sft_instructions_path = f"{paths_dir}/{dataset}_train_size_{train_size}.json" 72 | with open(sft_instructions_path, "w", encoding="utf-8") as f: 73 | json.dump(sft_instructions, f, ensure_ascii=False, indent=4) 74 | 75 | def main(): 76 | parser = argparse.ArgumentParser(description='Process datasets with given hyperparameters') 77 | parser.add_argument("--dataset", type=str, choices=["FB15k-237-subset", "NELL-995-subset", "WN18RR-subset"], default="FB15k-237-subset") 78 | parser.add_argument("--train_size", type=str, choices=["full"], default="full", help="Size of the training data") 79 | parser.add_argument("--neg_num", type=int, default=1, help="Number of negative samples") 80 | 81 | args = parser.parse_args() 82 | build_instructions(args.dataset, args.train_size, args.neg_num) 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /build_instructions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from data_manager import DataManager 5 | from tqdm import tqdm 6 | import argparse 7 | from prompt_templates import SUBGRAPH_REASON_PROMPT, NEIGHBOR_REASON_PROMPT, CLOSE_PATH_REASON_PROMPT 8 | 9 | 10 | def close_path_finder(data_manager:DataManager, triple): 11 | head, relation, tail = triple 12 | close_paths = list(data_manager.bfs_paths(head, tail)) 13 | 14 | if close_paths: 15 | path_degrees = [] 16 | for path in close_paths: 17 | degree_sum = sum(data_manager.relation_degree_dict[rel] for _, rel, _ in path) 18 | path_degrees.append((degree_sum, path)) 19 | path_degrees.sort(key=lambda x: x[0]) 20 | 21 | top_paths = [path for _, path in path_degrees[:data_manager.max_reason_paths]] 22 | top_paths.reverse() 23 | return top_paths 24 | 25 | return [] 26 | 27 | def build_instructions(dataset, train_size, subgraph_type, neg_num, version): 28 | setting = "transductive" 29 | 30 | data_manager = DataManager(dataset=dataset, setting=setting, train_size=train_size) 31 | 32 | paths_dir = f"instructions{version}/{dataset}" 33 | os.makedirs(paths_dir, exist_ok=True) 34 | 35 | sft_instructions = [] 36 | 37 | for pos_triple in tqdm(data_manager.path_set, desc=f"Processing {dataset} - setting: {setting} - Train_size: {train_size}"): 38 | pos_head, relation, pos_tail = pos_triple 39 | 40 | removed_from_head = (relation, pos_tail, 1) 41 | removed_from_tail = (relation, pos_head, -1) 42 | data_manager.entity2relationtail_dict[pos_head].remove(removed_from_head) 43 | data_manager.entity2relationtail_dict[pos_tail].remove(removed_from_tail) 44 | 45 | pos_type_prompt = data_manager.build_type_prompt(pos_triple) 46 | sft_instructions.append({"instruction": pos_type_prompt, "input": "", "output": "Y"}) 47 | 48 | if subgraph_type == "combine": 49 | pos_neighbor_triples = data_manager.neighbor_triple_finder(pos_triple) 50 | pos_close_paths = close_path_finder(data_manager, pos_triple) 51 | pos_reasoning_paths = "\n".join( 52 | " -> ".join(data_manager.triple_to_sentence(triple) for triple in path) 53 | for path in pos_close_paths 54 | ) 55 | pos_subgraph_prompt = SUBGRAPH_REASON_PROMPT.format(neighbor_triples="\n".join(pos_neighbor_triples), reasoning_paths=pos_reasoning_paths, test_triple=data_manager.triple_to_sentence(pos_triple)) 56 | elif subgraph_type == "neighbor-only": 57 | pos_neighbor_triples = data_manager.neighbor_triple_finder(pos_triple) 58 | pos_subgraph_prompt = NEIGHBOR_REASON_PROMPT.format(neighbor_triples="\n".join(pos_neighbor_triples), test_triple=data_manager.triple_to_sentence(pos_triple)) 59 | elif subgraph_type == "path-only": 60 | pos_close_paths = close_path_finder(data_manager, pos_triple) 61 | pos_reasoning_paths = "\n".join( 62 | " -> ".join(data_manager.triple_to_sentence(triple) for triple in path) 63 | for path in pos_close_paths 64 | ) 65 | pos_subgraph_prompt = CLOSE_PATH_REASON_PROMPT.format(reasoning_paths=pos_reasoning_paths, test_triple=data_manager.triple_to_sentence(pos_triple)) 66 | 67 | sft_instructions.append({"instruction": pos_subgraph_prompt, "input": "", "output": "Y"}) 68 | 69 | neg_samples = data_manager.neg_sampling(pos_triple, neg_num) 70 | for neg_triple in neg_samples: 71 | neg_type_prompt = data_manager.build_type_prompt(neg_triple) 72 | sft_instructions.append({"instruction": neg_type_prompt, "input": "", "output": "N"}) 73 | 74 | if subgraph_type == "combine": 75 | neg_neighbor_triples = data_manager.neighbor_triple_finder(neg_triple) 76 | neg_close_paths = close_path_finder(data_manager, neg_triple) 77 | neg_reasoning_paths = "\n".join( 78 | " -> ".join(data_manager.triple_to_sentence(triple) for triple in path) 79 | for path in neg_close_paths 80 | ) 81 | neg_subgraph_prompt = SUBGRAPH_REASON_PROMPT.format(neighbor_triples="\n".join(neg_neighbor_triples), reasoning_paths=neg_reasoning_paths, test_triple=data_manager.triple_to_sentence(neg_triple)) 82 | elif subgraph_type == "neighbor-only": 83 | neg_neighbor_triples = data_manager.neighbor_triple_finder(neg_triple) 84 | neg_subgraph_prompt = NEIGHBOR_REASON_PROMPT.format(neighbor_triples="\n".join(neg_neighbor_triples), test_triple=data_manager.triple_to_sentence(neg_triple)) 85 | elif subgraph_type == "path-only": 86 | neg_close_paths = close_path_finder(data_manager, neg_triple) 87 | neg_reasoning_paths = "\n".join( 88 | " -> ".join(data_manager.triple_to_sentence(triple) for triple in path) 89 | for path in neg_close_paths 90 | ) 91 | neg_subgraph_prompt = CLOSE_PATH_REASON_PROMPT.format(reasoning_paths=neg_reasoning_paths, test_triple=data_manager.triple_to_sentence(neg_triple)) 92 | 93 | sft_instructions.append({"instruction": neg_subgraph_prompt, "input": "", "output": "N"}) 94 | 95 | data_manager.entity2relationtail_dict[pos_head].append(removed_from_head) 96 | data_manager.entity2relationtail_dict[pos_tail].append(removed_from_tail) 97 | 98 | sft_instructions_path = f"{paths_dir}/{dataset}_train_size_{train_size}_{subgraph_type}.json" 99 | with open(sft_instructions_path, "w", encoding="utf-8") as f: 100 | json.dump(sft_instructions, f, ensure_ascii=False, indent=4) 101 | 102 | def build_vanilla_instructions(dataset, train_size, neg_num, version): 103 | setting = "transductive" # 指令构建默认是transductive,用训练集 104 | 105 | data_manager = DataManager(dataset=dataset, setting=setting, train_size=train_size) 106 | 107 | paths_dir = f"vanilla_instructions{version}/{dataset}" 108 | os.makedirs(paths_dir, exist_ok=True) 109 | 110 | sft_instructions = [] 111 | 112 | for pos_triple in tqdm(data_manager.path_set, desc=f"Processing {dataset} - setting: {setting} - Train_size: {train_size}"): 113 | pos_vanilla_prompt = data_manager.build_vanilla_prompt(pos_triple) 114 | sft_instructions.append({"instruction": pos_vanilla_prompt, "input": "", "output": "Y"}) 115 | neg_samples = data_manager.neg_sampling(pos_triple, neg_num) 116 | 117 | for neg_triple in neg_samples: 118 | neg_vanilla_prompt = data_manager.build_vanilla_prompt(neg_triple) 119 | sft_instructions.append({"instruction": neg_vanilla_prompt, "input": "", "output": "N"}) 120 | 121 | sft_instructions_path = f"{paths_dir}/{dataset}_train_size_{train_size}.json" 122 | with open(sft_instructions_path, "w", encoding="utf-8") as f: 123 | json.dump(sft_instructions, f, ensure_ascii=False, indent=4) 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser(description='Process datasets with given hyperparameters') 127 | parser.add_argument("--dataset", type=str, choices=["FB15k-237-subset", "NELL-995-subset", "WN18RR-subset"], default="FB15k-237-subset") 128 | parser.add_argument("--train_size", type=str, choices=["full", "1000", "2000"], default="full", help="Size of the training data") 129 | parser.add_argument("--neg_num", type=int, default=3, help="Number of negative samples") 130 | parser.add_argument("--prompt_type", type=str, default="CATS", choices=["CATS", "vanilla"]) 131 | parser.add_argument("--subgraph_type", type=str, default="combine", choices=["neighbor-only", "path-only", "combine"]) 132 | parser.add_argument("--version", type=str, default="") 133 | 134 | args = parser.parse_args() 135 | if args.prompt_type == "CATS": 136 | build_instructions(args.dataset, args.train_size, args.subgraph_type, args.neg_num, args.version) 137 | elif args.prompt_type == "vanilla": 138 | build_vanilla_instructions(args.dataset, args.train_size, args.neg_num, args.version) 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /build_path_together.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from concurrent.futures import ThreadPoolExecutor 3 | from itertools import product 4 | import datetime 5 | 6 | def build_path(params): 7 | dataset, setting, train_size, max_path_hops = params 8 | model_params = [ 9 | '--dataset', dataset, 10 | '--setting', setting, 11 | '--train_size', train_size, 12 | '--max_path_hops', max_path_hops 13 | ] 14 | command = ["python", "build_close_path.py"] + model_params 15 | subprocess.run(command) 16 | 17 | def main(): 18 | datasets = ["FB15k-237-subset", "NELL-995-subset", "WN18RR-subset"] 19 | settings = ["transductive"] 20 | train_sizes = ["1000", "2000"] 21 | max_path_hops = "3" 22 | # datasets = ["FB15k-237-subset"] 23 | # settings = ["inductive"] 24 | # train_sizes = ["full"] 25 | 26 | parameter_sets = [] 27 | for setting in settings: 28 | if setting == "inductive": 29 | for dataset in datasets: 30 | parameter_sets.append((dataset, setting, "full", max_path_hops)) 31 | 32 | if setting == "transductive": 33 | for dataset in datasets: 34 | for train_size in train_sizes: 35 | parameter_sets.append((dataset, setting, train_size, max_path_hops)) 36 | 37 | with ThreadPoolExecutor(max_workers=len(parameter_sets)) as executor: 38 | futures = [executor.submit(build_path, params) for params in parameter_sets] 39 | for future in futures: 40 | future.result() 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | from collections import defaultdict, deque 6 | from sentence_transformers import SentenceTransformer 7 | from prompt_templates import TYPE_REASON_PROMPT, SUBGRAPH_REASON_PROMPT, NEIGHBOR_REASON_PROMPT, CLOSE_PATH_REASON_PROMPT, BASE_REASON_PROMPT, ALL_REASON_PROMPT 8 | 9 | LLM_PATH = "" 10 | 11 | class DataManager: 12 | def __init__(self, dataset="FB15k-237-subset", setting="inductive", train_size="full", model_name="Qwen2-7B-Instruct", llm_type="sft"): 13 | self.dataset = dataset 14 | self.model_name = model_name 15 | self.dataset_name = dataset.split("-")[0] 16 | self.dataset_path = f"datasets/{dataset}" + ("-inductive" if setting=="inductive" else "") 17 | self.train_size = train_size 18 | self.model_path = f"{LLM_PATH}/{self.model_name}-{self.dataset_name}-{train_size}" if llm_type == "sft" else f"{LLM_PATH}/{self.model_name}" 19 | 20 | self.test_batch_size = 50 # 测试集中每50个sample为一个batch,并计算MRR和Hits@1 21 | self.max_type_triples = 5 # Type Reasoning阶段最多使用5个fewshot triples 22 | self.max_reason_paths = 6 # Path Reasoning阶段最多使用6个path,其中neighbor_triples和close_paths都最多六个 23 | self.max_path_hops = 3 # bfs搜索close_path的最大深度 24 | 25 | self.entity2text = self._load_text_file("entity2text.txt") 26 | self.relation2text = self._load_text_file("relation2text.txt") 27 | 28 | self.train_set = self._load_triples(f"train_{self.train_size}.txt") 29 | self.path_set = self._load_triples("inductive_graph.txt") if setting=="inductive" else self.train_set 30 | self.valid_set = self._load_triples(f"valid.txt") 31 | self.test_set_head = self._load_triples(f"ranking_head.txt") 32 | self.test_set_tail = self._load_triples(f"ranking_tail.txt") 33 | self.test_set = self.test_set_head + self.test_set_tail 34 | 35 | self.relation2headtail_dict = self._load_relation2headtail_dict(self.path_set) 36 | self.entity2relationtail_dict = self._load_entity2relationtail_dict(self.path_set) 37 | self.relation_degree_dict = self._load_relation_degree_dict(self.path_set) 38 | self.close_path_file = f"paths/close_path.json" if setting=="inductive" else f"paths/close_path_train_size_{self.train_size}.json" 39 | self.close_path_dict = self._load_close_path_dict(self.close_path_file) 40 | 41 | self.embedding_model = SentenceTransformer( 42 | model_name_or_path='BAAI/bge-small-en-v1.5', 43 | device="cuda" 44 | ) 45 | 46 | def _load_text_file(self, filename): 47 | filepath = f"{self.dataset_path}/{filename}" 48 | with open(filepath, "r", encoding="utf-8") as file: 49 | return dict(line.strip().split('\t', 1) for line in file if line.strip()) 50 | 51 | def _load_triples(self, filename): 52 | filepath = f"{self.dataset_path}/{filename}" 53 | with open(filepath, "r", encoding="utf-8") as file: 54 | return [line.strip().split('\t') for line in file if line.strip()] 55 | 56 | def _load_relation2headtail_dict(self, triple_set): 57 | relation2headtail_dict = defaultdict(list) 58 | for head, relation, tail in triple_set: 59 | relation2headtail_dict[relation].append([head, tail]) 60 | return relation2headtail_dict 61 | 62 | def _load_entity2relationtail_dict(self, triple_set): 63 | entity2relationtail_dict = defaultdict(list) 64 | for head, relation, tail in triple_set: 65 | entity2relationtail_dict[head].append((relation, tail, 1)) 66 | entity2relationtail_dict[tail].append((relation, head, -1)) 67 | return entity2relationtail_dict 68 | 69 | def _load_relation_degree_dict(self, triple_set): 70 | relation_degree_dict = defaultdict(int) 71 | for _, relation, _ in triple_set: 72 | relation_degree_dict[relation] += 1 73 | return relation_degree_dict 74 | 75 | def _load_close_path_dict(self, filename): 76 | filepath = f"{self.dataset_path}/{filename}" 77 | if os.path.exists(filepath): 78 | with open(filepath, "r", encoding="utf-8") as file: 79 | return json.load(file) 80 | return {} 81 | 82 | # 输入head entity和tail entity,使用bfs遍历搜索所有close_paths 83 | def bfs_paths(self, start, goal): 84 | queue = deque([(start, [], 0, set([start]))]) 85 | paths = [] 86 | while queue: 87 | current, path, hops, visited = queue.popleft() 88 | if hops < self.max_path_hops: 89 | for relation, neighbor, direction in self.entity2relationtail_dict[current]: 90 | if direction == 1: 91 | new_path = path + [(current, relation, neighbor)] 92 | else: 93 | new_path = path + [(neighbor, relation, current)] 94 | if neighbor == goal: 95 | paths.append(new_path) 96 | elif neighbor not in visited: 97 | queue.append((neighbor, new_path, hops + 1, visited | set([neighbor]))) 98 | return paths 99 | 100 | # 用一个relation_degree计算所有close_paths的degree和,然后排序,取最小的几个,这样能排除"gender","ethnicity"等高频relation 101 | def close_path_finder(self, triple): 102 | head, relation, tail = triple 103 | head_tail = f"{head}-{tail}" 104 | close_paths = self.close_path_dict[head_tail] 105 | 106 | if close_paths: 107 | path_degrees = [] 108 | for path in close_paths: 109 | degree_sum = sum(self.relation_degree_dict[rel] for _, rel, _ in path) 110 | path_degrees.append((degree_sum, path)) 111 | path_degrees.sort(key=lambda x: x[0]) 112 | 113 | top_paths = [path for _, path in path_degrees[:self.max_reason_paths]] 114 | top_paths.reverse() 115 | return top_paths 116 | 117 | return [] 118 | 119 | def close_path_finder_no_degree(self, triple): 120 | head, relation, tail = triple 121 | head_tail = f"{head}-{tail}" 122 | close_paths = self.close_path_dict[head_tail] 123 | 124 | if close_paths: 125 | return close_paths[:self.max_reason_paths] 126 | 127 | return [] 128 | 129 | def linearize_triple(self, triple): 130 | return f"({self.entity2text[triple[0]]}, {self.relation2text[triple[1]]}, {self.entity2text[triple[2]]})" 131 | 132 | def triple_to_sentence(self, triple): 133 | head, relation, tail = triple 134 | if self.dataset == "FB15k-237-subset": 135 | head_property = relation.split('/')[2] 136 | tail_property = relation.split('/')[-1] 137 | return f"('{self.entity2text[tail]}' is the {tail_property} of {head_property} '{self.entity2text[head]}')" 138 | elif self.dataset == "WN18RR-subset": 139 | return f"('{self.entity2text[head]}' {self.relation2text[relation]} '{self.entity2text[tail]}')" 140 | elif self.dataset == "NELL-995-subset": 141 | return f"('{self.entity2text[head]}' {self.relation2text[relation]} '{self.entity2text[tail]}')" 142 | 143 | def build_type_prompt(self, triple): 144 | fewshot_triples = self.diverse_fewshot_triple_finder(triple) 145 | fewshot_triples_sentence = '\n'.join(self.triple_to_sentence(triple) for triple in fewshot_triples) 146 | return TYPE_REASON_PROMPT.format(fewshot_triples=fewshot_triples_sentence, test_triple=self.triple_to_sentence(triple)) 147 | 148 | def build_subgraph_prompt(self, triple): 149 | neighbor_triples = self.neighbor_triple_finder(triple) 150 | close_paths = self.close_path_finder(triple) 151 | reasoning_paths = "\n".join( 152 | " -> ".join(self.triple_to_sentence(triple) for triple in path) 153 | for path in close_paths 154 | ) 155 | return SUBGRAPH_REASON_PROMPT.format(neighbor_triples="\n".join(neighbor_triples), reasoning_paths=reasoning_paths, test_triple=self.triple_to_sentence(triple)) 156 | 157 | def build_neighbor_prompt(self, triple): 158 | neighbor_triples = self.neighbor_triple_finder(triple) 159 | return NEIGHBOR_REASON_PROMPT.format(neighbor_triples="\n".join(neighbor_triples), test_triple=self.triple_to_sentence(triple)) 160 | 161 | def build_close_path_prompt(self, triple): 162 | close_paths = self.close_path_finder(triple) 163 | reasoning_paths = "\n".join( 164 | " -> ".join(self.triple_to_sentence(triple) for triple in path) 165 | for path in close_paths 166 | ) 167 | return CLOSE_PATH_REASON_PROMPT.format(reasoning_paths=reasoning_paths, test_triple=self.triple_to_sentence(triple)) 168 | 169 | def build_close_path_no_degree_prompt(self, triple): 170 | close_paths = self.close_path_finder_no_degree(triple) 171 | reasoning_paths = "\n".join( 172 | " -> ".join(self.triple_to_sentence(triple) for triple in path) 173 | for path in close_paths 174 | ) 175 | return CLOSE_PATH_REASON_PROMPT.format(reasoning_paths=reasoning_paths, test_triple=self.triple_to_sentence(triple)) 176 | 177 | def build_vanilla_prompt(self, triple): 178 | return BASE_REASON_PROMPT.format(test_triple=self.triple_to_sentence(triple)) 179 | 180 | def build_all_prompt(self, triple): 181 | fewshot_triples = self.diverse_fewshot_triple_finder(triple) 182 | neighbor_triples = self.neighbor_triple_finder(triple) 183 | close_paths = self.close_path_finder(triple) 184 | fewshot_triples_sentence = '\n'.join(self.triple_to_sentence(triple) for triple in fewshot_triples) 185 | reasoning_paths = "\n".join( 186 | " -> ".join(self.triple_to_sentence(triple) for triple in path) 187 | for path in close_paths 188 | ) 189 | return ALL_REASON_PROMPT.format(fewshot_triples=fewshot_triples_sentence, neighbor_triples="\n".join(neighbor_triples), reasoning_paths=reasoning_paths, test_triple=self.triple_to_sentence(triple)) 190 | 191 | def get_test_batches(self): 192 | return [self.test_set[i:i + self.test_batch_size] for i in range(0, len(self.test_set), self.test_batch_size)] 193 | 194 | def diverse_fewshot_triple_finder(self, test_triple): 195 | test_head, relation, test_tail = test_triple 196 | head_tail_pairs = self.relation2headtail_dict[relation] 197 | 198 | if len(head_tail_pairs) <= self.max_type_triples: 199 | return [[head, relation, tail] for head, tail in head_tail_pairs] 200 | 201 | used_heads = {test_head, test_tail} 202 | used_tails = {test_tail, test_head} 203 | used_pairs = set() 204 | selected_triples = [] 205 | 206 | for head, tail in head_tail_pairs: 207 | if head not in used_heads and tail not in used_tails: 208 | selected_triples.append([head, relation, tail]) 209 | used_heads.add(head) 210 | used_tails.add(tail) 211 | used_pairs.add((head, tail)) 212 | if len(selected_triples) == self.max_type_triples: 213 | return selected_triples 214 | 215 | for head, tail in head_tail_pairs: 216 | if (head, tail) not in used_pairs: 217 | if len(selected_triples) < self.max_type_triples: 218 | selected_triples.append([head, relation, tail]) 219 | used_heads.add(head) 220 | used_tails.add(tail) 221 | used_pairs.add((head, tail)) 222 | else: 223 | break 224 | 225 | return selected_triples 226 | 227 | # path_reasoning里面的neighbor triple,尽可能找到与当前triple相关的neighbor triple 228 | def neighbor_triple_finder(self, triple): 229 | head, relation, tail = triple 230 | head_triples = self.entity2relationtail_dict[head] 231 | tail_triples = self.entity2relationtail_dict[tail] 232 | 233 | triple_sentence = self.triple_to_sentence(triple) 234 | head_sentences = [self.triple_to_sentence((head, rel, t)) if direction == 1 else self.triple_to_sentence((t, rel, head)) 235 | for rel, t, direction in head_triples] 236 | tail_sentences = [self.triple_to_sentence((tail, rel, h)) if direction == 1 else self.triple_to_sentence((h, rel, tail)) 237 | for rel, h, direction in tail_triples] 238 | 239 | all_head_sentences = [triple_sentence] + head_sentences 240 | all_tail_sentences = [triple_sentence] + tail_sentences 241 | 242 | each_count = self.max_reason_paths // 2 243 | 244 | top_head_sentences = head_sentences 245 | top_tail_sentences = tail_sentences 246 | 247 | if len(head_sentences) > each_count: 248 | head_embeddings = self.embedding_model.encode(all_head_sentences, normalize_embeddings=True) 249 | head_similarity = head_embeddings[0] @ head_embeddings[1:].T 250 | top_head_indices = np.argsort(-head_similarity)[:each_count] 251 | top_head_sentences = [head_sentences[i] for i in top_head_indices] 252 | 253 | if len(tail_sentences) > each_count: 254 | tail_embeddings = self.embedding_model.encode(all_tail_sentences, normalize_embeddings=True) 255 | tail_similarity = tail_embeddings[0] @ tail_embeddings[1:].T 256 | top_tail_indices = np.argsort(-tail_similarity)[:each_count] 257 | top_tail_sentences = [tail_sentences[i] for i in top_tail_indices] 258 | 259 | return top_head_sentences + top_tail_sentences 260 | 261 | # 负采样:对于正样本triple,分别破坏head, relaton和tail,并为它们随机采样。 262 | def neg_sampling(self, pos_triple, count): 263 | head, relation, tail = pos_triple 264 | 265 | entities = set() 266 | for triple in self.path_set: 267 | entities.add(triple[0]) 268 | entities.add(triple[2]) 269 | 270 | candidate_entities = entities - {head, tail} 271 | seen_triples = {tuple(triple) for triple in self.path_set} 272 | negative_samples = [] 273 | 274 | # 破坏head 275 | for _ in range(count): 276 | while True: 277 | new_head = random.choice(list(candidate_entities)) 278 | if (new_head, relation, tail) not in seen_triples: 279 | seen_triples.add((new_head, relation, tail)) 280 | negative_samples.append((new_head, relation, tail)) 281 | break 282 | 283 | # 破坏tail 284 | for _ in range(count): 285 | while True: 286 | new_tail = random.choice(list(candidate_entities)) 287 | if (head, relation, new_tail) not in seen_triples: 288 | seen_triples.add((head, relation, new_tail)) 289 | negative_samples.append((head, relation, new_tail)) 290 | break 291 | 292 | # # 破坏relation 293 | # candidate_relations = {triple[1] for triple in self.path_set} - {relation} 294 | # for _ in range(count): 295 | # while True: 296 | # new_relation = random.choice(list(candidate_relations)) 297 | # if (head, new_relation, tail) not in seen_triples: 298 | # seen_triples.add((head, new_relation, tail)) 299 | # negative_samples.append((head, new_relation, tail)) 300 | # break 301 | 302 | return negative_samples 303 | -------------------------------------------------------------------------------- /explain.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/CATS/cca49e33ed9f918ec9ae19871edf3786a51e47fe/explain.py -------------------------------------------------------------------------------- /path_count_bins_histogram.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/CATS/cca49e33ed9f918ec9ae19871edf3786a51e47fe/path_count_bins_histogram.pdf -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from data_manager import DataManager 7 | from datetime import datetime 8 | 9 | def cal_Y_prob(model:AutoModelForCausalLM, tokenizer:AutoTokenizer, generation_config, prompt_list): 10 | messages_batch = [ 11 | [{"role": "user", "content": prompt}] 12 | for prompt in prompt_list 13 | ] 14 | texts = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for messages in messages_batch] 15 | inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to("cuda") 16 | 17 | generated_output = model.generate( 18 | input_ids=inputs.input_ids, 19 | pad_token_id=tokenizer.eos_token_id, 20 | eos_token_id=tokenizer.eos_token_id, 21 | return_dict_in_generate=True, 22 | output_scores=True, 23 | **generation_config 24 | ) 25 | 26 | scores = generated_output.scores[0] 27 | probs = scores.softmax(dim=-1) 28 | 29 | Y_id = tokenizer.encode("Y", add_special_tokens=False)[0] 30 | N_id = tokenizer.encode("N", add_special_tokens=False)[0] 31 | 32 | Y_probs = [probs[i, Y_id].item() for i in range(probs.shape[0])] 33 | 34 | return Y_probs 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--dataset", type=str, choices=["FB15k-237-subset", "NELL-995-subset", "WN18RR-subset"], default="FB15k-237-subset", help="Name of the dataset") 39 | parser.add_argument("--setting", type=str, choices=["inductive", "transductive"], default="inductive", help="Inductive or Transductive setting") 40 | parser.add_argument("--train_size", type=str, choices=["full", "1000", "2000"], default="full", help="Size of the training data") 41 | parser.add_argument("--model_name", type=str, choices=["Qwen2-7B-Instruct", "Meta-Llama-3-8B-Instruct", "Qwen2-1.5B-Instruct"], default="Qwen2-7B-Instruct") 42 | parser.add_argument("--llm_type", type=str, choices=["sft", "base"], default="base") 43 | parser.add_argument("--prompt_type", type=str, choices=["CATS", "vanilla", "CATS-all"], default="CATS") 44 | parser.add_argument("--subgraph_type", type=str, choices=["neighbor-only", "path-only", "combine"], default="combine") 45 | parser.add_argument("--path_type", type=str, choices=["degree", "no-degree"], default="degree") 46 | 47 | args = parser.parse_args() 48 | 49 | log_dir = f"logs_{args.model_name}_{args.llm_type}_{args.prompt_type}_{args.subgraph_type}_{args.path_type}" 50 | os.makedirs(log_dir, exist_ok=True) 51 | timestamp = datetime.now().strftime("%m%d%H%M") 52 | log_file = os.path.join(log_dir, f"log_{args.dataset}_{args.setting}_{args.train_size}_{timestamp}.txt") 53 | 54 | data_manager = DataManager(dataset=args.dataset, setting=args.setting, train_size=args.train_size, model_name=args.model_name, llm_type=args.llm_type) 55 | test_batches = data_manager.get_test_batches() 56 | 57 | model = AutoModelForCausalLM.from_pretrained(data_manager.model_path, torch_dtype="auto", device_map="auto") 58 | tokenizer = AutoTokenizer.from_pretrained(data_manager.model_path) 59 | generation_config = dict( 60 | temperature=0, 61 | top_k=0, 62 | top_p=0, 63 | do_sample=False, 64 | max_new_tokens=1, 65 | ) 66 | 67 | llm_batch_size = 1 68 | sample_counter = 0 69 | 70 | def log_results(label, results): 71 | log.write(f"{label} Hits results: {results}\n") 72 | hit_at_1 = round(sum(1 for hits in results if hits == 1) / len(results), 3) 73 | mrr = round(sum(1 / hits for hits in results if hits != 0) / len(results), 3) 74 | log.write(f"{label} Hit@1: {hit_at_1}\n") 75 | log.write(f"{label} MRR: {mrr}\n") 76 | 77 | with open(log_file, 'w') as log: 78 | if args.prompt_type == "vanilla": 79 | hits_result_vanilla = [] 80 | log.write(f"Using model: {data_manager.model_path}\n") 81 | 82 | for idx, batch in enumerate(tqdm(test_batches, desc="Processing test batches")): 83 | vanilla_prompts = [data_manager.build_vanilla_prompt(test_triple) for test_triple in batch] 84 | vanilla_probs = [] 85 | for i in range(0, len(vanilla_prompts), llm_batch_size): 86 | batch_prompts = vanilla_prompts[i:i + llm_batch_size] 87 | vanilla_probs.extend(cal_Y_prob(model, tokenizer, generation_config, batch_prompts)) 88 | for i, (prompt, prob) in enumerate(zip(vanilla_prompts, vanilla_probs)): 89 | log.write(f"Sample {sample_counter} vanilla Prompt: {prompt}\n") 90 | log.write(f"Sample {sample_counter} vanilla 'Y' token Probability: {prob}\n") 91 | log.write("*"*50 + "\n") 92 | sample_counter += 1 93 | vanilla_prob_in_batch = list(zip(vanilla_probs, range(len(vanilla_probs)))) 94 | sorted_vanilla_indices = sorted(range(len(vanilla_prob_in_batch)), key=lambda i: vanilla_prob_in_batch[i][0], reverse=True) 95 | log.write(f"Sorted vanilla indices: {sorted_vanilla_indices}\n") 96 | hits_position_base = sorted_vanilla_indices.index(0) + 1 if 0 in sorted_vanilla_indices else 0 97 | hits_result_vanilla.append(hits_position_base) 98 | log.write("*"*50 + "\n") 99 | log.flush() 100 | 101 | if (idx + 1) % 100 == 0: 102 | log.write(f"\nMetrics after processing {idx + 1} batches:\n") 103 | log_results("Vanilla", hits_result_vanilla) 104 | log.write("\n" + "="*50 + "\n") 105 | log.flush() 106 | 107 | log.write("Final Results:\n") 108 | log_results("Vanilla", hits_result_vanilla) 109 | log.flush() 110 | 111 | elif args.prompt_type == "CATS-all": 112 | hits_result_all = [] 113 | log.write(f"Using model: {data_manager.model_path}\n") 114 | 115 | for idx, batch in enumerate(tqdm(test_batches, desc="Processing test batches")): 116 | all_prompts = [data_manager.build_all_prompt(test_triple) for test_triple in batch] 117 | all_probs = [] 118 | for i in range(0, len(all_prompts), llm_batch_size): 119 | batch_prompts = all_prompts[i:i + llm_batch_size] 120 | all_probs.extend(cal_Y_prob(model, tokenizer, generation_config, batch_prompts)) 121 | for i, (prompt, prob) in enumerate(zip(all_prompts, all_probs)): 122 | log.write(f"Sample {sample_counter} all Prompt: {prompt}\n") 123 | log.write(f"Sample {sample_counter} all 'Y' token Probability: {prob}\n") 124 | log.write("*"*50 + "\n") 125 | sample_counter += 1 126 | all_prob_in_batch = list(zip(all_probs, range(len(all_probs)))) 127 | sorted_all_indices = sorted(range(len(all_prob_in_batch)), key=lambda i: all_prob_in_batch[i][0], reverse=True) 128 | log.write(f"Sorted all indices: {sorted_all_indices}\n") 129 | hits_position_all = sorted_all_indices.index(0) + 1 if 0 in sorted_all_indices else 0 130 | hits_result_all.append(hits_position_all) 131 | log.write("*"*50 + "\n") 132 | log.flush() 133 | 134 | if (idx + 1) % 100 == 0: 135 | log.write(f"\nMetrics after processing {idx + 1} batches:\n") 136 | log_results("All", hits_result_all) 137 | log.write("\n" + "="*50 + "\n") 138 | log.flush() 139 | 140 | log.write("Final Results:\n") 141 | log_results("All", hits_result_all) 142 | log.flush() 143 | 144 | elif args.prompt_type == "CATS": 145 | hits_result_type = [] 146 | hits_result_subgraph = [] 147 | hits_result_average_ensemble = [] 148 | TAR_infer_times = [] 149 | SR_infer_times = [] 150 | # hits_result_weighted_ensemble = [] 151 | # hits_result_type_filtered_subgraph = [] 152 | log.write(f"Using model: {data_manager.model_path}\n") 153 | 154 | for idx, batch in enumerate(tqdm(test_batches, desc="Processing test batches")): 155 | type_prompts = [data_manager.build_type_prompt(test_triple) for test_triple in batch] 156 | if args.subgraph_type == "combine": 157 | subgraph_prompts = [data_manager.build_subgraph_prompt(test_triple) for test_triple in batch] 158 | elif args.subgraph_type == "neighbor-only": 159 | subgraph_prompts = [data_manager.build_neighbor_prompt(test_triple) for test_triple in batch] 160 | elif args.subgraph_type == "path-only": 161 | if args.path_type == "degree": 162 | subgraph_prompts = [data_manager.build_close_path_prompt(test_triple) for test_triple in batch] 163 | elif args.path_type == "no-degree": 164 | subgraph_prompts = [data_manager.build_close_path_no_degree_prompt(test_triple) for test_triple in batch] 165 | type_probs = [] 166 | batch_infer_times = 0 167 | for i in range(0, len(type_prompts), llm_batch_size): 168 | batch_prompts = type_prompts[i:i + llm_batch_size] 169 | start_time = time.time() 170 | type_probs.extend(cal_Y_prob(model, tokenizer, generation_config, batch_prompts)) 171 | end_time = time.time() 172 | time_interval = end_time - start_time 173 | batch_infer_times += time_interval 174 | # log.write(f"Time for type reasoning inference: {time_interval}\n") 175 | TAR_infer_times.append(batch_infer_times) 176 | for i, (prompt, prob) in enumerate(zip(type_prompts, type_probs)): 177 | log.write(f"Sample {sample_counter} type Prompt: {prompt}") 178 | log.write(f"Sample {sample_counter} type 'Y' token Probability: {prob}\n") 179 | log.write("*"*50 + "\n") 180 | sample_counter += 1 181 | 182 | type_prob_in_batch = list(zip(type_probs, range(len(type_probs)))) 183 | sorted_type_indices = sorted(range(len(type_prob_in_batch)), key=lambda i: type_prob_in_batch[i][0], reverse=True) 184 | log.write(f"Sorted type indices: {sorted_type_indices}\n") 185 | hits_position_type = sorted_type_indices.index(0) + 1 if 0 in sorted_type_indices else 0 186 | hits_result_type.append(hits_position_type) 187 | 188 | top_10_type_indices = sorted_type_indices[:10] 189 | type_filtered_set = set(top_10_type_indices) 190 | 191 | subgraph_probs = [] 192 | batch_infer_times = 0 193 | for i in range(0, len(subgraph_prompts), llm_batch_size): 194 | batch_prompts = subgraph_prompts[i:i + llm_batch_size] 195 | start_time = time.time() 196 | subgraph_probs.extend(cal_Y_prob(model, tokenizer, generation_config, batch_prompts)) 197 | end_time = time.time() 198 | time_interval = end_time - start_time 199 | batch_infer_times += time_interval 200 | # log.write(f"Time for subgraph reasoning inference: {time_interval}\n") 201 | SR_infer_times.append(batch_infer_times) 202 | 203 | for i, (prompt, prob) in enumerate(zip(subgraph_prompts, subgraph_probs)): 204 | log.write(f"Sample {sample_counter} Subgraph Prompt: {prompt}\n") 205 | log.write(f"Sample {sample_counter} Subgraph 'Y' token Probability: {prob}\n") 206 | log.write("*"*50 + "\n") 207 | sample_counter += 1 208 | 209 | subgraph_prob_in_batch = list(zip(subgraph_probs, range(len(subgraph_probs)))) 210 | sorted_subgraph_indices = sorted(range(len(subgraph_prob_in_batch)), key=lambda i: subgraph_prob_in_batch[i][0], reverse=True) 211 | log.write(f"Sorted Subgraph indices: {sorted_subgraph_indices}\n") 212 | hits_position_subgraph = sorted_subgraph_indices.index(0) + 1 if 0 in sorted_subgraph_indices else 0 213 | hits_result_subgraph.append(hits_position_subgraph) 214 | 215 | # Ensemble type reasoning and subgraph reasoning 216 | combined_ranks = [sorted_type_indices.index(i) + sorted_subgraph_indices.index(i) for i in range(len(sorted_type_indices))] 217 | sorted_combined_indices = sorted(range(len(combined_ranks)), key=lambda i: combined_ranks[i]) 218 | hits_position_average_ensemble = sorted_combined_indices.index(0) + 1 if 0 in sorted_combined_indices else 0 219 | hits_result_average_ensemble.append(hits_position_average_ensemble) 220 | 221 | # # Weighted Ensemble 222 | # weighted_scores = [(1 / (sorted_type_indices.index(i) + 1) + 1 / (sorted_subgraph_indices.index(i) + 1)) for i in range(len(sorted_type_indices))] 223 | # sorted_weighted_indices = sorted(range(len(weighted_scores)), key=lambda i: weighted_scores[i], reverse=True) 224 | # hits_position_weighted_ensemble = sorted_weighted_indices.index(0) + 1 if 0 in sorted_weighted_indices else 0 225 | # hits_result_weighted_ensemble.append(hits_position_weighted_ensemble) 226 | 227 | # # Filter subgraph results based on type_filtered_list 228 | # sorted_filtered_subgraph_indices = [index for index in sorted_subgraph_indices if index in type_filtered_set] 229 | # log.write(f"Sorted filtered subgraph indices: {sorted_filtered_subgraph_indices}\n") 230 | # hits_position_type_filtered_subgraph = sorted_filtered_subgraph_indices.index(0) + 1 if 0 in sorted_filtered_subgraph_indices else 0 231 | # hits_result_type_filtered_subgraph.append(hits_position_type_filtered_subgraph) 232 | 233 | log.write("*"*50 + "\n") 234 | log.flush() 235 | 236 | if (idx + 1) % 100 == 0: 237 | log.write(f"\nMetrics after processing {idx + 1} batches:\n") 238 | log_results("Type", hits_result_type) 239 | log_results("Subgraph", hits_result_subgraph) 240 | log_results("Average Ensemble", hits_result_average_ensemble) 241 | # log_results("Weighted Ensemble", hits_result_weighted_ensemble) 242 | # log_results("Type Filtered Subgraph", hits_result_type_filtered_subgraph) 243 | log.write("\n" + "="*50 + "\n") 244 | log.flush() 245 | 246 | log.write("Final Results:\n") 247 | log.write("Propotion of type reasoning top 5: {}\n".format(sum(1 for hits in hits_result_type if hits <= 5) / len(hits_result_type))) 248 | log.write("Propotion of type reasoning top 10: {}\n".format(sum(1 for hits in hits_result_type if hits <= 10) / len(hits_result_type))) 249 | 250 | log_results("Type", hits_result_type) 251 | log_results("Subgraph", hits_result_subgraph) 252 | log_results("Average Ensemble", hits_result_average_ensemble) 253 | # log_results("Weighted Ensemble", hits_result_weighted_ensemble) 254 | # log_results("Type Filtered Subgraph", hits_result_type_filtered_subgraph) 255 | 256 | # Time cost 257 | log.write("Average time for type reasoning inference: {}\n".format(sum(TAR_infer_times) / len(TAR_infer_times))) 258 | log.write("Average time for subgraph reasoning inference: {}\n".format(sum(SR_infer_times) / len(SR_infer_times))) 259 | log.flush() 260 | 261 | if __name__ == "__main__": 262 | main() 263 | -------------------------------------------------------------------------------- /prompt_templates.py: -------------------------------------------------------------------------------- 1 | TYPE_REASON_PROMPT = """Please determine whether the entities in the input triples are consistent in entity type with a set of known triples in the knowledge graph provided. 2 | A set of known triples are: 3 | {fewshot_triples} 4 | The triple to be determined is: 5 | {test_triple} 6 | Please return 'Y' if the input triple is consistent in entity type, otherwise return 'N'. Do not say anything else except your determination. 7 | """ 8 | 9 | SUBGRAPH_REASON_PROMPT = """Please determine whether the relation in the input can be reliably inferred between the head and tail entities, based on a set of neighbor triples and reasoning paths from the knowledge graph. 10 | A set of neighbor triples from the knowledge graph are: 11 | {neighbor_triples} 12 | A set of reasoning paths from the knowledge graph are: 13 | {reasoning_paths} 14 | The relation to be inferred is: 15 | {test_triple} 16 | Please return 'Y' if there is sufficient evidence from the knowledge graph to infer the relation, otherwise return 'N'. Do not say anything else except your determination. 17 | """ 18 | 19 | NEIGHBOR_REASON_PROMPT = """Please determine whether the relation in the input can be reliably inferred between the head and tail entities, based on a set of neighbor triples from the knowledge graph. 20 | A set of neighbor triples from the knowledge graph are: 21 | {neighbor_triples} 22 | The relation to be inferred is: 23 | {test_triple} 24 | Please return 'Y' if there is sufficient evidence from the knowledge graph to infer the relation, otherwise return 'N'. Do not say anything else except your determination. 25 | """ 26 | 27 | CLOSE_PATH_REASON_PROMPT = """Please determine whether the relation in the input can be reliably inferred between the head and tail entities, based on a set of reasoning paths from the knowledge graph. 28 | A set of reasoning paths from the knowledge graph are: 29 | {reasoning_paths} 30 | The relation to be inferred is: 31 | {test_triple} 32 | Please return 'Y' if there is sufficient evidence from the knowledge graph to infer the relation, otherwise return 'N'. Do not say anything else except your determination. 33 | """ 34 | 35 | BASE_REASON_PROMPT = """Please determine whether the input triple from a knowledge graph is correct or incorrect. 36 | {test_triple} 37 | Please return 'Y' if it is correct, otherwise return 'N'. Do not say anything else except your determination. 38 | """ 39 | 40 | ALL_REASON_PROMPT = """Please determine whether the relation in the input can be reliably inferred between the head and tail entities, based on a set of known triples, neighbor triples and reasoning paths from the knowledge graph. 41 | A set of known triples are: 42 | {fewshot_triples} 43 | A set of neighbor triples from the knowledge graph are: 44 | {neighbor_triples} 45 | A set of reasoning paths from the knowledge graph are: 46 | {reasoning_paths} 47 | The relation to be inferred is: 48 | {test_triple} 49 | Please return 'Y' if there is sufficient evidence from the knowledge graph to infer the relation, otherwise return 'N'. Do not say anything else except your determination. 50 | """ 51 | 52 | # EXPLAINING_PROMPT = """Please determine whether the relation in the input can be reliably inferred between the head and tail entities, based on a set of neighbor triples and reasoning paths from the knowledge graph. 53 | # A set of neighbor triples from the knowledge graph are: 54 | # {neighbor_triples} 55 | # A set of reasoning paths from the knowledge graph are: 56 | # {reasoning_paths} 57 | # The relation to be inferred is: 58 | # {test_triple} 59 | # Please return 'Y' if there is sufficient evidence from the knowledge graph to infer the relation, otherwise return 'N'. Please provide a brief explanation for your determination. 60 | # """ 61 | -------------------------------------------------------------------------------- /statistics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | 4 | # 文件路径 5 | json_file_path = 'datasets/FB15k-237-subset-inductive/paths/close_path.json' 6 | txt_file_path = 'datasets/FB15k-237-subset-inductive/test.txt' 7 | 8 | # 读取JSON文件 9 | with open(json_file_path, 'r') as file: 10 | data = json.load(file) 11 | 12 | # 初始化路径计数器 13 | test_triple_path_count = {} 14 | empty_path_count = 0 15 | 16 | # 读取并处理test.txt文件 17 | with open(txt_file_path, 'r') as file: 18 | for line in file: 19 | parts = line.strip().split('\t') 20 | if len(parts) >= 2: # 确保至少有两个元素 21 | key = f"{parts[0]}-{parts[-1]}" 22 | # 检查这个key在JSON数据中是否存在 23 | if key in data: 24 | value = data[key] 25 | # 统计路径数量,如果是[]则计为0,否则为列表长度 26 | path_count = len(value) if value != [] else 0 27 | if path_count == 0: 28 | empty_path_count += 1 29 | if path_count not in test_triple_path_count: 30 | test_triple_path_count[path_count] = 0 31 | test_triple_path_count[path_count] += 1 32 | 33 | # 准备分区间 34 | bins = [0, 1, 2, 3, 4, 5, 10, 20, 50] 35 | bin_labels = ['0', '1', '2', '3', '4', '5', '6-10', '11-20', '21-50', '51+'] 36 | bin_counts = {label: 0 for label in bin_labels} 37 | 38 | # 根据bins对path_counts进行分区 39 | for count, num in test_triple_path_count.items(): 40 | if count > 50: 41 | bin_counts['51+'] += num 42 | else: 43 | for i in range(len(bins)): 44 | if count <= bins[i]: 45 | bin_counts[bin_labels[i]] += num 46 | break 47 | 48 | # 绘制直方图 49 | plt.figure(figsize=(5, 2)) # 修改此处的高度为2英寸 50 | plt.rcParams.update({'font.size': 15}) 51 | x_labels = list(bin_counts.keys()) 52 | y_values = list(bin_counts.values()) 53 | 54 | plt.bar(x_labels, y_values, color='skyblue') 55 | 56 | # 添加y值标注 57 | for i, y in enumerate(y_values): 58 | plt.text(i, y, str(y), ha='center', va='bottom') 59 | 60 | plt.xticks(rotation=45) 61 | plt.tight_layout() 62 | 63 | plt.xlabel('#Paths') 64 | plt.gca().xaxis.set_label_coords(1.05, -0.05) # Adjust the position as needed 65 | 66 | # Move the ylabel to the top side of the y-axis 67 | plt.ylabel('#Triples') 68 | plt.gca().yaxis.set_label_coords(-0.05, 1.05) # Adjust the position as needed, and rotate it 69 | plt.gca().yaxis.label.set_rotation(0) 70 | 71 | # 保存图像到文件 72 | from matplotlib.transforms import Bbox 73 | plt.savefig('path_count_bins_histogram.pdf', 74 | bbox_inches=Bbox.from_extents(-0.1, 0, 5.4, 2)) 75 | plt.show() 76 | --------------------------------------------------------------------------------