├── .gitignore ├── KnowledgeBase ├── KG_api.py ├── KG_sparse_api.py └── __init__.py ├── LICENSE ├── README.md ├── asset ├── model.png ├── spider_rea_wo_icl_v1.png ├── spider_syn_wo_icl_v1.png ├── spider_wo_icl_v1.png ├── tabfact_wo_icl_v1.png ├── webqsp_wo_icl_v1.png ├── wikisql_wo_icl_v1.png └── wtq_wo_icl_v1.png ├── evaluate_for_spider.py ├── evaluate_for_tabfact.py ├── evaluate_for_webqsp.py ├── evaluate_for_wikisql.py ├── outputs ├── spider-realistic │ └── output_wo_icl_v1.jsonl ├── spider-syn │ └── output_wo_icl_v1.jsonl ├── spider │ └── output_wo_icl_v1.jsonl ├── tabfact │ └── output_wo_icl_v1.jsonl ├── wikisql │ └── output_wo_icl_v1.jsonl └── wtq │ └── output_wo_icl_v1.jsonl ├── process_sql.py ├── prompts ├── prompt_for_spider.json ├── prompt_for_tabfact.json ├── prompt_for_webqsp.json ├── prompt_for_wikisql.json └── prompt_for_wtq.json ├── scripts ├── eval_for_tabfact.sh ├── eval_spider_pred.sh ├── run_spider_rea_wo_icl_v1.sh ├── run_spider_syn_wo_icl_v1.sh ├── run_spider_wo_icl_v1.sh ├── run_tabfact_wo_icl_v1.sh ├── run_webqsp_wo_icl_v1.sh ├── run_wikisql_wo_icl_v1.sh └── run_wtq_wo_icl_v1.sh ├── structgpt_for_tableqa.py ├── structgpt_for_text_to_sql.py └── structgpt_for_webqsp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /KnowledgeBase/KG_api.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from copy import deepcopy 3 | 4 | import logging 5 | import numpy as np 6 | 7 | from KnowledgeBase.KG_sparse_api import KnowledgeGraphSparse 8 | import pickle 9 | 10 | END_OF_HOP = "end.hop" 11 | SEP = "[SEP]" 12 | 13 | 14 | class KnowledgeGraph(object): 15 | def __init__(self, sparse_triples_path, sparse_ent_type_path, ent2id_path, rel2id_path): 16 | triples_path, ent_type_path = sparse_triples_path, sparse_ent_type_path 17 | print("The sparse KG instantiate via int triples from the %s" % (triples_path)) 18 | self.sparse_kg = KnowledgeGraphSparse(triples_path=triples_path, ent_type_path=ent_type_path) 19 | self.ent2id = self._load_pickle_file(ent2id_path) 20 | self.id2ent = self._reverse_dict(self.ent2id) 21 | self.rel2id = self._load_pickle_file(rel2id_path) 22 | self.id2rel = self._reverse_dict(self.rel2id) 23 | print("The sparse KG instantiate over, all triples: %d, max head id: %d." % ( 24 | self.sparse_kg.E, self.sparse_kg.max_head)) 25 | 26 | @staticmethod 27 | def _load_pickle_file(filename): 28 | with open(filename, "rb") as f: 29 | return pickle.load(f) 30 | 31 | @staticmethod 32 | def _reverse_dict(ori_dict): 33 | reversed_dict = {v: k for k, v in ori_dict.items()} 34 | return reversed_dict 35 | 36 | def get_facts_1hop(self, seeds, max_triples_per_relation, first_flag, gold_relations): 37 | if first_flag: 38 | seeds_id = [] 39 | for seed in seeds: 40 | try: 41 | seed_id = self.ent2id[seed] 42 | seeds_id.append(seed_id) 43 | except Exception as e: 44 | logging.exception(e) 45 | print("Entity string: %s not in ent2id dict" % seed) 46 | continue 47 | else: 48 | seeds_id = seeds 49 | if len(seeds_id) == 0: 50 | return defaultdict(list), [] 51 | triples_per_hop, tails = self.sparse_kg.get_facts_1hop(seeds_id, self.id2rel, self.rel2id, max_triples_per_relation, gold_relations) 52 | triples_per_hop = {hop: [[self.id2ent[triple[0]], self.id2rel[triple[1]], self.id2ent[triple[2]]] for triple in triples] 53 | for hop, triples in triples_per_hop.items()} 54 | return triples_per_hop, tails 55 | 56 | def get_rels_1hop(self, seeds, first_flag): 57 | if first_flag: 58 | seeds_id = [] 59 | for seed in seeds: 60 | try: 61 | seed_id = self.ent2id[seed] 62 | seeds_id.append(seed_id) 63 | except Exception as e: 64 | logging.exception(e) 65 | print("Entity string: %s not in ent2id dict" % seed) 66 | continue 67 | else: 68 | seeds_id = seeds 69 | if len(seeds_id) == 0: 70 | return [] 71 | can_rels = self.sparse_kg.get_rels_1hop(seeds_id, self.id2rel) 72 | return can_rels 73 | -------------------------------------------------------------------------------- /KnowledgeBase/KG_sparse_api.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from collections import defaultdict 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import pickle 8 | import pandas as pd 9 | from scipy.sparse import csr_matrix 10 | VERY_LARGT_NUM = 10**8 11 | PATH_CUTOFF = 10**6 12 | NODE_CUTOFF = 10**4 13 | 14 | 15 | class KnowledgeGraphSparse(object): 16 | def __init__(self, triples_path: str, ent_type_path: str): 17 | self.triple = self._load_npy_file(triples_path) 18 | self.ent_type = self._load_npy_file(ent_type_path) 19 | self.bin_map = np.zeros_like(self.ent_type, dtype=np.int32) 20 | self.E = self.triple.shape[0] 21 | self.head2fact = csr_matrix( 22 | (np.ones(self.E), (self.triple[:, 0], np.arange(self.E)))).astype('bool') 23 | self.rel2fact = csr_matrix( 24 | (np.ones(self.E), (self.triple[:, 1], np.arange(self.E)))).astype('bool') 25 | self.tail2fact = csr_matrix( 26 | (np.ones(self.E), (self.triple[:, 2], np.arange(self.E)))).astype('bool') 27 | self.max_head = max(self.triple[:, 0]) 28 | self.max_tail = max(self.triple[:, 2]) 29 | self.last_tails = set() 30 | 31 | @staticmethod 32 | def _load_npy_file(filename): 33 | return np.load(filename) 34 | 35 | def _fetch_forward_triple(self, seed_set): 36 | seed_set = np.clip(seed_set, a_min=0, a_max=self.max_head) 37 | indices = self.head2fact[seed_set].indices 38 | return self.triple[indices] 39 | 40 | def _fetch_backward_triple(self, seed_set): 41 | seed_set = np.clip(seed_set, a_min=0, a_max=self.max_tail) 42 | indices = self.tail2fact[seed_set].indices 43 | return self.triple[indices] 44 | 45 | def filter_cvt_nodes(self, seed_ary, CVT_TYPE=3): 46 | seed_type = self.ent_type[seed_ary] 47 | return seed_ary[seed_type == CVT_TYPE] 48 | 49 | def get_facts_1hop(self, seed_set, id2rel, rel2id, max_triples_per_relation, gold_relations): 50 | filtered_triples = defaultdict(list) 51 | filtered_tails = [] 52 | 53 | triples = self._fetch_forward_triple(seed_set) 54 | if len(triples) == 0: 55 | print("No triples") 56 | return filtered_triples, filtered_tails 57 | 58 | cur_heads = set() 59 | cur_heads.update(seed_set) 60 | 61 | candidate_rels = set(triples[:, 1].tolist()) 62 | candidate_rels_str = [id2rel[rel] for rel in candidate_rels] 63 | 64 | if len(candidate_rels_str) == 0: 65 | print("No candidate_rels_str") 66 | return filtered_triples, filtered_tails 67 | 68 | if gold_relations is None: 69 | filtered_rels_str = set(candidate_rels_str) 70 | else: 71 | filtered_rels_str = set(candidate_rels_str) & set(gold_relations) 72 | assert len(filtered_rels_str) == len(set(gold_relations)) 73 | 74 | for rel_str in filtered_rels_str: 75 | rel_indices = (triples[:, 1] == rel2id[rel_str]) 76 | triples_for_rel = triples[rel_indices] 77 | # if len(triples_for_rel) > max_triples_per_relation: 78 | # continue 79 | for triple in triples_for_rel.tolist(): 80 | if triple[2] not in self.last_tails: 81 | filtered_triples[0].append(triple) 82 | filtered_tails.append(triple[2]) 83 | 84 | cvt_tails = [tail for tail in filtered_tails if self.ent_type[tail] == 3] 85 | filtered_tails = list(set(filtered_tails) - set(cvt_tails)) 86 | if len(cvt_tails) > 0: 87 | triples = self._fetch_forward_triple(cvt_tails) 88 | if len(triples) > 0: 89 | cur_heads.update(cvt_tails) 90 | 91 | cur_invalid_rels = set() 92 | candidate_rels = set(triples[:, 1].tolist()) 93 | candidate_rels_str = [id2rel[rel] for rel in candidate_rels if rel not in cur_invalid_rels] 94 | filtered_rels_str = candidate_rels_str 95 | for rel_str in filtered_rels_str: 96 | rel_indices = (triples[:, 1] == rel2id[rel_str]) 97 | triples_for_rel = triples[rel_indices].tolist() 98 | for triple in triples_for_rel: 99 | if triple[2] not in seed_set: 100 | if self.ent_type[triple[2]] != 3: 101 | filtered_triples[1].append(triple) 102 | filtered_tails.append(triple[2]) 103 | 104 | self.last_tails = deepcopy(cur_heads) 105 | filtered_tails = list(set(filtered_tails)) 106 | return filtered_triples, filtered_tails 107 | 108 | def get_rels_1hop(self, seed_set, id2rel): 109 | triples = self._fetch_forward_triple(seed_set) 110 | if len(triples) == 0: 111 | return [] 112 | 113 | cur_heads = set() 114 | cur_heads.update(seed_set) 115 | 116 | cur_invalid_rels = set() 117 | for tail in self.last_tails: 118 | invalid_triples_indices = (triples[:, 2] == tail) 119 | invalid_triples = triples[invalid_triples_indices] 120 | invalid_rels = set(invalid_triples[:, 1]) 121 | cur_invalid_rels.update(invalid_rels) 122 | 123 | candidate_rels = set(triples[:, 1].tolist()) 124 | candidate_rels_str = [id2rel[rel] for rel in candidate_rels if rel not in cur_invalid_rels] 125 | 126 | return candidate_rels_str 127 | 128 | def get_filtered_rels(self, question, cur_relations, tokenizer, model, topk, filter_score): 129 | scored_rel_list, filtered_rel_scored_list = self.score_relations(question, cur_relations, tokenizer, model, filter_score) 130 | # 过滤关系和得分 131 | ordered_rels_scored = sorted(filtered_rel_scored_list, key=lambda x: x[1], reverse=True) 132 | # 过滤方法为topk和最少路径filter_method == "topk": 133 | reserved_rels = ordered_rels_scored[:topk] 134 | reserved_rels = [rel_score[0] for rel_score in reserved_rels] 135 | return reserved_rels 136 | 137 | 138 | -------------------------------------------------------------------------------- /KnowledgeBase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/KnowledgeBase/__init__.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StructGPT: A General Framework for Large Language Model to Reason on Structured Data 2 | 3 | This repo provides the source code & data of our paper: [StructGPT: A General Framework for Large Language Model to Reason on Structured Data](https://arxiv.org/pdf/2305.09645.pdf) (Arxiv 2023). 4 | 5 | ``` 6 | @InProceedings{Jiang-StructGPT-2022, 7 | author = {Jinhao Jiang and Kun Zhou and Zican Dong and Keming Ye and Wayne Xin Zhao and Ji-Rong Wen}, 8 | title = {StructGPT: A general framework for Large Language Model to Reason on Structured Data}, 9 | year = {2023}, 10 | journal={arXiv preprint arXiv:2305.09645}, 11 | url={https://arxiv.org/pdf/2305.09645} 12 | } 13 | ``` 14 | 15 | 16 |

17 | 18 |

19 | 20 | 21 | ## Usage 22 | ### 0. Requirements 23 | You only need to install the python library and openai for querying OpenAI model API. 24 | 25 | ### 1. Prepare Dataset 26 | We strongly suggest that download the processed datasets from [here](https://drive.google.com/drive/folders/11_2pqU_MhEtmxpp3zfK_8EJ1bbQzsnfJ?usp=sharing) and then directly use them. 27 | Apart from downloading from their original website, we use the processed datasets from [UnifiedSKG](https://github.com/HKUNLP/UnifiedSKG). 28 | After downloading our processed data, you can unzip them and put them in the */data* directory. 29 | 30 | ### 3. Experiment 31 | We have organized the running and evaluation scripts for each dataset under the */script* directory. 32 | 33 | #### 3.1 Evaluation on Text-to-SQL 34 | It is difficult to control the **randomness** of ChatGPT, so the reproduced results maybe a little different to the reported results. 35 | 36 | For **Spider** dataset, you can directly use the following command to start running and output the evaluation results: 37 | ```bash 38 | bash ./scripts/run_spider_wo_icl_v1.sh 39 | ``` 40 | 41 |

42 | 43 |

44 | 45 | Similarly, you can run the corresponding script for **Spider-SYN** and **Spider-Realistic** to get the evaluation results. 46 | 47 | **Spider_Realistic** 48 |

49 | 50 |

51 | 52 | **Spider-SYN** 53 |

54 | 55 |

56 | We save all the prediction file in *outputs/* directory. 57 | 58 | #### 3.2 Evaluation on TableQA 59 | It is difficult to control the **randomness** of ChatGPT, so the reproduced results maybe a little different to the reported results. 60 | 61 | For **TabFact** dataset, you can directly use the following command to start running and output the evaluation results: 62 | ```bash 63 | bash ./scripts/run_tabfact_wo_icl_v1.sh 64 | ``` 65 | 66 |

67 | 68 |

69 | 70 | Similarly, you can run the corresponding script for **WTQ** and **WikiSQL** to get the evaluation results. 71 | 72 | **WTQ** 73 |

74 | 75 |

76 | 77 | **WikiSQL** 78 |

79 | 80 |

81 | We save all the prediction file in *outputs/* directory. 82 | 83 | #### 3.3 Evaluation on KGQA 84 | It is difficult to control the **randomness** of ChatGPT, so the reproduced results maybe a little different to the reported results. 85 | 86 | For **WebQSP** dataset, you can directly use the following command to start running and output the evaluation results: 87 | ```bash 88 | bash ./scripts/run_webqsp_wo_icl_v1.sh 89 | ``` 90 | 91 |

92 | 93 |

94 | 95 | Similarly, you can run the corresponding script for **MetaQA (1hop,2hop,3hop)** to get the evaluation results. 96 | 97 | We save all the prediction file in *outputs/* directory. 98 | 99 | ## Plan 100 | Thanks for your attention. 101 | A version with better performance is on the way. 102 | Please continue to follow us! 103 | -------------------------------------------------------------------------------- /asset/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/model.png -------------------------------------------------------------------------------- /asset/spider_rea_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/spider_rea_wo_icl_v1.png -------------------------------------------------------------------------------- /asset/spider_syn_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/spider_syn_wo_icl_v1.png -------------------------------------------------------------------------------- /asset/spider_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/spider_wo_icl_v1.png -------------------------------------------------------------------------------- /asset/tabfact_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/tabfact_wo_icl_v1.png -------------------------------------------------------------------------------- /asset/webqsp_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/webqsp_wo_icl_v1.png -------------------------------------------------------------------------------- /asset/wikisql_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/wikisql_wo_icl_v1.png -------------------------------------------------------------------------------- /asset/wtq_wo_icl_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/wtq_wo_icl_v1.png -------------------------------------------------------------------------------- /evaluate_for_spider.py: -------------------------------------------------------------------------------- 1 | 2 | ################################ 3 | # val: number(float)/string(str)/sql(dict) 4 | # col_unit: (agg_id, col_id, isDistinct(bool)) 5 | # val_unit: (unit_op, col_unit1, col_unit2) 6 | # table_unit: (table_type, col_unit/sql) 7 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 8 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 9 | # sql { 10 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 11 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 12 | # 'where': condition 13 | # 'groupBy': [col_unit1, col_unit2, ...] 14 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 15 | # 'having': condition 16 | # 'limit': None/limit value 17 | # 'intersect': None/sql 18 | # 'except': None/sql 19 | # 'union': None/sql 20 | # } 21 | ################################ 22 | 23 | from __future__ import print_function 24 | import os, sys 25 | import json 26 | import sqlite3 27 | import traceback 28 | import argparse 29 | from tqdm import tqdm 30 | from itertools import product 31 | from collections import defaultdict 32 | import random 33 | 34 | from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql 35 | 36 | # Flag to disable value evaluation 37 | DISABLE_VALUE = True 38 | # Flag to disable distinct in select evaluation 39 | DISABLE_DISTINCT = True 40 | 41 | 42 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 43 | JOIN_KEYWORDS = ('join', 'on', 'as') 44 | 45 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 46 | UNIT_OPS = ('none', '-', '+', "*", '/') 47 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 48 | TABLE_TYPE = { 49 | 'sql': "sql", 50 | 'table_unit': "table_unit", 51 | } 52 | 53 | COND_OPS = ('and', 'or') 54 | SQL_OPS = ('intersect', 'union', 'except') 55 | ORDER_OPS = ('desc', 'asc') 56 | 57 | 58 | HARDNESS = { 59 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), 60 | "component2": ('except', 'union', 'intersect') 61 | } 62 | 63 | 64 | def condition_has_or(conds): 65 | return 'or' in conds[1::2] 66 | 67 | 68 | def condition_has_like(conds): 69 | return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] 70 | 71 | 72 | def condition_has_sql(conds): 73 | for cond_unit in conds[::2]: 74 | val1, val2 = cond_unit[3], cond_unit[4] 75 | if val1 is not None and type(val1) is dict: 76 | return True 77 | if val2 is not None and type(val2) is dict: 78 | return True 79 | return False 80 | 81 | 82 | def val_has_op(val_unit): 83 | return val_unit[0] != UNIT_OPS.index('none') 84 | 85 | 86 | def has_agg(unit): 87 | return unit[0] != AGG_OPS.index('none') 88 | 89 | 90 | def accuracy(count, total): 91 | if count == total: 92 | return 1 93 | return 0 94 | 95 | 96 | def recall(count, total): 97 | if count == total: 98 | return 1 99 | return 0 100 | 101 | 102 | def F1(acc, rec): 103 | if (acc + rec) == 0: 104 | return 0 105 | return (2. * acc * rec) / (acc + rec) 106 | 107 | 108 | def get_scores(count, pred_total, label_total): 109 | if pred_total != label_total: 110 | return 0,0,0 111 | elif count == pred_total: 112 | return 1,1,1 113 | return 0,0,0 114 | 115 | 116 | def eval_sel(pred, label): 117 | pred_sel = pred['select'][1] 118 | label_sel = label['select'][1] 119 | label_wo_agg = [unit[1] for unit in label_sel] 120 | pred_total = len(pred_sel) 121 | label_total = len(label_sel) 122 | cnt = 0 123 | cnt_wo_agg = 0 124 | 125 | for unit in pred_sel: 126 | if unit in label_sel: 127 | cnt += 1 128 | label_sel.remove(unit) 129 | if unit[1] in label_wo_agg: 130 | cnt_wo_agg += 1 131 | label_wo_agg.remove(unit[1]) 132 | 133 | return label_total, pred_total, cnt, cnt_wo_agg 134 | 135 | 136 | def eval_where(pred, label): 137 | pred_conds = [unit for unit in pred['where'][::2]] 138 | label_conds = [unit for unit in label['where'][::2]] 139 | label_wo_agg = [unit[2] for unit in label_conds] 140 | pred_total = len(pred_conds) 141 | label_total = len(label_conds) 142 | cnt = 0 143 | cnt_wo_agg = 0 144 | 145 | for unit in pred_conds: 146 | if unit in label_conds: 147 | cnt += 1 148 | label_conds.remove(unit) 149 | if unit[2] in label_wo_agg: 150 | cnt_wo_agg += 1 151 | label_wo_agg.remove(unit[2]) 152 | 153 | return label_total, pred_total, cnt, cnt_wo_agg 154 | 155 | 156 | def eval_group(pred, label): 157 | pred_cols = [unit[1] for unit in pred['groupBy']] 158 | label_cols = [unit[1] for unit in label['groupBy']] 159 | pred_total = len(pred_cols) 160 | label_total = len(label_cols) 161 | cnt = 0 162 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] 163 | label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] 164 | for col in pred_cols: 165 | if col in label_cols: 166 | cnt += 1 167 | label_cols.remove(col) 168 | return label_total, pred_total, cnt 169 | 170 | 171 | def eval_having(pred, label): 172 | pred_total = label_total = cnt = 0 173 | if len(pred['groupBy']) > 0: 174 | pred_total = 1 175 | if len(label['groupBy']) > 0: 176 | label_total = 1 177 | 178 | pred_cols = [unit[1] for unit in pred['groupBy']] 179 | label_cols = [unit[1] for unit in label['groupBy']] 180 | if pred_total == label_total == 1 \ 181 | and pred_cols == label_cols \ 182 | and pred['having'] == label['having']: 183 | cnt = 1 184 | 185 | return label_total, pred_total, cnt 186 | 187 | 188 | def eval_order(pred, label): 189 | pred_total = label_total = cnt = 0 190 | if len(pred['orderBy']) > 0: 191 | pred_total = 1 192 | if len(label['orderBy']) > 0: 193 | label_total = 1 194 | if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ 195 | ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): 196 | cnt = 1 197 | return label_total, pred_total, cnt 198 | 199 | 200 | def eval_and_or(pred, label): 201 | pred_ao = pred['where'][1::2] 202 | label_ao = label['where'][1::2] 203 | pred_ao = set(pred_ao) 204 | label_ao = set(label_ao) 205 | 206 | if pred_ao == label_ao: 207 | return 1,1,1 208 | return len(pred_ao),len(label_ao),0 209 | 210 | 211 | def get_nestedSQL(sql): 212 | nested = [] 213 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: 214 | if type(cond_unit[3]) is dict: 215 | nested.append(cond_unit[3]) 216 | if type(cond_unit[4]) is dict: 217 | nested.append(cond_unit[4]) 218 | if sql['intersect'] is not None: 219 | nested.append(sql['intersect']) 220 | if sql['except'] is not None: 221 | nested.append(sql['except']) 222 | if sql['union'] is not None: 223 | nested.append(sql['union']) 224 | return nested 225 | 226 | 227 | def eval_nested(pred, label): 228 | label_total = 0 229 | pred_total = 0 230 | cnt = 0 231 | if pred is not None: 232 | pred_total += 1 233 | if label is not None: 234 | label_total += 1 235 | if pred is not None and label is not None: 236 | cnt += Evaluator().eval_exact_match(pred, label) 237 | return label_total, pred_total, cnt 238 | 239 | 240 | def eval_IUEN(pred, label): 241 | lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) 242 | lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) 243 | lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) 244 | label_total = lt1 + lt2 + lt3 245 | pred_total = pt1 + pt2 + pt3 246 | cnt = cnt1 + cnt2 + cnt3 247 | return label_total, pred_total, cnt 248 | 249 | 250 | def get_keywords(sql): 251 | res = set() 252 | if len(sql['where']) > 0: 253 | res.add('where') 254 | if len(sql['groupBy']) > 0: 255 | res.add('group') 256 | if len(sql['having']) > 0: 257 | res.add('having') 258 | if len(sql['orderBy']) > 0: 259 | res.add(sql['orderBy'][0]) 260 | res.add('order') 261 | if sql['limit'] is not None: 262 | res.add('limit') 263 | if sql['except'] is not None: 264 | res.add('except') 265 | if sql['union'] is not None: 266 | res.add('union') 267 | if sql['intersect'] is not None: 268 | res.add('intersect') 269 | 270 | # or keyword 271 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 272 | if len([token for token in ao if token == 'or']) > 0: 273 | res.add('or') 274 | 275 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 276 | # not keyword 277 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: 278 | res.add('not') 279 | 280 | # in keyword 281 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: 282 | res.add('in') 283 | 284 | # like keyword 285 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: 286 | res.add('like') 287 | 288 | return res 289 | 290 | 291 | def eval_keywords(pred, label): 292 | pred_keywords = get_keywords(pred) 293 | label_keywords = get_keywords(label) 294 | pred_total = len(pred_keywords) 295 | label_total = len(label_keywords) 296 | cnt = 0 297 | 298 | for k in pred_keywords: 299 | if k in label_keywords: 300 | cnt += 1 301 | return label_total, pred_total, cnt 302 | 303 | 304 | def count_agg(units): 305 | return len([unit for unit in units if has_agg(unit)]) 306 | 307 | 308 | def count_component1(sql): 309 | count = 0 310 | if len(sql['where']) > 0: 311 | count += 1 312 | if len(sql['groupBy']) > 0: 313 | count += 1 314 | if len(sql['orderBy']) > 0: 315 | count += 1 316 | if sql['limit'] is not None: 317 | count += 1 318 | if len(sql['from']['table_units']) > 0: # JOIN 319 | count += len(sql['from']['table_units']) - 1 320 | 321 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 322 | count += len([token for token in ao if token == 'or']) 323 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 324 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) 325 | 326 | return count 327 | 328 | 329 | def count_component2(sql): 330 | nested = get_nestedSQL(sql) 331 | return len(nested) 332 | 333 | 334 | def count_others(sql): 335 | count = 0 336 | # number of aggregation 337 | agg_count = count_agg(sql['select'][1]) 338 | agg_count += count_agg(sql['where'][::2]) 339 | agg_count += count_agg(sql['groupBy']) 340 | if len(sql['orderBy']) > 0: 341 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + 342 | [unit[2] for unit in sql['orderBy'][1] if unit[2]]) 343 | agg_count += count_agg(sql['having']) 344 | if agg_count > 1: 345 | count += 1 346 | 347 | # number of select columns 348 | if len(sql['select'][1]) > 1: 349 | count += 1 350 | 351 | # number of where conditions 352 | if len(sql['where']) > 1: 353 | count += 1 354 | 355 | # number of group by clauses 356 | if len(sql['groupBy']) > 1: 357 | count += 1 358 | 359 | return count 360 | 361 | 362 | class Evaluator: 363 | """A simple evaluator""" 364 | def __init__(self): 365 | self.partial_scores = None 366 | 367 | def eval_hardness(self, sql): 368 | count_comp1_ = count_component1(sql) 369 | count_comp2_ = count_component2(sql) 370 | count_others_ = count_others(sql) 371 | 372 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: 373 | return "easy" 374 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ 375 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): 376 | return "medium" 377 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ 378 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ 379 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): 380 | return "hard" 381 | else: 382 | return "extra" 383 | 384 | def eval_exact_match(self, pred, label): 385 | partial_scores = self.eval_partial_match(pred, label) 386 | self.partial_scores = partial_scores 387 | 388 | for _, score in partial_scores.items(): 389 | if score['f1'] != 1: 390 | return 0 391 | if len(label['from']['table_units']) > 0: 392 | label_tables = sorted(label['from']['table_units']) 393 | pred_tables = sorted(pred['from']['table_units']) 394 | return label_tables == pred_tables 395 | return 1 396 | 397 | def eval_partial_match(self, pred, label): 398 | res = {} 399 | 400 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) 401 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 402 | res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 403 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 404 | res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 405 | 406 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) 407 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 408 | res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 409 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 410 | res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 411 | 412 | label_total, pred_total, cnt = eval_group(pred, label) 413 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 414 | res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 415 | 416 | label_total, pred_total, cnt = eval_having(pred, label) 417 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 418 | res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 419 | 420 | label_total, pred_total, cnt = eval_order(pred, label) 421 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 422 | res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 423 | 424 | label_total, pred_total, cnt = eval_and_or(pred, label) 425 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 426 | res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 427 | 428 | label_total, pred_total, cnt = eval_IUEN(pred, label) 429 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 430 | res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 431 | 432 | label_total, pred_total, cnt = eval_keywords(pred, label) 433 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 434 | res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 435 | 436 | return res 437 | 438 | 439 | def isValidSQL(sql, db): 440 | conn = sqlite3.connect(db) 441 | cursor = conn.cursor() 442 | try: 443 | cursor.execute(sql) 444 | except: 445 | return False 446 | return True 447 | 448 | 449 | def print_scores(scores, etype): 450 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 451 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 452 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 453 | 454 | print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) 455 | counts = [scores[level]['count'] for level in levels] 456 | print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) 457 | 458 | if etype in ["all", "exec"]: 459 | print('===================== EXECUTION ACCURACY =====================') 460 | this_scores = [scores[level]['exec'] for level in levels] 461 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) 462 | 463 | if etype in ["all", "match"]: 464 | print('\n====================== EXACT MATCHING ACCURACY =====================') 465 | exact_scores = [scores[level]['exact'] for level in levels] 466 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) 467 | print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') 468 | for type_ in partial_types: 469 | this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] 470 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) 471 | 472 | print('---------------------- PARTIAL MATCHING RECALL ----------------------') 473 | for type_ in partial_types: 474 | this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] 475 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) 476 | 477 | print('---------------------- PARTIAL MATCHING F1 --------------------------') 478 | for type_ in partial_types: 479 | this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] 480 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) 481 | 482 | 483 | def evaluate(path, db_dir, etype, kmaps, db_dir2): 484 | with open(path, "r") as f: 485 | all_lines = f.readlines() 486 | all_data = [json.loads(line) for line in all_lines] 487 | glist = [] 488 | plist = [] 489 | for item in all_data: 490 | glist.append((item['query'] + '\t' + item['db_id']).strip().split('\t')) 491 | result_query = 'SELECT ' + item['Prediction'].replace('\n', ' ') 492 | result_query = result_query.replace('SELECT SELECT', 'SELECT') 493 | plist.append(result_query.strip()) 494 | 495 | evaluator = Evaluator() 496 | 497 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 498 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 499 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 500 | entries = [] 501 | scores = {} 502 | 503 | for level in levels: 504 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} 505 | scores[level]['exec'] = 0 506 | for type_ in partial_types: 507 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} 508 | 509 | eval_err_num = 0 510 | index = 0 511 | for p, g in zip(plist, glist): 512 | # print(index) 513 | index += 1 514 | # p_str = p[0] 515 | p_str = p 516 | p_ori = p 517 | try: 518 | g_str, db = g 519 | except: 520 | import ipdb; ipdb.set_trace() 521 | db_name = db 522 | db = os.path.join(db_dir, db_name, db_name + ".sqlite") 523 | if db_dir2 != "": 524 | db2 = os.path.join(db_dir2, db_name, db_name + ".sqlite") 525 | schema = Schema(get_schema(db)) 526 | try: 527 | g_sql = get_sql(schema, g_str) 528 | except: 529 | continue 530 | 531 | hardness = evaluator.eval_hardness(g_sql) 532 | scores[hardness]['count'] += 1 533 | scores['all']['count'] += 1 534 | 535 | try: 536 | p_sql = get_sql(schema, p_str) 537 | except: 538 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 539 | p_sql = { 540 | "except": None, 541 | "from": { 542 | "conds": [], 543 | "table_units": [] 544 | }, 545 | "groupBy": [], 546 | "having": [], 547 | "intersect": None, 548 | "limit": None, 549 | "orderBy": [], 550 | "select": [ 551 | False, 552 | [] 553 | ], 554 | "union": None, 555 | "where": [] 556 | } 557 | # import ipdb; ipdb.set_trace() 558 | eval_err_num += 1 559 | # print("eval_err_num:{}".format(eval_err_num)) 560 | 561 | # rebuild sql for value evaluation 562 | kmap = kmaps[db_name] 563 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) 564 | g_sql = rebuild_sql_val(g_sql) 565 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 566 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) 567 | p_sql = rebuild_sql_val(p_sql) 568 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 569 | 570 | if etype in ["all", "exec"]: 571 | if db_dir2 == "": 572 | db2 = db 573 | results = eval_exec_match(db, db2, p_ori, g_str, p_sql, g_sql) 574 | if results is False: 575 | continue 576 | else: 577 | exec_score, (pred, gold) = results 578 | if exec_score == 0: 579 | print("{} pred: {}".format(hardness, p_str)) 580 | print("{} gold: {}".format(hardness, g_str)) 581 | 582 | if exec_score: 583 | scores[hardness]['exec'] += 1.0 584 | scores['all']['exec'] += 1.0 585 | 586 | if etype in ["all", "match"]: 587 | exact_score = evaluator.eval_exact_match(p_sql, g_sql) 588 | partial_scores = evaluator.partial_scores 589 | scores[hardness]['exact'] += exact_score 590 | scores['all']['exact'] += exact_score 591 | for type_ in partial_types: 592 | if partial_scores[type_]['pred_total'] > 0: 593 | scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] 594 | scores[hardness]['partial'][type_]['acc_count'] += 1 595 | if partial_scores[type_]['label_total'] > 0: 596 | scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] 597 | scores[hardness]['partial'][type_]['rec_count'] += 1 598 | scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] 599 | if partial_scores[type_]['pred_total'] > 0: 600 | scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] 601 | scores['all']['partial'][type_]['acc_count'] += 1 602 | if partial_scores[type_]['label_total'] > 0: 603 | scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] 604 | scores['all']['partial'][type_]['rec_count'] += 1 605 | scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] 606 | 607 | entries.append({ 608 | 'predictSQL': p_str, 609 | 'goldSQL': g_str, 610 | 'hardness': hardness, 611 | 'exact': exact_score, 612 | 'partial': partial_scores 613 | }) 614 | 615 | for level in levels: 616 | if scores[level]['count'] == 0: 617 | continue 618 | if etype in ["all", "exec"]: 619 | scores[level]['exec'] /= scores[level]['count'] 620 | 621 | if etype in ["all", "match"]: 622 | scores[level]['exact'] /= scores[level]['count'] 623 | for type_ in partial_types: 624 | if scores[level]['partial'][type_]['acc_count'] == 0: 625 | scores[level]['partial'][type_]['acc'] = 0 626 | else: 627 | scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ 628 | scores[level]['partial'][type_]['acc_count'] * 1.0 629 | if scores[level]['partial'][type_]['rec_count'] == 0: 630 | scores[level]['partial'][type_]['rec'] = 0 631 | else: 632 | scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ 633 | scores[level]['partial'][type_]['rec_count'] * 1.0 634 | if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: 635 | scores[level]['partial'][type_]['f1'] = 1 636 | else: 637 | scores[level]['partial'][type_]['f1'] = \ 638 | 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( 639 | scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) 640 | 641 | print_scores(scores, etype) 642 | 643 | def get_constraint_permutation(tab1_sets_by_columns, result2): 644 | num_cols = len(result2[0]) 645 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] 646 | if num_cols <= 3: 647 | return product(*perm_constraints) 648 | 649 | # we sample 20 rows and constrain the space of permutations 650 | for _ in range(20): 651 | random_tab2_row = random.choice(result2) 652 | 653 | for tab1_col in range(num_cols): 654 | for tab2_col in set(perm_constraints[tab1_col]): 655 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: 656 | perm_constraints[tab1_col].remove(tab2_col) 657 | return product(*perm_constraints) 658 | 659 | def unorder_row(row): 660 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) 661 | 662 | 663 | def quick_rej(result1, result2, order_matters): 664 | s1 = [unorder_row(row) for row in result1] 665 | s2 = [unorder_row(row) for row in result2] 666 | if order_matters: 667 | return s1 == s2 668 | else: 669 | return set(s1) == set(s2) 670 | 671 | def multiset_eq(l1, l2): 672 | if len(l1) != len(l2): 673 | return False 674 | d = defaultdict(int) 675 | for e in l1: 676 | d[e] = d[e] + 1 677 | for e in l2: 678 | d[e] = d[e] - 1 679 | if d[e] < 0: 680 | return False 681 | return True 682 | 683 | def permute_tuple(element, perm): 684 | assert len(element) == len(perm) 685 | return tuple([element[i] for i in perm]) 686 | 687 | 688 | def result_eq(result1, result2, order_matters): 689 | if len(result1) == 0 and len(result2) == 0: 690 | return True 691 | 692 | # if length is not the same, then they are definitely different bag of rows 693 | if len(result1) != len(result2): 694 | return False 695 | 696 | num_cols = len(result1[0]) 697 | 698 | # if the results do not have the same number of columns, they are different 699 | if len(result2[0]) != num_cols: 700 | return False 701 | 702 | # unorder each row and compare whether the denotation is the same 703 | # this can already find most pair of denotations that are different 704 | if not quick_rej(result1, result2, order_matters): 705 | return False 706 | 707 | # the rest of the problem is in fact more complicated than one might think 708 | # we want to find a permutation of column order and a permutation of row order, 709 | # s.t. result_1 is the same as result_2 710 | # we return true if we can find such column & row permutations 711 | # and false if we cannot 712 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] 713 | 714 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 715 | # we decrease the size of the column permutation space by the function get_constraint_permutation 716 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent 717 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2): 718 | if len(perm) != len(set(perm)): 719 | continue 720 | if num_cols == 1: 721 | result2_perm = result2 722 | else: 723 | result2_perm = [permute_tuple(element, perm) for element in result2] 724 | if order_matters: 725 | if result1 == result2_perm: 726 | return True 727 | else: 728 | # in fact the first condition must hold if the second condition holds 729 | # but the first is way more efficient implementation-wise 730 | # and we use it to quickly reject impossible candidates 731 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): 732 | return True 733 | return False 734 | 735 | 736 | def eval_exec_match(db, db2, p_str, g_str, pred, gold): 737 | """ 738 | return 1 if the values between prediction and gold are matching 739 | in the corresponding index. Currently not support multiple col_unit(pairs). 740 | """ 741 | conn = sqlite3.connect(db2) 742 | cursor = conn.cursor() 743 | try: 744 | cursor.execute(p_str) 745 | p_res = cursor.fetchall() 746 | except Exception as e: 747 | # import ipdb; ipdb.set_trace() 748 | print(e) 749 | print(p_str) 750 | return False 751 | 752 | conn = sqlite3.connect(db) 753 | cursor = conn.cursor() 754 | try: 755 | cursor.execute(g_str) 756 | except: 757 | # import ipdb; ipdb.set_trace() 758 | return False 759 | q_res = cursor.fetchall() 760 | 761 | orders_matter = 'order by' in g_str.lower() 762 | 763 | return result_eq(p_res, q_res, order_matters=orders_matter), (p_res, q_res) 764 | 765 | 766 | # Rebuild SQL functions for value evaluation 767 | def rebuild_cond_unit_val(cond_unit): 768 | if cond_unit is None or not DISABLE_VALUE: 769 | return cond_unit 770 | 771 | not_op, op_id, val_unit, val1, val2 = cond_unit 772 | if type(val1) is not dict: 773 | val1 = None 774 | else: 775 | val1 = rebuild_sql_val(val1) 776 | if type(val2) is not dict: 777 | val2 = None 778 | else: 779 | val2 = rebuild_sql_val(val2) 780 | return not_op, op_id, val_unit, val1, val2 781 | 782 | 783 | def rebuild_condition_val(condition): 784 | if condition is None or not DISABLE_VALUE: 785 | return condition 786 | 787 | res = [] 788 | for idx, it in enumerate(condition): 789 | if idx % 2 == 0: 790 | res.append(rebuild_cond_unit_val(it)) 791 | else: 792 | res.append(it) 793 | return res 794 | 795 | 796 | def rebuild_sql_val(sql): 797 | if sql is None or not DISABLE_VALUE: 798 | return sql 799 | 800 | sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) 801 | sql['having'] = rebuild_condition_val(sql['having']) 802 | sql['where'] = rebuild_condition_val(sql['where']) 803 | sql['intersect'] = rebuild_sql_val(sql['intersect']) 804 | sql['except'] = rebuild_sql_val(sql['except']) 805 | sql['union'] = rebuild_sql_val(sql['union']) 806 | 807 | return sql 808 | 809 | 810 | # Rebuild SQL functions for foreign key evaluation 811 | def build_valid_col_units(table_units, schema): 812 | col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] 813 | prefixs = [col_id[:-2] for col_id in col_ids] 814 | valid_col_units= [] 815 | for value in schema.idMap.values(): 816 | if '.' in value and value[:value.index('.')] in prefixs: 817 | valid_col_units.append(value) 818 | return valid_col_units 819 | 820 | 821 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap): 822 | if col_unit is None: 823 | return col_unit 824 | 825 | agg_id, col_id, distinct = col_unit 826 | if col_id in kmap and col_id in valid_col_units: 827 | col_id = kmap[col_id] 828 | if DISABLE_DISTINCT: 829 | distinct = None 830 | return agg_id, col_id, distinct 831 | 832 | 833 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap): 834 | if val_unit is None: 835 | return val_unit 836 | 837 | unit_op, col_unit1, col_unit2 = val_unit 838 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) 839 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) 840 | return unit_op, col_unit1, col_unit2 841 | 842 | 843 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap): 844 | if table_unit is None: 845 | return table_unit 846 | 847 | table_type, col_unit_or_sql = table_unit 848 | if isinstance(col_unit_or_sql, tuple): 849 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) 850 | return table_type, col_unit_or_sql 851 | 852 | 853 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): 854 | if cond_unit is None: 855 | return cond_unit 856 | 857 | not_op, op_id, val_unit, val1, val2 = cond_unit 858 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) 859 | return not_op, op_id, val_unit, val1, val2 860 | 861 | 862 | def rebuild_condition_col(valid_col_units, condition, kmap): 863 | for idx in range(len(condition)): 864 | if idx % 2 == 0: 865 | condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) 866 | return condition 867 | 868 | 869 | def rebuild_select_col(valid_col_units, sel, kmap): 870 | if sel is None: 871 | return sel 872 | distinct, _list = sel 873 | new_list = [] 874 | for it in _list: 875 | agg_id, val_unit = it 876 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) 877 | if DISABLE_DISTINCT: 878 | distinct = None 879 | return distinct, new_list 880 | 881 | 882 | def rebuild_from_col(valid_col_units, from_, kmap): 883 | if from_ is None: 884 | return from_ 885 | 886 | from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] 887 | from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) 888 | return from_ 889 | 890 | 891 | def rebuild_group_by_col(valid_col_units, group_by, kmap): 892 | if group_by is None: 893 | return group_by 894 | 895 | return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] 896 | 897 | 898 | def rebuild_order_by_col(valid_col_units, order_by, kmap): 899 | if order_by is None or len(order_by) == 0: 900 | return order_by 901 | 902 | direction, val_units = order_by 903 | new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] 904 | return direction, new_val_units 905 | 906 | 907 | def rebuild_sql_col(valid_col_units, sql, kmap): 908 | if sql is None: 909 | return sql 910 | 911 | sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) 912 | sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) 913 | sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) 914 | sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) 915 | sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) 916 | sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) 917 | sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) 918 | sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) 919 | sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) 920 | 921 | return sql 922 | 923 | 924 | def build_foreign_key_map(entry): 925 | try: 926 | cols_orig = entry["column_names_original"] 927 | except: 928 | import ipdb; ipdb.set_trace() 929 | tables_orig = entry["table_names_original"] 930 | 931 | # rebuild cols corresponding to idmap in Schema 932 | cols = [] 933 | for col_orig in cols_orig: 934 | if col_orig[0] >= 0: 935 | t = tables_orig[col_orig[0]] 936 | c = col_orig[1] 937 | cols.append("__" + t.lower() + "." + c.lower() + "__") 938 | else: 939 | cols.append("__all__") 940 | 941 | def keyset_in_list(k1, k2, k_list): 942 | for k_set in k_list: 943 | if k1 in k_set or k2 in k_set: 944 | return k_set 945 | new_k_set = set() 946 | k_list.append(new_k_set) 947 | return new_k_set 948 | 949 | foreign_key_list = [] 950 | foreign_keys = entry["foreign_keys"] 951 | for fkey in foreign_keys: 952 | key1, key2 = fkey 953 | key_set = keyset_in_list(key1, key2, foreign_key_list) 954 | key_set.add(key1) 955 | key_set.add(key2) 956 | 957 | foreign_key_map = {} 958 | for key_set in foreign_key_list: 959 | sorted_list = sorted(list(key_set)) 960 | midx = sorted_list[0] 961 | for idx in sorted_list: 962 | foreign_key_map[cols[idx]] = cols[midx] 963 | 964 | return foreign_key_map 965 | 966 | 967 | def build_foreign_key_map_from_json(table): 968 | with open(table) as f: 969 | data = json.load(f) 970 | tables = {} 971 | for entry in data: 972 | tables[entry['db_id']] = build_foreign_key_map(entry) 973 | return tables 974 | 975 | 976 | if __name__ == "__main__": 977 | parser = argparse.ArgumentParser() 978 | parser.add_argument('--path', dest='path', type=str,default="outputs/spider/output.jsonl") 979 | parser.add_argument('--db', dest='db', type=str,default="data/spider/database") 980 | parser.add_argument('--db2', dest='db2', type=str, default="") 981 | parser.add_argument('--table', dest='table', type=str, default='data/spider/tables.json') 982 | parser.add_argument('--etype', dest='etype', type=str, default="exec") 983 | args = parser.parse_args() 984 | 985 | path=args.path 986 | db_dir = args.db 987 | db_dir2 = args.db2 988 | table = args.table 989 | etype = args.etype 990 | 991 | assert etype in ["all", "exec", "match"], "Unknown evaluation method" 992 | 993 | kmaps = build_foreign_key_map_from_json(table) 994 | 995 | evaluate(path, db_dir, etype, kmaps, db_dir2) 996 | -------------------------------------------------------------------------------- /evaluate_for_tabfact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import argparse 4 | 5 | 6 | def evaluate(ori_path, inp_path, error_cases_output, write_flag): 7 | with open(ori_path, "r") as f: 8 | all_data = json.loads(f.read()) 9 | print("Totally %d test data" % len(all_data)) 10 | 11 | pred_data = [] 12 | with open(inp_path, "r") as f: 13 | lines = f.readlines() 14 | datas = [json.loads(line) for line in lines] 15 | pred_data.extend(datas) 16 | 17 | all_pred_data = {pred['id']: pred for pred in pred_data} 18 | print("Totally %d prediction data" % len(pred_data)) # evaluate_is_right 19 | avg_acc = [] 20 | bad_cases = [] 21 | error_count = 0 22 | max_count = 0 23 | right_count = 0 24 | for data in all_data: 25 | if data["id"] in all_pred_data: 26 | data = all_pred_data[data["id"]] 27 | question = data['statement'] 28 | pred = data['Prediction'].lower() 29 | 30 | if 'yes' in pred and 'no' in pred: 31 | pred = 'unknown' 32 | elif 'yes' in pred: 33 | pred = 'entailed' 34 | elif 'no' in pred: 35 | pred = 'refuted' 36 | else: 37 | pred = 'unknown' 38 | 39 | answers = data['seq_out'].lower() 40 | if pred.strip() == answers.strip(): 41 | avg_acc.append(1) 42 | right_count += 1 43 | 44 | else: 45 | error_count += 1 46 | avg_acc.append(0) 47 | print("ID: %s Ques: %s" % (data["id"], question)) 48 | print("Pred: ", pred) 49 | print("Ans: ", answers) 50 | print("------------------------------------------------------------------------") 51 | bad_cases.append(data) 52 | 53 | else: 54 | avg_acc.append(0) 55 | print("ID: %s can't be predicted" % (data["id"])) 56 | bad_cases.append(data) 57 | error_count += 1 58 | max_count += 1 59 | 60 | acc = np.mean(avg_acc) 61 | print("Acc: %.4f" % (acc)) 62 | if write_flag: 63 | with open(error_cases_output, "w") as f: 64 | for bc in bad_cases: 65 | f.write(json.dumps(bc) + "\n") 66 | print("Totally %d bad cases need further solved." % len(bad_cases)) 67 | print("Right count: %d, Error count: %d(Max len count: %d)" % (right_count, error_count, max_count)) 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--ori_path', type=str, default="./data/tabfact/tab_fact_test.json") 73 | parser.add_argument('--inp_path', type=str, default="./outputs/tabfact/tabfact_test_output.jsonl") 74 | parser.add_argument('--error_cases_output', type=str, 75 | default='./outputs/tabfact/bad_cases.jsonl') 76 | parser.add_argument('--write_flag', action="store_true") 77 | args = parser.parse_args() 78 | evaluate(args.ori_path, args.inp_path, args.error_cases_output, args.write_flag) 79 | -------------------------------------------------------------------------------- /evaluate_for_webqsp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import openai 4 | import json 5 | import os 6 | import numpy as np 7 | from collections import defaultdict 8 | import pickle 9 | 10 | import json 11 | import numpy as np 12 | 13 | def evaluate(ori_path, pred_path, error_cases_output, write_flag): 14 | with open(ori_path, "r") as f: 15 | all_data = f.readlines() 16 | all_data = [json.loads(line) for line in all_data] 17 | 18 | with open(pred_path, "r", encoding="UTF-8") as f: 19 | all_lines = f.readlines() 20 | all_pred_data = [] 21 | for idx, line in enumerate(all_lines): 22 | line = line.replace("\x00", "").strip("\n") 23 | all_pred_data.append(json.loads(line)) 24 | all_pred_data = {pred['ID']: pred for pred in all_pred_data} 25 | print("Load %d prediction" % len(all_pred_data)) 26 | 27 | max_len_count = len(all_data) - len(all_pred_data) 28 | print("Totally %d prediction / %d all data" % (len(all_pred_data), len(all_data))) 29 | 30 | avg_hits1 = [] 31 | bad_cases = [] 32 | right_cases_id = [] 33 | bad_cases_id = [] 34 | right_count = 0 35 | bad_count = 0 36 | need_cvt_count = 0 37 | for data in all_data: 38 | if data["ID"] in all_pred_data: 39 | data = all_pred_data[data["ID"]] 40 | question = data['Question'] 41 | pred = data['Prediction'].lower() 42 | if "i'm sorry" in pred: 43 | pred = '' 44 | 45 | answers = data['Answers'] 46 | aliases = data['Aliases'] 47 | hit_flag = [] 48 | recall_flag = [] 49 | for ans in answers: 50 | ans = ans.lower() 51 | if ans in pred: 52 | hit_flag.append(1) 53 | recall_flag.append(1) 54 | else: 55 | hit_flag.append(0) 56 | recall_flag.append(0) 57 | for alia in aliases: 58 | alia = alia.lower() 59 | if alia in pred: 60 | hit_flag.append(1) 61 | else: 62 | hit_flag.append(0) 63 | 64 | if len(hit_flag) == 0: 65 | # print("ID:%s doesn't have any gold answers." % data['ID']) 66 | continue 67 | 68 | if any(hit_flag): 69 | avg_hits1.append(1) 70 | right_count += 1 71 | # other_count += 1 72 | right_cases_id.append(data['ID']) 73 | else: 74 | avg_hits1.append(0) 75 | bad_count += 1 76 | # other_count += 1 77 | # if "max length" in pred: 78 | # need_cvt_count += 1 79 | # else: 80 | # other_count += 1 81 | print(data["ID"]) 82 | print("ID: %s Ques: %s" % (data["ID"], question)) 83 | print("Pred: ", pred) 84 | print("Ans: ", answers) 85 | print("------------------------------------------------------------------------") 86 | bad_cases.append(data) 87 | bad_cases_id.append(data["ID"]) 88 | 89 | else: 90 | avg_hits1.append(0) 91 | print("ID: %s can't be predicted" % (data["ID"])) 92 | bad_cases.append(data) 93 | bad_cases_id.append(data["ID"]) 94 | 95 | hits1 = np.mean(avg_hits1) 96 | print("Hits@1: %.4f" % (hits1)) 97 | if write_flag: 98 | with open(error_cases_output, "w") as f: 99 | for bc in bad_cases: 100 | f.write(json.dumps(bc) + "\n") 101 | print("Totally %d bad cases need further solved." % len(bad_cases)) 102 | print("Right:%d, Wrong:%d, Max_len:%d" % (right_count, bad_count, max_len_count)) 103 | 104 | if __name__ == "__main__": 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--ori_path', type=str) 107 | parser.add_argument('--inp_path', type=str) 108 | parser.add_argument('--error_cases_output', action="store_true") 109 | parser.add_argument('--write_flag', action="store_true") 110 | args = parser.parse_args() 111 | evaluate(args.ori_path, args.inp_path, args.error_cases_output, args.write_flag) 112 | -------------------------------------------------------------------------------- /evaluate_for_wikisql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | 7 | def evaluate_example(_predict_str: str, _ground_str: list, target_delimiter=', '): 8 | _predict_spans = _predict_str.split(target_delimiter) 9 | _predict_spans = [x.lower().strip().strip('.').strip("'").strip('"').strip() for x in _predict_spans] 10 | for i in range(len(_predict_spans)): 11 | if _predict_spans[i].endswith('.0'): 12 | _predict_spans[i] = _predict_spans[i][:-2] 13 | 14 | if _predict_spans[i].replace(',', '').isnumeric(): 15 | _predict_spans[i] = _predict_spans[i].replace(',', '') 16 | # _ground_spans = _ground_str.split(target_delimiter) 17 | _ground_spans = [x.lower().strip().strip('.').strip("'").strip('"').strip() for x in _ground_str] 18 | for i in range(len(_ground_spans)): 19 | if _ground_spans[i].endswith('.0'): 20 | _ground_spans[i] = _ground_spans[i][:-2] 21 | 22 | if _ground_spans[i].replace(',', '').isnumeric(): 23 | _ground_spans[i] = _ground_spans[i].replace(',', '') 24 | _predict_values = defaultdict(lambda: 0) 25 | _ground_values = defaultdict(lambda: 0) 26 | for span in _predict_spans: 27 | try: 28 | _predict_values[float(span)] += 1 29 | except ValueError: 30 | _predict_values[span.strip()] += 1 31 | for span in _ground_spans: 32 | try: 33 | _ground_values[float(span)] += 1 34 | except ValueError: 35 | _ground_values[span.strip()] += 1 36 | _is_correct = _predict_values == _ground_values 37 | return _is_correct 38 | 39 | 40 | def evaluate(ori_path, inp_path, error_cases_output, write_flag): 41 | with open(ori_path, "r") as f: 42 | all_data = json.loads(f.read()) 43 | print("Totally %d test data" % len(all_data)) 44 | 45 | pred_data = [] 46 | with open(inp_path, "r") as f: 47 | lines = f.readlines() 48 | datas = [json.loads(line) for line in lines] 49 | pred_data.extend(datas) 50 | 51 | all_pred_data = {pred['question']: pred for pred in pred_data} 52 | print("Totally %d prediction data" % len(pred_data)) # evaluate_is_right 53 | avg_deno_acc = [] 54 | bad_cases = [] 55 | error_count = 0 56 | max_count = 0 57 | right_count = 0 58 | for data in all_data: 59 | if data["question"] in all_pred_data: 60 | data = all_pred_data[data["question"]] 61 | pred = data['Prediction'].lower() 62 | 63 | if "answers: " in pred: 64 | pred = pred.split("answers: ")[1].strip() 65 | elif ":" in pred: 66 | pred = pred.split(":")[1].strip() 67 | else: 68 | pred = pred 69 | 70 | answers = data['answer_text'] 71 | answers = [ans if not ans.endswith(".0") else ans.replace(".0", "") for ans in answers] 72 | 73 | if evaluate_example(pred, answers): 74 | avg_deno_acc.append(1) 75 | right_count += 1 76 | else: 77 | error_count += 1 78 | avg_deno_acc.append(0) 79 | # print("ID: %s Ques: %s" % (data["id"], question)) 80 | # print("Pred: ", pred) 81 | # print("Ans: ", answers) 82 | # print("------------------------------------------------------------------------") 83 | bad_cases.append(data) 84 | else: 85 | avg_deno_acc.append(0) 86 | print("ID: %s can't be predicted" % (data["id"])) 87 | bad_cases.append(data) 88 | error_count += 1 89 | max_count += 1 90 | 91 | acc = np.mean(avg_deno_acc) 92 | print("Denotation Acc: %.4f" % (acc)) 93 | if write_flag: 94 | with open(error_cases_output, "w") as f: 95 | for bc in bad_cases: 96 | f.write(json.dumps(bc) + "\n") 97 | print("Totally %d bad cases need further solved." % len(bad_cases)) 98 | print("Right count: %d, Error count: %d(Max len count: %d)" % (right_count, error_count, max_count)) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--ori_path', type=str, default="./data/wikisql/wikisql_test.json") 104 | parser.add_argument('--inp_path', type=str, default="./outputs/wikisql/output_wo_icl_v1.jsonl") 105 | parser.add_argument('--error_cases_output', type=str, 106 | default='./outputs/wikisql/bad_cases.jsonl') 107 | parser.add_argument('--write_flag', action="store_true") 108 | args = parser.parse_args() 109 | evaluate(args.ori_path, args.inp_path, args.error_cases_output, args.write_flag) 110 | -------------------------------------------------------------------------------- /process_sql.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Assumptions: 3 | # 1. sql is correct 4 | # 2. only table name has alias 5 | # 3. only one intersect/union/except 6 | # 7 | # val: number(float)/string(str)/sql(dict) 8 | # col_unit: (agg_id, col_id, isDistinct(bool)) 9 | # val_unit: (unit_op, col_unit1, col_unit2) 10 | # table_unit: (table_type, col_unit/sql) 11 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 13 | # sql { 14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 16 | # 'where': condition 17 | # 'groupBy': [col_unit1, col_unit2, ...] 18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 19 | # 'having': condition 20 | # 'limit': None/limit value 21 | # 'intersect': None/sql 22 | # 'except': None/sql 23 | # 'union': None/sql 24 | # } 25 | ################################ 26 | 27 | import json 28 | import sqlite3 29 | from nltk import word_tokenize 30 | 31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 32 | JOIN_KEYWORDS = ('join', 'on', 'as') 33 | 34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 35 | UNIT_OPS = ('none', '-', '+', "*", '/') 36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 37 | TABLE_TYPE = { 38 | 'sql': "sql", 39 | 'table_unit': "table_unit", 40 | } 41 | 42 | COND_OPS = ('and', 'or') 43 | SQL_OPS = ('intersect', 'union', 'except') 44 | ORDER_OPS = ('desc', 'asc') 45 | 46 | 47 | 48 | class Schema: 49 | """ 50 | Simple schema which maps table&column to a unique identifier 51 | """ 52 | def __init__(self, schema): 53 | self._schema = schema 54 | self._idMap = self._map(self._schema) 55 | 56 | @property 57 | def schema(self): 58 | return self._schema 59 | 60 | @property 61 | def idMap(self): 62 | return self._idMap 63 | 64 | def _map(self, schema): 65 | idMap = {'*': "__all__"} 66 | id = 1 67 | for key, vals in schema.items(): 68 | for val in vals: 69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 70 | id += 1 71 | 72 | for key in schema: 73 | idMap[key.lower()] = "__" + key.lower() + "__" 74 | id += 1 75 | 76 | return idMap 77 | 78 | 79 | def get_schema(db): 80 | """ 81 | Get database's schema, which is a dict with table name as key 82 | and list of column names as value 83 | :param db: database path 84 | :return: schema dict 85 | """ 86 | 87 | schema = {} 88 | conn = sqlite3.connect(db) 89 | cursor = conn.cursor() 90 | 91 | # fetch table names 92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 93 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 94 | 95 | # fetch table info 96 | for table in tables: 97 | cursor.execute("PRAGMA table_info({})".format(table)) 98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 99 | # import ipdb; ipdb.set_trace() 100 | return schema 101 | 102 | 103 | def get_schema_from_json(fpath): 104 | with open(fpath) as f: 105 | data = json.load(f) 106 | 107 | schema = {} 108 | for entry in data: 109 | table = str(entry['table'].lower()) 110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 111 | schema[table] = cols 112 | 113 | return schema 114 | 115 | 116 | def tokenize(string): 117 | string = str(string) 118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 121 | 122 | # keep string value as token 123 | vals = {} 124 | for i in range(len(quote_idxs)-1, -1, -2): 125 | qidx1 = quote_idxs[i-1] 126 | qidx2 = quote_idxs[i] 127 | val = string[qidx1: qidx2+1] 128 | key = "__val_{}_{}__".format(qidx1, qidx2) 129 | string = string[:qidx1] + key + string[qidx2+1:] 130 | vals[key] = val 131 | 132 | toks = [word.lower() for word in word_tokenize(string)] 133 | # replace with string value token 134 | for i in range(len(toks)): 135 | if toks[i] in vals: 136 | toks[i] = vals[toks[i]] 137 | 138 | # find if there exists !=, >=, <= 139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 140 | eq_idxs.reverse() 141 | prefix = ('!', '>', '<') 142 | for eq_idx in eq_idxs: 143 | pre_tok = toks[eq_idx-1] 144 | if pre_tok in prefix: 145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 146 | 147 | return toks 148 | 149 | 150 | def scan_alias(toks): 151 | """Scan the index of 'as' and build the map for all alias""" 152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 153 | alias = {} 154 | for idx in as_idxs: 155 | alias[toks[idx+1]] = toks[idx-1] 156 | return alias 157 | 158 | 159 | def get_tables_with_alias(schema, toks): 160 | tables = scan_alias(toks) 161 | for key in schema: 162 | assert key not in tables, "Alias {} has the same name in table".format(key) 163 | tables[key] = key 164 | return tables 165 | 166 | 167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 168 | """ 169 | :returns next idx, column id 170 | """ 171 | tok = toks[start_idx] 172 | if tok == "*": 173 | return start_idx + 1, schema.idMap[tok] 174 | 175 | if '.' in tok: # if token is a composite 176 | alias, col = tok.split('.') 177 | key = tables_with_alias[alias] + "." + col 178 | return start_idx+1, schema.idMap[key] 179 | 180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 181 | 182 | for alias in default_tables: 183 | table = tables_with_alias[alias] 184 | if tok in schema.schema[table]: 185 | key = table + "." + tok 186 | return start_idx+1, schema.idMap[key] 187 | 188 | assert False, "Error col: {}".format(tok) 189 | 190 | 191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 192 | """ 193 | :returns next idx, (agg_op id, col_id) 194 | """ 195 | idx = start_idx 196 | len_ = len(toks) 197 | isBlock = False 198 | isDistinct = False 199 | if toks[idx] == '(': 200 | isBlock = True 201 | idx += 1 202 | 203 | if toks[idx] in AGG_OPS: 204 | agg_id = AGG_OPS.index(toks[idx]) 205 | idx += 1 206 | assert idx < len_ and toks[idx] == '(' 207 | idx += 1 208 | if toks[idx] == "distinct": 209 | idx += 1 210 | isDistinct = True 211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 212 | assert idx < len_ and toks[idx] == ')' 213 | idx += 1 214 | return idx, (agg_id, col_id, isDistinct) 215 | 216 | if toks[idx] == "distinct": 217 | idx += 1 218 | isDistinct = True 219 | agg_id = AGG_OPS.index("none") 220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 221 | 222 | if isBlock: 223 | assert toks[idx] == ')' 224 | idx += 1 # skip ')' 225 | 226 | return idx, (agg_id, col_id, isDistinct) 227 | 228 | 229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 230 | idx = start_idx 231 | len_ = len(toks) 232 | isBlock = False 233 | if toks[idx] == '(': 234 | isBlock = True 235 | idx += 1 236 | 237 | col_unit1 = None 238 | col_unit2 = None 239 | unit_op = UNIT_OPS.index('none') 240 | 241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 242 | if idx < len_ and toks[idx] in UNIT_OPS: 243 | unit_op = UNIT_OPS.index(toks[idx]) 244 | idx += 1 245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 246 | 247 | if isBlock: 248 | assert toks[idx] == ')' 249 | idx += 1 # skip ')' 250 | 251 | return idx, (unit_op, col_unit1, col_unit2) 252 | 253 | 254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 255 | """ 256 | :returns next idx, table id, table name 257 | """ 258 | idx = start_idx 259 | len_ = len(toks) 260 | key = tables_with_alias[toks[idx]] 261 | 262 | if idx + 1 < len_ and toks[idx+1] == "as": 263 | idx += 3 264 | else: 265 | idx += 1 266 | 267 | return idx, schema.idMap[key], key 268 | 269 | 270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 271 | idx = start_idx 272 | len_ = len(toks) 273 | 274 | isBlock = False 275 | if toks[idx] == '(': 276 | isBlock = True 277 | idx += 1 278 | 279 | if toks[idx] == 'select': 280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 281 | elif "\"" in toks[idx]: # token is a string value 282 | val = toks[idx] 283 | idx += 1 284 | else: 285 | try: 286 | val = float(toks[idx]) 287 | idx += 1 288 | except: 289 | end_idx = idx 290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: 292 | end_idx += 1 293 | 294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 295 | idx = end_idx 296 | 297 | if isBlock: 298 | assert toks[idx] == ')' 299 | idx += 1 300 | 301 | return idx, val 302 | 303 | 304 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 305 | idx = start_idx 306 | len_ = len(toks) 307 | conds = [] 308 | 309 | while idx < len_: 310 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 311 | not_op = False 312 | if toks[idx] == 'not': 313 | not_op = True 314 | idx += 1 315 | 316 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 317 | op_id = WHERE_OPS.index(toks[idx]) 318 | idx += 1 319 | val1 = val2 = None 320 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 321 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 322 | assert toks[idx] == 'and' 323 | idx += 1 324 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 325 | else: # normal case: single value 326 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 327 | val2 = None 328 | 329 | conds.append((not_op, op_id, val_unit, val1, val2)) 330 | 331 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): 332 | break 333 | 334 | if idx < len_ and toks[idx] in COND_OPS: 335 | conds.append(toks[idx]) 336 | idx += 1 # skip and/or 337 | 338 | return idx, conds 339 | 340 | 341 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 342 | idx = start_idx 343 | len_ = len(toks) 344 | 345 | assert toks[idx] == 'select', "'select' not found" 346 | idx += 1 347 | isDistinct = False 348 | if idx < len_ and toks[idx] == 'distinct': 349 | idx += 1 350 | isDistinct = True 351 | val_units = [] 352 | 353 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 354 | agg_id = AGG_OPS.index("none") 355 | if toks[idx] in AGG_OPS: 356 | agg_id = AGG_OPS.index(toks[idx]) 357 | idx += 1 358 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 359 | val_units.append((agg_id, val_unit)) 360 | if idx < len_ and toks[idx] == ',': 361 | idx += 1 # skip ',' 362 | 363 | return idx, (isDistinct, val_units) 364 | 365 | 366 | def parse_from(toks, start_idx, tables_with_alias, schema): 367 | """ 368 | Assume in the from clause, all table units are combined with join 369 | """ 370 | assert 'from' in toks[start_idx:], "'from' not found" 371 | 372 | len_ = len(toks) 373 | idx = toks.index('from', start_idx) + 1 374 | default_tables = [] 375 | table_units = [] 376 | conds = [] 377 | 378 | while idx < len_: 379 | isBlock = False 380 | if toks[idx] == '(': 381 | isBlock = True 382 | idx += 1 383 | 384 | if toks[idx] == 'select': 385 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 386 | table_units.append((TABLE_TYPE['sql'], sql)) 387 | else: 388 | if idx < len_ and toks[idx] == 'join': 389 | idx += 1 # skip join 390 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 391 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 392 | default_tables.append(table_name) 393 | if idx < len_ and toks[idx] == "on": 394 | idx += 1 # skip on 395 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 396 | if len(conds) > 0: 397 | conds.append('and') 398 | conds.extend(this_conds) 399 | 400 | if isBlock: 401 | assert toks[idx] == ')' 402 | idx += 1 403 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 404 | break 405 | 406 | return idx, table_units, conds, default_tables 407 | 408 | 409 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 410 | idx = start_idx 411 | len_ = len(toks) 412 | 413 | if idx >= len_ or toks[idx] != 'where': 414 | return idx, [] 415 | 416 | idx += 1 417 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 418 | return idx, conds 419 | 420 | 421 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 422 | idx = start_idx 423 | len_ = len(toks) 424 | col_units = [] 425 | 426 | if idx >= len_ or toks[idx] != 'group': 427 | return idx, col_units 428 | 429 | idx += 1 430 | assert toks[idx] == 'by' 431 | idx += 1 432 | 433 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 434 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 435 | col_units.append(col_unit) 436 | if idx < len_ and toks[idx] == ',': 437 | idx += 1 # skip ',' 438 | else: 439 | break 440 | 441 | return idx, col_units 442 | 443 | 444 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 445 | idx = start_idx 446 | len_ = len(toks) 447 | val_units = [] 448 | order_type = 'asc' # default type is 'asc' 449 | 450 | if idx >= len_ or toks[idx] != 'order': 451 | return idx, val_units 452 | 453 | idx += 1 454 | assert toks[idx] == 'by' 455 | idx += 1 456 | 457 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 458 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 459 | val_units.append(val_unit) 460 | if idx < len_ and toks[idx] in ORDER_OPS: 461 | order_type = toks[idx] 462 | idx += 1 463 | if idx < len_ and toks[idx] == ',': 464 | idx += 1 # skip ',' 465 | else: 466 | break 467 | 468 | return idx, (order_type, val_units) 469 | 470 | 471 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 472 | idx = start_idx 473 | len_ = len(toks) 474 | 475 | if idx >= len_ or toks[idx] != 'having': 476 | return idx, [] 477 | 478 | idx += 1 479 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 480 | return idx, conds 481 | 482 | 483 | def parse_limit(toks, start_idx): 484 | idx = start_idx 485 | len_ = len(toks) 486 | 487 | if idx < len_ and toks[idx] == 'limit': 488 | idx += 2 489 | return idx, int(toks[idx-1]) 490 | 491 | return idx, None 492 | 493 | 494 | def parse_sql(toks, start_idx, tables_with_alias, schema): 495 | isBlock = False # indicate whether this is a block of sql/sub-sql 496 | len_ = len(toks) 497 | idx = start_idx 498 | 499 | sql = {} 500 | if toks[idx] == '(': 501 | isBlock = True 502 | idx += 1 503 | 504 | # parse from clause in order to get default tables 505 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 506 | sql['from'] = {'table_units': table_units, 'conds': conds} 507 | # select clause 508 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) 509 | idx = from_end_idx 510 | sql['select'] = select_col_units 511 | # where clause 512 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 513 | sql['where'] = where_conds 514 | # group by clause 515 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 516 | sql['groupBy'] = group_col_units 517 | # having clause 518 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 519 | sql['having'] = having_conds 520 | # order by clause 521 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 522 | sql['orderBy'] = order_col_units 523 | # limit clause 524 | idx, limit_val = parse_limit(toks, idx) 525 | sql['limit'] = limit_val 526 | 527 | idx = skip_semicolon(toks, idx) 528 | if isBlock: 529 | assert toks[idx] == ')' 530 | idx += 1 # skip ')' 531 | idx = skip_semicolon(toks, idx) 532 | 533 | # intersect/union/except clause 534 | for op in SQL_OPS: # initialize IUE 535 | sql[op] = None 536 | if idx < len_ and toks[idx] in SQL_OPS: 537 | sql_op = toks[idx] 538 | idx += 1 539 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 540 | sql[sql_op] = IUE_sql 541 | return idx, sql 542 | 543 | 544 | def load_data(fpath): 545 | with open(fpath) as f: 546 | data = json.load(f) 547 | return data 548 | 549 | 550 | def get_sql(schema, query): 551 | toks = tokenize(query) 552 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 553 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 554 | 555 | return sql 556 | 557 | 558 | def skip_semicolon(toks, start_idx): 559 | idx = start_idx 560 | while idx < len(toks) and toks[idx] == ";": 561 | idx += 1 562 | return idx 563 | -------------------------------------------------------------------------------- /prompts/prompt_for_spider.json: -------------------------------------------------------------------------------- 1 | { 2 | "chat_v1": { 3 | "system": "You are a helpful assistant.", 4 | "free_generate": "### Here are the Sqlite SQL tables, with their properties:\n#\n{table}\n#\n### {question} Which tables do you need to complete the SQLite SQL query? Let's think step by step.", 5 | "table_column_select_reorganize": "Sure! Based on your response, provide only the required tables in the following format: 'Table_Name_1 | Table_Name_2' to simplify your response, which means using '|' to separate different tables. Don't use any other format or add any additional explanations.", 6 | "ask_final_answers": { 7 | "has_fk": "### Complete sqlite SQL query only and with no explanation.\n#\n### Sqlite SQL tables, with their properties: \n#\n{table}\n# {fk}\n#\n### {question}\n SELECT", 8 | "no_fk": "### Complete sqlite SQL query only and with no explanation.\n#\n### Sqlite SQL tables, with their properties: \n#\n{table}\n#\n### {question}\n SELECT" 9 | } 10 | } 11 | } -------------------------------------------------------------------------------- /prompts/prompt_for_tabfact.json: -------------------------------------------------------------------------------- 1 | { 2 | "chat_v1": { 3 | "system": "You are a helpful assistant.", 4 | "columns_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", first look at the available columns in the table: {columns}. Which columns are most relevant to answering the question? Your output format is only the “Columns: ColumnName1, ColumnName2, ColumnName3...” form, no other form, with no explanation.", 5 | "rows_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", once you have selected the relevant columns {selected_columns}, your next step is to identify the necessary rows to answer the question based on their column values. Below is the list of rows in the table, arranged in order. Each row is represented by one line, starting with the row name followed by its corresponding column name and value pairs:\n{rows}\nTo answer \"{question}\", which rows should be considered? Your output format is only the “Rows: RowName1, RowName2, RowName3...” form, no other form, with no explanation. Your response should only contain the row names from the above candidates such as item 1, item 2.", 6 | "ask_final_answer_or_next_question": "The ordered list below shows the rows of a table, with each line displaying a different row along with its corresponding column name and value pairs. The format for each pair is (column name, value). The table contains:\n{table}\nAccording to the table, is the statement \"{question}\" ture? If you think the statement is true, only output \"Yes.\". If you think the statement is not ture, only output \"No.\"." 7 | } 8 | } -------------------------------------------------------------------------------- /prompts/prompt_for_webqsp.json: -------------------------------------------------------------------------------- 1 | { 2 | "chat_v1": { 3 | "init_relation_rerank": "The candidate relations: {relations}.\nThe question is \"{question}\" and you'll start with \"{tpe}\". To answer this question, typically you would need to identify some relations that correspond to the meaning of the question. Therefore, select one relation from the candidate relations above that can be used to answer the question. Provide only one relevant relation that's present in the candidates, and begin your response with \"The relevant relation: \".", 4 | "constraints_flag": "The question is \"{question}\" and you'll start with \"{tpe}\". To answer this question, typically you would need to identify some relations that correspond to the meaning of the question. The already selected relevant relation {selected_relations}, and there are many candidate entities along these relations for the next step. If you think you can narrow down the current candidate entities using hints in the question, respond \"Yes\". If there are no hits in the question to narrow down the current candidate entities, respond \"No\".", 5 | "choose_constraints": "One constraint contains a relation and an entity. The list below shows the candidate relations and their possible entities. Each line contains one candidate relation and its corresponding entities:\n{relation_tails}\nThe question is \"{question}\" and you'll start with \"{tpe}\". Are there any constraint hints in the question that can narrow down the candidate entities?\nIf you think there are hints in the question that can narrow down the candidate entities, provide all possible constraints by combining the candidate relation and entity only using the above candidates. For each constraint, you should first choose one relation only from the above candidates, and then select one entity only from the corresponding entity candidates list of that relation. Do not modify the surface form of the chosen relation and entity, which means that they should be exactly consistent with the above provided. Use the format \"[relation: entity]\" and begin your response with \"The possible constraints: \".\nIf you think there are not any hints, directly respond \"No\".\"", 6 | "ask_final_answer_or_next_question": "The triples are: {facts}.\nBased on these triples, if you believe you have gathered sufficient information to answer \"{question}\", give me the final answer entity and start your response with \"The final answers:\". You just need to provide only one answer entity. If you think you still do not have enough information to answer the question, respond \"Need further information\".", 7 | "relation_rerank": "The candidate relations: {relations}.\nThe question is \"{question}\" and you'll start with \"{tpe}\". To answer this question, typically you would need to identify some relations that correspond to the meaning of the question. The already selected relevant relation {selected_relations}, then select the next relation from the candidate relations above that can be used to answer the question. Provide only one relevant relation that's present in the candidates, and begin your response with \"The relevant relation: \".", 8 | "final_query_template": "According to existing information, {question} Please start your response with \"The final answers:\". You just need to provide only one answer entity." 9 | } 10 | } -------------------------------------------------------------------------------- /prompts/prompt_for_wikisql.json: -------------------------------------------------------------------------------- 1 | { 2 | "chat_v1": { 3 | "system": "You are a helpful assistant.", 4 | "columns_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", first look at the available columns in the table: {columns}. Which columns are most relevant to answering the question? Your output format is only the “Columns: ColumnName1, ColumnName2, ColumnName3...” form, no other form, with no explanation.", 5 | "rows_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", once you have selected the relevant columns {selected_columns}, your next step is to identify the necessary rows to answer the question based on their column values. Below is the list of rows in the table, arranged in order. Each row is represented by one line, starting with the row name followed by its corresponding column name and value pairs:\n{rows}\nTo answer \"{question}\", which rows should be considered? Your output format is only the “Rows: RowName1, RowName2, RowName3...” form, no other form, with no explanation. Your response should only contain the row names from the above candidates such as item 1, item 2.", 6 | "ask_final_answer_or_next_question": "The ordered list below shows the rows of a table, with each line displaying a different row along with its corresponding column name and value pairs. The format for each pair is (column name, value). The table contains:\n{table}\nUsing this information, {question} Your output format is only “Answers: AnswerName1, AnswerName2...” form, no other form. And the output should be the number or entity names, as short as possible, without any explanation." 7 | } 8 | } -------------------------------------------------------------------------------- /prompts/prompt_for_wtq.json: -------------------------------------------------------------------------------- 1 | { 2 | "chat_v1": { 3 | "system": "You are a helpful assistant.", 4 | "columns_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", first look at the available columns in the table: {columns}. Which columns are most relevant to answering the question? Your output format is only the “Columns: ColumnName1, ColumnName2, ColumnName3...” form, no other form, with no explanation.", 5 | "rows_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", once you have selected the relevant columns {selected_columns}, your next step is to identify the necessary rows to answer the question based on their column values. Below is the list of rows in the table, arranged in order. Each row is represented by one line, starting with the row name followed by its corresponding column name and value pairs:\n{rows}\nTo answer \"{question}\", which rows should be considered? Your output format is only the “Rows: RowName1, RowName2, RowName3...” form, no other form, with no explanation. Your response should only contain the row names from the above candidates such as item 1, item 2.", 6 | "ask_final_answer_or_next_question": "The ordered list below shows the rows of a table, with each line displaying a different row along with its corresponding column name and value pairs. The format for each pair is (column name, value). The table contains:\n{table}\nUsing this information, {question} Your output format is only “Answers: AnswerName1, AnswerName2...” form, no other form. And the output should be the number or entity names, as short as possible, without any explanation." 7 | } 8 | } -------------------------------------------------------------------------------- /scripts/eval_for_tabfact.sh: -------------------------------------------------------------------------------- 1 | python evaluate_for_tabfact.py --ori_path ./data/tabfact/tab_fact_test.json --inp_path ./outputs/tabfact/tabfact_test_output.jsonl -------------------------------------------------------------------------------- /scripts/eval_spider_pred.sh: -------------------------------------------------------------------------------- 1 | python evaluate_for_spider.py --path ./outputs/spider/output_wo_icl_v1.jsonl --db=data/spider/database --table=data/spider/tables.json --etype=exec -------------------------------------------------------------------------------- /scripts/run_spider_rea_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 structgpt_for_text_to_sql.py \ 4 | --api_key ./api_key.txt --num_process 21 \ 5 | --prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \ 6 | --input_path ./data/spider-realistic/spider-realistic.json \ 7 | --output_path ./outputs/spider-realistic/output_wo_icl_v1.jsonl \ 8 | --chat_log_path ./outputs/spider-realistic/chat_wo_icl_v1.txt \ 9 | --schema_path ./data/spider-realistic/tables.json 10 | 11 | # single process usage 12 | #python3 structgpt_for_text_to_sql.py \ 13 | #--api_key sk-?? --num_process 1 \ 14 | #--prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \ 15 | #--input_path ./data/spider-realistic/spider-realistic.json \ 16 | #--output_path ./outputs/spider-realistic/output_wo_icl_v1.jsonl \ 17 | #--chat_log_path ./outputs/spider-realistic/chat_wo_icl_v1.txt \ 18 | #--schema_path ./data/spider-realistic/tables.json 19 | 20 | cat ./outputs/spider-realistic/output_wo_icl_v1.jsonl_* > ./outputs/spider-realistic/output_wo_icl_v1.jsonl 21 | rm ./outputs/spider-realistic/output_wo_icl_v1.jsonl_* 22 | cat ./outputs/spider-realistic/chat_wo_icl_v1.txt_* > ./outputs/spider-realistic/chat_wo_icl_v1.txt 23 | rm ./outputs/spider-realistic/chat_wo_icl_v1.txt_* 24 | 25 | python evaluate_for_spider.py --path ./outputs/spider-realistic/output_wo_icl_v1.jsonl --db data/spider/database --table data/spider-realistic/tables.json --etype exec -------------------------------------------------------------------------------- /scripts/run_spider_syn_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 structgpt_for_text_to_sql.py \ 4 | --api_key ./api_key.txt --num_process 21 \ 5 | --prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \ 6 | --input_path ./data/spider-syn/dev.json \ 7 | --output_path ./outputs/spider-syn/output_wo_icl_v1.jsonl \ 8 | --chat_log_path ./outputs/spider-syn/chat_wo_icl_v1.txt \ 9 | --schema_path ./data/spider-syn/tables.json 10 | 11 | # single process usage 12 | #python3 structgpt_for_text_to_sql.py \ 13 | #--api_key sk-?? --num_process 1 \ 14 | #--prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \ 15 | #--input_path ./data/spider-syn/spider-syn.json \ 16 | #--output_path ./outputs/spider-syn/output_wo_icl_v1.jsonl \ 17 | #--chat_log_path ./outputs/spider-syn/chat_wo_icl_v1.txt \ 18 | #--schema_path ./data/spider-syn/tables.json 19 | 20 | cat ./outputs/spider-syn/output_wo_icl_v1.jsonl_* > ./outputs/spider-syn/output_wo_icl_v1.jsonl 21 | rm ./outputs/spider-syn/output_wo_icl_v1.jsonl_* 22 | cat ./outputs/spider-syn/chat_wo_icl_v1.txt_* > ./outputs/spider-syn/chat_wo_icl_v1.txt 23 | rm ./outputs/spider-syn/chat_wo_icl_v1.txt_* 24 | 25 | python evaluate_for_spider.py --path ./outputs/spider-syn/output_wo_icl_v1.jsonl --db data/spider/database --table data/spider-syn/tables.json --etype exec -------------------------------------------------------------------------------- /scripts/run_spider_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 structgpt_for_text_to_sql.py \ 4 | --api_key ./api_key.txt --num_process 21 \ 5 | --prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \ 6 | --input_path ./data/spider/dev.jsonl \ 7 | --output_path ./outputs/spider/output_wo_icl_v1.jsonl \ 8 | --chat_log_path ./outputs/spider/chat_wo_icl_v1.txt \ 9 | --db_path ./data/spider/all_tables_content.json \ 10 | --schema_path ./data/spider/tables.json 11 | 12 | # single process usage 13 | #python3 structgpt_for_text_to_sql.py \ 14 | #--api_key sk-?? --num_process 1 \ 15 | #--prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \ 16 | #--input_path ./data/spider/dev.jsonl \ 17 | #--output_path ./outputs/spider/output_wo_icl_v1.jsonl \ 18 | #--chat_log_path ./outputs/spider/chat_wo_icl_v1.txt \ 19 | #--schema_path ./data/spider/tables.json 20 | 21 | cat ./outputs/spider/output_wo_icl_v1.jsonl_* > ./outputs/spider/output_wo_icl_v1.jsonl 22 | rm ./outputs/spider/output_wo_icl_v1.jsonl_* 23 | cat ./outputs/spider/chat_wo_icl_v1.txt_* > ./outputs/spider/chat_wo_icl_v1.txt 24 | rm ./outputs/spider/chat_wo_icl_v1.txt_* 25 | 26 | python evaluate_for_spider.py --path ./outputs/spider/output_wo_icl_v1.jsonl --db=data/spider/database --table=data/spider/tables.json --etype=exec -------------------------------------------------------------------------------- /scripts/run_tabfact_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | python3 structgpt_for_tableqa.py \ 2 | --api_key ./api_key.txt --num_process 37 \ 3 | --prompt_path ./prompts/prompt_for_tabfact.json --prompt_name chat_v1 \ 4 | --input_path ./data/tabfact/tab_fact_test.json \ 5 | --output_path ./outputs/tabfact/output_wo_icl_v1.jsonl \ 6 | --chat_log_path ./outputs/tabfact/chat_wo_icl_v1.txt --max_tokens 350 7 | 8 | cat ./outputs/tabfact/output_wo_icl_v1.jsonl_* > ./outputs/tabfact/output_wo_icl_v1.jsonl 9 | rm ./outputs/tabfact/output_wo_icl_v1.jsonl_* 10 | cat ./outputs/tabfact/chat_wo_icl_v1.txt_* > ./outputs/tabfact/chat_wo_icl_v1.txt 11 | rm ./outputs/tabfact/chat_wo_icl_v1.txt_* 12 | 13 | python evaluate_for_tabfact.py --ori_path ./data/tabfact/tab_fact_test.json --inp_path ./outputs/tabfact/output_wo_icl_v1.jsonl -------------------------------------------------------------------------------- /scripts/run_webqsp_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | python3 structgpt_for_webqsp.py \ 2 | --api_key ./api_key_for_kg.txt --num_process 14 \ 3 | --prompt_path ./prompts/prompt_for_webqsp.json --max_tokens 300 --prompt_name chat_v1 \ 4 | --kg_source_path ./data/webqsp/subgraph_2hop_triples.npy \ 5 | --ent_type_path ./data/webqsp/ent_type_ary.npy \ 6 | --ent2id_path ./data/webqsp/ent2id.pickle \ 7 | --rel2id_path ./data/webqsp/rel2id.pickle \ 8 | --ent2name_path ./data/webqsp/entity_name.pickle \ 9 | --max_triples_per_relation 60 \ 10 | --input_path ./data/webqsp/webqsp_simple_test.jsonl \ 11 | --output_path ./outputs/webqsp/output_wo_icl_v1.jsonl \ 12 | --chat_log_path ./outputs/webqsp/chat_wo_icl_v1.txt 13 | 14 | cat ./outputs/webqsp/output_wo_icl_v1.jsonl_* > ./outputs/webqsp/output_wo_icl_v1.jsonl 15 | rm ./outputs/webqsp/output_wo_icl_v1.jsonl_* 16 | cat ./outputs/webqsp/chat_wo_icl_v1.txt_* > ./outputs/webqsp/chat_wo_icl_v1.txt 17 | rm ./outputs/webqsp/chat_wo_icl_v1.txt_* -------------------------------------------------------------------------------- /scripts/run_wikisql_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | python3 structgpt_for_tableqa.py \ 2 | --api_key ./api_key.txt --num_process 37 \ 3 | --prompt_path ./prompts/prompt_for_wikisql.json --prompt_name chat_v1 \ 4 | --input_path ./data/wikisql/wikisql_test.json \ 5 | --output_path ./outputs/wikisql/output_wo_icl_v1.jsonl \ 6 | --chat_log_path ./outputs/wikisql/chat_wo_icl_v1.txt --max_tokens 350 7 | 8 | cat ./outputs/wikisql/output_wo_icl_v1.jsonl_* > ./outputs/wikisql/output_wo_icl_v1.jsonl 9 | rm ./outputs/wikisql/output_wo_icl_v1.jsonl_* 10 | cat ./outputs/wikisql/chat_wo_icl_v1.txt_* > ./outputs/wikisql/chat_wo_icl_v1.txt 11 | rm ./outputs/wikisql/chat_wo_icl_v1.txt_* 12 | 13 | python evaluate_for_wikisql.py --ori_path ./data/wikisql/wikisql_test.json --inp_path ./outputs/wikisql/output_wo_icl_v1.jsonl -------------------------------------------------------------------------------- /scripts/run_wtq_wo_icl_v1.sh: -------------------------------------------------------------------------------- 1 | python3 structgpt_for_tableqa.py \ 2 | --api_key ./api_key.txt --num_process 37 \ 3 | --prompt_path ./prompts/prompt_for_wtq.json --prompt_name chat_v1 \ 4 | --input_path ./data/wtq/wikitq_test.json \ 5 | --output_path ./outputs/wtq/output_wo_icl_v1.jsonl \ 6 | --chat_log_path ./outputs/wtq/chat_wo_icl_v1.txt --max_tokens 350 7 | 8 | cat ./outputs/wtq/output_wo_icl_v1.jsonl_* > ./outputs/wtq/output_wo_icl_v1.jsonl 9 | rm ./outputs/wtq/output_wo_icl_v1.jsonl_* 10 | cat ./outputs/wtq/chat_wo_icl_v1.txt_* > ./outputs/wtq/chat_wo_icl_v1.txt 11 | rm ./outputs/wtq/chat_wo_icl_v1.txt_* 12 | 13 | python evaluate_for_wikisql.py --ori_path ./data/wtq/wikitq_test.json --inp_path ./outputs/wtq/output_wo_icl_v1.jsonl 14 | -------------------------------------------------------------------------------- /structgpt_for_tableqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | import re 7 | import multiprocessing as mp 8 | from tqdm import tqdm 9 | import time 10 | 11 | import openai 12 | 13 | 14 | class ChatGPT: 15 | def __init__(self, args, prompt_path, prompt_name, max_tokens): 16 | self.args = args 17 | self.history_messages = [] 18 | self.history_contents = [] 19 | self.max_tokens = max_tokens 20 | self.prompt = self.load_prompt_template(prompt_path, prompt_name) 21 | self.idx_mapping = {"0": "first", "1": "second", "2": "third", "3": "fourth", "4": "fifth", "5": "sixth", 22 | "6": "seventh", 23 | "7": "eighth", "8": "ninth", "9": "tenth"} 24 | 25 | def get_response_v1(self, input_text, turn_type): 26 | message = self.create_message_v1(input_text, turn_type) 27 | self.history_contents.append(message['content']) 28 | self.history_messages.append(message) 29 | message = self.query_API_to_get_message(self.history_messages) 30 | self.history_contents.append(message['content']) 31 | self.history_messages.append(message) 32 | response = message['content'] 33 | return response 34 | 35 | def create_message_v1(self, input_text, turn_type): 36 | if turn_type == "columns_select": 37 | template = self.prompt['columns_select'] 38 | columns, question = input_text 39 | # question = question.capitalize() 40 | input_text = template.format(question=question, columns=columns) 41 | elif turn_type == 'rows_select': 42 | template = self.prompt['rows_select'] 43 | selected_cols, rows, question = input_text 44 | # question = question.capitalize() 45 | input_text = template.format(selected_columns=selected_cols, rows=rows, question=question) 46 | elif turn_type == "ask_final_answer_or_next_question": 47 | question, serialized_table = input_text 48 | template = self.prompt['ask_final_answer_or_next_question'] 49 | input_text = template.format(table=serialized_table, question=question) 50 | else: 51 | raise NotImplementedError 52 | message = {'role': 'user', 'content': input_text} 53 | return message 54 | 55 | def query_API_to_get_message(self, messages): 56 | while True: 57 | try: 58 | res = openai.ChatCompletion.create( 59 | model="gpt-3.5-turbo", 60 | messages=messages, 61 | temperature=0, 62 | max_tokens=self.max_tokens, 63 | top_p=1, 64 | frequency_penalty=0, 65 | presence_penalty=0, 66 | ) 67 | return res['choices'][0]['message'] 68 | except openai.error.RateLimitError as e: 69 | err_mes = str(e) 70 | if "You exceeded your current quota" in err_mes: 71 | print("You exceeded your current quota: %s" % openai.api_key) 72 | print('openai.error.RateLimitError\nRetrying...') 73 | time.sleep(30) 74 | except openai.error.ServiceUnavailableError: 75 | print('openai.error.ServiceUnavailableError\nRetrying...') 76 | time.sleep(20) 77 | except openai.error.Timeout: 78 | print('openai.error.Timeout\nRetrying...') 79 | time.sleep(20) 80 | except openai.error.APIError: 81 | print('openai.error.APIError\nRetrying...') 82 | time.sleep(20) 83 | except openai.error.APIConnectionError: 84 | print('openai.error.APIConnectionError\nRetrying...') 85 | time.sleep(20) 86 | 87 | def parse_result(self, result, turn_type): 88 | content = result['content'].strip() 89 | if turn_type in ["initial", "question_template"]: 90 | if "should be" in content: 91 | content = content.split("should be")[1].strip() 92 | if content.startswith('"') and content.endswith('"'): 93 | content = content[1:-1] 94 | else: 95 | matchObj = re.search(r'"(.*?)"', content) 96 | if matchObj is not None: 97 | content = matchObj.group() 98 | content = content[1:-1] 99 | else: 100 | content = content.strip().strip('"') 101 | print("Not exactly parse, we directly use content: %s" % content) 102 | 103 | return content 104 | 105 | def reset_history(self): 106 | self.history_messages = [] 107 | self.history_contents = [] 108 | 109 | def reset_history_messages(self): 110 | self.history_messages = [] 111 | 112 | def load_prompt_template(self, prompt_path, prompt_name): 113 | if prompt_path.endswith(".json"): 114 | with open(prompt_path, "rb") as f: 115 | prompt = json.load(f) 116 | return prompt[prompt_name] 117 | 118 | 119 | class Retriever: 120 | def __init__(self, args): 121 | self.args = args 122 | 123 | def serialize_headers(self, headers): 124 | # headers = ['"' + header.replace("\n", " ") + '"' for header in headers] 125 | # if len(headers) == 0: 126 | # ser_hea = "" 127 | # elif len(headers) == 1: 128 | # ser_hea = headers[0] 129 | # elif len(headers) == 2: 130 | # ser_hea = headers[0] + " and " + headers[1] 131 | # else: 132 | # ser_hea = ", ".join(headers[0:-1]) + ", and " + headers[-1] 133 | headers = [header.replace("\n", " ") for header in headers] 134 | ser_hea = ", ".join(headers) 135 | return ser_hea 136 | 137 | def filter_table_with_col_name(self, table, selected_relations_list, selected_relations_str): 138 | new_table = dict() 139 | header = table['header'] 140 | rows = table['rows'] 141 | reserved_col_idx = [idx for idx, rel in enumerate(header) if rel.replace("\n", " ") in selected_relations_list 142 | or rel.replace("\n", " ").lower() in selected_relations_str.lower()] 143 | new_header = [header[idx] for idx in reserved_col_idx] 144 | new_rows = [[row[idx] for idx in reserved_col_idx] for row in rows] 145 | new_table["header"] = new_header 146 | new_table["rows"] = new_rows 147 | return new_table 148 | 149 | 150 | class Solver: 151 | def __init__(self, args): 152 | self.args = args 153 | self.LLM = ChatGPT(args=args, prompt_path=args.prompt_path, prompt_name=args.prompt_name, 154 | max_tokens=args.max_tokens) 155 | self.SLM = Retriever(args) 156 | self.max_serialization_tokens = args.max_llm_input_tokens 157 | self.selected_relations = [] 158 | 159 | def forward(self, question, table): 160 | self.LLM.reset_history() 161 | self.reset_history() 162 | 163 | iterative_step = 0 164 | 165 | # select 166 | table = self.normalize_table_header(table) 167 | header = table['header'] 168 | ser_hea = self.SLM.serialize_headers(header) 169 | if args.debug: 170 | print("Step-%d: ser_hea:%s" % (iterative_step, ser_hea)) 171 | 172 | llm_selected_cols = self.LLM.get_response_v1((ser_hea, question), "columns_select") 173 | self.LLM.reset_history_messages() 174 | if args.debug: 175 | print("Step-%d: llm_selected_cols:%s" % (iterative_step, llm_selected_cols)) 176 | 177 | selected_cols_list = self.parse_selected_cols(llm_selected_cols, header) 178 | if args.debug: 179 | print("Step-%d: selected_cols_list:%s" % (iterative_step, selected_cols_list)) 180 | 181 | filtered_table = self.SLM.filter_table_with_col_name(table, selected_cols_list, llm_selected_cols) 182 | if args.debug: 183 | print("Step-%d: filtered_table:%s" % (iterative_step, filtered_table)) 184 | 185 | candidate_rows = self.serialize_table(filtered_table) 186 | if args.debug: 187 | print("Step-%d: candidate_rows:%s" % (iterative_step, candidate_rows)) 188 | 189 | selected_columns = self.SLM.serialize_headers(selected_cols_list) 190 | choose_rows = self.LLM.get_response_v1((selected_columns, candidate_rows, question), "rows_select") 191 | self.LLM.reset_history_messages() 192 | try: 193 | choose_row_list = self.parse_row_idx(choose_rows) 194 | filtered_table = self.filter_table_with_rows_constraints(filtered_table, choose_row_list) 195 | except Exception as e: 196 | logging.exception(e) 197 | # print(candidate_rows) 198 | # print(choose_rows) 199 | if args.debug: 200 | print("Step-%d: filtered_table:%s" % (iterative_step, filtered_table)) 201 | 202 | serialized_table = self.serialize_table(filtered_table) 203 | if args.debug: 204 | print("Step-%d: serialized_table:%s" % (iterative_step, serialized_table)) 205 | 206 | final_answers = self.LLM.get_response_v1((question, serialized_table), 207 | "ask_final_answer_or_next_question") 208 | self.LLM.reset_history_messages() 209 | 210 | return final_answers, self.LLM.history_contents, self.log 211 | 212 | def is_end(self, response, iterative_step): 213 | if "no" in response.lower() or iterative_step > 8: 214 | return True 215 | else: 216 | return False 217 | 218 | def is_end_v1(self, response, iterative_step): 219 | if "final" in response.lower() or iterative_step > 3: 220 | return True 221 | elif "next" in response.lower(): 222 | return False 223 | else: 224 | return False 225 | 226 | def get_final_answers(self, history_responses, final_response): 227 | answer_tmp_1 = history_responses[-2] 228 | answer_tmp_2 = final_response 229 | return [answer_tmp_1, answer_tmp_2] 230 | 231 | def parse_result(self, response, parse_type): 232 | response = response.lower() 233 | if parse_type == "next_question": 234 | if "the next question:" in response: 235 | next_question = response.split("the next question:")[1].strip() 236 | elif ":" in response: 237 | next_question = response.split(":")[1].strip() 238 | else: 239 | next_question = response 240 | print("Not parse the next question exactly, directly use the response: ", response) 241 | return next_question 242 | elif parse_type == "final_answer": 243 | if 'yes' in response and 'no' in response: 244 | final_answer = 'unknown' 245 | elif 'yes' in response: 246 | final_answer = 'entailed' 247 | else: 248 | final_answer='refuted' 249 | 250 | return final_answer 251 | 252 | def reset_history(self): 253 | self.log = [] 254 | self.selected_relations = [] 255 | 256 | def serialize_table(self, table): 257 | header = table['header'] 258 | rows = table['rows'] 259 | lines = [] 260 | for idx, row in enumerate(rows): 261 | pairs = [] 262 | for rel, ent in zip(header, row): 263 | pair = "(" + rel + ", " + ent + ")" 264 | pairs.append(pair) 265 | 266 | line = 'item ' + str(idx + 1) + ': ' + "; ".join(pairs) 267 | lines.append(line) 268 | output = "\n".join(lines) 269 | return output 270 | 271 | def parse_selected_cols(self, llm_selected_cols, header): 272 | llm_selected_cols = [h for h in header if h.replace("\n", " ").lower() in llm_selected_cols.lower()] 273 | return llm_selected_cols 274 | 275 | def parse_row_idx(self, selected_rows): 276 | pattern = re.compile(r'(\d+)') 277 | m = pattern.finditer(selected_rows) 278 | m = [i.group() for i in m] 279 | selected_rows = [int(rid)-1 for rid in m] 280 | return selected_rows 281 | 282 | def filter_table_with_rows_constraints(self, table, row_constraints): 283 | new_table = dict() 284 | header = table['header'] 285 | rows = table['rows'] 286 | new_rows = [] 287 | for rid in row_constraints: 288 | if rid < len(rows): 289 | new_rows.append(rows[rid]) 290 | new_table["header"] = header 291 | new_table["rows"] = new_rows 292 | return new_table 293 | 294 | def normalize_table_header(self, table): 295 | header = table['header'] 296 | rows = table['rows'] 297 | new_table = {} 298 | new_header = [] 299 | for h in header: 300 | h = h.replace("\n", " ") 301 | new_header.append(h) 302 | new_table['header'] = new_header 303 | new_table['rows'] = rows 304 | return new_table 305 | 306 | def main(args, all_data, idx, api_key): 307 | import openai 308 | openai.api_key = api_key 309 | 310 | if idx == -1: 311 | output_path = args.output_path 312 | chat_log_path = args.chat_log_path 313 | else: 314 | idx = "0" + str(idx) if idx < 10 else str(idx) # 00 01 02 ... 29 315 | output_path = args.output_path + "_" + idx 316 | chat_log_path = args.chat_log_path + "_" + idx 317 | 318 | print("Start PID %d and save to %s" % (os.getpid(), output_path)) 319 | 320 | solver = Solver(args) 321 | 322 | count = 0 323 | with open(output_path, "w") as f: 324 | with open(chat_log_path, "w") as fclog: 325 | for sample in tqdm(all_data, total=len(all_data), desc="PID: %d" % os.getpid()): 326 | try: 327 | question = sample["statement"] if 'statement' in sample else sample['question'] 328 | question = question + "?" if not question.endswith("?") else question 329 | table = sample['table'] 330 | prediction, chat_history, record = solver.forward(question, table) 331 | except openai.error.InvalidRequestError as e: 332 | print(e) 333 | continue 334 | except Exception as e: 335 | logging.exception(e) 336 | continue 337 | if 'id' in sample.keys(): 338 | flag = str(sample['id']) 339 | else: 340 | flag = question 341 | 342 | try: 343 | chat = flag + "\n" + "\n******\n".join(chat_history) + "\nAnswers: " + str( 344 | sample['seq_out']) + "\n------------------------------------------\n" 345 | fclog.write(chat) 346 | except Exception as e: 347 | print(e) 348 | count += 1 349 | if count < 5: 350 | print(sample['seq_out']) 351 | print(prediction) 352 | print("---------------------") 353 | sample["Prediction"] = prediction 354 | f.write(json.dumps(sample) + "\n") 355 | 356 | 357 | def parse_args(): 358 | parser = argparse.ArgumentParser() 359 | parser.add_argument('--input_path', default=None) 360 | parser.add_argument('--output_path', default=None) 361 | parser.add_argument('--chat_log_path', default=None) 362 | parser.add_argument('--debug', action="store_true") 363 | parser.add_argument('--prompt_path') 364 | parser.add_argument('--prompt_name', default="chat", ) 365 | parser.add_argument('--overwrite', action="store_true") 366 | parser.add_argument('--num_process', default=1, type=int, help='the number of multi-process') 367 | parser.add_argument('--max_tokens', default=10, type=int, help='retrieve the topk score paths') 368 | parser.add_argument('--api_key', default="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", type=str) 369 | parser.add_argument('--max_llm_input_tokens', default=3400, type=int) 370 | 371 | args = parser.parse_args() 372 | 373 | print("Start querying the LLM.") 374 | return args 375 | 376 | 377 | if __name__ == '__main__': 378 | args = parse_args() 379 | if not args.api_key.startswith("sk-"): 380 | with open(args.api_key, "r") as f: 381 | all_keys = f.readlines() 382 | all_keys = [line.strip('\n') for line in all_keys] 383 | assert len(all_keys) == args.num_process, (len(all_keys), args.num_process) 384 | 385 | with open(args.input_path, "rb") as f: 386 | all_data = json.load(f) 387 | print("Totally %d test examples." % len(all_data)) 388 | 389 | # Test data that has not yet been predicted 390 | # if os.path.exists(args.output_path): 391 | # with open(args.output_path, "r") as f: 392 | # all_lines = f.readlines() 393 | # all_lines = [json.loads(line.strip("\n")) for line in all_lines] 394 | # already_id = [line['id'] for line in all_lines] 395 | # all_data = [data for data in all_data if data['id'] not in already_id] 396 | # print("There are %d test examples need to be processed." % len(all_data)) 397 | 398 | if args.num_process == 1: 399 | main(args, all_data, idx=-1, api_key=args.api_key) 400 | else: 401 | num_each_split = int(len(all_data) / args.num_process) 402 | p = mp.Pool(args.num_process) 403 | for idx in range(args.num_process): 404 | start = idx * num_each_split 405 | if idx == args.num_process - 1: 406 | end = max((idx + 1) * num_each_split, len(all_data)) 407 | else: 408 | end = (idx + 1) * num_each_split 409 | split_data = all_data[start:end] 410 | p.apply_async(main, args=(args, split_data, idx, all_keys[idx])) 411 | p.close() 412 | p.join() 413 | print("All of the child processes over!") 414 | -------------------------------------------------------------------------------- /structgpt_for_text_to_sql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | from collections import defaultdict 7 | import pandas as pd 8 | import multiprocessing as mp 9 | from tqdm import tqdm 10 | import time 11 | 12 | import openai 13 | 14 | random.seed(42) 15 | 16 | 17 | class ChatGPT: 18 | def __init__(self, args, prompt_path, prompt_name, max_tokens): 19 | self.args = args 20 | self.history_messages = [] 21 | self.history_contents = [] 22 | self.max_tokens = max_tokens 23 | self.prompt = self.load_prompt_template(prompt_path, prompt_name) 24 | self.idx_mapping = {"0": "first", "1": "second", "2": "third", "3": "fourth", "4": "fifth", "5": "sixth", 25 | "6": "seventh", 26 | "7": "eighth", "8": "ninth", "9": "tenth"} 27 | 28 | def get_response_v1(self, input_text, turn_type, max_output_len): 29 | message = self.create_message_v1(input_text, turn_type) 30 | self.history_messages.append(message) 31 | self.history_contents.append(message['content']) 32 | message = self.query_API_to_get_message(self.history_messages, max_output_len) 33 | self.history_messages.append(message) 34 | self.history_contents.append(message['content']) 35 | response = self.parse_result(message) 36 | 37 | return response 38 | 39 | def create_message_v1(self, input_text, turn_type): 40 | if turn_type == "select_tab": 41 | template = self.prompt['free_generate'] 42 | question, ser_table_name = input_text 43 | input_text = template.format(question=question, table=ser_table_name) 44 | elif turn_type == "reorg_sel_tab": 45 | template = self.prompt['table_column_select_reorganize'] 46 | input_text = template 47 | elif turn_type == "ask_final_answers": 48 | question, ser_table_name, ser_fks = input_text 49 | if len(ser_fks) > 1: 50 | template = self.prompt['ask_final_answers']['has_fk'] 51 | input_text = template.format(question=question, table=ser_table_name, fk=ser_fks) 52 | else: 53 | template = self.prompt['ask_final_answers']['no_fk'] 54 | input_text = template.format(question=question, table=ser_table_name) 55 | else: 56 | raise NotImplementedError 57 | message = {'role': 'user', 'content': input_text} 58 | return message 59 | 60 | def query_API_to_get_message(self, messages, max_output_len): 61 | while True: 62 | try: 63 | res = openai.ChatCompletion.create( 64 | model="gpt-3.5-turbo", 65 | messages=messages, 66 | temperature=0, 67 | max_tokens=max_output_len, 68 | top_p=1, 69 | frequency_penalty=0, 70 | presence_penalty=0, 71 | ) 72 | return res['choices'][0]['message'] 73 | except openai.error.RateLimitError as e: 74 | err_mes = str(e) 75 | if "You exceeded your current quota" in err_mes: 76 | print("You exceeded your current quota: %s" % openai.api_key) 77 | print('openai.error.RateLimitError\nRetrying...') 78 | time.sleep(30) 79 | except openai.error.ServiceUnavailableError: 80 | print('openai.error.ServiceUnavailableError\nRetrying...') 81 | time.sleep(20) 82 | except openai.error.Timeout: 83 | print('openai.error.Timeout\nRetrying...') 84 | time.sleep(20) 85 | except openai.error.APIError: 86 | print('openai.error.APIError\nRetrying...') 87 | time.sleep(20) 88 | except openai.error.APIConnectionError: 89 | print('openai.error.APIConnectionError\nRetrying...') 90 | time.sleep(20) 91 | except openai.error.InvalidRequestError as e: 92 | logging.exception(e) 93 | exit(0) 94 | 95 | 96 | def parse_result(self, result): 97 | content = result['content'].strip() 98 | 99 | return content 100 | 101 | def reset_history(self): 102 | self.history_messages = [] 103 | self.history_contents = [] 104 | 105 | def reset_history_messages(self): 106 | self.history_messages = [] 107 | 108 | def reseta_history_contents(self): 109 | self.history_contents = [] 110 | 111 | def load_prompt_template(self, prompt_path, prompt_name): 112 | if prompt_path.endswith(".json"): 113 | with open(prompt_path, "rb") as f: 114 | prompt = json.load(f) 115 | return prompt[prompt_name] 116 | 117 | 118 | class Retriever: 119 | def __init__(self, args): 120 | self.args = args 121 | self.prompt = self.load_prompt_template(args.prompt_path, args.prompt_name) 122 | self.creatiing_schema(args.schema_path) 123 | 124 | def filter_table_with_col_name(self, table, selected_relations_list, selected_relations_str): 125 | new_table = dict() 126 | header = table['header'] 127 | rows = table['rows'] 128 | reserved_col_idx = [idx for idx, h in enumerate(header) if h in selected_relations_list] 129 | new_header = [header[idx] for idx in reserved_col_idx] 130 | new_rows = [[row[idx] for idx in reserved_col_idx] for row in rows] 131 | new_table["header"] = new_header 132 | new_table["rows"] = new_rows 133 | return new_table 134 | 135 | def load_db(self, db_path): 136 | with open(db_path, "r") as f: 137 | db = json.load(f) 138 | return db 139 | 140 | def serialize_table_and_column(self, db_name): 141 | df = self.spider_schema[self.spider_schema['Database name'] == db_name] 142 | df = df.groupby('Table Name') 143 | table2columns = {} 144 | tables_name = [] 145 | for name, group in df: 146 | columns = [] 147 | for index, row in group.iterrows(): 148 | columns.append(row["Field Name"]) 149 | table2columns[name] = columns 150 | if name not in tables_name: 151 | tables_name.append(name) 152 | ser_heas = [] 153 | prompt_tmp = "# {tn}({cols});" 154 | for name, cols in table2columns.items(): 155 | cols = [col for col in cols] 156 | cols = ",".join(cols) 157 | hea = prompt_tmp.format(tn=name, cols=cols) 158 | ser_heas.append(hea) 159 | ser_hea = "\n".join(ser_heas) 160 | return ser_hea, table2columns 161 | 162 | def serialize_tab_and_col_of_demons(self, tables): 163 | prompt_tmp = " # {tn}({cols});" 164 | ser_heas = [] 165 | for name, cols in tables.items(): 166 | cols = [col for col in cols] 167 | cols = ",".join(cols) 168 | hea = prompt_tmp.format(tn=name, cols=cols) 169 | ser_heas.append(hea) 170 | ser_heas = "\n".join(ser_heas) 171 | return ser_heas 172 | 173 | def serialize_selected_table_name(self, sel_tab_cols): 174 | ser_heas = [] 175 | prompt_tmp = "# {tn}({cols});" 176 | for name, cols in sel_tab_cols.items(): 177 | cols = [col for col in cols] 178 | cols = ",".join(cols) 179 | hea = prompt_tmp.format(tn=name, cols=cols) 180 | ser_heas.append(hea) 181 | 182 | ser_hea = "\n".join(ser_heas) 183 | return ser_hea 184 | 185 | def creatiing_schema(self, schema_path): 186 | schema_df = pd.read_json(schema_path) 187 | schema_df = schema_df.drop(['column_names', 'table_names'], axis=1) 188 | schema = [] 189 | f_keys = [] 190 | p_keys = [] 191 | for index, row in schema_df.iterrows(): 192 | tables = row['table_names_original'] 193 | col_names = row['column_names_original'] 194 | col_types = row['column_types'] 195 | foreign_keys = row['foreign_keys'] 196 | primary_keys = row['primary_keys'] 197 | for col, col_type in zip(col_names, col_types): 198 | index, col_name = col 199 | if index == -1: 200 | for table in tables: 201 | schema.append([row['db_id'], table, '*', 'text']) 202 | else: 203 | schema.append([row['db_id'], tables[index], col_name, col_type]) 204 | for primary_key in primary_keys: 205 | index, column = col_names[primary_key] 206 | p_keys.append([row['db_id'], tables[index], column]) 207 | for foreign_key in foreign_keys: 208 | first, second = foreign_key 209 | first_index, first_column = col_names[first] 210 | second_index, second_column = col_names[second] 211 | f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column]) 212 | self.spider_schema = pd.DataFrame(schema, columns=['Database name', 'Table Name', 'Field Name', 'Type']) 213 | self.spider_primary = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key']) 214 | self.spider_foreign = pd.DataFrame(f_keys, 215 | columns=['Database name', 'First Table Name', 'Second Table Name', 216 | 'First Table Foreign Key', 217 | 'Second Table Foreign Key']) 218 | 219 | def find_primary_keys(self, db_name, table_name): 220 | df = self.spider_primary[self.spider_primary['Database name'] == db_name] 221 | for index, row in df.iterrows(): 222 | if row['Table Name'] == table_name: 223 | return row['Primary Key'] 224 | 225 | def find_foreign_keys(self, db_name): 226 | df = self.spider_foreign[self.spider_foreign['Database name'] == db_name] 227 | output = [] 228 | for index, row in df.iterrows(): 229 | first_tab_fk = (row['First Table Name'], row['First Table Foreign Key']) 230 | second_tab_fk = (row['Second Table Name'], row['Second Table Foreign Key']) 231 | output.append((first_tab_fk, second_tab_fk)) 232 | return output 233 | 234 | def serialize_fk_with_sel_table(self, db_id, table_cols): 235 | prompt_tmp = "{ftk_0}.{ftk_1}={stk_0}.{stk_1}" 236 | fk = self.find_foreign_keys(db_id) 237 | ser_fks = [] 238 | for (ftk, stk) in fk: 239 | if ftk[0] in table_cols and stk[0] in table_cols: 240 | ser_fk = prompt_tmp.format(ftk_0=ftk[0], ftk_1=ftk[1], stk_0=stk[0], stk_1=stk[1]) 241 | ser_fks.append(ser_fk) 242 | ser_fks = ",".join(ser_fks) 243 | ser_fks = ser_fks + ";" 244 | return ser_fks 245 | 246 | def serialize_demonstrations(self, demonstrations, example_ids): 247 | prompt = self.prompt["demonstration"] 248 | prompt_fk = "{ftk_0}.{ftk_1}={stk_0}.{stk_1}" 249 | all_sel_demons = [] 250 | for name, ids in example_ids.items(): 251 | demons = demonstrations[name] 252 | sel_demons = [demons[i] for i in ids] 253 | all_sel_demons.extend(sel_demons) 254 | ser_demons = [] 255 | for d in all_sel_demons: 256 | ser_fks = [] 257 | for (ftk, stk) in d['fks']: 258 | ser_fk = prompt_fk.format(ftk_0=ftk[0], ftk_1=ftk[1], stk_0=stk[0], stk_1=stk[1]) 259 | ser_fks.append(ser_fk) 260 | ser_fks = ",".join(ser_fks) 261 | ser_fks = ser_fks + ";" 262 | 263 | tables = d['tables'] 264 | question = d['question'] 265 | query = d['query'] 266 | ser_tables = self.serialize_tab_and_col_of_demons(tables) 267 | 268 | if len(d['fks']) == 0: 269 | p = prompt['no_fk'] 270 | one_demo = p.format(table=ser_tables, question=question, sql=query) 271 | else: 272 | p = prompt['has_fk'] 273 | one_demo = p.format(table=ser_tables, question=question, sql=query, fk=ser_fks) 274 | ser_demons.append(one_demo) 275 | random.shuffle(ser_demons) 276 | ser_demons = "\n #\n".join(ser_demons) 277 | return ser_demons 278 | 279 | def serialize_demonstrations_with_const(self, demonstrations, example_ids, num_tables, fk_flag): 280 | num_tables = 1 if num_tables == 1 else 2 281 | prompt = self.prompt["demonstration"] 282 | prompt_fk = "{ftk_0}.{ftk_1}={stk_0}.{stk_1}" 283 | all_sel_demons = [] 284 | for name, ids in example_ids.items(): 285 | demons = demonstrations[name] 286 | sel_demons = [demons[i] for i in ids] 287 | all_sel_demons.extend(sel_demons) 288 | ser_demons = [] 289 | for d in all_sel_demons: 290 | ser_fks = [] 291 | for (ftk, stk) in d['fks']: 292 | ser_fk = prompt_fk.format(ftk_0=ftk[0], ftk_1=ftk[1], stk_0=stk[0], stk_1=stk[1]) 293 | ser_fks.append(ser_fk) 294 | ser_fks = ",".join(ser_fks) 295 | ser_fks = ser_fks + ";" 296 | 297 | tables = d['tables'] 298 | num_tables_demon = 1 if len(tables) == 1 else 2 299 | 300 | if num_tables_demon != num_tables: 301 | continue 302 | 303 | question = d['question'] 304 | query = d['query'] 305 | ser_tables = self.serialize_table_name_v6(tables) 306 | 307 | if len(d['fks']) == 0: 308 | p = prompt['no_fk'] 309 | one_demo = p.format(table=ser_tables, question=question, sql=query) 310 | else: 311 | p = prompt['has_fk'] 312 | one_demo = p.format(table=ser_tables, question=question, sql=query, fk=ser_fks) 313 | ser_demons.append(one_demo) 314 | if len(ser_demons) == 0: 315 | print("-------------------No demonstrations-------------------") 316 | random.shuffle(ser_demons) 317 | ser_demons = "\n\n".join(ser_demons) 318 | return ser_demons 319 | 320 | def load_prompt_template(self, prompt_path, prompt_name): 321 | if prompt_path.endswith(".json"): 322 | with open(prompt_path, "rb") as f: 323 | prompt = json.load(f) 324 | return prompt[prompt_name] 325 | 326 | 327 | class Solver: 328 | def __init__(self, args): 329 | self.args = args 330 | self.LLM = ChatGPT(args=args, prompt_path=args.prompt_path, prompt_name=args.prompt_name, 331 | max_tokens=args.max_tokens) 332 | self.SLM = Retriever(args) 333 | self.max_serialization_tokens = args.max_llm_input_tokens 334 | self.log = [] 335 | self.selected_relations = [] 336 | 337 | def forward_wo_icl_v1(self, question, db_id): 338 | self.LLM.reset_history() 339 | self.reset_history() 340 | 341 | iterative_step = 0 342 | 343 | ser_tab_col, table2columns = self.SLM.serialize_table_and_column(db_id) 344 | if args.debug: 345 | print("-----Step-%d: ser_table_name:\n%s" % (iterative_step, ser_tab_col)) 346 | 347 | llm_select_tab = self.LLM.get_response_v1((question, ser_tab_col), "select_tab", max_output_len=600) 348 | if args.debug: 349 | print("-----Step-%d: llm_select_tab:\n%s" % (iterative_step, llm_select_tab)) 350 | 351 | llm_reorg_sel_tab = self.LLM.get_response_v1("", "reorg_sel_tab", max_output_len=300) 352 | if args.debug: 353 | print("-----Step-%d: llm_reorg_sel_tab:\n%s" % (iterative_step, llm_reorg_sel_tab)) 354 | self.LLM.reset_history_messages() 355 | 356 | sel_tab_cols = self.parse_sele_tab(llm_reorg_sel_tab, table2columns) 357 | if args.debug: 358 | print("-----Step-%d: sel_tab_cols:\n%s" % (iterative_step, sel_tab_cols)) 359 | 360 | ser_table_name = self.SLM.serialize_selected_table_name(sel_tab_cols) 361 | ser_fk = self.SLM.serialize_fk_with_sel_table(db_id, sel_tab_cols) 362 | if args.debug: 363 | print("-----Step-%d: ser_table_name:\n%s" % (iterative_step, ser_table_name)) 364 | print("-----Step-%d: ser_fk:\n%s" % (iterative_step, ser_fk)) 365 | 366 | final_answers = self.LLM.get_response_v1((question, ser_table_name, ser_fk), "ask_final_answers", 367 | max_output_len=300) 368 | if args.debug: 369 | print("-----Step-%d: final_answers:\n%s" % (iterative_step, final_answers)) 370 | self.LLM.reset_history_messages() 371 | 372 | return final_answers, self.LLM.history_contents 373 | 374 | def reset_history(self): 375 | self.log = [] 376 | self.selected_relations = [] 377 | 378 | def serialize_table(self, db_name, table_name): 379 | # get primary key 380 | pk = self.SLM.find_primary_keys(db_name, table_name) 381 | # get table content 382 | table = self.SLM.db_content[db_name][table_name] 383 | # serialize 384 | header = table['headers'] 385 | rows = table['rows'] 386 | lines = [] 387 | for idx, row in enumerate(rows): 388 | pairs = [] 389 | row_name = "" 390 | for rel, ent in zip(header, row): 391 | if rel == pk: 392 | row_name = pk + " " + str(ent) + ": " 393 | else: 394 | pair = "(" + rel + ", " + str(ent) + ")" 395 | pairs.append(pair) 396 | line = row_name + "; ".join(pairs) 397 | lines.append(line) 398 | output = "\n".join(lines) 399 | return output 400 | 401 | def serialize_table_with_constraints(self, db_name, table_name, constraints_cols): 402 | # get primary key 403 | pk = self.SLM.find_primary_keys(db_name, table_name) 404 | # get table content 405 | table = self.SLM.db_content[db_name][table_name] 406 | # serialize 407 | header = table['headers'] 408 | rows = table['rows'] 409 | lines = [] 410 | if len(constraints_cols) == 1 and pk in constraints_cols: 411 | constraints_cols.append(random.sample(list(set(header) - set(constraints_cols)), k=1)[0]) 412 | for idx, row in enumerate(rows): 413 | pairs = [] 414 | row_name = "" 415 | for rel, ent in zip(header, row): 416 | if rel == pk: 417 | row_name = pk + " " + str(ent) + ": " 418 | else: 419 | if rel in constraints_cols: 420 | pair = "(" + rel + ", " + str(ent) + ")" 421 | pairs.append(pair) 422 | line = row_name + "; ".join(pairs) 423 | lines.append(line) 424 | output = "\n".join(lines) 425 | return output 426 | 427 | def parse_sele_tab(self, llm_selected_table_col, table2columns): 428 | sel_table_to_cols = defaultdict(list) 429 | try: 430 | tabs = llm_selected_table_col.strip(" \n|.;`").strip() 431 | tabs = tabs.split("|") 432 | tabs = [tab.strip(" \n|.;`") for tab in tabs] 433 | for tab in tabs: 434 | if tab in table2columns: 435 | sel_table_to_cols[tab] = table2columns[tab] 436 | except Exception as e: 437 | logging.exception(e) 438 | for tab, cols in table2columns.items(): 439 | if tab in llm_selected_table_col: 440 | sel_table_to_cols[tab] = table2columns[tab] 441 | print("*****LLM selected tables output doesn't match the predefined format:\n%s" % llm_selected_table_col) 442 | return sel_table_to_cols 443 | 444 | 445 | def main(args, all_data, idx, api_key): 446 | import openai 447 | openai.api_key = api_key 448 | 449 | if idx == -1: 450 | output_path = args.output_path 451 | chat_log_path = args.chat_log_path 452 | else: 453 | idx = "0" + str(idx) if idx < 10 else str(idx) # 00 01 02 ... 29 454 | output_path = args.output_path + "_" + idx 455 | chat_log_path = args.chat_log_path + "_" + idx 456 | 457 | print("Start PID %d and save to %s" % (os.getpid(), output_path)) 458 | solver = Solver(args) 459 | 460 | count = 0 461 | valid_count = 0 462 | with open(output_path, "w") as f: 463 | with open(chat_log_path, "w") as fclog: 464 | for sample in tqdm(all_data, total=len(all_data), desc="PID: %d" % os.getpid()): 465 | try: 466 | if "question" in sample: 467 | question = sample["question"] 468 | elif "SpiderSynQuestion" in sample: 469 | question = sample["SpiderSynQuestion"] 470 | else: 471 | print("Specify an error question key.") 472 | print(sample) 473 | exit(0) 474 | db_id = sample['db_id'] 475 | 476 | if not args.icl: 477 | results = solver.forward_wo_icl_v1(question, db_id) 478 | 479 | if results is not None: 480 | prediction, chat_history = results 481 | valid_count += 1 482 | else: 483 | continue 484 | except Exception as e: 485 | logging.exception(e) 486 | continue 487 | 488 | if 'id' in sample: 489 | flag = sample['id'] 490 | else: 491 | flag = question 492 | chat = flag + "\n" + "\n******\n".join(chat_history) + \ 493 | "\nGold SQL: " + str(sample['query']) + "\n------------------------------------------\n" 494 | fclog.write(chat) 495 | 496 | count += 1 497 | if count < 3: 498 | print(prediction) 499 | print("---------------------") 500 | sample["Prediction"] = prediction 501 | f.write(json.dumps(sample) + "\n") 502 | print("---------------PID %d end with %d/%d samples--------------" % (os.getpid(), valid_count, count)) 503 | 504 | 505 | def parse_args(): 506 | parser = argparse.ArgumentParser() 507 | parser.add_argument('--input_path', default=None) 508 | parser.add_argument('--output_path', default=None) 509 | parser.add_argument('--chat_log_path', default=None) 510 | parser.add_argument('--schema_path', default=None) 511 | parser.add_argument('--debug', action="store_true") 512 | parser.add_argument('--prompt_path') 513 | parser.add_argument('--prompt_name', default="chat", ) 514 | parser.add_argument('--icl', action="store_true", ) 515 | parser.add_argument('--demonstrations', default=None) 516 | parser.add_argument('--example_ids', default=None) 517 | parser.add_argument('--overwrite', action="store_true") 518 | parser.add_argument('--num_process', default=1, type=int, help='the number of multi-process') 519 | parser.add_argument('--max_tokens', default=10, type=int, help='retrieve the topk score paths') 520 | parser.add_argument('--api_key', default="", type=str) 521 | parser.add_argument('--max_llm_input_tokens', default=3400, type=int) 522 | 523 | args = parser.parse_args() 524 | 525 | print("Start querying the LLM.") 526 | return args 527 | 528 | 529 | if __name__ == '__main__': 530 | args = parse_args() 531 | if not args.api_key.startswith("sk-"): 532 | with open(args.api_key, "r") as f: 533 | all_keys = f.readlines() 534 | all_keys = [line.strip('\n') for line in all_keys] 535 | assert len(all_keys) == args.num_process, (len(all_keys), args.num_process) 536 | 537 | if args.input_path.endswith('jsonl'): 538 | with open(args.input_path, "r") as f: 539 | all_lines = f.readlines() 540 | all_data = [json.loads(line) for line in all_lines] 541 | print("Totally %d test examples." % len(all_data)) 542 | elif args.input_path.endswith('json'): 543 | with open(args.input_path, "r") as f: 544 | all_data = json.load(f) 545 | print("Totally %d test examples." % len(all_data)) 546 | 547 | if args.demonstrations is not None: 548 | with open(args.demonstrations, "r") as f: 549 | demonstrations = json.load(f) 550 | if args.example_ids is not None: 551 | with open(args.example_ids, "r") as f: 552 | example_ids = json.load(f) 553 | 554 | if args.num_process == 1: 555 | main(args, all_data, idx=-1, api_key=args.api_key) 556 | else: 557 | num_each_split = int(len(all_data) / args.num_process) 558 | p = mp.Pool(args.num_process) 559 | for idx in range(args.num_process): 560 | start = idx * num_each_split 561 | if idx == args.num_process - 1: 562 | end = max((idx + 1) * num_each_split, len(all_data)) 563 | else: 564 | end = (idx + 1) * num_each_split 565 | split_data = all_data[start:end] 566 | try: 567 | p.apply_async(main, args=(args, split_data, idx, all_keys[idx])) 568 | except Exception as e: 569 | logging.exception(e) 570 | p.close() 571 | p.join() 572 | print("All of the child processes over!") 573 | -------------------------------------------------------------------------------- /structgpt_for_webqsp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | import re 7 | 8 | import openai 9 | from KnowledgeBase.KG_api import KnowledgeGraph 10 | # from KnowledgeBase.sparql_executor import * 11 | import multiprocessing as mp 12 | 13 | 14 | class ChatGPT: 15 | def __init__(self, args, prompt_path, prompt_name, max_tokens): 16 | self.args = args 17 | self.history_messages = [] 18 | self.history_contents = [] 19 | self.max_tokens = max_tokens 20 | self.prompt = self.load_prompt_template(prompt_path, prompt_name) 21 | self.idx_mapping = {"0": "first", "1": "second", "2": "third", "3": "fourth", "4": "fifth", "5": "sixth", 22 | "6": "seventh", 23 | "7": "eighth", "8": "ninth", "9": "tenth"} 24 | 25 | def get_response(self, input_text, turn_type, tpe_name=None): 26 | if self.args.debug: 27 | message = self.create_message(input_text, turn_type, tpe_name) 28 | self.history_messages.append(message) 29 | self.history_contents.append(message['content']) 30 | print("query API to get message:\n%s" % message['content']) 31 | # message = self.query_API_to_get_message(self.history) 32 | # self.history.append(message) 33 | # response = self.parse_result(message) 34 | response = input("input the returned response:") 35 | else: 36 | message = self.create_message(input_text, turn_type, tpe_name) 37 | self.history_messages.append(message) 38 | self.history_contents.append(message['content']) 39 | message = self.query_API_to_get_message(self.history_messages) 40 | self.history_messages.append(message) 41 | self.history_contents.append(message['content']) 42 | response = self.parse_result(message, turn_type) 43 | return response 44 | 45 | def get_response_v1(self, input_text, turn_type, tpe_name=None): 46 | if self.args.debug: 47 | message = self.create_message_v1(input_text, turn_type) 48 | self.history_messages.append(message) 49 | self.history_contents.append(message['content']) 50 | print("query API to get message:\n%s" % message['content']) 51 | # message = self.query_API_to_get_message(self.history) 52 | # self.history.append(message) 53 | # response = self.parse_result(message) 54 | response = input("input the returned response:") 55 | else: 56 | message = self.create_message_v1(input_text, turn_type) 57 | self.history_messages.append(message) 58 | self.history_contents.append(message['content']) 59 | message = self.query_API_to_get_message(self.history_messages) 60 | self.history_messages.append(message) 61 | self.history_contents.append(message['content']) 62 | response = self.parse_result_v1(message, turn_type) 63 | return response 64 | 65 | def create_message(self, input_text, turn_type, tpe_name): 66 | if turn_type == "initial": # the initial query 67 | instruction = self.prompt[turn_type]['instruction'] 68 | template = self.prompt[turn_type]['init_template'] 69 | self.question = input_text 70 | input_text = instruction + template.format(question=input_text, tpe=tpe_name) 71 | elif turn_type == "continue_template": 72 | input_text = self.prompt[turn_type] 73 | elif turn_type == "question_template": 74 | template = self.prompt[turn_type] 75 | input_text = template.format(idx=self.idx_mapping[input_text]) 76 | elif turn_type == "answer_template": 77 | template = self.prompt[turn_type] 78 | if len(input_text) > 0: 79 | input_text = template["valid"].format(facts=input_text) 80 | else: 81 | input_text = template["invalid"] 82 | elif turn_type == "final_query_template": 83 | template = self.prompt[turn_type] 84 | input_text = template.format(question=self.question) 85 | else: 86 | raise NotImplementedError 87 | message = {'role': 'user', 'content': input_text} 88 | return message 89 | 90 | def create_message_v1(self, input_text, turn_type): 91 | if turn_type == "instruction": # the initial query 92 | instruction = self.prompt['instruction'] 93 | input_text = instruction 94 | elif turn_type == "init_relation_rerank": 95 | template = self.prompt['init_relation_rerank'] 96 | question, tpe, can_rels = input_text 97 | input_text = template.format(question=question, tpe=tpe, relations=can_rels) 98 | elif turn_type == "ask_question": 99 | template = self.prompt['ask_question'] 100 | idx, relations = input_text 101 | idx = self.idx_mapping[idx] 102 | input_text = template.format(idx=idx, relations=relations) 103 | elif turn_type == "ask_answer": 104 | facts = input_text 105 | template = self.prompt['ask_answer'] 106 | input_text = template.format(facts=facts) 107 | elif turn_type == "ask_final_answer_or_next_question": 108 | question, serialized_facts = input_text 109 | template = self.prompt['ask_final_answer_or_next_question'] 110 | input_text = template.format(facts=serialized_facts, question=question) 111 | elif turn_type == "condition": 112 | input_text = self.prompt['continue_template']['condition'] 113 | elif turn_type == "continue": 114 | input_text = self.prompt['continue_template']['continue'] 115 | elif turn_type == "stop": 116 | input_text = self.prompt['continue_template']['stop'] 117 | elif turn_type == 'relation_rerank': 118 | template = self.prompt['relation_rerank'] 119 | question, can_rels = input_text 120 | input_text = template.format(question=question, relations=can_rels) 121 | else: 122 | raise NotImplementedError 123 | message = {'role': 'user', 'content': input_text} 124 | return message 125 | 126 | def query_API_to_get_message(self, messages): 127 | while True: 128 | try: 129 | res = openai.ChatCompletion.create( 130 | model="gpt-3.5-turbo", 131 | messages=messages, 132 | temperature=0, 133 | max_tokens=self.max_tokens, 134 | top_p=1, 135 | frequency_penalty=0, 136 | presence_penalty=0, 137 | ) 138 | return res['choices'][0]['message'] 139 | except openai.error.RateLimitError: 140 | print('openai.error.RateLimitError\nRetrying...') 141 | time.sleep(30) 142 | except openai.error.ServiceUnavailableError: 143 | print('openai.error.ServiceUnavailableError\nRetrying...') 144 | time.sleep(20) 145 | except openai.error.Timeout: 146 | print('openai.error.Timeout\nRetrying...') 147 | time.sleep(20) 148 | except openai.error.APIError: 149 | print('openai.error.APIError\nRetrying...') 150 | time.sleep(20) 151 | except openai.error.APIConnectionError: 152 | print('openai.error.APIConnectionError\nRetrying...') 153 | time.sleep(20) 154 | # except openai.error.InvalidRequestError: 155 | # print('openai.error.InvalidRequestError\nRetrying...') 156 | 157 | def parse_result(self, result, turn_type): 158 | content = result['content'].strip() 159 | if turn_type in ["initial", "question_template"]: 160 | if "should be" in content: 161 | content = content.split("should be")[1].strip() 162 | if content.startswith('"') and content.endswith('"'): 163 | content = content[1:-1] 164 | else: 165 | matchObj = re.search(r'"(.*?)"', content) 166 | if matchObj is not None: 167 | content = matchObj.group() 168 | content = content[1:-1] 169 | else: 170 | content = content.strip().strip('"') 171 | print("Not exactly parse, we directly use content: %s" % content) 172 | 173 | return content 174 | 175 | def parse_result_v1(self, result, turn_type): 176 | content = result['content'].strip() 177 | if turn_type in ["ask_question", "continue"]: 178 | if "the simple question:" in content: 179 | content = content.split("the simple question:")[1].strip() 180 | if content.startswith('"') and content.endswith('"'): 181 | content = content[1:-1] 182 | else: 183 | matchObj = re.search(r'"(.*?)"', content) 184 | if matchObj is not None: 185 | content = matchObj.group() 186 | content = content[1:-1] 187 | else: 188 | content = content.strip().strip('"') 189 | print("Not exactly parse, we directly use content: %s" % content) 190 | 191 | return content 192 | 193 | def parse_result_v2(self, result, turn_type): 194 | content = result['content'].strip() 195 | 196 | return content 197 | 198 | def reset_history(self): 199 | self.history_messages = [] 200 | self.history_contents = [] 201 | 202 | def reset_history_messages(self): 203 | self.history_messages = [] 204 | 205 | def reset_history_contents(self): 206 | self.history_contents = [] 207 | 208 | def load_prompt_template(self, prompt_path, prompt_name): 209 | if prompt_path.endswith(".json"): 210 | with open(prompt_path, "rb") as f: 211 | prompt = json.load(f) 212 | return prompt[prompt_name] 213 | 214 | def get_response_v2(self, input_text, turn_type): 215 | message = self.create_message_v2(input_text, turn_type) 216 | self.history_messages.append(message) 217 | self.history_contents.append(message['content']) 218 | message = self.query_API_to_get_message(self.history_messages) 219 | self.history_messages.append(message) 220 | self.history_contents.append(message['content']) 221 | response = message['content'].strip() 222 | 223 | return response 224 | 225 | def create_message_v2(self, input_text, turn_type): 226 | if turn_type == "instruction": # the initial query 227 | instruction = self.prompt['instruction'] 228 | input_text = instruction 229 | # ykm 230 | # elif turn_type == "init_relation_rerank": 231 | # template = self.prompt['init_relation_rerank'] 232 | # can_rels, question, tpe, hop = input_text 233 | # if hop == 1: 234 | # hop = "first" 235 | # elif hop == 2: 236 | # hop = "second" 237 | # elif hop == 3: 238 | # hop = "third" 239 | # input_text = template.format(question=question, tpe=tpe, relations=can_rels, hop=hop) 240 | elif turn_type == "init_relation_rerank": 241 | template = self.prompt['init_relation_rerank'] 242 | can_rels, question, tpe = input_text 243 | input_text = template.format(question=question, tpe=tpe, relations=can_rels) 244 | elif turn_type == "constraints_flag": 245 | template = self.prompt['constraints_flag'] 246 | question, tpe, selected_relations = input_text 247 | if len(selected_relations) > 1: 248 | selected_relations = "are " + ", ".join(selected_relations) 249 | else: 250 | selected_relations = "is " + ", ".join(selected_relations) 251 | input_text = template.format(question=question, tpe=tpe, selected_relations=selected_relations) 252 | elif turn_type == "ask_final_answer_or_next_question": 253 | question, serialized_facts = input_text 254 | template = self.prompt['ask_final_answer_or_next_question'] 255 | input_text = template.format(facts=serialized_facts, question=question) 256 | elif turn_type == "choose_constraints": 257 | question, relation_tails, tpe_name = input_text 258 | template = self.prompt['choose_constraints'] 259 | input_text = template.format(question=question, relation_tails=relation_tails, tpe=tpe_name) 260 | elif turn_type == "final_query_template": 261 | template = self.prompt['final_query_template'] 262 | input_text = template.format(question=input_text) 263 | elif turn_type == 'relation_rerank': 264 | template = self.prompt['relation_rerank'] 265 | can_rels, question, tpe, selected_relations = input_text 266 | # 暂时注释掉 267 | # if len(selected_relations) > 1: 268 | # selected_relations = "are " + ", ".join(selected_relations) 269 | # else: 270 | # selected_relations = "is " + ", ".join(selected_relations) 271 | selected_relations = "".join(selected_relations) 272 | input_text = template.format(question=question, relations=can_rels, tpe=tpe, 273 | selected_relations=selected_relations) 274 | elif turn_type == 'relation_rerank_2hop': 275 | template = self.prompt['relation_rerank_2hop'] 276 | can_rels, question, tpe, sub_question, selected_relations = input_text 277 | sub_question = ", ".join(sub_question) 278 | selected_relations = ", ".join(selected_relations) 279 | input_text = template.format(question=question, relations=can_rels, tpe=tpe, 280 | first_sub_question=sub_question, first_relation=selected_relations) 281 | elif turn_type == 'relation_rerank_3hop': 282 | template = self.prompt['relation_rerank_3hop'] 283 | can_rels, question, tpe, sub_question, selected_relations = input_text 284 | first_sub_question = sub_question[0] 285 | second_sub_question = sub_question[1] 286 | fisrt_relation = selected_relations[0] 287 | second_relation = selected_relations[1] 288 | input_text = template.format(question=question, relations=can_rels, tpe=tpe, 289 | first_sub_question=first_sub_question, first_relation = fisrt_relation, 290 | second_sub_question=second_sub_question, second_relation=second_relation) 291 | elif turn_type == 'direct_ask_final_answer': 292 | template = self.prompt['direct_ask_final_answer'] 293 | question = input_text 294 | input_text = template.format(question=question) 295 | elif turn_type == 'final_answer_organize': 296 | template = self.prompt['final_answer_organize'] 297 | input_text = template 298 | else: 299 | raise NotImplementedError 300 | message = {'role': 'user', 'content': input_text} 301 | return message 302 | 303 | 304 | class Retriever: 305 | def __init__(self, args): 306 | self.args = args 307 | # self.initialize_PLM(args) 308 | self.initialize_KG(args) 309 | 310 | def get_retrieval_information(self, first_flag=False, gold_relations=None): 311 | triples_per_hop, tails = self.KG.get_facts_1hop(self.cur_ents, self.args.max_triples_per_relation, 312 | first_flag, gold_relations) 313 | self.reset_cur_ents(tails) 314 | # self.reset_last_ents(self.cur_ents) 315 | return triples_per_hop 316 | 317 | # 直接获得triples 318 | def get_retrieval_information_direct(self, response, tpe, first_flag=False, gold_relations=None): 319 | triples, tails = self.KG.get_facts_1hop_direct(response, tpe, self.cur_ents, self.tokenizer, self.retriever, 320 | self.args.topk, 321 | self.args.filter_score, self.args.max_triples_per_relation, 322 | first_flag, gold_relations) 323 | self.reset_cur_ents(tails) 324 | # self.reset_last_ents(self.cur_ents) 325 | return triples 326 | 327 | def get_retrieval_relations(self, first_flag=False): 328 | rels = self.KG.get_rels_1hop(self.cur_ents, first_flag) 329 | return rels 330 | 331 | def initialize_KG(self, args): 332 | self.KG = KnowledgeGraph(args.kg_source_path, args.ent_type_path, args.ent2id_path, args.rel2id_path) 333 | 334 | def initialize_PLM(self, args): 335 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 336 | self.tokenizer = tokenizer 337 | model = AutoModel.from_pretrained(args.model_path) 338 | # self.retriever = model.cuda("cuda:" + str(args.device)) 339 | self.retriever = None 340 | 341 | def reset_cur_ents(self, entity_list): 342 | self.cur_ents = entity_list 343 | # print("Current entity num: ", len(self.cur_ents)) 344 | 345 | def update_cur_ents(self, filtered_triples_per_hop): 346 | new_tails = set() 347 | for tri in filtered_triples_per_hop[1]: 348 | h, r, t = tri 349 | try: 350 | t_id = self.KG.ent2id[t] 351 | new_tails.add(t_id) 352 | except Exception as e: 353 | logging.exception(e) 354 | print("Entity string: %s not in ent2id dict" % t) 355 | continue 356 | new_tails = list(new_tails) 357 | self.reset_cur_ents(new_tails) 358 | 359 | def extract_facts(self, facts, response): 360 | response = response.lower().strip() 361 | # if response.startswith("the relevant relations:"): 362 | # response = response.replace("the relevant relations:", "") 363 | # response = response.strip() 364 | # nor_rels = response.split(",") 365 | # nor_rels = [rel.strip() for rel in nor_rels] 366 | # else: 367 | # nor_rels = response 368 | 369 | filtered_facts = [] 370 | for tri in facts: 371 | h, r, t = tri 372 | if self.filter_relation(r): 373 | continue 374 | nor_r = self.normalize_relation(r) 375 | if nor_r in response: 376 | filtered_facts.append(tri) 377 | return filtered_facts 378 | 379 | def filter_relation(self, rel): 380 | # same criteria as GraftNet 381 | relation = rel 382 | if relation == "common.topic.notable_types": return False 383 | if relation == "base.kwebbase.kwtopic.has_sentences": return False 384 | domain = relation.split(".")[0] 385 | if domain == "type" or domain == "common": return True 386 | return False 387 | 388 | def should_ignore(self, rel): 389 | if self.filter_relation(rel): 390 | return True 391 | return False 392 | 393 | def normalize_relation(self, rel): 394 | # e.g. 395 | rel_surface = rel 396 | # replace '.' and '_' with ' ' 397 | rel_surface = rel_surface.replace('.', ' ') 398 | # only keep the last two words 399 | rel_surface = ' '.join(rel_surface.split(' ')[-2:]) 400 | rel_surface = rel_surface.replace('_', ' ') 401 | return rel_surface 402 | 403 | def get_one_hop_cand_rels(self, question): 404 | pass 405 | 406 | def get_tails_list(self, cur_ents): 407 | tails = [self.KG.id2ent[ent] for ent in cur_ents] 408 | return tails 409 | 410 | 411 | class Solver: 412 | def __init__(self, args): 413 | self.args = args 414 | self.LLM = ChatGPT(args=args, prompt_path=args.prompt_path, prompt_name=args.prompt_name, 415 | max_tokens=args.max_tokens) 416 | self.SLM = Retriever(args) 417 | self.max_serialization_tokens = args.max_llm_input_tokens 418 | self.load_ent2name(args.ent2name_path) 419 | self.log = [] 420 | self.selected_relations = [] 421 | # 暂时添加一个selected_sub_questions = []来存放解析的子问题 422 | self.selected_sub_questions = [] 423 | 424 | 425 | def forward_v2(self, question, tpe_str, tpe_id): 426 | self.LLM.reset_history() 427 | self.SLM.reset_cur_ents([tpe_id]) 428 | self.reset_history() 429 | 430 | iterative_step = 0 431 | 432 | # start_response = self.LLM.get_response_v2("", "instruction") 433 | # self.log.append(start_response) 434 | 435 | while True: 436 | # select 437 | all_rel_one_hop = self.SLM.get_retrieval_relations(first_flag=iterative_step == 0) 438 | if len(all_rel_one_hop) == 0: 439 | final_answers = self.LLM.get_response_v2(question, "final_query_template") 440 | break 441 | 442 | serialized_rels = self.extract_can_rels(all_rel_one_hop, normalize_rel=False) 443 | if args.debug: 444 | print("Step-%d: serialized_rels:%s" % (iterative_step, serialized_rels)) 445 | 446 | if iterative_step == 0: 447 | llm_selected_rels = self.LLM.get_response_v2((serialized_rels, question, tpe_str), 448 | "init_relation_rerank") 449 | else: 450 | llm_selected_rels = self.LLM.get_response_v2( 451 | (serialized_rels, question, tpe_str, self.selected_relations), "relation_rerank") 452 | self.LLM.reset_history_messages() 453 | if args.debug: 454 | print("Step-%d: llm_selected_rels:%s" % (iterative_step, llm_selected_rels)) 455 | 456 | selected_relations_list = self.parse_llm_selected_relations(llm_selected_rels, all_rel_one_hop) 457 | if args.debug: 458 | print("Step-%d: selected_relations_list:%s" % (iterative_step, selected_relations_list)) 459 | if len(selected_relations_list) == 0: 460 | final_answers = self.LLM.get_response_v2(question, "final_query_template") 461 | break 462 | 463 | self.selected_relations.extend(selected_relations_list) 464 | if args.debug: 465 | print("Step-%d: self.selected_relations:%s" % (iterative_step, self.selected_relations)) 466 | 467 | filtered_triples_per_hop = self.SLM.get_retrieval_information(first_flag=iterative_step == 0, 468 | gold_relations=selected_relations_list) 469 | cvt_triples, mid_triples, entstr_triples = self.classify_triples(filtered_triples_per_hop) 470 | if len(cvt_triples) > 0: 471 | # constraint 472 | if args.debug: 473 | print("Step-%d: Constraints" % iterative_step) 474 | constraints_candidate = self.serialize_constraints(cvt_triples) 475 | if args.debug: 476 | print("Step-%d: constraints_candidate:%s" % (iterative_step, constraints_candidate)) 477 | constraint_response = self.LLM.get_response_v2((question, constraints_candidate, tpe_str), "choose_constraints") 478 | self.log.append(constraint_response) 479 | if args.debug: 480 | print("Step-%d: constraint_response:%s" % (iterative_step, constraint_response)) 481 | if self.has_constraints(constraint_response): 482 | filtered_triples_per_hop = self.filter_triples(filtered_triples_per_hop, cvt_triples, 483 | constraint_response) 484 | self.SLM.update_cur_ents(filtered_triples_per_hop) 485 | if args.debug: 486 | print("Step-%d: filtered_triples_per_hop:%s" % (iterative_step, filtered_triples_per_hop)) 487 | if args.debug: 488 | print("Step-%d: self.SLM.cur_ents:%s" % (iterative_step, self.SLM.cur_ents)) 489 | serialized_facts = self.serialize_facts(filtered_triples_per_hop) 490 | self.log.append(serialized_facts) 491 | 492 | if args.debug: 493 | print("Step-%d: serialized_facts:%s" % (iterative_step, serialized_facts)) 494 | 495 | final_ans_or_next_que = self.LLM.get_response_v2((question, serialized_facts), 496 | "ask_final_answer_or_next_question") 497 | self.log.append(final_ans_or_next_que) 498 | 499 | # 新加的 500 | final_answers = self.parse_result(final_ans_or_next_que, "final_answer") 501 | self.log.append(final_answers) 502 | break 503 | return final_answers, self.LLM.history_contents, self.log 504 | 505 | def reset_selected_list(self): 506 | self.selected_sub_questions = [] 507 | self.selected_relations = [] 508 | 509 | def is_end(self, response, iterative_step): 510 | if "no" in response.lower() or iterative_step > 8: 511 | return True 512 | else: 513 | return False 514 | 515 | def load_ent2name(self, ent2name_path): 516 | with open(ent2name_path, "rb") as f: 517 | self.cvt_flag_dict, self.mid_mapping_dict = pickle.load(f) 518 | 519 | def convert_hyper_facts_to_text(self, facts): 520 | subj, rels, objs = facts 521 | 522 | if self.is_cvt(subj): 523 | return None 524 | elif subj in self.mid_mapping_dict: 525 | subj_surface = self.mid_mapping_dict[subj] 526 | elif self.is_ent(subj): 527 | # print("head entity %s doesn't have name, we skip this triple." % subj) 528 | return None 529 | else: 530 | subj_surface = subj 531 | 532 | flat_facts = [] 533 | for rel, obj in zip(rels, objs): 534 | if self.should_ignore(rel): 535 | continue 536 | else: 537 | nor_rel = self.normalize_relation(rel) 538 | 539 | if self.is_cvt(obj): 540 | continue 541 | elif obj in self.mid_mapping_dict: 542 | obj_surface = self.mid_mapping_dict[obj] 543 | elif self.is_ent(obj): 544 | # print("tail entity %s doesn't have name, we skip this triple." % obj) 545 | continue 546 | else: 547 | obj_surface = obj 548 | 549 | flat_facts.append((subj_surface, nor_rel, obj_surface)) 550 | 551 | return flat_facts 552 | 553 | def convert_fact_to_text(self, fact, normalize_rel=False): 554 | subj, rel, obj = fact 555 | 556 | if self.should_ignore(rel): 557 | return None 558 | 559 | if rel.endswith(".from"): 560 | rel = rel.rstrip(".from") 561 | rel = rel + ".start_time" 562 | if rel.endswith(".to"): 563 | rel = rel.rstrip(".to") 564 | rel = rel + ".end_time" 565 | rel_surface = self.normalize_relation(rel) if normalize_rel else rel 566 | 567 | # subject 568 | if subj.startswith("CVT"): 569 | subj_surface = subj 570 | elif subj in self.mid_mapping_dict: 571 | subj_surface = self.mid_mapping_dict[subj] 572 | elif subj.startswith("m.") or subj.startswith('g.'): 573 | # print("head entity %s doesn't have name, we skip this triple." % subj) 574 | return None 575 | else: 576 | subj_surface = subj 577 | 578 | # object 579 | if obj.startswith("CVT"): 580 | obj_surface = obj 581 | elif obj in self.mid_mapping_dict: 582 | obj_surface = self.mid_mapping_dict[obj] 583 | elif obj.startswith("m.") or obj.startswith('g.'): 584 | # print("tail entity %s doesn't have name, we skip this triple." % obj) 585 | return None 586 | else: 587 | obj_surface = obj 588 | 589 | return (subj_surface, rel_surface, obj_surface) 590 | 591 | def extract_can_rels(self, all_rel_one_hop, normalize_rel=True): 592 | rel_prompt = '"{relation}"' 593 | nor_rels_set = [] 594 | for rel in all_rel_one_hop: 595 | if self.filter_relation(rel): 596 | continue 597 | nor_r = self.normalize_relation(rel) if normalize_rel else rel 598 | if nor_r not in nor_rels_set: 599 | nor_rels_set.append(rel_prompt.format(relation=nor_r)) 600 | rel_candidate = ", ".join(nor_rels_set) 601 | return rel_candidate 602 | 603 | def serialize_rels(self, rels, normalize_rel=True): 604 | nor_rels_set = [] 605 | for rel in rels: 606 | if self.filter_relation(rel): 607 | continue 608 | nor_r = self.normalize_relation(rel) if normalize_rel else rel 609 | if nor_r not in nor_rels_set: 610 | nor_rels_set.append(nor_r) 611 | # rel_candidate = ", ".join(nor_rels_set) 612 | rel_candidate = ";\n ".join(nor_rels_set) 613 | return rel_candidate 614 | 615 | # 直接拼接 616 | def serialize_facts_direct(self, facts): 617 | # 拼接triples 618 | facts_str_for_one_tail_ent = ["(" + ", ".join(fact) + ")" for fact in facts] 619 | 620 | serialized_facts = "" 621 | for fact in facts_str_for_one_tail_ent: 622 | serialized_facts_tmp = serialized_facts + fact + "; " 623 | serialized_facts = serialized_facts_tmp 624 | return serialized_facts 625 | 626 | def serialize_facts(self, facts_per_hop): 627 | h_r_t = defaultdict(lambda: defaultdict(set)) 628 | visited_flag = {} 629 | name2cvt_tmp = {} 630 | cvt_count = 0 631 | all_facts = [] 632 | for hop, facts in facts_per_hop.items(): 633 | if len(facts) > 0: 634 | for fact in facts: 635 | h, r, t = fact 636 | if self.is_cvt(h): 637 | if h not in name2cvt_tmp: 638 | cvt = "CVT_" + str(cvt_count) 639 | cvt_count += 1 640 | name2cvt_tmp[h] = cvt 641 | h = name2cvt_tmp[h] 642 | if self.is_cvt(t): 643 | if t not in name2cvt_tmp: 644 | cvt = "CVT_" + str(cvt_count) 645 | cvt_count += 1 646 | name2cvt_tmp[t] = cvt 647 | t = name2cvt_tmp[t] 648 | fact = (h, r, t) 649 | all_facts.append(fact) 650 | visited_flag[fact] = False 651 | h_r_t[h][r].add(t) 652 | 653 | if len(all_facts) > 0: 654 | all_facts_str = [] 655 | for tri in all_facts: 656 | facts_str_for_one_tail_ent = [] 657 | if not visited_flag[tri]: 658 | h, r, t = tri 659 | if t.startswith("CVT") and len(h_r_t[t]) == 0: 660 | continue 661 | 662 | if h.startswith("CVT"): 663 | # print("Qid:[%s] has single cvt head entities." % qid) 664 | # logger.info(triples_per_hop) 665 | continue 666 | elif t.startswith("CVT"): 667 | st = self.convert_fact_to_text(tri, normalize_rel=False) 668 | facts_str_for_one_tail_ent.append(st) 669 | one_hop_triples = h_r_t[t] 670 | if len(one_hop_triples) > 0: 671 | for key_r, value_ts in one_hop_triples.items(): 672 | for t_ in value_ts: 673 | visit_tri = (t, key_r, t_) 674 | if not visited_flag[visit_tri]: 675 | visited_flag[visit_tri] = True 676 | st = self.convert_fact_to_text(visit_tri, normalize_rel=False) 677 | if st is not None: 678 | assert len(st) == 3 679 | facts_str_for_one_tail_ent.append(st) 680 | # h_new = t 681 | # r_new = [] 682 | # t_new = [] 683 | # for key_r, value_ts in one_hop_triples.items(): 684 | # for t_ in value_ts: 685 | # visit_tri = (t, key_r, t_) 686 | # if not visited_flag[visit_tri]: 687 | # r_new.append(key_r) 688 | # t_new.append(t_) 689 | # visited_flag[visit_tri] = True 690 | # tri_new = (t, r_new, t_new) 691 | # if len(r_new) == len(t_new) > 0: 692 | # str_tri_list = self.convert_hyper_facts_to_text(tri_new) 693 | # if str_tri_list is not None: 694 | # for st in str_tri_list: 695 | # assert len(st) == 3 696 | # if st not in facts_str: 697 | # facts_str.append(st) 698 | else: 699 | st = self.convert_fact_to_text(tri, normalize_rel=False) 700 | if st is not None: 701 | assert len(st) == 3 702 | if st not in facts_str_for_one_tail_ent: 703 | facts_str_for_one_tail_ent.append(st) 704 | facts_str_for_one_tail_ent = ["(" + ", ".join(fact) + ")" for fact in facts_str_for_one_tail_ent] 705 | facts_str = ", ".join(facts_str_for_one_tail_ent) 706 | all_facts_str.append(facts_str) 707 | 708 | # facts_str = ["(" + ", ".join(fact) + ")" for fact in facts_str] 709 | serialized_facts = "" 710 | for fact in all_facts_str: 711 | serialized_facts_tmp = serialized_facts + fact + "; " 712 | if len(serialized_facts_tmp.split()) > self.max_serialization_tokens: 713 | break 714 | else: 715 | serialized_facts = serialized_facts_tmp 716 | serialized_facts = serialized_facts.strip("; ") 717 | else: 718 | serialized_facts = "" 719 | return serialized_facts 720 | 721 | def serialize_facts_v1(self, facts): 722 | if len(facts) > 0: 723 | h_r_t = defaultdict(lambda: defaultdict(set)) 724 | visited_flag = {} 725 | for fact in facts: 726 | h, r, t = fact 727 | visited_flag[tuple(fact)] = False 728 | h_r_t[h][r].add(t) 729 | facts_str = [] 730 | for tri in facts: 731 | if not visited_flag[tuple(tri)]: 732 | h, r, t = tri 733 | if self.is_cvt(t) and len(h_r_t[t]) == 0: 734 | continue 735 | if self.is_cvt(h): 736 | # print("Qid:[%s] has single cvt head entities." % qid) 737 | # logger.info(triples_per_hop) 738 | continue 739 | elif self.is_cvt(t): 740 | one_hop_triples = h_r_t[t] 741 | if len(one_hop_triples) > 0: 742 | h_new = t 743 | r_new = [] 744 | t_new = [] 745 | for key_r, value_ts in one_hop_triples.items(): 746 | for t_ in value_ts: 747 | visit_tri = (t, key_r, t_) 748 | if not visited_flag[visit_tri]: 749 | r_new.append(key_r) 750 | t_new.append(t_) 751 | visited_flag[visit_tri] = True 752 | tri_new = (h, r_new, t_new) 753 | if len(r_new) == len(t_new) > 0: 754 | str_tri_list = self.convert_hyper_facts_to_text(tri_new) 755 | if str_tri_list is not None: 756 | for st in str_tri_list: 757 | assert len(st) == 3 758 | if st not in facts_str: 759 | facts_str.append(st) 760 | else: 761 | st = self.convert_fact_to_text(tri) 762 | if st is not None: 763 | assert len(st) == 3 764 | if st not in facts_str: 765 | facts_str.append(st) 766 | facts_str = ["(" + ", ".join(fact) + ")" for fact in facts_str] 767 | serialized_facts = "" 768 | for fact in facts_str: 769 | serialized_facts_tmp = serialized_facts + fact + "; " 770 | if len(serialized_facts_tmp.split()) > self.max_serialization_tokens: 771 | break 772 | else: 773 | serialized_facts = serialized_facts_tmp 774 | # serialized_facts = "; ".join(facts_str) 775 | serialized_facts = serialized_facts.strip("; ") 776 | else: 777 | serialized_facts = "" 778 | return serialized_facts 779 | 780 | def is_cvt(self, entity): 781 | if self.cvt_flag_dict[entity]: 782 | return True 783 | else: 784 | return False 785 | 786 | def is_ent(self, ent_str): 787 | if type(ent_str) is not bool and (ent_str.startswith("m.") or ent_str.startswith("g.")): 788 | return True 789 | else: 790 | return False 791 | 792 | def filter_relation(self, rel): 793 | # same criteria as GraftNet 794 | relation = rel 795 | if relation == "common.topic.notable_types": return False 796 | if relation == "base.kwebbase.kwtopic.has_sentences": return False 797 | domain = relation.split(".")[0] 798 | if domain == "type" or domain == "common": return True 799 | return False 800 | 801 | def should_ignore(self, rel): 802 | if self.filter_relation(rel): 803 | return True 804 | return False 805 | 806 | def normalize_relation(self, rel): 807 | # e.g. 808 | rel_surface = rel 809 | # replace '.' and '_' with ' ' 810 | rel_surface = rel_surface.replace('.', ' ') 811 | # only keep the last two words 812 | rel_surface = ' '.join(rel_surface.split(' ')[-2:]) 813 | rel_surface = rel_surface.replace('_', ' ') 814 | return rel_surface 815 | 816 | def parse_llm_selected_relations(self, llm_sel_rels_str, can_rels): 817 | # llm_sel_rels = llm_sel_rels_str.strip(" ;.|,<>`[]'") 818 | # llm_sel_rels = llm_sel_rels.split(',') 819 | # llm_sel_rels = [rel.strip(" ;.|,<>`[]'").strip(" ;.|,<>`[]'") for rel in llm_sel_rels] 820 | # llm_sel_rel_list = [] 821 | # for rel in llm_sel_rels: 822 | # if rel in can_rels: 823 | # llm_sel_rel_list.append(rel) 824 | # else: 825 | # print(rel) 826 | # if len(llm_sel_rel_list) == 0: 827 | # for rel in can_rels: 828 | # if rel in llm_sel_rels_str: 829 | # llm_sel_rel_list.append(rel) 830 | # print("-----llm_ser_rels:\n%s\ndoesn't match the predefined format" % llm_sel_rels) 831 | llm_sel_rel_list = [] 832 | for rel in can_rels: 833 | if rel in llm_sel_rels_str: 834 | llm_sel_rel_list.append(rel) 835 | return llm_sel_rel_list 836 | 837 | def parse_result(self, response, parse_type): 838 | response = response.lower() 839 | if parse_type == "next_question": 840 | if "the next question:" in response: 841 | next_question = response.split("the next question:")[1].strip() 842 | elif ":" in response: 843 | next_question = response.split(":")[1].strip() 844 | else: 845 | next_question = response 846 | print("Not parse the next question exactly, directly use the response: ", response) 847 | return next_question 848 | elif parse_type == "final_answer": 849 | if "the final answers:" in response: 850 | final_answer = response.split("the final answers:")[1].strip() 851 | # 暂时注释掉 852 | elif ":" in response: 853 | final_answer = response.split(":")[1].strip() 854 | # 新添加的用于解析direct query 855 | else: 856 | final_answer = response 857 | # 暂时注释掉 858 | # print("Not parse the final answer exactly, directly use the response: ", response) 859 | return final_answer 860 | 861 | def classify_triples(self, filtered_triples_per_hop): 862 | cvt_triples, mid_triples, entstr_triples = set(), set(), set() 863 | if 0 in filtered_triples_per_hop: 864 | triples_0 = filtered_triples_per_hop[0] 865 | else: 866 | triples_0 = [] 867 | if 1 in filtered_triples_per_hop: 868 | triples_1 = filtered_triples_per_hop[1] 869 | else: 870 | triples_1 = [] 871 | 872 | if len(triples_1) == 0: 873 | for tri in triples_0: 874 | if self.is_ent(tri[2]): 875 | mid_triples.add(tuple(tri)) 876 | else: 877 | entstr_triples.add(tuple(tri)) 878 | else: 879 | for tri in triples_1: 880 | cvt_triples.add(tuple(tri)) 881 | return cvt_triples, mid_triples, entstr_triples 882 | 883 | def serialize_constraints(self, cvt_triples): 884 | r2t_set = defaultdict(set) 885 | for tri in cvt_triples: 886 | subj, rel, obj = tri 887 | if self.should_ignore(rel): 888 | continue 889 | 890 | if rel.endswith(".from"): 891 | rel = rel.rstrip(".from") 892 | rel = rel + ".start_time" 893 | if rel.endswith(".to"): 894 | rel = rel.rstrip(".to") 895 | rel = rel + ".end_time" 896 | 897 | rel_surface = rel 898 | 899 | # object 900 | if obj in self.mid_mapping_dict: 901 | obj_surface = self.mid_mapping_dict[obj] 902 | elif obj.startswith("m.") or obj.startswith('g.'): 903 | # print("tail entity %s doesn't have name, we skip this triple." % obj) 904 | continue 905 | else: 906 | obj_surface = obj 907 | 908 | if obj_surface == "To" or "has_no_value" in rel: 909 | continue 910 | 911 | r2t_set[rel_surface].add(obj_surface) 912 | 913 | constraints = [] 914 | for r, t_set in r2t_set.items(): 915 | t_set = ['"' + t + '"' for t in t_set] 916 | constraints.append('"' + r + '"' + ": [" + ", ".join(t_set) + "]") 917 | # constraints = constraints.rstrip("\n") 918 | constraints = "\n".join(constraints) 919 | return constraints 920 | 921 | def has_constraints(self, constraint_response): 922 | if "no" in constraint_response.lower(): 923 | return False 924 | else: 925 | return True 926 | 927 | def filter_triples(self, filtered_triples_per_hop, cvt_triples, constraint_response): 928 | valid_cvt_nodes = set() 929 | h_r_t = defaultdict(list) 930 | for tri in cvt_triples: 931 | h, r, t = tri 932 | h_r_t[h].append((r, t)) 933 | for cvt, r_ts in h_r_t.items(): 934 | flag = True 935 | at_leat_one_flag = False 936 | for r_t in r_ts: 937 | rel, obj = r_t 938 | 939 | if rel.endswith(".from"): 940 | rel = rel.rstrip(".from") 941 | rel = rel + ".start_time" 942 | if rel.endswith(".to"): 943 | rel = rel.rstrip(".to") 944 | rel = rel + ".end_time" 945 | rel_surface = rel 946 | 947 | # object 948 | if obj in self.mid_mapping_dict: 949 | obj_surface = self.mid_mapping_dict[obj] 950 | elif obj.startswith("m.") or obj.startswith('g.'): 951 | # print("tail entity %s doesn't have name, we skip this triple." % obj) 952 | continue 953 | else: 954 | obj_surface = obj 955 | 956 | if rel_surface.lower() in constraint_response.lower(): 957 | at_leat_one_flag = True 958 | if obj_surface.lower() not in constraint_response.lower(): 959 | flag = False 960 | break 961 | if flag and at_leat_one_flag: 962 | valid_cvt_nodes.add(cvt) 963 | 964 | # 添加软约束条件,解析cvt结点的rel,若有两部分在response中则选中 965 | if len(valid_cvt_nodes) == 0: 966 | for cvt, r_ts in h_r_t.items(): 967 | flag = True 968 | at_leat_one_flag = False 969 | for r_t in r_ts: 970 | rel, obj = r_t 971 | 972 | if rel.endswith(".from"): 973 | rel = rel.rstrip(".from") 974 | rel = rel + ".start_time" 975 | if rel.endswith(".to"): 976 | rel = rel.rstrip(".to") 977 | rel = rel + ".end_time" 978 | rel_surface = rel 979 | 980 | # object 981 | if obj in self.mid_mapping_dict: 982 | obj_surface = self.mid_mapping_dict[obj] 983 | elif obj.startswith("m.") or obj.startswith('g.'): 984 | # print("tail entity %s doesn't have name, we skip this triple." % obj) 985 | continue 986 | else: 987 | obj_surface = obj 988 | 989 | rel_surface_list = rel_surface.split(".") 990 | for rel in rel_surface_list: 991 | if rel.lower() in constraint_response.lower(): 992 | at_leat_one_flag = True 993 | if obj_surface.lower() not in constraint_response.lower(): 994 | flag = False 995 | break 996 | else: 997 | flag = True 998 | if flag and at_leat_one_flag: 999 | valid_cvt_nodes.add(cvt) 1000 | break 1001 | 1002 | new_tris_per_hop = defaultdict(set) 1003 | for hop in [0, 1]: 1004 | triples = filtered_triples_per_hop[hop] 1005 | for tri in triples: 1006 | h, r, t = tri 1007 | if hop == 0: 1008 | if t in valid_cvt_nodes: 1009 | new_tris_per_hop[hop].add(tuple(tri)) 1010 | elif hop == 1: 1011 | if h in valid_cvt_nodes: 1012 | new_tris_per_hop[hop].add(tuple(tri)) 1013 | return new_tris_per_hop 1014 | 1015 | def serialize_facts_one_hop(self, facts): 1016 | if len(facts) > 0: 1017 | h_r_t = defaultdict(lambda: defaultdict(set)) 1018 | visited_flag = {} 1019 | for fact in facts: 1020 | h, r, t = fact 1021 | visited_flag[tuple(fact)] = False 1022 | h_r_t[h][r].add(t) 1023 | facts_str = [] 1024 | for tri in facts: 1025 | if not visited_flag[tuple(tri)]: 1026 | h, r, t = tri 1027 | if self.is_cvt(t) and len(h_r_t[t]) == 0: 1028 | continue 1029 | if self.is_cvt(h): 1030 | # print("Qid:[%s] has single cvt head entities." % qid) 1031 | # logger.info(triples_per_hop) 1032 | continue 1033 | elif self.is_cvt(t): 1034 | one_hop_triples = h_r_t[t] 1035 | if len(one_hop_triples) > 0: 1036 | h_new = t 1037 | r_new = [] 1038 | t_new = [] 1039 | for key_r, value_ts in one_hop_triples.items(): 1040 | for t_ in value_ts: 1041 | visit_tri = (t, key_r, t_) 1042 | if not visited_flag[visit_tri]: 1043 | r_new.append(key_r) 1044 | t_new.append(t_) 1045 | visited_flag[visit_tri] = True 1046 | tri_new = (h, r_new, t_new) 1047 | if len(r_new) == len(t_new) > 0: 1048 | str_tri_list = self.convert_hyper_facts_to_text(tri_new) 1049 | if str_tri_list is not None: 1050 | for st in str_tri_list: 1051 | assert len(st) == 3 1052 | if st not in facts_str: 1053 | facts_str.append(st) 1054 | else: 1055 | st = self.convert_fact_to_text(tri) 1056 | if st is not None: 1057 | assert len(st) == 3 1058 | if st not in facts_str: 1059 | facts_str.append(st) 1060 | facts_str = ["(" + ", ".join(fact) + ")" for fact in facts_str] 1061 | serialized_facts = "" 1062 | for fact in facts_str: 1063 | serialized_facts_tmp = serialized_facts + fact + "; " 1064 | if len(serialized_facts_tmp.split()) > self.max_serialization_tokens: 1065 | break 1066 | else: 1067 | serialized_facts = serialized_facts_tmp 1068 | # serialized_facts = "; ".join(facts_str) 1069 | serialized_facts = serialized_facts.strip("; ") 1070 | else: 1071 | serialized_facts = "" 1072 | return serialized_facts 1073 | 1074 | def is_end_v2(self, response, iterative_step): 1075 | if "final" in response.lower() or iterative_step > 3: 1076 | return True 1077 | else: 1078 | return False 1079 | 1080 | def reset_history(self): 1081 | self.log = [] 1082 | self.selected_relations = [] 1083 | self.selected_sub_questions = [] 1084 | 1085 | def get_tails_list(self, cur_ents): 1086 | tails = self.SLM.get_tails_list(cur_ents) 1087 | return tails 1088 | 1089 | def main(args, all_data, idx, api_key): 1090 | import openai 1091 | openai.api_key = api_key 1092 | if idx == -1: 1093 | output_path = args.output_path 1094 | chat_log_path = args.chat_log_path 1095 | else: 1096 | idx = "0" + str(idx) if idx < 10 else str(idx) # 00 01 02 ... 29 1097 | output_path = args.output_path + "_" + idx 1098 | chat_log_path = args.chat_log_path + "_" + idx 1099 | 1100 | print("Start PID %d and save to %s" % (os.getpid(), output_path)) 1101 | solver = Solver(args) 1102 | 1103 | count = 0 1104 | valid_count = 0 1105 | with open(output_path, "w") as f: 1106 | with open(chat_log_path, "w") as fclog: 1107 | for sample in tqdm(all_data, total=len(all_data)): 1108 | # if sample["ID"] not in ["test_10943"]: 1109 | # continue 1110 | try: 1111 | question = sample["Question"] 1112 | tpe_name = sample["TopicEntityName"] 1113 | tpe_id = sample['TopicEntityID'] 1114 | 1115 | prediction, chat_history, record = solver.forward_v2(question, tpe_name, tpe_id) 1116 | valid_count += 1 1117 | except openai.error.InvalidRequestError as e: 1118 | print(e) 1119 | continue 1120 | except Exception as e: 1121 | logging.exception(e) 1122 | continue 1123 | 1124 | chat = sample["ID"] + "\n" + "\n******\n".join(chat_history) + "\nAnswers: " + str( 1125 | sample['Answers']) + "\n------------------------------------------\n" 1126 | fclog.write(chat) 1127 | 1128 | count += 1 1129 | if count < 5: 1130 | print(sample['Answers']) 1131 | print(prediction) 1132 | print("---------------------") 1133 | sample["Prediction"] = prediction 1134 | f.write(json.dumps(sample) + "\n") 1135 | 1136 | print("---------------PID %d end with %d/%d samples--------------" % (os.getpid(), valid_count, count)) 1137 | 1138 | 1139 | 1140 | def parse_args(): 1141 | parser = argparse.ArgumentParser() 1142 | parser.add_argument('--input_path', default=None) 1143 | parser.add_argument('--output_path', default=None) 1144 | parser.add_argument('--chat_log_path', default=None) 1145 | parser.add_argument('--log_path', default=None) 1146 | parser.add_argument('--model_path', default=None) 1147 | parser.add_argument('--debug', action="store_true") 1148 | parser.add_argument('--prompt_path') 1149 | parser.add_argument('--prompt_name', default="chat", ) 1150 | parser.add_argument('--bagging_type', default="llm", ) 1151 | parser.add_argument('--overwrite', action="store_true") 1152 | parser.add_argument('--device', default=0, help='the gpu device') 1153 | parser.add_argument('--topk', default=10, type=int, help='retrieve the topk score paths') 1154 | parser.add_argument('--max_tokens', default=10, type=int, help='retrieve the topk score paths') 1155 | parser.add_argument('--api_key', default="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", type=str) 1156 | parser.add_argument('--filter_score', default=0.0, type=float, help='the minimal cosine similarity') 1157 | parser.add_argument('--kg_source_path', default=None, help='the sparse triples file') 1158 | parser.add_argument('--ent_type_path', default=None, help='the file of entities type of sparse triples') 1159 | parser.add_argument('--ent2id_path', default=None, help='the sparse ent2id file') 1160 | parser.add_argument('--rel2id_path', default=None, help='the sparse rel2id file') 1161 | parser.add_argument('--ent2name_path', default=None, help='the sparse rel2id file') 1162 | parser.add_argument('--max_triples_per_relation', default=40, type=int) 1163 | parser.add_argument('--max_llm_input_tokens', default=3400, type=int) 1164 | parser.add_argument('--num_process', default=1, type=int, help='the number of multi-process') 1165 | 1166 | 1167 | args = parser.parse_args() 1168 | 1169 | print("Start querying the LLM.") 1170 | return args 1171 | 1172 | 1173 | if __name__ == '__main__': 1174 | args = parse_args() 1175 | if not args.api_key.startswith("sk-"): 1176 | with open(args.api_key, "r") as f: 1177 | all_keys = f.readlines() 1178 | all_keys = [line.strip('\n') for line in all_keys] 1179 | assert len(all_keys) == args.num_process, (len(all_keys), args.num_process) 1180 | 1181 | if args.input_path.endswith('jsonl'): 1182 | with open(args.input_path, "r") as f: 1183 | all_lines = f.readlines() 1184 | all_data = [json.loads(line) for line in all_lines] 1185 | print("Totally %d test examples." % len(all_data)) 1186 | elif args.input_path.endswith('json'): 1187 | with open(args.input_path, "r") as f: 1188 | all_data = json.load(f) 1189 | print("Totally %d test examples." % len(all_data)) 1190 | 1191 | # used for interrupted scenario 1192 | # with open(args.output_path, "r") as f: 1193 | # all_lines = f.readlines() 1194 | # all_lines = [json.loads(line.strip("\n")) for line in all_lines] 1195 | # already_id = [line['ID'] for line in all_lines] 1196 | # all_data = [data for data in all_data if data['ID'] not in already_id] 1197 | # print("There are %d test examples need to be processed." % len(all_data)) 1198 | 1199 | if args.num_process == 1: 1200 | main(args, all_data, idx=-1, api_key=args.api_key) 1201 | else: 1202 | num_each_split = int(len(all_data) / args.num_process) 1203 | p = mp.Pool(args.num_process) 1204 | for idx in range(args.num_process): 1205 | start = idx * num_each_split 1206 | if idx == args.num_process - 1: 1207 | end = max((idx + 1) * num_each_split, len(all_data)) 1208 | else: 1209 | end = (idx + 1) * num_each_split 1210 | split_data = all_data[start:end] 1211 | try: 1212 | p.apply_async(main, args=(args, split_data, idx, all_keys[idx])) 1213 | except Exception as e: 1214 | logging.exception(e) 1215 | 1216 | p.close() 1217 | p.join() 1218 | print("All of the child processes over!") 1219 | --------------------------------------------------------------------------------