├── .gitignore
├── KnowledgeBase
├── KG_api.py
├── KG_sparse_api.py
└── __init__.py
├── LICENSE
├── README.md
├── asset
├── model.png
├── spider_rea_wo_icl_v1.png
├── spider_syn_wo_icl_v1.png
├── spider_wo_icl_v1.png
├── tabfact_wo_icl_v1.png
├── webqsp_wo_icl_v1.png
├── wikisql_wo_icl_v1.png
└── wtq_wo_icl_v1.png
├── evaluate_for_spider.py
├── evaluate_for_tabfact.py
├── evaluate_for_webqsp.py
├── evaluate_for_wikisql.py
├── outputs
├── spider-realistic
│ └── output_wo_icl_v1.jsonl
├── spider-syn
│ └── output_wo_icl_v1.jsonl
├── spider
│ └── output_wo_icl_v1.jsonl
├── tabfact
│ └── output_wo_icl_v1.jsonl
├── wikisql
│ └── output_wo_icl_v1.jsonl
└── wtq
│ └── output_wo_icl_v1.jsonl
├── process_sql.py
├── prompts
├── prompt_for_spider.json
├── prompt_for_tabfact.json
├── prompt_for_webqsp.json
├── prompt_for_wikisql.json
└── prompt_for_wtq.json
├── scripts
├── eval_for_tabfact.sh
├── eval_spider_pred.sh
├── run_spider_rea_wo_icl_v1.sh
├── run_spider_syn_wo_icl_v1.sh
├── run_spider_wo_icl_v1.sh
├── run_tabfact_wo_icl_v1.sh
├── run_webqsp_wo_icl_v1.sh
├── run_wikisql_wo_icl_v1.sh
└── run_wtq_wo_icl_v1.sh
├── structgpt_for_tableqa.py
├── structgpt_for_text_to_sql.py
└── structgpt_for_webqsp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/KnowledgeBase/KG_api.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from copy import deepcopy
3 |
4 | import logging
5 | import numpy as np
6 |
7 | from KnowledgeBase.KG_sparse_api import KnowledgeGraphSparse
8 | import pickle
9 |
10 | END_OF_HOP = "end.hop"
11 | SEP = "[SEP]"
12 |
13 |
14 | class KnowledgeGraph(object):
15 | def __init__(self, sparse_triples_path, sparse_ent_type_path, ent2id_path, rel2id_path):
16 | triples_path, ent_type_path = sparse_triples_path, sparse_ent_type_path
17 | print("The sparse KG instantiate via int triples from the %s" % (triples_path))
18 | self.sparse_kg = KnowledgeGraphSparse(triples_path=triples_path, ent_type_path=ent_type_path)
19 | self.ent2id = self._load_pickle_file(ent2id_path)
20 | self.id2ent = self._reverse_dict(self.ent2id)
21 | self.rel2id = self._load_pickle_file(rel2id_path)
22 | self.id2rel = self._reverse_dict(self.rel2id)
23 | print("The sparse KG instantiate over, all triples: %d, max head id: %d." % (
24 | self.sparse_kg.E, self.sparse_kg.max_head))
25 |
26 | @staticmethod
27 | def _load_pickle_file(filename):
28 | with open(filename, "rb") as f:
29 | return pickle.load(f)
30 |
31 | @staticmethod
32 | def _reverse_dict(ori_dict):
33 | reversed_dict = {v: k for k, v in ori_dict.items()}
34 | return reversed_dict
35 |
36 | def get_facts_1hop(self, seeds, max_triples_per_relation, first_flag, gold_relations):
37 | if first_flag:
38 | seeds_id = []
39 | for seed in seeds:
40 | try:
41 | seed_id = self.ent2id[seed]
42 | seeds_id.append(seed_id)
43 | except Exception as e:
44 | logging.exception(e)
45 | print("Entity string: %s not in ent2id dict" % seed)
46 | continue
47 | else:
48 | seeds_id = seeds
49 | if len(seeds_id) == 0:
50 | return defaultdict(list), []
51 | triples_per_hop, tails = self.sparse_kg.get_facts_1hop(seeds_id, self.id2rel, self.rel2id, max_triples_per_relation, gold_relations)
52 | triples_per_hop = {hop: [[self.id2ent[triple[0]], self.id2rel[triple[1]], self.id2ent[triple[2]]] for triple in triples]
53 | for hop, triples in triples_per_hop.items()}
54 | return triples_per_hop, tails
55 |
56 | def get_rels_1hop(self, seeds, first_flag):
57 | if first_flag:
58 | seeds_id = []
59 | for seed in seeds:
60 | try:
61 | seed_id = self.ent2id[seed]
62 | seeds_id.append(seed_id)
63 | except Exception as e:
64 | logging.exception(e)
65 | print("Entity string: %s not in ent2id dict" % seed)
66 | continue
67 | else:
68 | seeds_id = seeds
69 | if len(seeds_id) == 0:
70 | return []
71 | can_rels = self.sparse_kg.get_rels_1hop(seeds_id, self.id2rel)
72 | return can_rels
73 |
--------------------------------------------------------------------------------
/KnowledgeBase/KG_sparse_api.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | from collections import defaultdict
4 | from copy import deepcopy
5 |
6 | import numpy as np
7 | import pickle
8 | import pandas as pd
9 | from scipy.sparse import csr_matrix
10 | VERY_LARGT_NUM = 10**8
11 | PATH_CUTOFF = 10**6
12 | NODE_CUTOFF = 10**4
13 |
14 |
15 | class KnowledgeGraphSparse(object):
16 | def __init__(self, triples_path: str, ent_type_path: str):
17 | self.triple = self._load_npy_file(triples_path)
18 | self.ent_type = self._load_npy_file(ent_type_path)
19 | self.bin_map = np.zeros_like(self.ent_type, dtype=np.int32)
20 | self.E = self.triple.shape[0]
21 | self.head2fact = csr_matrix(
22 | (np.ones(self.E), (self.triple[:, 0], np.arange(self.E)))).astype('bool')
23 | self.rel2fact = csr_matrix(
24 | (np.ones(self.E), (self.triple[:, 1], np.arange(self.E)))).astype('bool')
25 | self.tail2fact = csr_matrix(
26 | (np.ones(self.E), (self.triple[:, 2], np.arange(self.E)))).astype('bool')
27 | self.max_head = max(self.triple[:, 0])
28 | self.max_tail = max(self.triple[:, 2])
29 | self.last_tails = set()
30 |
31 | @staticmethod
32 | def _load_npy_file(filename):
33 | return np.load(filename)
34 |
35 | def _fetch_forward_triple(self, seed_set):
36 | seed_set = np.clip(seed_set, a_min=0, a_max=self.max_head)
37 | indices = self.head2fact[seed_set].indices
38 | return self.triple[indices]
39 |
40 | def _fetch_backward_triple(self, seed_set):
41 | seed_set = np.clip(seed_set, a_min=0, a_max=self.max_tail)
42 | indices = self.tail2fact[seed_set].indices
43 | return self.triple[indices]
44 |
45 | def filter_cvt_nodes(self, seed_ary, CVT_TYPE=3):
46 | seed_type = self.ent_type[seed_ary]
47 | return seed_ary[seed_type == CVT_TYPE]
48 |
49 | def get_facts_1hop(self, seed_set, id2rel, rel2id, max_triples_per_relation, gold_relations):
50 | filtered_triples = defaultdict(list)
51 | filtered_tails = []
52 |
53 | triples = self._fetch_forward_triple(seed_set)
54 | if len(triples) == 0:
55 | print("No triples")
56 | return filtered_triples, filtered_tails
57 |
58 | cur_heads = set()
59 | cur_heads.update(seed_set)
60 |
61 | candidate_rels = set(triples[:, 1].tolist())
62 | candidate_rels_str = [id2rel[rel] for rel in candidate_rels]
63 |
64 | if len(candidate_rels_str) == 0:
65 | print("No candidate_rels_str")
66 | return filtered_triples, filtered_tails
67 |
68 | if gold_relations is None:
69 | filtered_rels_str = set(candidate_rels_str)
70 | else:
71 | filtered_rels_str = set(candidate_rels_str) & set(gold_relations)
72 | assert len(filtered_rels_str) == len(set(gold_relations))
73 |
74 | for rel_str in filtered_rels_str:
75 | rel_indices = (triples[:, 1] == rel2id[rel_str])
76 | triples_for_rel = triples[rel_indices]
77 | # if len(triples_for_rel) > max_triples_per_relation:
78 | # continue
79 | for triple in triples_for_rel.tolist():
80 | if triple[2] not in self.last_tails:
81 | filtered_triples[0].append(triple)
82 | filtered_tails.append(triple[2])
83 |
84 | cvt_tails = [tail for tail in filtered_tails if self.ent_type[tail] == 3]
85 | filtered_tails = list(set(filtered_tails) - set(cvt_tails))
86 | if len(cvt_tails) > 0:
87 | triples = self._fetch_forward_triple(cvt_tails)
88 | if len(triples) > 0:
89 | cur_heads.update(cvt_tails)
90 |
91 | cur_invalid_rels = set()
92 | candidate_rels = set(triples[:, 1].tolist())
93 | candidate_rels_str = [id2rel[rel] for rel in candidate_rels if rel not in cur_invalid_rels]
94 | filtered_rels_str = candidate_rels_str
95 | for rel_str in filtered_rels_str:
96 | rel_indices = (triples[:, 1] == rel2id[rel_str])
97 | triples_for_rel = triples[rel_indices].tolist()
98 | for triple in triples_for_rel:
99 | if triple[2] not in seed_set:
100 | if self.ent_type[triple[2]] != 3:
101 | filtered_triples[1].append(triple)
102 | filtered_tails.append(triple[2])
103 |
104 | self.last_tails = deepcopy(cur_heads)
105 | filtered_tails = list(set(filtered_tails))
106 | return filtered_triples, filtered_tails
107 |
108 | def get_rels_1hop(self, seed_set, id2rel):
109 | triples = self._fetch_forward_triple(seed_set)
110 | if len(triples) == 0:
111 | return []
112 |
113 | cur_heads = set()
114 | cur_heads.update(seed_set)
115 |
116 | cur_invalid_rels = set()
117 | for tail in self.last_tails:
118 | invalid_triples_indices = (triples[:, 2] == tail)
119 | invalid_triples = triples[invalid_triples_indices]
120 | invalid_rels = set(invalid_triples[:, 1])
121 | cur_invalid_rels.update(invalid_rels)
122 |
123 | candidate_rels = set(triples[:, 1].tolist())
124 | candidate_rels_str = [id2rel[rel] for rel in candidate_rels if rel not in cur_invalid_rels]
125 |
126 | return candidate_rels_str
127 |
128 | def get_filtered_rels(self, question, cur_relations, tokenizer, model, topk, filter_score):
129 | scored_rel_list, filtered_rel_scored_list = self.score_relations(question, cur_relations, tokenizer, model, filter_score)
130 | # 过滤关系和得分
131 | ordered_rels_scored = sorted(filtered_rel_scored_list, key=lambda x: x[1], reverse=True)
132 | # 过滤方法为topk和最少路径filter_method == "topk":
133 | reserved_rels = ordered_rels_scored[:topk]
134 | reserved_rels = [rel_score[0] for rel_score in reserved_rels]
135 | return reserved_rels
136 |
137 |
138 |
--------------------------------------------------------------------------------
/KnowledgeBase/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/KnowledgeBase/__init__.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StructGPT: A General Framework for Large Language Model to Reason on Structured Data
2 |
3 | This repo provides the source code & data of our paper: [StructGPT: A General Framework for Large Language Model to Reason on Structured Data](https://arxiv.org/pdf/2305.09645.pdf) (Arxiv 2023).
4 |
5 | ```
6 | @InProceedings{Jiang-StructGPT-2022,
7 | author = {Jinhao Jiang and Kun Zhou and Zican Dong and Keming Ye and Wayne Xin Zhao and Ji-Rong Wen},
8 | title = {StructGPT: A general framework for Large Language Model to Reason on Structured Data},
9 | year = {2023},
10 | journal={arXiv preprint arXiv:2305.09645},
11 | url={https://arxiv.org/pdf/2305.09645}
12 | }
13 | ```
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Usage
22 | ### 0. Requirements
23 | You only need to install the python library and openai for querying OpenAI model API.
24 |
25 | ### 1. Prepare Dataset
26 | We strongly suggest that download the processed datasets from [here](https://drive.google.com/drive/folders/11_2pqU_MhEtmxpp3zfK_8EJ1bbQzsnfJ?usp=sharing) and then directly use them.
27 | Apart from downloading from their original website, we use the processed datasets from [UnifiedSKG](https://github.com/HKUNLP/UnifiedSKG).
28 | After downloading our processed data, you can unzip them and put them in the */data* directory.
29 |
30 | ### 3. Experiment
31 | We have organized the running and evaluation scripts for each dataset under the */script* directory.
32 |
33 | #### 3.1 Evaluation on Text-to-SQL
34 | It is difficult to control the **randomness** of ChatGPT, so the reproduced results maybe a little different to the reported results.
35 |
36 | For **Spider** dataset, you can directly use the following command to start running and output the evaluation results:
37 | ```bash
38 | bash ./scripts/run_spider_wo_icl_v1.sh
39 | ```
40 |
41 |
42 |
43 |
44 |
45 | Similarly, you can run the corresponding script for **Spider-SYN** and **Spider-Realistic** to get the evaluation results.
46 |
47 | **Spider_Realistic**
48 |
49 |
50 |
51 |
52 | **Spider-SYN**
53 |
54 |
55 |
56 | We save all the prediction file in *outputs/* directory.
57 |
58 | #### 3.2 Evaluation on TableQA
59 | It is difficult to control the **randomness** of ChatGPT, so the reproduced results maybe a little different to the reported results.
60 |
61 | For **TabFact** dataset, you can directly use the following command to start running and output the evaluation results:
62 | ```bash
63 | bash ./scripts/run_tabfact_wo_icl_v1.sh
64 | ```
65 |
66 |
67 |
68 |
69 |
70 | Similarly, you can run the corresponding script for **WTQ** and **WikiSQL** to get the evaluation results.
71 |
72 | **WTQ**
73 |
74 |
75 |
76 |
77 | **WikiSQL**
78 |
79 |
80 |
81 | We save all the prediction file in *outputs/* directory.
82 |
83 | #### 3.3 Evaluation on KGQA
84 | It is difficult to control the **randomness** of ChatGPT, so the reproduced results maybe a little different to the reported results.
85 |
86 | For **WebQSP** dataset, you can directly use the following command to start running and output the evaluation results:
87 | ```bash
88 | bash ./scripts/run_webqsp_wo_icl_v1.sh
89 | ```
90 |
91 |
92 |
93 |
94 |
95 | Similarly, you can run the corresponding script for **MetaQA (1hop,2hop,3hop)** to get the evaluation results.
96 |
97 | We save all the prediction file in *outputs/* directory.
98 |
99 | ## Plan
100 | Thanks for your attention.
101 | A version with better performance is on the way.
102 | Please continue to follow us!
103 |
--------------------------------------------------------------------------------
/asset/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/model.png
--------------------------------------------------------------------------------
/asset/spider_rea_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/spider_rea_wo_icl_v1.png
--------------------------------------------------------------------------------
/asset/spider_syn_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/spider_syn_wo_icl_v1.png
--------------------------------------------------------------------------------
/asset/spider_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/spider_wo_icl_v1.png
--------------------------------------------------------------------------------
/asset/tabfact_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/tabfact_wo_icl_v1.png
--------------------------------------------------------------------------------
/asset/webqsp_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/webqsp_wo_icl_v1.png
--------------------------------------------------------------------------------
/asset/wikisql_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/wikisql_wo_icl_v1.png
--------------------------------------------------------------------------------
/asset/wtq_wo_icl_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JBoRu/StructGPT/46c4c2c3998dca87cb4807f9eadb03afacce399f/asset/wtq_wo_icl_v1.png
--------------------------------------------------------------------------------
/evaluate_for_spider.py:
--------------------------------------------------------------------------------
1 |
2 | ################################
3 | # val: number(float)/string(str)/sql(dict)
4 | # col_unit: (agg_id, col_id, isDistinct(bool))
5 | # val_unit: (unit_op, col_unit1, col_unit2)
6 | # table_unit: (table_type, col_unit/sql)
7 | # cond_unit: (not_op, op_id, val_unit, val1, val2)
8 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
9 | # sql {
10 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
11 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
12 | # 'where': condition
13 | # 'groupBy': [col_unit1, col_unit2, ...]
14 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
15 | # 'having': condition
16 | # 'limit': None/limit value
17 | # 'intersect': None/sql
18 | # 'except': None/sql
19 | # 'union': None/sql
20 | # }
21 | ################################
22 |
23 | from __future__ import print_function
24 | import os, sys
25 | import json
26 | import sqlite3
27 | import traceback
28 | import argparse
29 | from tqdm import tqdm
30 | from itertools import product
31 | from collections import defaultdict
32 | import random
33 |
34 | from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
35 |
36 | # Flag to disable value evaluation
37 | DISABLE_VALUE = True
38 | # Flag to disable distinct in select evaluation
39 | DISABLE_DISTINCT = True
40 |
41 |
42 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
43 | JOIN_KEYWORDS = ('join', 'on', 'as')
44 |
45 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
46 | UNIT_OPS = ('none', '-', '+', "*", '/')
47 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
48 | TABLE_TYPE = {
49 | 'sql': "sql",
50 | 'table_unit': "table_unit",
51 | }
52 |
53 | COND_OPS = ('and', 'or')
54 | SQL_OPS = ('intersect', 'union', 'except')
55 | ORDER_OPS = ('desc', 'asc')
56 |
57 |
58 | HARDNESS = {
59 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
60 | "component2": ('except', 'union', 'intersect')
61 | }
62 |
63 |
64 | def condition_has_or(conds):
65 | return 'or' in conds[1::2]
66 |
67 |
68 | def condition_has_like(conds):
69 | return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
70 |
71 |
72 | def condition_has_sql(conds):
73 | for cond_unit in conds[::2]:
74 | val1, val2 = cond_unit[3], cond_unit[4]
75 | if val1 is not None and type(val1) is dict:
76 | return True
77 | if val2 is not None and type(val2) is dict:
78 | return True
79 | return False
80 |
81 |
82 | def val_has_op(val_unit):
83 | return val_unit[0] != UNIT_OPS.index('none')
84 |
85 |
86 | def has_agg(unit):
87 | return unit[0] != AGG_OPS.index('none')
88 |
89 |
90 | def accuracy(count, total):
91 | if count == total:
92 | return 1
93 | return 0
94 |
95 |
96 | def recall(count, total):
97 | if count == total:
98 | return 1
99 | return 0
100 |
101 |
102 | def F1(acc, rec):
103 | if (acc + rec) == 0:
104 | return 0
105 | return (2. * acc * rec) / (acc + rec)
106 |
107 |
108 | def get_scores(count, pred_total, label_total):
109 | if pred_total != label_total:
110 | return 0,0,0
111 | elif count == pred_total:
112 | return 1,1,1
113 | return 0,0,0
114 |
115 |
116 | def eval_sel(pred, label):
117 | pred_sel = pred['select'][1]
118 | label_sel = label['select'][1]
119 | label_wo_agg = [unit[1] for unit in label_sel]
120 | pred_total = len(pred_sel)
121 | label_total = len(label_sel)
122 | cnt = 0
123 | cnt_wo_agg = 0
124 |
125 | for unit in pred_sel:
126 | if unit in label_sel:
127 | cnt += 1
128 | label_sel.remove(unit)
129 | if unit[1] in label_wo_agg:
130 | cnt_wo_agg += 1
131 | label_wo_agg.remove(unit[1])
132 |
133 | return label_total, pred_total, cnt, cnt_wo_agg
134 |
135 |
136 | def eval_where(pred, label):
137 | pred_conds = [unit for unit in pred['where'][::2]]
138 | label_conds = [unit for unit in label['where'][::2]]
139 | label_wo_agg = [unit[2] for unit in label_conds]
140 | pred_total = len(pred_conds)
141 | label_total = len(label_conds)
142 | cnt = 0
143 | cnt_wo_agg = 0
144 |
145 | for unit in pred_conds:
146 | if unit in label_conds:
147 | cnt += 1
148 | label_conds.remove(unit)
149 | if unit[2] in label_wo_agg:
150 | cnt_wo_agg += 1
151 | label_wo_agg.remove(unit[2])
152 |
153 | return label_total, pred_total, cnt, cnt_wo_agg
154 |
155 |
156 | def eval_group(pred, label):
157 | pred_cols = [unit[1] for unit in pred['groupBy']]
158 | label_cols = [unit[1] for unit in label['groupBy']]
159 | pred_total = len(pred_cols)
160 | label_total = len(label_cols)
161 | cnt = 0
162 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
163 | label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
164 | for col in pred_cols:
165 | if col in label_cols:
166 | cnt += 1
167 | label_cols.remove(col)
168 | return label_total, pred_total, cnt
169 |
170 |
171 | def eval_having(pred, label):
172 | pred_total = label_total = cnt = 0
173 | if len(pred['groupBy']) > 0:
174 | pred_total = 1
175 | if len(label['groupBy']) > 0:
176 | label_total = 1
177 |
178 | pred_cols = [unit[1] for unit in pred['groupBy']]
179 | label_cols = [unit[1] for unit in label['groupBy']]
180 | if pred_total == label_total == 1 \
181 | and pred_cols == label_cols \
182 | and pred['having'] == label['having']:
183 | cnt = 1
184 |
185 | return label_total, pred_total, cnt
186 |
187 |
188 | def eval_order(pred, label):
189 | pred_total = label_total = cnt = 0
190 | if len(pred['orderBy']) > 0:
191 | pred_total = 1
192 | if len(label['orderBy']) > 0:
193 | label_total = 1
194 | if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
195 | ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
196 | cnt = 1
197 | return label_total, pred_total, cnt
198 |
199 |
200 | def eval_and_or(pred, label):
201 | pred_ao = pred['where'][1::2]
202 | label_ao = label['where'][1::2]
203 | pred_ao = set(pred_ao)
204 | label_ao = set(label_ao)
205 |
206 | if pred_ao == label_ao:
207 | return 1,1,1
208 | return len(pred_ao),len(label_ao),0
209 |
210 |
211 | def get_nestedSQL(sql):
212 | nested = []
213 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
214 | if type(cond_unit[3]) is dict:
215 | nested.append(cond_unit[3])
216 | if type(cond_unit[4]) is dict:
217 | nested.append(cond_unit[4])
218 | if sql['intersect'] is not None:
219 | nested.append(sql['intersect'])
220 | if sql['except'] is not None:
221 | nested.append(sql['except'])
222 | if sql['union'] is not None:
223 | nested.append(sql['union'])
224 | return nested
225 |
226 |
227 | def eval_nested(pred, label):
228 | label_total = 0
229 | pred_total = 0
230 | cnt = 0
231 | if pred is not None:
232 | pred_total += 1
233 | if label is not None:
234 | label_total += 1
235 | if pred is not None and label is not None:
236 | cnt += Evaluator().eval_exact_match(pred, label)
237 | return label_total, pred_total, cnt
238 |
239 |
240 | def eval_IUEN(pred, label):
241 | lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
242 | lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
243 | lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
244 | label_total = lt1 + lt2 + lt3
245 | pred_total = pt1 + pt2 + pt3
246 | cnt = cnt1 + cnt2 + cnt3
247 | return label_total, pred_total, cnt
248 |
249 |
250 | def get_keywords(sql):
251 | res = set()
252 | if len(sql['where']) > 0:
253 | res.add('where')
254 | if len(sql['groupBy']) > 0:
255 | res.add('group')
256 | if len(sql['having']) > 0:
257 | res.add('having')
258 | if len(sql['orderBy']) > 0:
259 | res.add(sql['orderBy'][0])
260 | res.add('order')
261 | if sql['limit'] is not None:
262 | res.add('limit')
263 | if sql['except'] is not None:
264 | res.add('except')
265 | if sql['union'] is not None:
266 | res.add('union')
267 | if sql['intersect'] is not None:
268 | res.add('intersect')
269 |
270 | # or keyword
271 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
272 | if len([token for token in ao if token == 'or']) > 0:
273 | res.add('or')
274 |
275 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
276 | # not keyword
277 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
278 | res.add('not')
279 |
280 | # in keyword
281 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
282 | res.add('in')
283 |
284 | # like keyword
285 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
286 | res.add('like')
287 |
288 | return res
289 |
290 |
291 | def eval_keywords(pred, label):
292 | pred_keywords = get_keywords(pred)
293 | label_keywords = get_keywords(label)
294 | pred_total = len(pred_keywords)
295 | label_total = len(label_keywords)
296 | cnt = 0
297 |
298 | for k in pred_keywords:
299 | if k in label_keywords:
300 | cnt += 1
301 | return label_total, pred_total, cnt
302 |
303 |
304 | def count_agg(units):
305 | return len([unit for unit in units if has_agg(unit)])
306 |
307 |
308 | def count_component1(sql):
309 | count = 0
310 | if len(sql['where']) > 0:
311 | count += 1
312 | if len(sql['groupBy']) > 0:
313 | count += 1
314 | if len(sql['orderBy']) > 0:
315 | count += 1
316 | if sql['limit'] is not None:
317 | count += 1
318 | if len(sql['from']['table_units']) > 0: # JOIN
319 | count += len(sql['from']['table_units']) - 1
320 |
321 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
322 | count += len([token for token in ao if token == 'or'])
323 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
324 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
325 |
326 | return count
327 |
328 |
329 | def count_component2(sql):
330 | nested = get_nestedSQL(sql)
331 | return len(nested)
332 |
333 |
334 | def count_others(sql):
335 | count = 0
336 | # number of aggregation
337 | agg_count = count_agg(sql['select'][1])
338 | agg_count += count_agg(sql['where'][::2])
339 | agg_count += count_agg(sql['groupBy'])
340 | if len(sql['orderBy']) > 0:
341 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
342 | [unit[2] for unit in sql['orderBy'][1] if unit[2]])
343 | agg_count += count_agg(sql['having'])
344 | if agg_count > 1:
345 | count += 1
346 |
347 | # number of select columns
348 | if len(sql['select'][1]) > 1:
349 | count += 1
350 |
351 | # number of where conditions
352 | if len(sql['where']) > 1:
353 | count += 1
354 |
355 | # number of group by clauses
356 | if len(sql['groupBy']) > 1:
357 | count += 1
358 |
359 | return count
360 |
361 |
362 | class Evaluator:
363 | """A simple evaluator"""
364 | def __init__(self):
365 | self.partial_scores = None
366 |
367 | def eval_hardness(self, sql):
368 | count_comp1_ = count_component1(sql)
369 | count_comp2_ = count_component2(sql)
370 | count_others_ = count_others(sql)
371 |
372 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
373 | return "easy"
374 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
375 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
376 | return "medium"
377 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
378 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
379 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
380 | return "hard"
381 | else:
382 | return "extra"
383 |
384 | def eval_exact_match(self, pred, label):
385 | partial_scores = self.eval_partial_match(pred, label)
386 | self.partial_scores = partial_scores
387 |
388 | for _, score in partial_scores.items():
389 | if score['f1'] != 1:
390 | return 0
391 | if len(label['from']['table_units']) > 0:
392 | label_tables = sorted(label['from']['table_units'])
393 | pred_tables = sorted(pred['from']['table_units'])
394 | return label_tables == pred_tables
395 | return 1
396 |
397 | def eval_partial_match(self, pred, label):
398 | res = {}
399 |
400 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
401 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
402 | res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
403 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
404 | res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
405 |
406 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
407 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
408 | res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
409 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
410 | res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
411 |
412 | label_total, pred_total, cnt = eval_group(pred, label)
413 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
414 | res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
415 |
416 | label_total, pred_total, cnt = eval_having(pred, label)
417 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
418 | res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
419 |
420 | label_total, pred_total, cnt = eval_order(pred, label)
421 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
422 | res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
423 |
424 | label_total, pred_total, cnt = eval_and_or(pred, label)
425 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
426 | res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
427 |
428 | label_total, pred_total, cnt = eval_IUEN(pred, label)
429 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
430 | res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
431 |
432 | label_total, pred_total, cnt = eval_keywords(pred, label)
433 | acc, rec, f1 = get_scores(cnt, pred_total, label_total)
434 | res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
435 |
436 | return res
437 |
438 |
439 | def isValidSQL(sql, db):
440 | conn = sqlite3.connect(db)
441 | cursor = conn.cursor()
442 | try:
443 | cursor.execute(sql)
444 | except:
445 | return False
446 | return True
447 |
448 |
449 | def print_scores(scores, etype):
450 | levels = ['easy', 'medium', 'hard', 'extra', 'all']
451 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
452 | 'group', 'order', 'and/or', 'IUEN', 'keywords']
453 |
454 | print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
455 | counts = [scores[level]['count'] for level in levels]
456 | print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
457 |
458 | if etype in ["all", "exec"]:
459 | print('===================== EXECUTION ACCURACY =====================')
460 | this_scores = [scores[level]['exec'] for level in levels]
461 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
462 |
463 | if etype in ["all", "match"]:
464 | print('\n====================== EXACT MATCHING ACCURACY =====================')
465 | exact_scores = [scores[level]['exact'] for level in levels]
466 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
467 | print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
468 | for type_ in partial_types:
469 | this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
470 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
471 |
472 | print('---------------------- PARTIAL MATCHING RECALL ----------------------')
473 | for type_ in partial_types:
474 | this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
475 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
476 |
477 | print('---------------------- PARTIAL MATCHING F1 --------------------------')
478 | for type_ in partial_types:
479 | this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
480 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
481 |
482 |
483 | def evaluate(path, db_dir, etype, kmaps, db_dir2):
484 | with open(path, "r") as f:
485 | all_lines = f.readlines()
486 | all_data = [json.loads(line) for line in all_lines]
487 | glist = []
488 | plist = []
489 | for item in all_data:
490 | glist.append((item['query'] + '\t' + item['db_id']).strip().split('\t'))
491 | result_query = 'SELECT ' + item['Prediction'].replace('\n', ' ')
492 | result_query = result_query.replace('SELECT SELECT', 'SELECT')
493 | plist.append(result_query.strip())
494 |
495 | evaluator = Evaluator()
496 |
497 | levels = ['easy', 'medium', 'hard', 'extra', 'all']
498 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
499 | 'group', 'order', 'and/or', 'IUEN', 'keywords']
500 | entries = []
501 | scores = {}
502 |
503 | for level in levels:
504 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
505 | scores[level]['exec'] = 0
506 | for type_ in partial_types:
507 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
508 |
509 | eval_err_num = 0
510 | index = 0
511 | for p, g in zip(plist, glist):
512 | # print(index)
513 | index += 1
514 | # p_str = p[0]
515 | p_str = p
516 | p_ori = p
517 | try:
518 | g_str, db = g
519 | except:
520 | import ipdb; ipdb.set_trace()
521 | db_name = db
522 | db = os.path.join(db_dir, db_name, db_name + ".sqlite")
523 | if db_dir2 != "":
524 | db2 = os.path.join(db_dir2, db_name, db_name + ".sqlite")
525 | schema = Schema(get_schema(db))
526 | try:
527 | g_sql = get_sql(schema, g_str)
528 | except:
529 | continue
530 |
531 | hardness = evaluator.eval_hardness(g_sql)
532 | scores[hardness]['count'] += 1
533 | scores['all']['count'] += 1
534 |
535 | try:
536 | p_sql = get_sql(schema, p_str)
537 | except:
538 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
539 | p_sql = {
540 | "except": None,
541 | "from": {
542 | "conds": [],
543 | "table_units": []
544 | },
545 | "groupBy": [],
546 | "having": [],
547 | "intersect": None,
548 | "limit": None,
549 | "orderBy": [],
550 | "select": [
551 | False,
552 | []
553 | ],
554 | "union": None,
555 | "where": []
556 | }
557 | # import ipdb; ipdb.set_trace()
558 | eval_err_num += 1
559 | # print("eval_err_num:{}".format(eval_err_num))
560 |
561 | # rebuild sql for value evaluation
562 | kmap = kmaps[db_name]
563 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
564 | g_sql = rebuild_sql_val(g_sql)
565 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
566 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
567 | p_sql = rebuild_sql_val(p_sql)
568 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
569 |
570 | if etype in ["all", "exec"]:
571 | if db_dir2 == "":
572 | db2 = db
573 | results = eval_exec_match(db, db2, p_ori, g_str, p_sql, g_sql)
574 | if results is False:
575 | continue
576 | else:
577 | exec_score, (pred, gold) = results
578 | if exec_score == 0:
579 | print("{} pred: {}".format(hardness, p_str))
580 | print("{} gold: {}".format(hardness, g_str))
581 |
582 | if exec_score:
583 | scores[hardness]['exec'] += 1.0
584 | scores['all']['exec'] += 1.0
585 |
586 | if etype in ["all", "match"]:
587 | exact_score = evaluator.eval_exact_match(p_sql, g_sql)
588 | partial_scores = evaluator.partial_scores
589 | scores[hardness]['exact'] += exact_score
590 | scores['all']['exact'] += exact_score
591 | for type_ in partial_types:
592 | if partial_scores[type_]['pred_total'] > 0:
593 | scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
594 | scores[hardness]['partial'][type_]['acc_count'] += 1
595 | if partial_scores[type_]['label_total'] > 0:
596 | scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
597 | scores[hardness]['partial'][type_]['rec_count'] += 1
598 | scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
599 | if partial_scores[type_]['pred_total'] > 0:
600 | scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
601 | scores['all']['partial'][type_]['acc_count'] += 1
602 | if partial_scores[type_]['label_total'] > 0:
603 | scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
604 | scores['all']['partial'][type_]['rec_count'] += 1
605 | scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
606 |
607 | entries.append({
608 | 'predictSQL': p_str,
609 | 'goldSQL': g_str,
610 | 'hardness': hardness,
611 | 'exact': exact_score,
612 | 'partial': partial_scores
613 | })
614 |
615 | for level in levels:
616 | if scores[level]['count'] == 0:
617 | continue
618 | if etype in ["all", "exec"]:
619 | scores[level]['exec'] /= scores[level]['count']
620 |
621 | if etype in ["all", "match"]:
622 | scores[level]['exact'] /= scores[level]['count']
623 | for type_ in partial_types:
624 | if scores[level]['partial'][type_]['acc_count'] == 0:
625 | scores[level]['partial'][type_]['acc'] = 0
626 | else:
627 | scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
628 | scores[level]['partial'][type_]['acc_count'] * 1.0
629 | if scores[level]['partial'][type_]['rec_count'] == 0:
630 | scores[level]['partial'][type_]['rec'] = 0
631 | else:
632 | scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
633 | scores[level]['partial'][type_]['rec_count'] * 1.0
634 | if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
635 | scores[level]['partial'][type_]['f1'] = 1
636 | else:
637 | scores[level]['partial'][type_]['f1'] = \
638 | 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
639 | scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
640 |
641 | print_scores(scores, etype)
642 |
643 | def get_constraint_permutation(tab1_sets_by_columns, result2):
644 | num_cols = len(result2[0])
645 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
646 | if num_cols <= 3:
647 | return product(*perm_constraints)
648 |
649 | # we sample 20 rows and constrain the space of permutations
650 | for _ in range(20):
651 | random_tab2_row = random.choice(result2)
652 |
653 | for tab1_col in range(num_cols):
654 | for tab2_col in set(perm_constraints[tab1_col]):
655 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
656 | perm_constraints[tab1_col].remove(tab2_col)
657 | return product(*perm_constraints)
658 |
659 | def unorder_row(row):
660 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
661 |
662 |
663 | def quick_rej(result1, result2, order_matters):
664 | s1 = [unorder_row(row) for row in result1]
665 | s2 = [unorder_row(row) for row in result2]
666 | if order_matters:
667 | return s1 == s2
668 | else:
669 | return set(s1) == set(s2)
670 |
671 | def multiset_eq(l1, l2):
672 | if len(l1) != len(l2):
673 | return False
674 | d = defaultdict(int)
675 | for e in l1:
676 | d[e] = d[e] + 1
677 | for e in l2:
678 | d[e] = d[e] - 1
679 | if d[e] < 0:
680 | return False
681 | return True
682 |
683 | def permute_tuple(element, perm):
684 | assert len(element) == len(perm)
685 | return tuple([element[i] for i in perm])
686 |
687 |
688 | def result_eq(result1, result2, order_matters):
689 | if len(result1) == 0 and len(result2) == 0:
690 | return True
691 |
692 | # if length is not the same, then they are definitely different bag of rows
693 | if len(result1) != len(result2):
694 | return False
695 |
696 | num_cols = len(result1[0])
697 |
698 | # if the results do not have the same number of columns, they are different
699 | if len(result2[0]) != num_cols:
700 | return False
701 |
702 | # unorder each row and compare whether the denotation is the same
703 | # this can already find most pair of denotations that are different
704 | if not quick_rej(result1, result2, order_matters):
705 | return False
706 |
707 | # the rest of the problem is in fact more complicated than one might think
708 | # we want to find a permutation of column order and a permutation of row order,
709 | # s.t. result_1 is the same as result_2
710 | # we return true if we can find such column & row permutations
711 | # and false if we cannot
712 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
713 |
714 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2
715 | # we decrease the size of the column permutation space by the function get_constraint_permutation
716 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent
717 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
718 | if len(perm) != len(set(perm)):
719 | continue
720 | if num_cols == 1:
721 | result2_perm = result2
722 | else:
723 | result2_perm = [permute_tuple(element, perm) for element in result2]
724 | if order_matters:
725 | if result1 == result2_perm:
726 | return True
727 | else:
728 | # in fact the first condition must hold if the second condition holds
729 | # but the first is way more efficient implementation-wise
730 | # and we use it to quickly reject impossible candidates
731 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
732 | return True
733 | return False
734 |
735 |
736 | def eval_exec_match(db, db2, p_str, g_str, pred, gold):
737 | """
738 | return 1 if the values between prediction and gold are matching
739 | in the corresponding index. Currently not support multiple col_unit(pairs).
740 | """
741 | conn = sqlite3.connect(db2)
742 | cursor = conn.cursor()
743 | try:
744 | cursor.execute(p_str)
745 | p_res = cursor.fetchall()
746 | except Exception as e:
747 | # import ipdb; ipdb.set_trace()
748 | print(e)
749 | print(p_str)
750 | return False
751 |
752 | conn = sqlite3.connect(db)
753 | cursor = conn.cursor()
754 | try:
755 | cursor.execute(g_str)
756 | except:
757 | # import ipdb; ipdb.set_trace()
758 | return False
759 | q_res = cursor.fetchall()
760 |
761 | orders_matter = 'order by' in g_str.lower()
762 |
763 | return result_eq(p_res, q_res, order_matters=orders_matter), (p_res, q_res)
764 |
765 |
766 | # Rebuild SQL functions for value evaluation
767 | def rebuild_cond_unit_val(cond_unit):
768 | if cond_unit is None or not DISABLE_VALUE:
769 | return cond_unit
770 |
771 | not_op, op_id, val_unit, val1, val2 = cond_unit
772 | if type(val1) is not dict:
773 | val1 = None
774 | else:
775 | val1 = rebuild_sql_val(val1)
776 | if type(val2) is not dict:
777 | val2 = None
778 | else:
779 | val2 = rebuild_sql_val(val2)
780 | return not_op, op_id, val_unit, val1, val2
781 |
782 |
783 | def rebuild_condition_val(condition):
784 | if condition is None or not DISABLE_VALUE:
785 | return condition
786 |
787 | res = []
788 | for idx, it in enumerate(condition):
789 | if idx % 2 == 0:
790 | res.append(rebuild_cond_unit_val(it))
791 | else:
792 | res.append(it)
793 | return res
794 |
795 |
796 | def rebuild_sql_val(sql):
797 | if sql is None or not DISABLE_VALUE:
798 | return sql
799 |
800 | sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
801 | sql['having'] = rebuild_condition_val(sql['having'])
802 | sql['where'] = rebuild_condition_val(sql['where'])
803 | sql['intersect'] = rebuild_sql_val(sql['intersect'])
804 | sql['except'] = rebuild_sql_val(sql['except'])
805 | sql['union'] = rebuild_sql_val(sql['union'])
806 |
807 | return sql
808 |
809 |
810 | # Rebuild SQL functions for foreign key evaluation
811 | def build_valid_col_units(table_units, schema):
812 | col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
813 | prefixs = [col_id[:-2] for col_id in col_ids]
814 | valid_col_units= []
815 | for value in schema.idMap.values():
816 | if '.' in value and value[:value.index('.')] in prefixs:
817 | valid_col_units.append(value)
818 | return valid_col_units
819 |
820 |
821 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
822 | if col_unit is None:
823 | return col_unit
824 |
825 | agg_id, col_id, distinct = col_unit
826 | if col_id in kmap and col_id in valid_col_units:
827 | col_id = kmap[col_id]
828 | if DISABLE_DISTINCT:
829 | distinct = None
830 | return agg_id, col_id, distinct
831 |
832 |
833 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
834 | if val_unit is None:
835 | return val_unit
836 |
837 | unit_op, col_unit1, col_unit2 = val_unit
838 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
839 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
840 | return unit_op, col_unit1, col_unit2
841 |
842 |
843 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
844 | if table_unit is None:
845 | return table_unit
846 |
847 | table_type, col_unit_or_sql = table_unit
848 | if isinstance(col_unit_or_sql, tuple):
849 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
850 | return table_type, col_unit_or_sql
851 |
852 |
853 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
854 | if cond_unit is None:
855 | return cond_unit
856 |
857 | not_op, op_id, val_unit, val1, val2 = cond_unit
858 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
859 | return not_op, op_id, val_unit, val1, val2
860 |
861 |
862 | def rebuild_condition_col(valid_col_units, condition, kmap):
863 | for idx in range(len(condition)):
864 | if idx % 2 == 0:
865 | condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
866 | return condition
867 |
868 |
869 | def rebuild_select_col(valid_col_units, sel, kmap):
870 | if sel is None:
871 | return sel
872 | distinct, _list = sel
873 | new_list = []
874 | for it in _list:
875 | agg_id, val_unit = it
876 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
877 | if DISABLE_DISTINCT:
878 | distinct = None
879 | return distinct, new_list
880 |
881 |
882 | def rebuild_from_col(valid_col_units, from_, kmap):
883 | if from_ is None:
884 | return from_
885 |
886 | from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
887 | from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
888 | return from_
889 |
890 |
891 | def rebuild_group_by_col(valid_col_units, group_by, kmap):
892 | if group_by is None:
893 | return group_by
894 |
895 | return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
896 |
897 |
898 | def rebuild_order_by_col(valid_col_units, order_by, kmap):
899 | if order_by is None or len(order_by) == 0:
900 | return order_by
901 |
902 | direction, val_units = order_by
903 | new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
904 | return direction, new_val_units
905 |
906 |
907 | def rebuild_sql_col(valid_col_units, sql, kmap):
908 | if sql is None:
909 | return sql
910 |
911 | sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
912 | sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
913 | sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
914 | sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
915 | sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
916 | sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
917 | sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
918 | sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
919 | sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
920 |
921 | return sql
922 |
923 |
924 | def build_foreign_key_map(entry):
925 | try:
926 | cols_orig = entry["column_names_original"]
927 | except:
928 | import ipdb; ipdb.set_trace()
929 | tables_orig = entry["table_names_original"]
930 |
931 | # rebuild cols corresponding to idmap in Schema
932 | cols = []
933 | for col_orig in cols_orig:
934 | if col_orig[0] >= 0:
935 | t = tables_orig[col_orig[0]]
936 | c = col_orig[1]
937 | cols.append("__" + t.lower() + "." + c.lower() + "__")
938 | else:
939 | cols.append("__all__")
940 |
941 | def keyset_in_list(k1, k2, k_list):
942 | for k_set in k_list:
943 | if k1 in k_set or k2 in k_set:
944 | return k_set
945 | new_k_set = set()
946 | k_list.append(new_k_set)
947 | return new_k_set
948 |
949 | foreign_key_list = []
950 | foreign_keys = entry["foreign_keys"]
951 | for fkey in foreign_keys:
952 | key1, key2 = fkey
953 | key_set = keyset_in_list(key1, key2, foreign_key_list)
954 | key_set.add(key1)
955 | key_set.add(key2)
956 |
957 | foreign_key_map = {}
958 | for key_set in foreign_key_list:
959 | sorted_list = sorted(list(key_set))
960 | midx = sorted_list[0]
961 | for idx in sorted_list:
962 | foreign_key_map[cols[idx]] = cols[midx]
963 |
964 | return foreign_key_map
965 |
966 |
967 | def build_foreign_key_map_from_json(table):
968 | with open(table) as f:
969 | data = json.load(f)
970 | tables = {}
971 | for entry in data:
972 | tables[entry['db_id']] = build_foreign_key_map(entry)
973 | return tables
974 |
975 |
976 | if __name__ == "__main__":
977 | parser = argparse.ArgumentParser()
978 | parser.add_argument('--path', dest='path', type=str,default="outputs/spider/output.jsonl")
979 | parser.add_argument('--db', dest='db', type=str,default="data/spider/database")
980 | parser.add_argument('--db2', dest='db2', type=str, default="")
981 | parser.add_argument('--table', dest='table', type=str, default='data/spider/tables.json')
982 | parser.add_argument('--etype', dest='etype', type=str, default="exec")
983 | args = parser.parse_args()
984 |
985 | path=args.path
986 | db_dir = args.db
987 | db_dir2 = args.db2
988 | table = args.table
989 | etype = args.etype
990 |
991 | assert etype in ["all", "exec", "match"], "Unknown evaluation method"
992 |
993 | kmaps = build_foreign_key_map_from_json(table)
994 |
995 | evaluate(path, db_dir, etype, kmaps, db_dir2)
996 |
--------------------------------------------------------------------------------
/evaluate_for_tabfact.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import argparse
4 |
5 |
6 | def evaluate(ori_path, inp_path, error_cases_output, write_flag):
7 | with open(ori_path, "r") as f:
8 | all_data = json.loads(f.read())
9 | print("Totally %d test data" % len(all_data))
10 |
11 | pred_data = []
12 | with open(inp_path, "r") as f:
13 | lines = f.readlines()
14 | datas = [json.loads(line) for line in lines]
15 | pred_data.extend(datas)
16 |
17 | all_pred_data = {pred['id']: pred for pred in pred_data}
18 | print("Totally %d prediction data" % len(pred_data)) # evaluate_is_right
19 | avg_acc = []
20 | bad_cases = []
21 | error_count = 0
22 | max_count = 0
23 | right_count = 0
24 | for data in all_data:
25 | if data["id"] in all_pred_data:
26 | data = all_pred_data[data["id"]]
27 | question = data['statement']
28 | pred = data['Prediction'].lower()
29 |
30 | if 'yes' in pred and 'no' in pred:
31 | pred = 'unknown'
32 | elif 'yes' in pred:
33 | pred = 'entailed'
34 | elif 'no' in pred:
35 | pred = 'refuted'
36 | else:
37 | pred = 'unknown'
38 |
39 | answers = data['seq_out'].lower()
40 | if pred.strip() == answers.strip():
41 | avg_acc.append(1)
42 | right_count += 1
43 |
44 | else:
45 | error_count += 1
46 | avg_acc.append(0)
47 | print("ID: %s Ques: %s" % (data["id"], question))
48 | print("Pred: ", pred)
49 | print("Ans: ", answers)
50 | print("------------------------------------------------------------------------")
51 | bad_cases.append(data)
52 |
53 | else:
54 | avg_acc.append(0)
55 | print("ID: %s can't be predicted" % (data["id"]))
56 | bad_cases.append(data)
57 | error_count += 1
58 | max_count += 1
59 |
60 | acc = np.mean(avg_acc)
61 | print("Acc: %.4f" % (acc))
62 | if write_flag:
63 | with open(error_cases_output, "w") as f:
64 | for bc in bad_cases:
65 | f.write(json.dumps(bc) + "\n")
66 | print("Totally %d bad cases need further solved." % len(bad_cases))
67 | print("Right count: %d, Error count: %d(Max len count: %d)" % (right_count, error_count, max_count))
68 |
69 |
70 | if __name__ == "__main__":
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument('--ori_path', type=str, default="./data/tabfact/tab_fact_test.json")
73 | parser.add_argument('--inp_path', type=str, default="./outputs/tabfact/tabfact_test_output.jsonl")
74 | parser.add_argument('--error_cases_output', type=str,
75 | default='./outputs/tabfact/bad_cases.jsonl')
76 | parser.add_argument('--write_flag', action="store_true")
77 | args = parser.parse_args()
78 | evaluate(args.ori_path, args.inp_path, args.error_cases_output, args.write_flag)
79 |
--------------------------------------------------------------------------------
/evaluate_for_webqsp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import openai
4 | import json
5 | import os
6 | import numpy as np
7 | from collections import defaultdict
8 | import pickle
9 |
10 | import json
11 | import numpy as np
12 |
13 | def evaluate(ori_path, pred_path, error_cases_output, write_flag):
14 | with open(ori_path, "r") as f:
15 | all_data = f.readlines()
16 | all_data = [json.loads(line) for line in all_data]
17 |
18 | with open(pred_path, "r", encoding="UTF-8") as f:
19 | all_lines = f.readlines()
20 | all_pred_data = []
21 | for idx, line in enumerate(all_lines):
22 | line = line.replace("\x00", "").strip("\n")
23 | all_pred_data.append(json.loads(line))
24 | all_pred_data = {pred['ID']: pred for pred in all_pred_data}
25 | print("Load %d prediction" % len(all_pred_data))
26 |
27 | max_len_count = len(all_data) - len(all_pred_data)
28 | print("Totally %d prediction / %d all data" % (len(all_pred_data), len(all_data)))
29 |
30 | avg_hits1 = []
31 | bad_cases = []
32 | right_cases_id = []
33 | bad_cases_id = []
34 | right_count = 0
35 | bad_count = 0
36 | need_cvt_count = 0
37 | for data in all_data:
38 | if data["ID"] in all_pred_data:
39 | data = all_pred_data[data["ID"]]
40 | question = data['Question']
41 | pred = data['Prediction'].lower()
42 | if "i'm sorry" in pred:
43 | pred = ''
44 |
45 | answers = data['Answers']
46 | aliases = data['Aliases']
47 | hit_flag = []
48 | recall_flag = []
49 | for ans in answers:
50 | ans = ans.lower()
51 | if ans in pred:
52 | hit_flag.append(1)
53 | recall_flag.append(1)
54 | else:
55 | hit_flag.append(0)
56 | recall_flag.append(0)
57 | for alia in aliases:
58 | alia = alia.lower()
59 | if alia in pred:
60 | hit_flag.append(1)
61 | else:
62 | hit_flag.append(0)
63 |
64 | if len(hit_flag) == 0:
65 | # print("ID:%s doesn't have any gold answers." % data['ID'])
66 | continue
67 |
68 | if any(hit_flag):
69 | avg_hits1.append(1)
70 | right_count += 1
71 | # other_count += 1
72 | right_cases_id.append(data['ID'])
73 | else:
74 | avg_hits1.append(0)
75 | bad_count += 1
76 | # other_count += 1
77 | # if "max length" in pred:
78 | # need_cvt_count += 1
79 | # else:
80 | # other_count += 1
81 | print(data["ID"])
82 | print("ID: %s Ques: %s" % (data["ID"], question))
83 | print("Pred: ", pred)
84 | print("Ans: ", answers)
85 | print("------------------------------------------------------------------------")
86 | bad_cases.append(data)
87 | bad_cases_id.append(data["ID"])
88 |
89 | else:
90 | avg_hits1.append(0)
91 | print("ID: %s can't be predicted" % (data["ID"]))
92 | bad_cases.append(data)
93 | bad_cases_id.append(data["ID"])
94 |
95 | hits1 = np.mean(avg_hits1)
96 | print("Hits@1: %.4f" % (hits1))
97 | if write_flag:
98 | with open(error_cases_output, "w") as f:
99 | for bc in bad_cases:
100 | f.write(json.dumps(bc) + "\n")
101 | print("Totally %d bad cases need further solved." % len(bad_cases))
102 | print("Right:%d, Wrong:%d, Max_len:%d" % (right_count, bad_count, max_len_count))
103 |
104 | if __name__ == "__main__":
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('--ori_path', type=str)
107 | parser.add_argument('--inp_path', type=str)
108 | parser.add_argument('--error_cases_output', action="store_true")
109 | parser.add_argument('--write_flag', action="store_true")
110 | args = parser.parse_args()
111 | evaluate(args.ori_path, args.inp_path, args.error_cases_output, args.write_flag)
112 |
--------------------------------------------------------------------------------
/evaluate_for_wikisql.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from collections import defaultdict
4 | import numpy as np
5 |
6 |
7 | def evaluate_example(_predict_str: str, _ground_str: list, target_delimiter=', '):
8 | _predict_spans = _predict_str.split(target_delimiter)
9 | _predict_spans = [x.lower().strip().strip('.').strip("'").strip('"').strip() for x in _predict_spans]
10 | for i in range(len(_predict_spans)):
11 | if _predict_spans[i].endswith('.0'):
12 | _predict_spans[i] = _predict_spans[i][:-2]
13 |
14 | if _predict_spans[i].replace(',', '').isnumeric():
15 | _predict_spans[i] = _predict_spans[i].replace(',', '')
16 | # _ground_spans = _ground_str.split(target_delimiter)
17 | _ground_spans = [x.lower().strip().strip('.').strip("'").strip('"').strip() for x in _ground_str]
18 | for i in range(len(_ground_spans)):
19 | if _ground_spans[i].endswith('.0'):
20 | _ground_spans[i] = _ground_spans[i][:-2]
21 |
22 | if _ground_spans[i].replace(',', '').isnumeric():
23 | _ground_spans[i] = _ground_spans[i].replace(',', '')
24 | _predict_values = defaultdict(lambda: 0)
25 | _ground_values = defaultdict(lambda: 0)
26 | for span in _predict_spans:
27 | try:
28 | _predict_values[float(span)] += 1
29 | except ValueError:
30 | _predict_values[span.strip()] += 1
31 | for span in _ground_spans:
32 | try:
33 | _ground_values[float(span)] += 1
34 | except ValueError:
35 | _ground_values[span.strip()] += 1
36 | _is_correct = _predict_values == _ground_values
37 | return _is_correct
38 |
39 |
40 | def evaluate(ori_path, inp_path, error_cases_output, write_flag):
41 | with open(ori_path, "r") as f:
42 | all_data = json.loads(f.read())
43 | print("Totally %d test data" % len(all_data))
44 |
45 | pred_data = []
46 | with open(inp_path, "r") as f:
47 | lines = f.readlines()
48 | datas = [json.loads(line) for line in lines]
49 | pred_data.extend(datas)
50 |
51 | all_pred_data = {pred['question']: pred for pred in pred_data}
52 | print("Totally %d prediction data" % len(pred_data)) # evaluate_is_right
53 | avg_deno_acc = []
54 | bad_cases = []
55 | error_count = 0
56 | max_count = 0
57 | right_count = 0
58 | for data in all_data:
59 | if data["question"] in all_pred_data:
60 | data = all_pred_data[data["question"]]
61 | pred = data['Prediction'].lower()
62 |
63 | if "answers: " in pred:
64 | pred = pred.split("answers: ")[1].strip()
65 | elif ":" in pred:
66 | pred = pred.split(":")[1].strip()
67 | else:
68 | pred = pred
69 |
70 | answers = data['answer_text']
71 | answers = [ans if not ans.endswith(".0") else ans.replace(".0", "") for ans in answers]
72 |
73 | if evaluate_example(pred, answers):
74 | avg_deno_acc.append(1)
75 | right_count += 1
76 | else:
77 | error_count += 1
78 | avg_deno_acc.append(0)
79 | # print("ID: %s Ques: %s" % (data["id"], question))
80 | # print("Pred: ", pred)
81 | # print("Ans: ", answers)
82 | # print("------------------------------------------------------------------------")
83 | bad_cases.append(data)
84 | else:
85 | avg_deno_acc.append(0)
86 | print("ID: %s can't be predicted" % (data["id"]))
87 | bad_cases.append(data)
88 | error_count += 1
89 | max_count += 1
90 |
91 | acc = np.mean(avg_deno_acc)
92 | print("Denotation Acc: %.4f" % (acc))
93 | if write_flag:
94 | with open(error_cases_output, "w") as f:
95 | for bc in bad_cases:
96 | f.write(json.dumps(bc) + "\n")
97 | print("Totally %d bad cases need further solved." % len(bad_cases))
98 | print("Right count: %d, Error count: %d(Max len count: %d)" % (right_count, error_count, max_count))
99 |
100 |
101 | if __name__ == "__main__":
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument('--ori_path', type=str, default="./data/wikisql/wikisql_test.json")
104 | parser.add_argument('--inp_path', type=str, default="./outputs/wikisql/output_wo_icl_v1.jsonl")
105 | parser.add_argument('--error_cases_output', type=str,
106 | default='./outputs/wikisql/bad_cases.jsonl')
107 | parser.add_argument('--write_flag', action="store_true")
108 | args = parser.parse_args()
109 | evaluate(args.ori_path, args.inp_path, args.error_cases_output, args.write_flag)
110 |
--------------------------------------------------------------------------------
/process_sql.py:
--------------------------------------------------------------------------------
1 | ################################
2 | # Assumptions:
3 | # 1. sql is correct
4 | # 2. only table name has alias
5 | # 3. only one intersect/union/except
6 | #
7 | # val: number(float)/string(str)/sql(dict)
8 | # col_unit: (agg_id, col_id, isDistinct(bool))
9 | # val_unit: (unit_op, col_unit1, col_unit2)
10 | # table_unit: (table_type, col_unit/sql)
11 | # cond_unit: (not_op, op_id, val_unit, val1, val2)
12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
13 | # sql {
14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
16 | # 'where': condition
17 | # 'groupBy': [col_unit1, col_unit2, ...]
18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
19 | # 'having': condition
20 | # 'limit': None/limit value
21 | # 'intersect': None/sql
22 | # 'except': None/sql
23 | # 'union': None/sql
24 | # }
25 | ################################
26 |
27 | import json
28 | import sqlite3
29 | from nltk import word_tokenize
30 |
31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
32 | JOIN_KEYWORDS = ('join', 'on', 'as')
33 |
34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
35 | UNIT_OPS = ('none', '-', '+', "*", '/')
36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
37 | TABLE_TYPE = {
38 | 'sql': "sql",
39 | 'table_unit': "table_unit",
40 | }
41 |
42 | COND_OPS = ('and', 'or')
43 | SQL_OPS = ('intersect', 'union', 'except')
44 | ORDER_OPS = ('desc', 'asc')
45 |
46 |
47 |
48 | class Schema:
49 | """
50 | Simple schema which maps table&column to a unique identifier
51 | """
52 | def __init__(self, schema):
53 | self._schema = schema
54 | self._idMap = self._map(self._schema)
55 |
56 | @property
57 | def schema(self):
58 | return self._schema
59 |
60 | @property
61 | def idMap(self):
62 | return self._idMap
63 |
64 | def _map(self, schema):
65 | idMap = {'*': "__all__"}
66 | id = 1
67 | for key, vals in schema.items():
68 | for val in vals:
69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
70 | id += 1
71 |
72 | for key in schema:
73 | idMap[key.lower()] = "__" + key.lower() + "__"
74 | id += 1
75 |
76 | return idMap
77 |
78 |
79 | def get_schema(db):
80 | """
81 | Get database's schema, which is a dict with table name as key
82 | and list of column names as value
83 | :param db: database path
84 | :return: schema dict
85 | """
86 |
87 | schema = {}
88 | conn = sqlite3.connect(db)
89 | cursor = conn.cursor()
90 |
91 | # fetch table names
92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
93 | tables = [str(table[0].lower()) for table in cursor.fetchall()]
94 |
95 | # fetch table info
96 | for table in tables:
97 | cursor.execute("PRAGMA table_info({})".format(table))
98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
99 | # import ipdb; ipdb.set_trace()
100 | return schema
101 |
102 |
103 | def get_schema_from_json(fpath):
104 | with open(fpath) as f:
105 | data = json.load(f)
106 |
107 | schema = {}
108 | for entry in data:
109 | table = str(entry['table'].lower())
110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']]
111 | schema[table] = cols
112 |
113 | return schema
114 |
115 |
116 | def tokenize(string):
117 | string = str(string)
118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote"
121 |
122 | # keep string value as token
123 | vals = {}
124 | for i in range(len(quote_idxs)-1, -1, -2):
125 | qidx1 = quote_idxs[i-1]
126 | qidx2 = quote_idxs[i]
127 | val = string[qidx1: qidx2+1]
128 | key = "__val_{}_{}__".format(qidx1, qidx2)
129 | string = string[:qidx1] + key + string[qidx2+1:]
130 | vals[key] = val
131 |
132 | toks = [word.lower() for word in word_tokenize(string)]
133 | # replace with string value token
134 | for i in range(len(toks)):
135 | if toks[i] in vals:
136 | toks[i] = vals[toks[i]]
137 |
138 | # find if there exists !=, >=, <=
139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
140 | eq_idxs.reverse()
141 | prefix = ('!', '>', '<')
142 | for eq_idx in eq_idxs:
143 | pre_tok = toks[eq_idx-1]
144 | if pre_tok in prefix:
145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
146 |
147 | return toks
148 |
149 |
150 | def scan_alias(toks):
151 | """Scan the index of 'as' and build the map for all alias"""
152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
153 | alias = {}
154 | for idx in as_idxs:
155 | alias[toks[idx+1]] = toks[idx-1]
156 | return alias
157 |
158 |
159 | def get_tables_with_alias(schema, toks):
160 | tables = scan_alias(toks)
161 | for key in schema:
162 | assert key not in tables, "Alias {} has the same name in table".format(key)
163 | tables[key] = key
164 | return tables
165 |
166 |
167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
168 | """
169 | :returns next idx, column id
170 | """
171 | tok = toks[start_idx]
172 | if tok == "*":
173 | return start_idx + 1, schema.idMap[tok]
174 |
175 | if '.' in tok: # if token is a composite
176 | alias, col = tok.split('.')
177 | key = tables_with_alias[alias] + "." + col
178 | return start_idx+1, schema.idMap[key]
179 |
180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
181 |
182 | for alias in default_tables:
183 | table = tables_with_alias[alias]
184 | if tok in schema.schema[table]:
185 | key = table + "." + tok
186 | return start_idx+1, schema.idMap[key]
187 |
188 | assert False, "Error col: {}".format(tok)
189 |
190 |
191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
192 | """
193 | :returns next idx, (agg_op id, col_id)
194 | """
195 | idx = start_idx
196 | len_ = len(toks)
197 | isBlock = False
198 | isDistinct = False
199 | if toks[idx] == '(':
200 | isBlock = True
201 | idx += 1
202 |
203 | if toks[idx] in AGG_OPS:
204 | agg_id = AGG_OPS.index(toks[idx])
205 | idx += 1
206 | assert idx < len_ and toks[idx] == '('
207 | idx += 1
208 | if toks[idx] == "distinct":
209 | idx += 1
210 | isDistinct = True
211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
212 | assert idx < len_ and toks[idx] == ')'
213 | idx += 1
214 | return idx, (agg_id, col_id, isDistinct)
215 |
216 | if toks[idx] == "distinct":
217 | idx += 1
218 | isDistinct = True
219 | agg_id = AGG_OPS.index("none")
220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
221 |
222 | if isBlock:
223 | assert toks[idx] == ')'
224 | idx += 1 # skip ')'
225 |
226 | return idx, (agg_id, col_id, isDistinct)
227 |
228 |
229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
230 | idx = start_idx
231 | len_ = len(toks)
232 | isBlock = False
233 | if toks[idx] == '(':
234 | isBlock = True
235 | idx += 1
236 |
237 | col_unit1 = None
238 | col_unit2 = None
239 | unit_op = UNIT_OPS.index('none')
240 |
241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
242 | if idx < len_ and toks[idx] in UNIT_OPS:
243 | unit_op = UNIT_OPS.index(toks[idx])
244 | idx += 1
245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
246 |
247 | if isBlock:
248 | assert toks[idx] == ')'
249 | idx += 1 # skip ')'
250 |
251 | return idx, (unit_op, col_unit1, col_unit2)
252 |
253 |
254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema):
255 | """
256 | :returns next idx, table id, table name
257 | """
258 | idx = start_idx
259 | len_ = len(toks)
260 | key = tables_with_alias[toks[idx]]
261 |
262 | if idx + 1 < len_ and toks[idx+1] == "as":
263 | idx += 3
264 | else:
265 | idx += 1
266 |
267 | return idx, schema.idMap[key], key
268 |
269 |
270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
271 | idx = start_idx
272 | len_ = len(toks)
273 |
274 | isBlock = False
275 | if toks[idx] == '(':
276 | isBlock = True
277 | idx += 1
278 |
279 | if toks[idx] == 'select':
280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema)
281 | elif "\"" in toks[idx]: # token is a string value
282 | val = toks[idx]
283 | idx += 1
284 | else:
285 | try:
286 | val = float(toks[idx])
287 | idx += 1
288 | except:
289 | end_idx = idx
290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
292 | end_idx += 1
293 |
294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
295 | idx = end_idx
296 |
297 | if isBlock:
298 | assert toks[idx] == ')'
299 | idx += 1
300 |
301 | return idx, val
302 |
303 |
304 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
305 | idx = start_idx
306 | len_ = len(toks)
307 | conds = []
308 |
309 | while idx < len_:
310 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
311 | not_op = False
312 | if toks[idx] == 'not':
313 | not_op = True
314 | idx += 1
315 |
316 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
317 | op_id = WHERE_OPS.index(toks[idx])
318 | idx += 1
319 | val1 = val2 = None
320 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
321 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
322 | assert toks[idx] == 'and'
323 | idx += 1
324 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
325 | else: # normal case: single value
326 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
327 | val2 = None
328 |
329 | conds.append((not_op, op_id, val_unit, val1, val2))
330 |
331 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
332 | break
333 |
334 | if idx < len_ and toks[idx] in COND_OPS:
335 | conds.append(toks[idx])
336 | idx += 1 # skip and/or
337 |
338 | return idx, conds
339 |
340 |
341 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
342 | idx = start_idx
343 | len_ = len(toks)
344 |
345 | assert toks[idx] == 'select', "'select' not found"
346 | idx += 1
347 | isDistinct = False
348 | if idx < len_ and toks[idx] == 'distinct':
349 | idx += 1
350 | isDistinct = True
351 | val_units = []
352 |
353 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
354 | agg_id = AGG_OPS.index("none")
355 | if toks[idx] in AGG_OPS:
356 | agg_id = AGG_OPS.index(toks[idx])
357 | idx += 1
358 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
359 | val_units.append((agg_id, val_unit))
360 | if idx < len_ and toks[idx] == ',':
361 | idx += 1 # skip ','
362 |
363 | return idx, (isDistinct, val_units)
364 |
365 |
366 | def parse_from(toks, start_idx, tables_with_alias, schema):
367 | """
368 | Assume in the from clause, all table units are combined with join
369 | """
370 | assert 'from' in toks[start_idx:], "'from' not found"
371 |
372 | len_ = len(toks)
373 | idx = toks.index('from', start_idx) + 1
374 | default_tables = []
375 | table_units = []
376 | conds = []
377 |
378 | while idx < len_:
379 | isBlock = False
380 | if toks[idx] == '(':
381 | isBlock = True
382 | idx += 1
383 |
384 | if toks[idx] == 'select':
385 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
386 | table_units.append((TABLE_TYPE['sql'], sql))
387 | else:
388 | if idx < len_ and toks[idx] == 'join':
389 | idx += 1 # skip join
390 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
391 | table_units.append((TABLE_TYPE['table_unit'],table_unit))
392 | default_tables.append(table_name)
393 | if idx < len_ and toks[idx] == "on":
394 | idx += 1 # skip on
395 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
396 | if len(conds) > 0:
397 | conds.append('and')
398 | conds.extend(this_conds)
399 |
400 | if isBlock:
401 | assert toks[idx] == ')'
402 | idx += 1
403 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
404 | break
405 |
406 | return idx, table_units, conds, default_tables
407 |
408 |
409 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
410 | idx = start_idx
411 | len_ = len(toks)
412 |
413 | if idx >= len_ or toks[idx] != 'where':
414 | return idx, []
415 |
416 | idx += 1
417 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
418 | return idx, conds
419 |
420 |
421 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
422 | idx = start_idx
423 | len_ = len(toks)
424 | col_units = []
425 |
426 | if idx >= len_ or toks[idx] != 'group':
427 | return idx, col_units
428 |
429 | idx += 1
430 | assert toks[idx] == 'by'
431 | idx += 1
432 |
433 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
434 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
435 | col_units.append(col_unit)
436 | if idx < len_ and toks[idx] == ',':
437 | idx += 1 # skip ','
438 | else:
439 | break
440 |
441 | return idx, col_units
442 |
443 |
444 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
445 | idx = start_idx
446 | len_ = len(toks)
447 | val_units = []
448 | order_type = 'asc' # default type is 'asc'
449 |
450 | if idx >= len_ or toks[idx] != 'order':
451 | return idx, val_units
452 |
453 | idx += 1
454 | assert toks[idx] == 'by'
455 | idx += 1
456 |
457 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
458 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
459 | val_units.append(val_unit)
460 | if idx < len_ and toks[idx] in ORDER_OPS:
461 | order_type = toks[idx]
462 | idx += 1
463 | if idx < len_ and toks[idx] == ',':
464 | idx += 1 # skip ','
465 | else:
466 | break
467 |
468 | return idx, (order_type, val_units)
469 |
470 |
471 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
472 | idx = start_idx
473 | len_ = len(toks)
474 |
475 | if idx >= len_ or toks[idx] != 'having':
476 | return idx, []
477 |
478 | idx += 1
479 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
480 | return idx, conds
481 |
482 |
483 | def parse_limit(toks, start_idx):
484 | idx = start_idx
485 | len_ = len(toks)
486 |
487 | if idx < len_ and toks[idx] == 'limit':
488 | idx += 2
489 | return idx, int(toks[idx-1])
490 |
491 | return idx, None
492 |
493 |
494 | def parse_sql(toks, start_idx, tables_with_alias, schema):
495 | isBlock = False # indicate whether this is a block of sql/sub-sql
496 | len_ = len(toks)
497 | idx = start_idx
498 |
499 | sql = {}
500 | if toks[idx] == '(':
501 | isBlock = True
502 | idx += 1
503 |
504 | # parse from clause in order to get default tables
505 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
506 | sql['from'] = {'table_units': table_units, 'conds': conds}
507 | # select clause
508 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
509 | idx = from_end_idx
510 | sql['select'] = select_col_units
511 | # where clause
512 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
513 | sql['where'] = where_conds
514 | # group by clause
515 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
516 | sql['groupBy'] = group_col_units
517 | # having clause
518 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
519 | sql['having'] = having_conds
520 | # order by clause
521 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
522 | sql['orderBy'] = order_col_units
523 | # limit clause
524 | idx, limit_val = parse_limit(toks, idx)
525 | sql['limit'] = limit_val
526 |
527 | idx = skip_semicolon(toks, idx)
528 | if isBlock:
529 | assert toks[idx] == ')'
530 | idx += 1 # skip ')'
531 | idx = skip_semicolon(toks, idx)
532 |
533 | # intersect/union/except clause
534 | for op in SQL_OPS: # initialize IUE
535 | sql[op] = None
536 | if idx < len_ and toks[idx] in SQL_OPS:
537 | sql_op = toks[idx]
538 | idx += 1
539 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
540 | sql[sql_op] = IUE_sql
541 | return idx, sql
542 |
543 |
544 | def load_data(fpath):
545 | with open(fpath) as f:
546 | data = json.load(f)
547 | return data
548 |
549 |
550 | def get_sql(schema, query):
551 | toks = tokenize(query)
552 | tables_with_alias = get_tables_with_alias(schema.schema, toks)
553 | _, sql = parse_sql(toks, 0, tables_with_alias, schema)
554 |
555 | return sql
556 |
557 |
558 | def skip_semicolon(toks, start_idx):
559 | idx = start_idx
560 | while idx < len(toks) and toks[idx] == ";":
561 | idx += 1
562 | return idx
563 |
--------------------------------------------------------------------------------
/prompts/prompt_for_spider.json:
--------------------------------------------------------------------------------
1 | {
2 | "chat_v1": {
3 | "system": "You are a helpful assistant.",
4 | "free_generate": "### Here are the Sqlite SQL tables, with their properties:\n#\n{table}\n#\n### {question} Which tables do you need to complete the SQLite SQL query? Let's think step by step.",
5 | "table_column_select_reorganize": "Sure! Based on your response, provide only the required tables in the following format: 'Table_Name_1 | Table_Name_2' to simplify your response, which means using '|' to separate different tables. Don't use any other format or add any additional explanations.",
6 | "ask_final_answers": {
7 | "has_fk": "### Complete sqlite SQL query only and with no explanation.\n#\n### Sqlite SQL tables, with their properties: \n#\n{table}\n# {fk}\n#\n### {question}\n SELECT",
8 | "no_fk": "### Complete sqlite SQL query only and with no explanation.\n#\n### Sqlite SQL tables, with their properties: \n#\n{table}\n#\n### {question}\n SELECT"
9 | }
10 | }
11 | }
--------------------------------------------------------------------------------
/prompts/prompt_for_tabfact.json:
--------------------------------------------------------------------------------
1 | {
2 | "chat_v1": {
3 | "system": "You are a helpful assistant.",
4 | "columns_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", first look at the available columns in the table: {columns}. Which columns are most relevant to answering the question? Your output format is only the “Columns: ColumnName1, ColumnName2, ColumnName3...” form, no other form, with no explanation.",
5 | "rows_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", once you have selected the relevant columns {selected_columns}, your next step is to identify the necessary rows to answer the question based on their column values. Below is the list of rows in the table, arranged in order. Each row is represented by one line, starting with the row name followed by its corresponding column name and value pairs:\n{rows}\nTo answer \"{question}\", which rows should be considered? Your output format is only the “Rows: RowName1, RowName2, RowName3...” form, no other form, with no explanation. Your response should only contain the row names from the above candidates such as item 1, item 2.",
6 | "ask_final_answer_or_next_question": "The ordered list below shows the rows of a table, with each line displaying a different row along with its corresponding column name and value pairs. The format for each pair is (column name, value). The table contains:\n{table}\nAccording to the table, is the statement \"{question}\" ture? If you think the statement is true, only output \"Yes.\". If you think the statement is not ture, only output \"No.\"."
7 | }
8 | }
--------------------------------------------------------------------------------
/prompts/prompt_for_webqsp.json:
--------------------------------------------------------------------------------
1 | {
2 | "chat_v1": {
3 | "init_relation_rerank": "The candidate relations: {relations}.\nThe question is \"{question}\" and you'll start with \"{tpe}\". To answer this question, typically you would need to identify some relations that correspond to the meaning of the question. Therefore, select one relation from the candidate relations above that can be used to answer the question. Provide only one relevant relation that's present in the candidates, and begin your response with \"The relevant relation: \".",
4 | "constraints_flag": "The question is \"{question}\" and you'll start with \"{tpe}\". To answer this question, typically you would need to identify some relations that correspond to the meaning of the question. The already selected relevant relation {selected_relations}, and there are many candidate entities along these relations for the next step. If you think you can narrow down the current candidate entities using hints in the question, respond \"Yes\". If there are no hits in the question to narrow down the current candidate entities, respond \"No\".",
5 | "choose_constraints": "One constraint contains a relation and an entity. The list below shows the candidate relations and their possible entities. Each line contains one candidate relation and its corresponding entities:\n{relation_tails}\nThe question is \"{question}\" and you'll start with \"{tpe}\". Are there any constraint hints in the question that can narrow down the candidate entities?\nIf you think there are hints in the question that can narrow down the candidate entities, provide all possible constraints by combining the candidate relation and entity only using the above candidates. For each constraint, you should first choose one relation only from the above candidates, and then select one entity only from the corresponding entity candidates list of that relation. Do not modify the surface form of the chosen relation and entity, which means that they should be exactly consistent with the above provided. Use the format \"[relation: entity]\" and begin your response with \"The possible constraints: \".\nIf you think there are not any hints, directly respond \"No\".\"",
6 | "ask_final_answer_or_next_question": "The triples are: {facts}.\nBased on these triples, if you believe you have gathered sufficient information to answer \"{question}\", give me the final answer entity and start your response with \"The final answers:\". You just need to provide only one answer entity. If you think you still do not have enough information to answer the question, respond \"Need further information\".",
7 | "relation_rerank": "The candidate relations: {relations}.\nThe question is \"{question}\" and you'll start with \"{tpe}\". To answer this question, typically you would need to identify some relations that correspond to the meaning of the question. The already selected relevant relation {selected_relations}, then select the next relation from the candidate relations above that can be used to answer the question. Provide only one relevant relation that's present in the candidates, and begin your response with \"The relevant relation: \".",
8 | "final_query_template": "According to existing information, {question} Please start your response with \"The final answers:\". You just need to provide only one answer entity."
9 | }
10 | }
--------------------------------------------------------------------------------
/prompts/prompt_for_wikisql.json:
--------------------------------------------------------------------------------
1 | {
2 | "chat_v1": {
3 | "system": "You are a helpful assistant.",
4 | "columns_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", first look at the available columns in the table: {columns}. Which columns are most relevant to answering the question? Your output format is only the “Columns: ColumnName1, ColumnName2, ColumnName3...” form, no other form, with no explanation.",
5 | "rows_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", once you have selected the relevant columns {selected_columns}, your next step is to identify the necessary rows to answer the question based on their column values. Below is the list of rows in the table, arranged in order. Each row is represented by one line, starting with the row name followed by its corresponding column name and value pairs:\n{rows}\nTo answer \"{question}\", which rows should be considered? Your output format is only the “Rows: RowName1, RowName2, RowName3...” form, no other form, with no explanation. Your response should only contain the row names from the above candidates such as item 1, item 2.",
6 | "ask_final_answer_or_next_question": "The ordered list below shows the rows of a table, with each line displaying a different row along with its corresponding column name and value pairs. The format for each pair is (column name, value). The table contains:\n{table}\nUsing this information, {question} Your output format is only “Answers: AnswerName1, AnswerName2...” form, no other form. And the output should be the number or entity names, as short as possible, without any explanation."
7 | }
8 | }
--------------------------------------------------------------------------------
/prompts/prompt_for_wtq.json:
--------------------------------------------------------------------------------
1 | {
2 | "chat_v1": {
3 | "system": "You are a helpful assistant.",
4 | "columns_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", first look at the available columns in the table: {columns}. Which columns are most relevant to answering the question? Your output format is only the “Columns: ColumnName1, ColumnName2, ColumnName3...” form, no other form, with no explanation.",
5 | "rows_select": "You need to answer a question using a table with multiple rows and column headings. Specifically, you need to select the relevant columns and rows from the table and then obtain a sub-table that is relevant to the question. Then, use the sub-table to answer the question. \nTherefore, to answer \"{question}\", once you have selected the relevant columns {selected_columns}, your next step is to identify the necessary rows to answer the question based on their column values. Below is the list of rows in the table, arranged in order. Each row is represented by one line, starting with the row name followed by its corresponding column name and value pairs:\n{rows}\nTo answer \"{question}\", which rows should be considered? Your output format is only the “Rows: RowName1, RowName2, RowName3...” form, no other form, with no explanation. Your response should only contain the row names from the above candidates such as item 1, item 2.",
6 | "ask_final_answer_or_next_question": "The ordered list below shows the rows of a table, with each line displaying a different row along with its corresponding column name and value pairs. The format for each pair is (column name, value). The table contains:\n{table}\nUsing this information, {question} Your output format is only “Answers: AnswerName1, AnswerName2...” form, no other form. And the output should be the number or entity names, as short as possible, without any explanation."
7 | }
8 | }
--------------------------------------------------------------------------------
/scripts/eval_for_tabfact.sh:
--------------------------------------------------------------------------------
1 | python evaluate_for_tabfact.py --ori_path ./data/tabfact/tab_fact_test.json --inp_path ./outputs/tabfact/tabfact_test_output.jsonl
--------------------------------------------------------------------------------
/scripts/eval_spider_pred.sh:
--------------------------------------------------------------------------------
1 | python evaluate_for_spider.py --path ./outputs/spider/output_wo_icl_v1.jsonl --db=data/spider/database --table=data/spider/tables.json --etype=exec
--------------------------------------------------------------------------------
/scripts/run_spider_rea_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python3 structgpt_for_text_to_sql.py \
4 | --api_key ./api_key.txt --num_process 21 \
5 | --prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \
6 | --input_path ./data/spider-realistic/spider-realistic.json \
7 | --output_path ./outputs/spider-realistic/output_wo_icl_v1.jsonl \
8 | --chat_log_path ./outputs/spider-realistic/chat_wo_icl_v1.txt \
9 | --schema_path ./data/spider-realistic/tables.json
10 |
11 | # single process usage
12 | #python3 structgpt_for_text_to_sql.py \
13 | #--api_key sk-?? --num_process 1 \
14 | #--prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \
15 | #--input_path ./data/spider-realistic/spider-realistic.json \
16 | #--output_path ./outputs/spider-realistic/output_wo_icl_v1.jsonl \
17 | #--chat_log_path ./outputs/spider-realistic/chat_wo_icl_v1.txt \
18 | #--schema_path ./data/spider-realistic/tables.json
19 |
20 | cat ./outputs/spider-realistic/output_wo_icl_v1.jsonl_* > ./outputs/spider-realistic/output_wo_icl_v1.jsonl
21 | rm ./outputs/spider-realistic/output_wo_icl_v1.jsonl_*
22 | cat ./outputs/spider-realistic/chat_wo_icl_v1.txt_* > ./outputs/spider-realistic/chat_wo_icl_v1.txt
23 | rm ./outputs/spider-realistic/chat_wo_icl_v1.txt_*
24 |
25 | python evaluate_for_spider.py --path ./outputs/spider-realistic/output_wo_icl_v1.jsonl --db data/spider/database --table data/spider-realistic/tables.json --etype exec
--------------------------------------------------------------------------------
/scripts/run_spider_syn_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python3 structgpt_for_text_to_sql.py \
4 | --api_key ./api_key.txt --num_process 21 \
5 | --prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \
6 | --input_path ./data/spider-syn/dev.json \
7 | --output_path ./outputs/spider-syn/output_wo_icl_v1.jsonl \
8 | --chat_log_path ./outputs/spider-syn/chat_wo_icl_v1.txt \
9 | --schema_path ./data/spider-syn/tables.json
10 |
11 | # single process usage
12 | #python3 structgpt_for_text_to_sql.py \
13 | #--api_key sk-?? --num_process 1 \
14 | #--prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \
15 | #--input_path ./data/spider-syn/spider-syn.json \
16 | #--output_path ./outputs/spider-syn/output_wo_icl_v1.jsonl \
17 | #--chat_log_path ./outputs/spider-syn/chat_wo_icl_v1.txt \
18 | #--schema_path ./data/spider-syn/tables.json
19 |
20 | cat ./outputs/spider-syn/output_wo_icl_v1.jsonl_* > ./outputs/spider-syn/output_wo_icl_v1.jsonl
21 | rm ./outputs/spider-syn/output_wo_icl_v1.jsonl_*
22 | cat ./outputs/spider-syn/chat_wo_icl_v1.txt_* > ./outputs/spider-syn/chat_wo_icl_v1.txt
23 | rm ./outputs/spider-syn/chat_wo_icl_v1.txt_*
24 |
25 | python evaluate_for_spider.py --path ./outputs/spider-syn/output_wo_icl_v1.jsonl --db data/spider/database --table data/spider-syn/tables.json --etype exec
--------------------------------------------------------------------------------
/scripts/run_spider_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python3 structgpt_for_text_to_sql.py \
4 | --api_key ./api_key.txt --num_process 21 \
5 | --prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \
6 | --input_path ./data/spider/dev.jsonl \
7 | --output_path ./outputs/spider/output_wo_icl_v1.jsonl \
8 | --chat_log_path ./outputs/spider/chat_wo_icl_v1.txt \
9 | --db_path ./data/spider/all_tables_content.json \
10 | --schema_path ./data/spider/tables.json
11 |
12 | # single process usage
13 | #python3 structgpt_for_text_to_sql.py \
14 | #--api_key sk-?? --num_process 1 \
15 | #--prompt_path ./prompts/prompt_for_spider.json --prompt_name chat_v1 \
16 | #--input_path ./data/spider/dev.jsonl \
17 | #--output_path ./outputs/spider/output_wo_icl_v1.jsonl \
18 | #--chat_log_path ./outputs/spider/chat_wo_icl_v1.txt \
19 | #--schema_path ./data/spider/tables.json
20 |
21 | cat ./outputs/spider/output_wo_icl_v1.jsonl_* > ./outputs/spider/output_wo_icl_v1.jsonl
22 | rm ./outputs/spider/output_wo_icl_v1.jsonl_*
23 | cat ./outputs/spider/chat_wo_icl_v1.txt_* > ./outputs/spider/chat_wo_icl_v1.txt
24 | rm ./outputs/spider/chat_wo_icl_v1.txt_*
25 |
26 | python evaluate_for_spider.py --path ./outputs/spider/output_wo_icl_v1.jsonl --db=data/spider/database --table=data/spider/tables.json --etype=exec
--------------------------------------------------------------------------------
/scripts/run_tabfact_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | python3 structgpt_for_tableqa.py \
2 | --api_key ./api_key.txt --num_process 37 \
3 | --prompt_path ./prompts/prompt_for_tabfact.json --prompt_name chat_v1 \
4 | --input_path ./data/tabfact/tab_fact_test.json \
5 | --output_path ./outputs/tabfact/output_wo_icl_v1.jsonl \
6 | --chat_log_path ./outputs/tabfact/chat_wo_icl_v1.txt --max_tokens 350
7 |
8 | cat ./outputs/tabfact/output_wo_icl_v1.jsonl_* > ./outputs/tabfact/output_wo_icl_v1.jsonl
9 | rm ./outputs/tabfact/output_wo_icl_v1.jsonl_*
10 | cat ./outputs/tabfact/chat_wo_icl_v1.txt_* > ./outputs/tabfact/chat_wo_icl_v1.txt
11 | rm ./outputs/tabfact/chat_wo_icl_v1.txt_*
12 |
13 | python evaluate_for_tabfact.py --ori_path ./data/tabfact/tab_fact_test.json --inp_path ./outputs/tabfact/output_wo_icl_v1.jsonl
--------------------------------------------------------------------------------
/scripts/run_webqsp_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | python3 structgpt_for_webqsp.py \
2 | --api_key ./api_key_for_kg.txt --num_process 14 \
3 | --prompt_path ./prompts/prompt_for_webqsp.json --max_tokens 300 --prompt_name chat_v1 \
4 | --kg_source_path ./data/webqsp/subgraph_2hop_triples.npy \
5 | --ent_type_path ./data/webqsp/ent_type_ary.npy \
6 | --ent2id_path ./data/webqsp/ent2id.pickle \
7 | --rel2id_path ./data/webqsp/rel2id.pickle \
8 | --ent2name_path ./data/webqsp/entity_name.pickle \
9 | --max_triples_per_relation 60 \
10 | --input_path ./data/webqsp/webqsp_simple_test.jsonl \
11 | --output_path ./outputs/webqsp/output_wo_icl_v1.jsonl \
12 | --chat_log_path ./outputs/webqsp/chat_wo_icl_v1.txt
13 |
14 | cat ./outputs/webqsp/output_wo_icl_v1.jsonl_* > ./outputs/webqsp/output_wo_icl_v1.jsonl
15 | rm ./outputs/webqsp/output_wo_icl_v1.jsonl_*
16 | cat ./outputs/webqsp/chat_wo_icl_v1.txt_* > ./outputs/webqsp/chat_wo_icl_v1.txt
17 | rm ./outputs/webqsp/chat_wo_icl_v1.txt_*
--------------------------------------------------------------------------------
/scripts/run_wikisql_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | python3 structgpt_for_tableqa.py \
2 | --api_key ./api_key.txt --num_process 37 \
3 | --prompt_path ./prompts/prompt_for_wikisql.json --prompt_name chat_v1 \
4 | --input_path ./data/wikisql/wikisql_test.json \
5 | --output_path ./outputs/wikisql/output_wo_icl_v1.jsonl \
6 | --chat_log_path ./outputs/wikisql/chat_wo_icl_v1.txt --max_tokens 350
7 |
8 | cat ./outputs/wikisql/output_wo_icl_v1.jsonl_* > ./outputs/wikisql/output_wo_icl_v1.jsonl
9 | rm ./outputs/wikisql/output_wo_icl_v1.jsonl_*
10 | cat ./outputs/wikisql/chat_wo_icl_v1.txt_* > ./outputs/wikisql/chat_wo_icl_v1.txt
11 | rm ./outputs/wikisql/chat_wo_icl_v1.txt_*
12 |
13 | python evaluate_for_wikisql.py --ori_path ./data/wikisql/wikisql_test.json --inp_path ./outputs/wikisql/output_wo_icl_v1.jsonl
--------------------------------------------------------------------------------
/scripts/run_wtq_wo_icl_v1.sh:
--------------------------------------------------------------------------------
1 | python3 structgpt_for_tableqa.py \
2 | --api_key ./api_key.txt --num_process 37 \
3 | --prompt_path ./prompts/prompt_for_wtq.json --prompt_name chat_v1 \
4 | --input_path ./data/wtq/wikitq_test.json \
5 | --output_path ./outputs/wtq/output_wo_icl_v1.jsonl \
6 | --chat_log_path ./outputs/wtq/chat_wo_icl_v1.txt --max_tokens 350
7 |
8 | cat ./outputs/wtq/output_wo_icl_v1.jsonl_* > ./outputs/wtq/output_wo_icl_v1.jsonl
9 | rm ./outputs/wtq/output_wo_icl_v1.jsonl_*
10 | cat ./outputs/wtq/chat_wo_icl_v1.txt_* > ./outputs/wtq/chat_wo_icl_v1.txt
11 | rm ./outputs/wtq/chat_wo_icl_v1.txt_*
12 |
13 | python evaluate_for_wikisql.py --ori_path ./data/wtq/wikitq_test.json --inp_path ./outputs/wtq/output_wo_icl_v1.jsonl
14 |
--------------------------------------------------------------------------------
/structgpt_for_tableqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import pickle
6 | import re
7 | import multiprocessing as mp
8 | from tqdm import tqdm
9 | import time
10 |
11 | import openai
12 |
13 |
14 | class ChatGPT:
15 | def __init__(self, args, prompt_path, prompt_name, max_tokens):
16 | self.args = args
17 | self.history_messages = []
18 | self.history_contents = []
19 | self.max_tokens = max_tokens
20 | self.prompt = self.load_prompt_template(prompt_path, prompt_name)
21 | self.idx_mapping = {"0": "first", "1": "second", "2": "third", "3": "fourth", "4": "fifth", "5": "sixth",
22 | "6": "seventh",
23 | "7": "eighth", "8": "ninth", "9": "tenth"}
24 |
25 | def get_response_v1(self, input_text, turn_type):
26 | message = self.create_message_v1(input_text, turn_type)
27 | self.history_contents.append(message['content'])
28 | self.history_messages.append(message)
29 | message = self.query_API_to_get_message(self.history_messages)
30 | self.history_contents.append(message['content'])
31 | self.history_messages.append(message)
32 | response = message['content']
33 | return response
34 |
35 | def create_message_v1(self, input_text, turn_type):
36 | if turn_type == "columns_select":
37 | template = self.prompt['columns_select']
38 | columns, question = input_text
39 | # question = question.capitalize()
40 | input_text = template.format(question=question, columns=columns)
41 | elif turn_type == 'rows_select':
42 | template = self.prompt['rows_select']
43 | selected_cols, rows, question = input_text
44 | # question = question.capitalize()
45 | input_text = template.format(selected_columns=selected_cols, rows=rows, question=question)
46 | elif turn_type == "ask_final_answer_or_next_question":
47 | question, serialized_table = input_text
48 | template = self.prompt['ask_final_answer_or_next_question']
49 | input_text = template.format(table=serialized_table, question=question)
50 | else:
51 | raise NotImplementedError
52 | message = {'role': 'user', 'content': input_text}
53 | return message
54 |
55 | def query_API_to_get_message(self, messages):
56 | while True:
57 | try:
58 | res = openai.ChatCompletion.create(
59 | model="gpt-3.5-turbo",
60 | messages=messages,
61 | temperature=0,
62 | max_tokens=self.max_tokens,
63 | top_p=1,
64 | frequency_penalty=0,
65 | presence_penalty=0,
66 | )
67 | return res['choices'][0]['message']
68 | except openai.error.RateLimitError as e:
69 | err_mes = str(e)
70 | if "You exceeded your current quota" in err_mes:
71 | print("You exceeded your current quota: %s" % openai.api_key)
72 | print('openai.error.RateLimitError\nRetrying...')
73 | time.sleep(30)
74 | except openai.error.ServiceUnavailableError:
75 | print('openai.error.ServiceUnavailableError\nRetrying...')
76 | time.sleep(20)
77 | except openai.error.Timeout:
78 | print('openai.error.Timeout\nRetrying...')
79 | time.sleep(20)
80 | except openai.error.APIError:
81 | print('openai.error.APIError\nRetrying...')
82 | time.sleep(20)
83 | except openai.error.APIConnectionError:
84 | print('openai.error.APIConnectionError\nRetrying...')
85 | time.sleep(20)
86 |
87 | def parse_result(self, result, turn_type):
88 | content = result['content'].strip()
89 | if turn_type in ["initial", "question_template"]:
90 | if "should be" in content:
91 | content = content.split("should be")[1].strip()
92 | if content.startswith('"') and content.endswith('"'):
93 | content = content[1:-1]
94 | else:
95 | matchObj = re.search(r'"(.*?)"', content)
96 | if matchObj is not None:
97 | content = matchObj.group()
98 | content = content[1:-1]
99 | else:
100 | content = content.strip().strip('"')
101 | print("Not exactly parse, we directly use content: %s" % content)
102 |
103 | return content
104 |
105 | def reset_history(self):
106 | self.history_messages = []
107 | self.history_contents = []
108 |
109 | def reset_history_messages(self):
110 | self.history_messages = []
111 |
112 | def load_prompt_template(self, prompt_path, prompt_name):
113 | if prompt_path.endswith(".json"):
114 | with open(prompt_path, "rb") as f:
115 | prompt = json.load(f)
116 | return prompt[prompt_name]
117 |
118 |
119 | class Retriever:
120 | def __init__(self, args):
121 | self.args = args
122 |
123 | def serialize_headers(self, headers):
124 | # headers = ['"' + header.replace("\n", " ") + '"' for header in headers]
125 | # if len(headers) == 0:
126 | # ser_hea = ""
127 | # elif len(headers) == 1:
128 | # ser_hea = headers[0]
129 | # elif len(headers) == 2:
130 | # ser_hea = headers[0] + " and " + headers[1]
131 | # else:
132 | # ser_hea = ", ".join(headers[0:-1]) + ", and " + headers[-1]
133 | headers = [header.replace("\n", " ") for header in headers]
134 | ser_hea = ", ".join(headers)
135 | return ser_hea
136 |
137 | def filter_table_with_col_name(self, table, selected_relations_list, selected_relations_str):
138 | new_table = dict()
139 | header = table['header']
140 | rows = table['rows']
141 | reserved_col_idx = [idx for idx, rel in enumerate(header) if rel.replace("\n", " ") in selected_relations_list
142 | or rel.replace("\n", " ").lower() in selected_relations_str.lower()]
143 | new_header = [header[idx] for idx in reserved_col_idx]
144 | new_rows = [[row[idx] for idx in reserved_col_idx] for row in rows]
145 | new_table["header"] = new_header
146 | new_table["rows"] = new_rows
147 | return new_table
148 |
149 |
150 | class Solver:
151 | def __init__(self, args):
152 | self.args = args
153 | self.LLM = ChatGPT(args=args, prompt_path=args.prompt_path, prompt_name=args.prompt_name,
154 | max_tokens=args.max_tokens)
155 | self.SLM = Retriever(args)
156 | self.max_serialization_tokens = args.max_llm_input_tokens
157 | self.selected_relations = []
158 |
159 | def forward(self, question, table):
160 | self.LLM.reset_history()
161 | self.reset_history()
162 |
163 | iterative_step = 0
164 |
165 | # select
166 | table = self.normalize_table_header(table)
167 | header = table['header']
168 | ser_hea = self.SLM.serialize_headers(header)
169 | if args.debug:
170 | print("Step-%d: ser_hea:%s" % (iterative_step, ser_hea))
171 |
172 | llm_selected_cols = self.LLM.get_response_v1((ser_hea, question), "columns_select")
173 | self.LLM.reset_history_messages()
174 | if args.debug:
175 | print("Step-%d: llm_selected_cols:%s" % (iterative_step, llm_selected_cols))
176 |
177 | selected_cols_list = self.parse_selected_cols(llm_selected_cols, header)
178 | if args.debug:
179 | print("Step-%d: selected_cols_list:%s" % (iterative_step, selected_cols_list))
180 |
181 | filtered_table = self.SLM.filter_table_with_col_name(table, selected_cols_list, llm_selected_cols)
182 | if args.debug:
183 | print("Step-%d: filtered_table:%s" % (iterative_step, filtered_table))
184 |
185 | candidate_rows = self.serialize_table(filtered_table)
186 | if args.debug:
187 | print("Step-%d: candidate_rows:%s" % (iterative_step, candidate_rows))
188 |
189 | selected_columns = self.SLM.serialize_headers(selected_cols_list)
190 | choose_rows = self.LLM.get_response_v1((selected_columns, candidate_rows, question), "rows_select")
191 | self.LLM.reset_history_messages()
192 | try:
193 | choose_row_list = self.parse_row_idx(choose_rows)
194 | filtered_table = self.filter_table_with_rows_constraints(filtered_table, choose_row_list)
195 | except Exception as e:
196 | logging.exception(e)
197 | # print(candidate_rows)
198 | # print(choose_rows)
199 | if args.debug:
200 | print("Step-%d: filtered_table:%s" % (iterative_step, filtered_table))
201 |
202 | serialized_table = self.serialize_table(filtered_table)
203 | if args.debug:
204 | print("Step-%d: serialized_table:%s" % (iterative_step, serialized_table))
205 |
206 | final_answers = self.LLM.get_response_v1((question, serialized_table),
207 | "ask_final_answer_or_next_question")
208 | self.LLM.reset_history_messages()
209 |
210 | return final_answers, self.LLM.history_contents, self.log
211 |
212 | def is_end(self, response, iterative_step):
213 | if "no" in response.lower() or iterative_step > 8:
214 | return True
215 | else:
216 | return False
217 |
218 | def is_end_v1(self, response, iterative_step):
219 | if "final" in response.lower() or iterative_step > 3:
220 | return True
221 | elif "next" in response.lower():
222 | return False
223 | else:
224 | return False
225 |
226 | def get_final_answers(self, history_responses, final_response):
227 | answer_tmp_1 = history_responses[-2]
228 | answer_tmp_2 = final_response
229 | return [answer_tmp_1, answer_tmp_2]
230 |
231 | def parse_result(self, response, parse_type):
232 | response = response.lower()
233 | if parse_type == "next_question":
234 | if "the next question:" in response:
235 | next_question = response.split("the next question:")[1].strip()
236 | elif ":" in response:
237 | next_question = response.split(":")[1].strip()
238 | else:
239 | next_question = response
240 | print("Not parse the next question exactly, directly use the response: ", response)
241 | return next_question
242 | elif parse_type == "final_answer":
243 | if 'yes' in response and 'no' in response:
244 | final_answer = 'unknown'
245 | elif 'yes' in response:
246 | final_answer = 'entailed'
247 | else:
248 | final_answer='refuted'
249 |
250 | return final_answer
251 |
252 | def reset_history(self):
253 | self.log = []
254 | self.selected_relations = []
255 |
256 | def serialize_table(self, table):
257 | header = table['header']
258 | rows = table['rows']
259 | lines = []
260 | for idx, row in enumerate(rows):
261 | pairs = []
262 | for rel, ent in zip(header, row):
263 | pair = "(" + rel + ", " + ent + ")"
264 | pairs.append(pair)
265 |
266 | line = 'item ' + str(idx + 1) + ': ' + "; ".join(pairs)
267 | lines.append(line)
268 | output = "\n".join(lines)
269 | return output
270 |
271 | def parse_selected_cols(self, llm_selected_cols, header):
272 | llm_selected_cols = [h for h in header if h.replace("\n", " ").lower() in llm_selected_cols.lower()]
273 | return llm_selected_cols
274 |
275 | def parse_row_idx(self, selected_rows):
276 | pattern = re.compile(r'(\d+)')
277 | m = pattern.finditer(selected_rows)
278 | m = [i.group() for i in m]
279 | selected_rows = [int(rid)-1 for rid in m]
280 | return selected_rows
281 |
282 | def filter_table_with_rows_constraints(self, table, row_constraints):
283 | new_table = dict()
284 | header = table['header']
285 | rows = table['rows']
286 | new_rows = []
287 | for rid in row_constraints:
288 | if rid < len(rows):
289 | new_rows.append(rows[rid])
290 | new_table["header"] = header
291 | new_table["rows"] = new_rows
292 | return new_table
293 |
294 | def normalize_table_header(self, table):
295 | header = table['header']
296 | rows = table['rows']
297 | new_table = {}
298 | new_header = []
299 | for h in header:
300 | h = h.replace("\n", " ")
301 | new_header.append(h)
302 | new_table['header'] = new_header
303 | new_table['rows'] = rows
304 | return new_table
305 |
306 | def main(args, all_data, idx, api_key):
307 | import openai
308 | openai.api_key = api_key
309 |
310 | if idx == -1:
311 | output_path = args.output_path
312 | chat_log_path = args.chat_log_path
313 | else:
314 | idx = "0" + str(idx) if idx < 10 else str(idx) # 00 01 02 ... 29
315 | output_path = args.output_path + "_" + idx
316 | chat_log_path = args.chat_log_path + "_" + idx
317 |
318 | print("Start PID %d and save to %s" % (os.getpid(), output_path))
319 |
320 | solver = Solver(args)
321 |
322 | count = 0
323 | with open(output_path, "w") as f:
324 | with open(chat_log_path, "w") as fclog:
325 | for sample in tqdm(all_data, total=len(all_data), desc="PID: %d" % os.getpid()):
326 | try:
327 | question = sample["statement"] if 'statement' in sample else sample['question']
328 | question = question + "?" if not question.endswith("?") else question
329 | table = sample['table']
330 | prediction, chat_history, record = solver.forward(question, table)
331 | except openai.error.InvalidRequestError as e:
332 | print(e)
333 | continue
334 | except Exception as e:
335 | logging.exception(e)
336 | continue
337 | if 'id' in sample.keys():
338 | flag = str(sample['id'])
339 | else:
340 | flag = question
341 |
342 | try:
343 | chat = flag + "\n" + "\n******\n".join(chat_history) + "\nAnswers: " + str(
344 | sample['seq_out']) + "\n------------------------------------------\n"
345 | fclog.write(chat)
346 | except Exception as e:
347 | print(e)
348 | count += 1
349 | if count < 5:
350 | print(sample['seq_out'])
351 | print(prediction)
352 | print("---------------------")
353 | sample["Prediction"] = prediction
354 | f.write(json.dumps(sample) + "\n")
355 |
356 |
357 | def parse_args():
358 | parser = argparse.ArgumentParser()
359 | parser.add_argument('--input_path', default=None)
360 | parser.add_argument('--output_path', default=None)
361 | parser.add_argument('--chat_log_path', default=None)
362 | parser.add_argument('--debug', action="store_true")
363 | parser.add_argument('--prompt_path')
364 | parser.add_argument('--prompt_name', default="chat", )
365 | parser.add_argument('--overwrite', action="store_true")
366 | parser.add_argument('--num_process', default=1, type=int, help='the number of multi-process')
367 | parser.add_argument('--max_tokens', default=10, type=int, help='retrieve the topk score paths')
368 | parser.add_argument('--api_key', default="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", type=str)
369 | parser.add_argument('--max_llm_input_tokens', default=3400, type=int)
370 |
371 | args = parser.parse_args()
372 |
373 | print("Start querying the LLM.")
374 | return args
375 |
376 |
377 | if __name__ == '__main__':
378 | args = parse_args()
379 | if not args.api_key.startswith("sk-"):
380 | with open(args.api_key, "r") as f:
381 | all_keys = f.readlines()
382 | all_keys = [line.strip('\n') for line in all_keys]
383 | assert len(all_keys) == args.num_process, (len(all_keys), args.num_process)
384 |
385 | with open(args.input_path, "rb") as f:
386 | all_data = json.load(f)
387 | print("Totally %d test examples." % len(all_data))
388 |
389 | # Test data that has not yet been predicted
390 | # if os.path.exists(args.output_path):
391 | # with open(args.output_path, "r") as f:
392 | # all_lines = f.readlines()
393 | # all_lines = [json.loads(line.strip("\n")) for line in all_lines]
394 | # already_id = [line['id'] for line in all_lines]
395 | # all_data = [data for data in all_data if data['id'] not in already_id]
396 | # print("There are %d test examples need to be processed." % len(all_data))
397 |
398 | if args.num_process == 1:
399 | main(args, all_data, idx=-1, api_key=args.api_key)
400 | else:
401 | num_each_split = int(len(all_data) / args.num_process)
402 | p = mp.Pool(args.num_process)
403 | for idx in range(args.num_process):
404 | start = idx * num_each_split
405 | if idx == args.num_process - 1:
406 | end = max((idx + 1) * num_each_split, len(all_data))
407 | else:
408 | end = (idx + 1) * num_each_split
409 | split_data = all_data[start:end]
410 | p.apply_async(main, args=(args, split_data, idx, all_keys[idx]))
411 | p.close()
412 | p.join()
413 | print("All of the child processes over!")
414 |
--------------------------------------------------------------------------------
/structgpt_for_text_to_sql.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import random
6 | from collections import defaultdict
7 | import pandas as pd
8 | import multiprocessing as mp
9 | from tqdm import tqdm
10 | import time
11 |
12 | import openai
13 |
14 | random.seed(42)
15 |
16 |
17 | class ChatGPT:
18 | def __init__(self, args, prompt_path, prompt_name, max_tokens):
19 | self.args = args
20 | self.history_messages = []
21 | self.history_contents = []
22 | self.max_tokens = max_tokens
23 | self.prompt = self.load_prompt_template(prompt_path, prompt_name)
24 | self.idx_mapping = {"0": "first", "1": "second", "2": "third", "3": "fourth", "4": "fifth", "5": "sixth",
25 | "6": "seventh",
26 | "7": "eighth", "8": "ninth", "9": "tenth"}
27 |
28 | def get_response_v1(self, input_text, turn_type, max_output_len):
29 | message = self.create_message_v1(input_text, turn_type)
30 | self.history_messages.append(message)
31 | self.history_contents.append(message['content'])
32 | message = self.query_API_to_get_message(self.history_messages, max_output_len)
33 | self.history_messages.append(message)
34 | self.history_contents.append(message['content'])
35 | response = self.parse_result(message)
36 |
37 | return response
38 |
39 | def create_message_v1(self, input_text, turn_type):
40 | if turn_type == "select_tab":
41 | template = self.prompt['free_generate']
42 | question, ser_table_name = input_text
43 | input_text = template.format(question=question, table=ser_table_name)
44 | elif turn_type == "reorg_sel_tab":
45 | template = self.prompt['table_column_select_reorganize']
46 | input_text = template
47 | elif turn_type == "ask_final_answers":
48 | question, ser_table_name, ser_fks = input_text
49 | if len(ser_fks) > 1:
50 | template = self.prompt['ask_final_answers']['has_fk']
51 | input_text = template.format(question=question, table=ser_table_name, fk=ser_fks)
52 | else:
53 | template = self.prompt['ask_final_answers']['no_fk']
54 | input_text = template.format(question=question, table=ser_table_name)
55 | else:
56 | raise NotImplementedError
57 | message = {'role': 'user', 'content': input_text}
58 | return message
59 |
60 | def query_API_to_get_message(self, messages, max_output_len):
61 | while True:
62 | try:
63 | res = openai.ChatCompletion.create(
64 | model="gpt-3.5-turbo",
65 | messages=messages,
66 | temperature=0,
67 | max_tokens=max_output_len,
68 | top_p=1,
69 | frequency_penalty=0,
70 | presence_penalty=0,
71 | )
72 | return res['choices'][0]['message']
73 | except openai.error.RateLimitError as e:
74 | err_mes = str(e)
75 | if "You exceeded your current quota" in err_mes:
76 | print("You exceeded your current quota: %s" % openai.api_key)
77 | print('openai.error.RateLimitError\nRetrying...')
78 | time.sleep(30)
79 | except openai.error.ServiceUnavailableError:
80 | print('openai.error.ServiceUnavailableError\nRetrying...')
81 | time.sleep(20)
82 | except openai.error.Timeout:
83 | print('openai.error.Timeout\nRetrying...')
84 | time.sleep(20)
85 | except openai.error.APIError:
86 | print('openai.error.APIError\nRetrying...')
87 | time.sleep(20)
88 | except openai.error.APIConnectionError:
89 | print('openai.error.APIConnectionError\nRetrying...')
90 | time.sleep(20)
91 | except openai.error.InvalidRequestError as e:
92 | logging.exception(e)
93 | exit(0)
94 |
95 |
96 | def parse_result(self, result):
97 | content = result['content'].strip()
98 |
99 | return content
100 |
101 | def reset_history(self):
102 | self.history_messages = []
103 | self.history_contents = []
104 |
105 | def reset_history_messages(self):
106 | self.history_messages = []
107 |
108 | def reseta_history_contents(self):
109 | self.history_contents = []
110 |
111 | def load_prompt_template(self, prompt_path, prompt_name):
112 | if prompt_path.endswith(".json"):
113 | with open(prompt_path, "rb") as f:
114 | prompt = json.load(f)
115 | return prompt[prompt_name]
116 |
117 |
118 | class Retriever:
119 | def __init__(self, args):
120 | self.args = args
121 | self.prompt = self.load_prompt_template(args.prompt_path, args.prompt_name)
122 | self.creatiing_schema(args.schema_path)
123 |
124 | def filter_table_with_col_name(self, table, selected_relations_list, selected_relations_str):
125 | new_table = dict()
126 | header = table['header']
127 | rows = table['rows']
128 | reserved_col_idx = [idx for idx, h in enumerate(header) if h in selected_relations_list]
129 | new_header = [header[idx] for idx in reserved_col_idx]
130 | new_rows = [[row[idx] for idx in reserved_col_idx] for row in rows]
131 | new_table["header"] = new_header
132 | new_table["rows"] = new_rows
133 | return new_table
134 |
135 | def load_db(self, db_path):
136 | with open(db_path, "r") as f:
137 | db = json.load(f)
138 | return db
139 |
140 | def serialize_table_and_column(self, db_name):
141 | df = self.spider_schema[self.spider_schema['Database name'] == db_name]
142 | df = df.groupby('Table Name')
143 | table2columns = {}
144 | tables_name = []
145 | for name, group in df:
146 | columns = []
147 | for index, row in group.iterrows():
148 | columns.append(row["Field Name"])
149 | table2columns[name] = columns
150 | if name not in tables_name:
151 | tables_name.append(name)
152 | ser_heas = []
153 | prompt_tmp = "# {tn}({cols});"
154 | for name, cols in table2columns.items():
155 | cols = [col for col in cols]
156 | cols = ",".join(cols)
157 | hea = prompt_tmp.format(tn=name, cols=cols)
158 | ser_heas.append(hea)
159 | ser_hea = "\n".join(ser_heas)
160 | return ser_hea, table2columns
161 |
162 | def serialize_tab_and_col_of_demons(self, tables):
163 | prompt_tmp = " # {tn}({cols});"
164 | ser_heas = []
165 | for name, cols in tables.items():
166 | cols = [col for col in cols]
167 | cols = ",".join(cols)
168 | hea = prompt_tmp.format(tn=name, cols=cols)
169 | ser_heas.append(hea)
170 | ser_heas = "\n".join(ser_heas)
171 | return ser_heas
172 |
173 | def serialize_selected_table_name(self, sel_tab_cols):
174 | ser_heas = []
175 | prompt_tmp = "# {tn}({cols});"
176 | for name, cols in sel_tab_cols.items():
177 | cols = [col for col in cols]
178 | cols = ",".join(cols)
179 | hea = prompt_tmp.format(tn=name, cols=cols)
180 | ser_heas.append(hea)
181 |
182 | ser_hea = "\n".join(ser_heas)
183 | return ser_hea
184 |
185 | def creatiing_schema(self, schema_path):
186 | schema_df = pd.read_json(schema_path)
187 | schema_df = schema_df.drop(['column_names', 'table_names'], axis=1)
188 | schema = []
189 | f_keys = []
190 | p_keys = []
191 | for index, row in schema_df.iterrows():
192 | tables = row['table_names_original']
193 | col_names = row['column_names_original']
194 | col_types = row['column_types']
195 | foreign_keys = row['foreign_keys']
196 | primary_keys = row['primary_keys']
197 | for col, col_type in zip(col_names, col_types):
198 | index, col_name = col
199 | if index == -1:
200 | for table in tables:
201 | schema.append([row['db_id'], table, '*', 'text'])
202 | else:
203 | schema.append([row['db_id'], tables[index], col_name, col_type])
204 | for primary_key in primary_keys:
205 | index, column = col_names[primary_key]
206 | p_keys.append([row['db_id'], tables[index], column])
207 | for foreign_key in foreign_keys:
208 | first, second = foreign_key
209 | first_index, first_column = col_names[first]
210 | second_index, second_column = col_names[second]
211 | f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
212 | self.spider_schema = pd.DataFrame(schema, columns=['Database name', 'Table Name', 'Field Name', 'Type'])
213 | self.spider_primary = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
214 | self.spider_foreign = pd.DataFrame(f_keys,
215 | columns=['Database name', 'First Table Name', 'Second Table Name',
216 | 'First Table Foreign Key',
217 | 'Second Table Foreign Key'])
218 |
219 | def find_primary_keys(self, db_name, table_name):
220 | df = self.spider_primary[self.spider_primary['Database name'] == db_name]
221 | for index, row in df.iterrows():
222 | if row['Table Name'] == table_name:
223 | return row['Primary Key']
224 |
225 | def find_foreign_keys(self, db_name):
226 | df = self.spider_foreign[self.spider_foreign['Database name'] == db_name]
227 | output = []
228 | for index, row in df.iterrows():
229 | first_tab_fk = (row['First Table Name'], row['First Table Foreign Key'])
230 | second_tab_fk = (row['Second Table Name'], row['Second Table Foreign Key'])
231 | output.append((first_tab_fk, second_tab_fk))
232 | return output
233 |
234 | def serialize_fk_with_sel_table(self, db_id, table_cols):
235 | prompt_tmp = "{ftk_0}.{ftk_1}={stk_0}.{stk_1}"
236 | fk = self.find_foreign_keys(db_id)
237 | ser_fks = []
238 | for (ftk, stk) in fk:
239 | if ftk[0] in table_cols and stk[0] in table_cols:
240 | ser_fk = prompt_tmp.format(ftk_0=ftk[0], ftk_1=ftk[1], stk_0=stk[0], stk_1=stk[1])
241 | ser_fks.append(ser_fk)
242 | ser_fks = ",".join(ser_fks)
243 | ser_fks = ser_fks + ";"
244 | return ser_fks
245 |
246 | def serialize_demonstrations(self, demonstrations, example_ids):
247 | prompt = self.prompt["demonstration"]
248 | prompt_fk = "{ftk_0}.{ftk_1}={stk_0}.{stk_1}"
249 | all_sel_demons = []
250 | for name, ids in example_ids.items():
251 | demons = demonstrations[name]
252 | sel_demons = [demons[i] for i in ids]
253 | all_sel_demons.extend(sel_demons)
254 | ser_demons = []
255 | for d in all_sel_demons:
256 | ser_fks = []
257 | for (ftk, stk) in d['fks']:
258 | ser_fk = prompt_fk.format(ftk_0=ftk[0], ftk_1=ftk[1], stk_0=stk[0], stk_1=stk[1])
259 | ser_fks.append(ser_fk)
260 | ser_fks = ",".join(ser_fks)
261 | ser_fks = ser_fks + ";"
262 |
263 | tables = d['tables']
264 | question = d['question']
265 | query = d['query']
266 | ser_tables = self.serialize_tab_and_col_of_demons(tables)
267 |
268 | if len(d['fks']) == 0:
269 | p = prompt['no_fk']
270 | one_demo = p.format(table=ser_tables, question=question, sql=query)
271 | else:
272 | p = prompt['has_fk']
273 | one_demo = p.format(table=ser_tables, question=question, sql=query, fk=ser_fks)
274 | ser_demons.append(one_demo)
275 | random.shuffle(ser_demons)
276 | ser_demons = "\n #\n".join(ser_demons)
277 | return ser_demons
278 |
279 | def serialize_demonstrations_with_const(self, demonstrations, example_ids, num_tables, fk_flag):
280 | num_tables = 1 if num_tables == 1 else 2
281 | prompt = self.prompt["demonstration"]
282 | prompt_fk = "{ftk_0}.{ftk_1}={stk_0}.{stk_1}"
283 | all_sel_demons = []
284 | for name, ids in example_ids.items():
285 | demons = demonstrations[name]
286 | sel_demons = [demons[i] for i in ids]
287 | all_sel_demons.extend(sel_demons)
288 | ser_demons = []
289 | for d in all_sel_demons:
290 | ser_fks = []
291 | for (ftk, stk) in d['fks']:
292 | ser_fk = prompt_fk.format(ftk_0=ftk[0], ftk_1=ftk[1], stk_0=stk[0], stk_1=stk[1])
293 | ser_fks.append(ser_fk)
294 | ser_fks = ",".join(ser_fks)
295 | ser_fks = ser_fks + ";"
296 |
297 | tables = d['tables']
298 | num_tables_demon = 1 if len(tables) == 1 else 2
299 |
300 | if num_tables_demon != num_tables:
301 | continue
302 |
303 | question = d['question']
304 | query = d['query']
305 | ser_tables = self.serialize_table_name_v6(tables)
306 |
307 | if len(d['fks']) == 0:
308 | p = prompt['no_fk']
309 | one_demo = p.format(table=ser_tables, question=question, sql=query)
310 | else:
311 | p = prompt['has_fk']
312 | one_demo = p.format(table=ser_tables, question=question, sql=query, fk=ser_fks)
313 | ser_demons.append(one_demo)
314 | if len(ser_demons) == 0:
315 | print("-------------------No demonstrations-------------------")
316 | random.shuffle(ser_demons)
317 | ser_demons = "\n\n".join(ser_demons)
318 | return ser_demons
319 |
320 | def load_prompt_template(self, prompt_path, prompt_name):
321 | if prompt_path.endswith(".json"):
322 | with open(prompt_path, "rb") as f:
323 | prompt = json.load(f)
324 | return prompt[prompt_name]
325 |
326 |
327 | class Solver:
328 | def __init__(self, args):
329 | self.args = args
330 | self.LLM = ChatGPT(args=args, prompt_path=args.prompt_path, prompt_name=args.prompt_name,
331 | max_tokens=args.max_tokens)
332 | self.SLM = Retriever(args)
333 | self.max_serialization_tokens = args.max_llm_input_tokens
334 | self.log = []
335 | self.selected_relations = []
336 |
337 | def forward_wo_icl_v1(self, question, db_id):
338 | self.LLM.reset_history()
339 | self.reset_history()
340 |
341 | iterative_step = 0
342 |
343 | ser_tab_col, table2columns = self.SLM.serialize_table_and_column(db_id)
344 | if args.debug:
345 | print("-----Step-%d: ser_table_name:\n%s" % (iterative_step, ser_tab_col))
346 |
347 | llm_select_tab = self.LLM.get_response_v1((question, ser_tab_col), "select_tab", max_output_len=600)
348 | if args.debug:
349 | print("-----Step-%d: llm_select_tab:\n%s" % (iterative_step, llm_select_tab))
350 |
351 | llm_reorg_sel_tab = self.LLM.get_response_v1("", "reorg_sel_tab", max_output_len=300)
352 | if args.debug:
353 | print("-----Step-%d: llm_reorg_sel_tab:\n%s" % (iterative_step, llm_reorg_sel_tab))
354 | self.LLM.reset_history_messages()
355 |
356 | sel_tab_cols = self.parse_sele_tab(llm_reorg_sel_tab, table2columns)
357 | if args.debug:
358 | print("-----Step-%d: sel_tab_cols:\n%s" % (iterative_step, sel_tab_cols))
359 |
360 | ser_table_name = self.SLM.serialize_selected_table_name(sel_tab_cols)
361 | ser_fk = self.SLM.serialize_fk_with_sel_table(db_id, sel_tab_cols)
362 | if args.debug:
363 | print("-----Step-%d: ser_table_name:\n%s" % (iterative_step, ser_table_name))
364 | print("-----Step-%d: ser_fk:\n%s" % (iterative_step, ser_fk))
365 |
366 | final_answers = self.LLM.get_response_v1((question, ser_table_name, ser_fk), "ask_final_answers",
367 | max_output_len=300)
368 | if args.debug:
369 | print("-----Step-%d: final_answers:\n%s" % (iterative_step, final_answers))
370 | self.LLM.reset_history_messages()
371 |
372 | return final_answers, self.LLM.history_contents
373 |
374 | def reset_history(self):
375 | self.log = []
376 | self.selected_relations = []
377 |
378 | def serialize_table(self, db_name, table_name):
379 | # get primary key
380 | pk = self.SLM.find_primary_keys(db_name, table_name)
381 | # get table content
382 | table = self.SLM.db_content[db_name][table_name]
383 | # serialize
384 | header = table['headers']
385 | rows = table['rows']
386 | lines = []
387 | for idx, row in enumerate(rows):
388 | pairs = []
389 | row_name = ""
390 | for rel, ent in zip(header, row):
391 | if rel == pk:
392 | row_name = pk + " " + str(ent) + ": "
393 | else:
394 | pair = "(" + rel + ", " + str(ent) + ")"
395 | pairs.append(pair)
396 | line = row_name + "; ".join(pairs)
397 | lines.append(line)
398 | output = "\n".join(lines)
399 | return output
400 |
401 | def serialize_table_with_constraints(self, db_name, table_name, constraints_cols):
402 | # get primary key
403 | pk = self.SLM.find_primary_keys(db_name, table_name)
404 | # get table content
405 | table = self.SLM.db_content[db_name][table_name]
406 | # serialize
407 | header = table['headers']
408 | rows = table['rows']
409 | lines = []
410 | if len(constraints_cols) == 1 and pk in constraints_cols:
411 | constraints_cols.append(random.sample(list(set(header) - set(constraints_cols)), k=1)[0])
412 | for idx, row in enumerate(rows):
413 | pairs = []
414 | row_name = ""
415 | for rel, ent in zip(header, row):
416 | if rel == pk:
417 | row_name = pk + " " + str(ent) + ": "
418 | else:
419 | if rel in constraints_cols:
420 | pair = "(" + rel + ", " + str(ent) + ")"
421 | pairs.append(pair)
422 | line = row_name + "; ".join(pairs)
423 | lines.append(line)
424 | output = "\n".join(lines)
425 | return output
426 |
427 | def parse_sele_tab(self, llm_selected_table_col, table2columns):
428 | sel_table_to_cols = defaultdict(list)
429 | try:
430 | tabs = llm_selected_table_col.strip(" \n|.;`").strip()
431 | tabs = tabs.split("|")
432 | tabs = [tab.strip(" \n|.;`") for tab in tabs]
433 | for tab in tabs:
434 | if tab in table2columns:
435 | sel_table_to_cols[tab] = table2columns[tab]
436 | except Exception as e:
437 | logging.exception(e)
438 | for tab, cols in table2columns.items():
439 | if tab in llm_selected_table_col:
440 | sel_table_to_cols[tab] = table2columns[tab]
441 | print("*****LLM selected tables output doesn't match the predefined format:\n%s" % llm_selected_table_col)
442 | return sel_table_to_cols
443 |
444 |
445 | def main(args, all_data, idx, api_key):
446 | import openai
447 | openai.api_key = api_key
448 |
449 | if idx == -1:
450 | output_path = args.output_path
451 | chat_log_path = args.chat_log_path
452 | else:
453 | idx = "0" + str(idx) if idx < 10 else str(idx) # 00 01 02 ... 29
454 | output_path = args.output_path + "_" + idx
455 | chat_log_path = args.chat_log_path + "_" + idx
456 |
457 | print("Start PID %d and save to %s" % (os.getpid(), output_path))
458 | solver = Solver(args)
459 |
460 | count = 0
461 | valid_count = 0
462 | with open(output_path, "w") as f:
463 | with open(chat_log_path, "w") as fclog:
464 | for sample in tqdm(all_data, total=len(all_data), desc="PID: %d" % os.getpid()):
465 | try:
466 | if "question" in sample:
467 | question = sample["question"]
468 | elif "SpiderSynQuestion" in sample:
469 | question = sample["SpiderSynQuestion"]
470 | else:
471 | print("Specify an error question key.")
472 | print(sample)
473 | exit(0)
474 | db_id = sample['db_id']
475 |
476 | if not args.icl:
477 | results = solver.forward_wo_icl_v1(question, db_id)
478 |
479 | if results is not None:
480 | prediction, chat_history = results
481 | valid_count += 1
482 | else:
483 | continue
484 | except Exception as e:
485 | logging.exception(e)
486 | continue
487 |
488 | if 'id' in sample:
489 | flag = sample['id']
490 | else:
491 | flag = question
492 | chat = flag + "\n" + "\n******\n".join(chat_history) + \
493 | "\nGold SQL: " + str(sample['query']) + "\n------------------------------------------\n"
494 | fclog.write(chat)
495 |
496 | count += 1
497 | if count < 3:
498 | print(prediction)
499 | print("---------------------")
500 | sample["Prediction"] = prediction
501 | f.write(json.dumps(sample) + "\n")
502 | print("---------------PID %d end with %d/%d samples--------------" % (os.getpid(), valid_count, count))
503 |
504 |
505 | def parse_args():
506 | parser = argparse.ArgumentParser()
507 | parser.add_argument('--input_path', default=None)
508 | parser.add_argument('--output_path', default=None)
509 | parser.add_argument('--chat_log_path', default=None)
510 | parser.add_argument('--schema_path', default=None)
511 | parser.add_argument('--debug', action="store_true")
512 | parser.add_argument('--prompt_path')
513 | parser.add_argument('--prompt_name', default="chat", )
514 | parser.add_argument('--icl', action="store_true", )
515 | parser.add_argument('--demonstrations', default=None)
516 | parser.add_argument('--example_ids', default=None)
517 | parser.add_argument('--overwrite', action="store_true")
518 | parser.add_argument('--num_process', default=1, type=int, help='the number of multi-process')
519 | parser.add_argument('--max_tokens', default=10, type=int, help='retrieve the topk score paths')
520 | parser.add_argument('--api_key', default="", type=str)
521 | parser.add_argument('--max_llm_input_tokens', default=3400, type=int)
522 |
523 | args = parser.parse_args()
524 |
525 | print("Start querying the LLM.")
526 | return args
527 |
528 |
529 | if __name__ == '__main__':
530 | args = parse_args()
531 | if not args.api_key.startswith("sk-"):
532 | with open(args.api_key, "r") as f:
533 | all_keys = f.readlines()
534 | all_keys = [line.strip('\n') for line in all_keys]
535 | assert len(all_keys) == args.num_process, (len(all_keys), args.num_process)
536 |
537 | if args.input_path.endswith('jsonl'):
538 | with open(args.input_path, "r") as f:
539 | all_lines = f.readlines()
540 | all_data = [json.loads(line) for line in all_lines]
541 | print("Totally %d test examples." % len(all_data))
542 | elif args.input_path.endswith('json'):
543 | with open(args.input_path, "r") as f:
544 | all_data = json.load(f)
545 | print("Totally %d test examples." % len(all_data))
546 |
547 | if args.demonstrations is not None:
548 | with open(args.demonstrations, "r") as f:
549 | demonstrations = json.load(f)
550 | if args.example_ids is not None:
551 | with open(args.example_ids, "r") as f:
552 | example_ids = json.load(f)
553 |
554 | if args.num_process == 1:
555 | main(args, all_data, idx=-1, api_key=args.api_key)
556 | else:
557 | num_each_split = int(len(all_data) / args.num_process)
558 | p = mp.Pool(args.num_process)
559 | for idx in range(args.num_process):
560 | start = idx * num_each_split
561 | if idx == args.num_process - 1:
562 | end = max((idx + 1) * num_each_split, len(all_data))
563 | else:
564 | end = (idx + 1) * num_each_split
565 | split_data = all_data[start:end]
566 | try:
567 | p.apply_async(main, args=(args, split_data, idx, all_keys[idx]))
568 | except Exception as e:
569 | logging.exception(e)
570 | p.close()
571 | p.join()
572 | print("All of the child processes over!")
573 |
--------------------------------------------------------------------------------
/structgpt_for_webqsp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import pickle
6 | import re
7 |
8 | import openai
9 | from KnowledgeBase.KG_api import KnowledgeGraph
10 | # from KnowledgeBase.sparql_executor import *
11 | import multiprocessing as mp
12 |
13 |
14 | class ChatGPT:
15 | def __init__(self, args, prompt_path, prompt_name, max_tokens):
16 | self.args = args
17 | self.history_messages = []
18 | self.history_contents = []
19 | self.max_tokens = max_tokens
20 | self.prompt = self.load_prompt_template(prompt_path, prompt_name)
21 | self.idx_mapping = {"0": "first", "1": "second", "2": "third", "3": "fourth", "4": "fifth", "5": "sixth",
22 | "6": "seventh",
23 | "7": "eighth", "8": "ninth", "9": "tenth"}
24 |
25 | def get_response(self, input_text, turn_type, tpe_name=None):
26 | if self.args.debug:
27 | message = self.create_message(input_text, turn_type, tpe_name)
28 | self.history_messages.append(message)
29 | self.history_contents.append(message['content'])
30 | print("query API to get message:\n%s" % message['content'])
31 | # message = self.query_API_to_get_message(self.history)
32 | # self.history.append(message)
33 | # response = self.parse_result(message)
34 | response = input("input the returned response:")
35 | else:
36 | message = self.create_message(input_text, turn_type, tpe_name)
37 | self.history_messages.append(message)
38 | self.history_contents.append(message['content'])
39 | message = self.query_API_to_get_message(self.history_messages)
40 | self.history_messages.append(message)
41 | self.history_contents.append(message['content'])
42 | response = self.parse_result(message, turn_type)
43 | return response
44 |
45 | def get_response_v1(self, input_text, turn_type, tpe_name=None):
46 | if self.args.debug:
47 | message = self.create_message_v1(input_text, turn_type)
48 | self.history_messages.append(message)
49 | self.history_contents.append(message['content'])
50 | print("query API to get message:\n%s" % message['content'])
51 | # message = self.query_API_to_get_message(self.history)
52 | # self.history.append(message)
53 | # response = self.parse_result(message)
54 | response = input("input the returned response:")
55 | else:
56 | message = self.create_message_v1(input_text, turn_type)
57 | self.history_messages.append(message)
58 | self.history_contents.append(message['content'])
59 | message = self.query_API_to_get_message(self.history_messages)
60 | self.history_messages.append(message)
61 | self.history_contents.append(message['content'])
62 | response = self.parse_result_v1(message, turn_type)
63 | return response
64 |
65 | def create_message(self, input_text, turn_type, tpe_name):
66 | if turn_type == "initial": # the initial query
67 | instruction = self.prompt[turn_type]['instruction']
68 | template = self.prompt[turn_type]['init_template']
69 | self.question = input_text
70 | input_text = instruction + template.format(question=input_text, tpe=tpe_name)
71 | elif turn_type == "continue_template":
72 | input_text = self.prompt[turn_type]
73 | elif turn_type == "question_template":
74 | template = self.prompt[turn_type]
75 | input_text = template.format(idx=self.idx_mapping[input_text])
76 | elif turn_type == "answer_template":
77 | template = self.prompt[turn_type]
78 | if len(input_text) > 0:
79 | input_text = template["valid"].format(facts=input_text)
80 | else:
81 | input_text = template["invalid"]
82 | elif turn_type == "final_query_template":
83 | template = self.prompt[turn_type]
84 | input_text = template.format(question=self.question)
85 | else:
86 | raise NotImplementedError
87 | message = {'role': 'user', 'content': input_text}
88 | return message
89 |
90 | def create_message_v1(self, input_text, turn_type):
91 | if turn_type == "instruction": # the initial query
92 | instruction = self.prompt['instruction']
93 | input_text = instruction
94 | elif turn_type == "init_relation_rerank":
95 | template = self.prompt['init_relation_rerank']
96 | question, tpe, can_rels = input_text
97 | input_text = template.format(question=question, tpe=tpe, relations=can_rels)
98 | elif turn_type == "ask_question":
99 | template = self.prompt['ask_question']
100 | idx, relations = input_text
101 | idx = self.idx_mapping[idx]
102 | input_text = template.format(idx=idx, relations=relations)
103 | elif turn_type == "ask_answer":
104 | facts = input_text
105 | template = self.prompt['ask_answer']
106 | input_text = template.format(facts=facts)
107 | elif turn_type == "ask_final_answer_or_next_question":
108 | question, serialized_facts = input_text
109 | template = self.prompt['ask_final_answer_or_next_question']
110 | input_text = template.format(facts=serialized_facts, question=question)
111 | elif turn_type == "condition":
112 | input_text = self.prompt['continue_template']['condition']
113 | elif turn_type == "continue":
114 | input_text = self.prompt['continue_template']['continue']
115 | elif turn_type == "stop":
116 | input_text = self.prompt['continue_template']['stop']
117 | elif turn_type == 'relation_rerank':
118 | template = self.prompt['relation_rerank']
119 | question, can_rels = input_text
120 | input_text = template.format(question=question, relations=can_rels)
121 | else:
122 | raise NotImplementedError
123 | message = {'role': 'user', 'content': input_text}
124 | return message
125 |
126 | def query_API_to_get_message(self, messages):
127 | while True:
128 | try:
129 | res = openai.ChatCompletion.create(
130 | model="gpt-3.5-turbo",
131 | messages=messages,
132 | temperature=0,
133 | max_tokens=self.max_tokens,
134 | top_p=1,
135 | frequency_penalty=0,
136 | presence_penalty=0,
137 | )
138 | return res['choices'][0]['message']
139 | except openai.error.RateLimitError:
140 | print('openai.error.RateLimitError\nRetrying...')
141 | time.sleep(30)
142 | except openai.error.ServiceUnavailableError:
143 | print('openai.error.ServiceUnavailableError\nRetrying...')
144 | time.sleep(20)
145 | except openai.error.Timeout:
146 | print('openai.error.Timeout\nRetrying...')
147 | time.sleep(20)
148 | except openai.error.APIError:
149 | print('openai.error.APIError\nRetrying...')
150 | time.sleep(20)
151 | except openai.error.APIConnectionError:
152 | print('openai.error.APIConnectionError\nRetrying...')
153 | time.sleep(20)
154 | # except openai.error.InvalidRequestError:
155 | # print('openai.error.InvalidRequestError\nRetrying...')
156 |
157 | def parse_result(self, result, turn_type):
158 | content = result['content'].strip()
159 | if turn_type in ["initial", "question_template"]:
160 | if "should be" in content:
161 | content = content.split("should be")[1].strip()
162 | if content.startswith('"') and content.endswith('"'):
163 | content = content[1:-1]
164 | else:
165 | matchObj = re.search(r'"(.*?)"', content)
166 | if matchObj is not None:
167 | content = matchObj.group()
168 | content = content[1:-1]
169 | else:
170 | content = content.strip().strip('"')
171 | print("Not exactly parse, we directly use content: %s" % content)
172 |
173 | return content
174 |
175 | def parse_result_v1(self, result, turn_type):
176 | content = result['content'].strip()
177 | if turn_type in ["ask_question", "continue"]:
178 | if "the simple question:" in content:
179 | content = content.split("the simple question:")[1].strip()
180 | if content.startswith('"') and content.endswith('"'):
181 | content = content[1:-1]
182 | else:
183 | matchObj = re.search(r'"(.*?)"', content)
184 | if matchObj is not None:
185 | content = matchObj.group()
186 | content = content[1:-1]
187 | else:
188 | content = content.strip().strip('"')
189 | print("Not exactly parse, we directly use content: %s" % content)
190 |
191 | return content
192 |
193 | def parse_result_v2(self, result, turn_type):
194 | content = result['content'].strip()
195 |
196 | return content
197 |
198 | def reset_history(self):
199 | self.history_messages = []
200 | self.history_contents = []
201 |
202 | def reset_history_messages(self):
203 | self.history_messages = []
204 |
205 | def reset_history_contents(self):
206 | self.history_contents = []
207 |
208 | def load_prompt_template(self, prompt_path, prompt_name):
209 | if prompt_path.endswith(".json"):
210 | with open(prompt_path, "rb") as f:
211 | prompt = json.load(f)
212 | return prompt[prompt_name]
213 |
214 | def get_response_v2(self, input_text, turn_type):
215 | message = self.create_message_v2(input_text, turn_type)
216 | self.history_messages.append(message)
217 | self.history_contents.append(message['content'])
218 | message = self.query_API_to_get_message(self.history_messages)
219 | self.history_messages.append(message)
220 | self.history_contents.append(message['content'])
221 | response = message['content'].strip()
222 |
223 | return response
224 |
225 | def create_message_v2(self, input_text, turn_type):
226 | if turn_type == "instruction": # the initial query
227 | instruction = self.prompt['instruction']
228 | input_text = instruction
229 | # ykm
230 | # elif turn_type == "init_relation_rerank":
231 | # template = self.prompt['init_relation_rerank']
232 | # can_rels, question, tpe, hop = input_text
233 | # if hop == 1:
234 | # hop = "first"
235 | # elif hop == 2:
236 | # hop = "second"
237 | # elif hop == 3:
238 | # hop = "third"
239 | # input_text = template.format(question=question, tpe=tpe, relations=can_rels, hop=hop)
240 | elif turn_type == "init_relation_rerank":
241 | template = self.prompt['init_relation_rerank']
242 | can_rels, question, tpe = input_text
243 | input_text = template.format(question=question, tpe=tpe, relations=can_rels)
244 | elif turn_type == "constraints_flag":
245 | template = self.prompt['constraints_flag']
246 | question, tpe, selected_relations = input_text
247 | if len(selected_relations) > 1:
248 | selected_relations = "are " + ", ".join(selected_relations)
249 | else:
250 | selected_relations = "is " + ", ".join(selected_relations)
251 | input_text = template.format(question=question, tpe=tpe, selected_relations=selected_relations)
252 | elif turn_type == "ask_final_answer_or_next_question":
253 | question, serialized_facts = input_text
254 | template = self.prompt['ask_final_answer_or_next_question']
255 | input_text = template.format(facts=serialized_facts, question=question)
256 | elif turn_type == "choose_constraints":
257 | question, relation_tails, tpe_name = input_text
258 | template = self.prompt['choose_constraints']
259 | input_text = template.format(question=question, relation_tails=relation_tails, tpe=tpe_name)
260 | elif turn_type == "final_query_template":
261 | template = self.prompt['final_query_template']
262 | input_text = template.format(question=input_text)
263 | elif turn_type == 'relation_rerank':
264 | template = self.prompt['relation_rerank']
265 | can_rels, question, tpe, selected_relations = input_text
266 | # 暂时注释掉
267 | # if len(selected_relations) > 1:
268 | # selected_relations = "are " + ", ".join(selected_relations)
269 | # else:
270 | # selected_relations = "is " + ", ".join(selected_relations)
271 | selected_relations = "".join(selected_relations)
272 | input_text = template.format(question=question, relations=can_rels, tpe=tpe,
273 | selected_relations=selected_relations)
274 | elif turn_type == 'relation_rerank_2hop':
275 | template = self.prompt['relation_rerank_2hop']
276 | can_rels, question, tpe, sub_question, selected_relations = input_text
277 | sub_question = ", ".join(sub_question)
278 | selected_relations = ", ".join(selected_relations)
279 | input_text = template.format(question=question, relations=can_rels, tpe=tpe,
280 | first_sub_question=sub_question, first_relation=selected_relations)
281 | elif turn_type == 'relation_rerank_3hop':
282 | template = self.prompt['relation_rerank_3hop']
283 | can_rels, question, tpe, sub_question, selected_relations = input_text
284 | first_sub_question = sub_question[0]
285 | second_sub_question = sub_question[1]
286 | fisrt_relation = selected_relations[0]
287 | second_relation = selected_relations[1]
288 | input_text = template.format(question=question, relations=can_rels, tpe=tpe,
289 | first_sub_question=first_sub_question, first_relation = fisrt_relation,
290 | second_sub_question=second_sub_question, second_relation=second_relation)
291 | elif turn_type == 'direct_ask_final_answer':
292 | template = self.prompt['direct_ask_final_answer']
293 | question = input_text
294 | input_text = template.format(question=question)
295 | elif turn_type == 'final_answer_organize':
296 | template = self.prompt['final_answer_organize']
297 | input_text = template
298 | else:
299 | raise NotImplementedError
300 | message = {'role': 'user', 'content': input_text}
301 | return message
302 |
303 |
304 | class Retriever:
305 | def __init__(self, args):
306 | self.args = args
307 | # self.initialize_PLM(args)
308 | self.initialize_KG(args)
309 |
310 | def get_retrieval_information(self, first_flag=False, gold_relations=None):
311 | triples_per_hop, tails = self.KG.get_facts_1hop(self.cur_ents, self.args.max_triples_per_relation,
312 | first_flag, gold_relations)
313 | self.reset_cur_ents(tails)
314 | # self.reset_last_ents(self.cur_ents)
315 | return triples_per_hop
316 |
317 | # 直接获得triples
318 | def get_retrieval_information_direct(self, response, tpe, first_flag=False, gold_relations=None):
319 | triples, tails = self.KG.get_facts_1hop_direct(response, tpe, self.cur_ents, self.tokenizer, self.retriever,
320 | self.args.topk,
321 | self.args.filter_score, self.args.max_triples_per_relation,
322 | first_flag, gold_relations)
323 | self.reset_cur_ents(tails)
324 | # self.reset_last_ents(self.cur_ents)
325 | return triples
326 |
327 | def get_retrieval_relations(self, first_flag=False):
328 | rels = self.KG.get_rels_1hop(self.cur_ents, first_flag)
329 | return rels
330 |
331 | def initialize_KG(self, args):
332 | self.KG = KnowledgeGraph(args.kg_source_path, args.ent_type_path, args.ent2id_path, args.rel2id_path)
333 |
334 | def initialize_PLM(self, args):
335 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
336 | self.tokenizer = tokenizer
337 | model = AutoModel.from_pretrained(args.model_path)
338 | # self.retriever = model.cuda("cuda:" + str(args.device))
339 | self.retriever = None
340 |
341 | def reset_cur_ents(self, entity_list):
342 | self.cur_ents = entity_list
343 | # print("Current entity num: ", len(self.cur_ents))
344 |
345 | def update_cur_ents(self, filtered_triples_per_hop):
346 | new_tails = set()
347 | for tri in filtered_triples_per_hop[1]:
348 | h, r, t = tri
349 | try:
350 | t_id = self.KG.ent2id[t]
351 | new_tails.add(t_id)
352 | except Exception as e:
353 | logging.exception(e)
354 | print("Entity string: %s not in ent2id dict" % t)
355 | continue
356 | new_tails = list(new_tails)
357 | self.reset_cur_ents(new_tails)
358 |
359 | def extract_facts(self, facts, response):
360 | response = response.lower().strip()
361 | # if response.startswith("the relevant relations:"):
362 | # response = response.replace("the relevant relations:", "")
363 | # response = response.strip()
364 | # nor_rels = response.split(",")
365 | # nor_rels = [rel.strip() for rel in nor_rels]
366 | # else:
367 | # nor_rels = response
368 |
369 | filtered_facts = []
370 | for tri in facts:
371 | h, r, t = tri
372 | if self.filter_relation(r):
373 | continue
374 | nor_r = self.normalize_relation(r)
375 | if nor_r in response:
376 | filtered_facts.append(tri)
377 | return filtered_facts
378 |
379 | def filter_relation(self, rel):
380 | # same criteria as GraftNet
381 | relation = rel
382 | if relation == "common.topic.notable_types": return False
383 | if relation == "base.kwebbase.kwtopic.has_sentences": return False
384 | domain = relation.split(".")[0]
385 | if domain == "type" or domain == "common": return True
386 | return False
387 |
388 | def should_ignore(self, rel):
389 | if self.filter_relation(rel):
390 | return True
391 | return False
392 |
393 | def normalize_relation(self, rel):
394 | # e.g.
395 | rel_surface = rel
396 | # replace '.' and '_' with ' '
397 | rel_surface = rel_surface.replace('.', ' ')
398 | # only keep the last two words
399 | rel_surface = ' '.join(rel_surface.split(' ')[-2:])
400 | rel_surface = rel_surface.replace('_', ' ')
401 | return rel_surface
402 |
403 | def get_one_hop_cand_rels(self, question):
404 | pass
405 |
406 | def get_tails_list(self, cur_ents):
407 | tails = [self.KG.id2ent[ent] for ent in cur_ents]
408 | return tails
409 |
410 |
411 | class Solver:
412 | def __init__(self, args):
413 | self.args = args
414 | self.LLM = ChatGPT(args=args, prompt_path=args.prompt_path, prompt_name=args.prompt_name,
415 | max_tokens=args.max_tokens)
416 | self.SLM = Retriever(args)
417 | self.max_serialization_tokens = args.max_llm_input_tokens
418 | self.load_ent2name(args.ent2name_path)
419 | self.log = []
420 | self.selected_relations = []
421 | # 暂时添加一个selected_sub_questions = []来存放解析的子问题
422 | self.selected_sub_questions = []
423 |
424 |
425 | def forward_v2(self, question, tpe_str, tpe_id):
426 | self.LLM.reset_history()
427 | self.SLM.reset_cur_ents([tpe_id])
428 | self.reset_history()
429 |
430 | iterative_step = 0
431 |
432 | # start_response = self.LLM.get_response_v2("", "instruction")
433 | # self.log.append(start_response)
434 |
435 | while True:
436 | # select
437 | all_rel_one_hop = self.SLM.get_retrieval_relations(first_flag=iterative_step == 0)
438 | if len(all_rel_one_hop) == 0:
439 | final_answers = self.LLM.get_response_v2(question, "final_query_template")
440 | break
441 |
442 | serialized_rels = self.extract_can_rels(all_rel_one_hop, normalize_rel=False)
443 | if args.debug:
444 | print("Step-%d: serialized_rels:%s" % (iterative_step, serialized_rels))
445 |
446 | if iterative_step == 0:
447 | llm_selected_rels = self.LLM.get_response_v2((serialized_rels, question, tpe_str),
448 | "init_relation_rerank")
449 | else:
450 | llm_selected_rels = self.LLM.get_response_v2(
451 | (serialized_rels, question, tpe_str, self.selected_relations), "relation_rerank")
452 | self.LLM.reset_history_messages()
453 | if args.debug:
454 | print("Step-%d: llm_selected_rels:%s" % (iterative_step, llm_selected_rels))
455 |
456 | selected_relations_list = self.parse_llm_selected_relations(llm_selected_rels, all_rel_one_hop)
457 | if args.debug:
458 | print("Step-%d: selected_relations_list:%s" % (iterative_step, selected_relations_list))
459 | if len(selected_relations_list) == 0:
460 | final_answers = self.LLM.get_response_v2(question, "final_query_template")
461 | break
462 |
463 | self.selected_relations.extend(selected_relations_list)
464 | if args.debug:
465 | print("Step-%d: self.selected_relations:%s" % (iterative_step, self.selected_relations))
466 |
467 | filtered_triples_per_hop = self.SLM.get_retrieval_information(first_flag=iterative_step == 0,
468 | gold_relations=selected_relations_list)
469 | cvt_triples, mid_triples, entstr_triples = self.classify_triples(filtered_triples_per_hop)
470 | if len(cvt_triples) > 0:
471 | # constraint
472 | if args.debug:
473 | print("Step-%d: Constraints" % iterative_step)
474 | constraints_candidate = self.serialize_constraints(cvt_triples)
475 | if args.debug:
476 | print("Step-%d: constraints_candidate:%s" % (iterative_step, constraints_candidate))
477 | constraint_response = self.LLM.get_response_v2((question, constraints_candidate, tpe_str), "choose_constraints")
478 | self.log.append(constraint_response)
479 | if args.debug:
480 | print("Step-%d: constraint_response:%s" % (iterative_step, constraint_response))
481 | if self.has_constraints(constraint_response):
482 | filtered_triples_per_hop = self.filter_triples(filtered_triples_per_hop, cvt_triples,
483 | constraint_response)
484 | self.SLM.update_cur_ents(filtered_triples_per_hop)
485 | if args.debug:
486 | print("Step-%d: filtered_triples_per_hop:%s" % (iterative_step, filtered_triples_per_hop))
487 | if args.debug:
488 | print("Step-%d: self.SLM.cur_ents:%s" % (iterative_step, self.SLM.cur_ents))
489 | serialized_facts = self.serialize_facts(filtered_triples_per_hop)
490 | self.log.append(serialized_facts)
491 |
492 | if args.debug:
493 | print("Step-%d: serialized_facts:%s" % (iterative_step, serialized_facts))
494 |
495 | final_ans_or_next_que = self.LLM.get_response_v2((question, serialized_facts),
496 | "ask_final_answer_or_next_question")
497 | self.log.append(final_ans_or_next_que)
498 |
499 | # 新加的
500 | final_answers = self.parse_result(final_ans_or_next_que, "final_answer")
501 | self.log.append(final_answers)
502 | break
503 | return final_answers, self.LLM.history_contents, self.log
504 |
505 | def reset_selected_list(self):
506 | self.selected_sub_questions = []
507 | self.selected_relations = []
508 |
509 | def is_end(self, response, iterative_step):
510 | if "no" in response.lower() or iterative_step > 8:
511 | return True
512 | else:
513 | return False
514 |
515 | def load_ent2name(self, ent2name_path):
516 | with open(ent2name_path, "rb") as f:
517 | self.cvt_flag_dict, self.mid_mapping_dict = pickle.load(f)
518 |
519 | def convert_hyper_facts_to_text(self, facts):
520 | subj, rels, objs = facts
521 |
522 | if self.is_cvt(subj):
523 | return None
524 | elif subj in self.mid_mapping_dict:
525 | subj_surface = self.mid_mapping_dict[subj]
526 | elif self.is_ent(subj):
527 | # print("head entity %s doesn't have name, we skip this triple." % subj)
528 | return None
529 | else:
530 | subj_surface = subj
531 |
532 | flat_facts = []
533 | for rel, obj in zip(rels, objs):
534 | if self.should_ignore(rel):
535 | continue
536 | else:
537 | nor_rel = self.normalize_relation(rel)
538 |
539 | if self.is_cvt(obj):
540 | continue
541 | elif obj in self.mid_mapping_dict:
542 | obj_surface = self.mid_mapping_dict[obj]
543 | elif self.is_ent(obj):
544 | # print("tail entity %s doesn't have name, we skip this triple." % obj)
545 | continue
546 | else:
547 | obj_surface = obj
548 |
549 | flat_facts.append((subj_surface, nor_rel, obj_surface))
550 |
551 | return flat_facts
552 |
553 | def convert_fact_to_text(self, fact, normalize_rel=False):
554 | subj, rel, obj = fact
555 |
556 | if self.should_ignore(rel):
557 | return None
558 |
559 | if rel.endswith(".from"):
560 | rel = rel.rstrip(".from")
561 | rel = rel + ".start_time"
562 | if rel.endswith(".to"):
563 | rel = rel.rstrip(".to")
564 | rel = rel + ".end_time"
565 | rel_surface = self.normalize_relation(rel) if normalize_rel else rel
566 |
567 | # subject
568 | if subj.startswith("CVT"):
569 | subj_surface = subj
570 | elif subj in self.mid_mapping_dict:
571 | subj_surface = self.mid_mapping_dict[subj]
572 | elif subj.startswith("m.") or subj.startswith('g.'):
573 | # print("head entity %s doesn't have name, we skip this triple." % subj)
574 | return None
575 | else:
576 | subj_surface = subj
577 |
578 | # object
579 | if obj.startswith("CVT"):
580 | obj_surface = obj
581 | elif obj in self.mid_mapping_dict:
582 | obj_surface = self.mid_mapping_dict[obj]
583 | elif obj.startswith("m.") or obj.startswith('g.'):
584 | # print("tail entity %s doesn't have name, we skip this triple." % obj)
585 | return None
586 | else:
587 | obj_surface = obj
588 |
589 | return (subj_surface, rel_surface, obj_surface)
590 |
591 | def extract_can_rels(self, all_rel_one_hop, normalize_rel=True):
592 | rel_prompt = '"{relation}"'
593 | nor_rels_set = []
594 | for rel in all_rel_one_hop:
595 | if self.filter_relation(rel):
596 | continue
597 | nor_r = self.normalize_relation(rel) if normalize_rel else rel
598 | if nor_r not in nor_rels_set:
599 | nor_rels_set.append(rel_prompt.format(relation=nor_r))
600 | rel_candidate = ", ".join(nor_rels_set)
601 | return rel_candidate
602 |
603 | def serialize_rels(self, rels, normalize_rel=True):
604 | nor_rels_set = []
605 | for rel in rels:
606 | if self.filter_relation(rel):
607 | continue
608 | nor_r = self.normalize_relation(rel) if normalize_rel else rel
609 | if nor_r not in nor_rels_set:
610 | nor_rels_set.append(nor_r)
611 | # rel_candidate = ", ".join(nor_rels_set)
612 | rel_candidate = ";\n ".join(nor_rels_set)
613 | return rel_candidate
614 |
615 | # 直接拼接
616 | def serialize_facts_direct(self, facts):
617 | # 拼接triples
618 | facts_str_for_one_tail_ent = ["(" + ", ".join(fact) + ")" for fact in facts]
619 |
620 | serialized_facts = ""
621 | for fact in facts_str_for_one_tail_ent:
622 | serialized_facts_tmp = serialized_facts + fact + "; "
623 | serialized_facts = serialized_facts_tmp
624 | return serialized_facts
625 |
626 | def serialize_facts(self, facts_per_hop):
627 | h_r_t = defaultdict(lambda: defaultdict(set))
628 | visited_flag = {}
629 | name2cvt_tmp = {}
630 | cvt_count = 0
631 | all_facts = []
632 | for hop, facts in facts_per_hop.items():
633 | if len(facts) > 0:
634 | for fact in facts:
635 | h, r, t = fact
636 | if self.is_cvt(h):
637 | if h not in name2cvt_tmp:
638 | cvt = "CVT_" + str(cvt_count)
639 | cvt_count += 1
640 | name2cvt_tmp[h] = cvt
641 | h = name2cvt_tmp[h]
642 | if self.is_cvt(t):
643 | if t not in name2cvt_tmp:
644 | cvt = "CVT_" + str(cvt_count)
645 | cvt_count += 1
646 | name2cvt_tmp[t] = cvt
647 | t = name2cvt_tmp[t]
648 | fact = (h, r, t)
649 | all_facts.append(fact)
650 | visited_flag[fact] = False
651 | h_r_t[h][r].add(t)
652 |
653 | if len(all_facts) > 0:
654 | all_facts_str = []
655 | for tri in all_facts:
656 | facts_str_for_one_tail_ent = []
657 | if not visited_flag[tri]:
658 | h, r, t = tri
659 | if t.startswith("CVT") and len(h_r_t[t]) == 0:
660 | continue
661 |
662 | if h.startswith("CVT"):
663 | # print("Qid:[%s] has single cvt head entities." % qid)
664 | # logger.info(triples_per_hop)
665 | continue
666 | elif t.startswith("CVT"):
667 | st = self.convert_fact_to_text(tri, normalize_rel=False)
668 | facts_str_for_one_tail_ent.append(st)
669 | one_hop_triples = h_r_t[t]
670 | if len(one_hop_triples) > 0:
671 | for key_r, value_ts in one_hop_triples.items():
672 | for t_ in value_ts:
673 | visit_tri = (t, key_r, t_)
674 | if not visited_flag[visit_tri]:
675 | visited_flag[visit_tri] = True
676 | st = self.convert_fact_to_text(visit_tri, normalize_rel=False)
677 | if st is not None:
678 | assert len(st) == 3
679 | facts_str_for_one_tail_ent.append(st)
680 | # h_new = t
681 | # r_new = []
682 | # t_new = []
683 | # for key_r, value_ts in one_hop_triples.items():
684 | # for t_ in value_ts:
685 | # visit_tri = (t, key_r, t_)
686 | # if not visited_flag[visit_tri]:
687 | # r_new.append(key_r)
688 | # t_new.append(t_)
689 | # visited_flag[visit_tri] = True
690 | # tri_new = (t, r_new, t_new)
691 | # if len(r_new) == len(t_new) > 0:
692 | # str_tri_list = self.convert_hyper_facts_to_text(tri_new)
693 | # if str_tri_list is not None:
694 | # for st in str_tri_list:
695 | # assert len(st) == 3
696 | # if st not in facts_str:
697 | # facts_str.append(st)
698 | else:
699 | st = self.convert_fact_to_text(tri, normalize_rel=False)
700 | if st is not None:
701 | assert len(st) == 3
702 | if st not in facts_str_for_one_tail_ent:
703 | facts_str_for_one_tail_ent.append(st)
704 | facts_str_for_one_tail_ent = ["(" + ", ".join(fact) + ")" for fact in facts_str_for_one_tail_ent]
705 | facts_str = ", ".join(facts_str_for_one_tail_ent)
706 | all_facts_str.append(facts_str)
707 |
708 | # facts_str = ["(" + ", ".join(fact) + ")" for fact in facts_str]
709 | serialized_facts = ""
710 | for fact in all_facts_str:
711 | serialized_facts_tmp = serialized_facts + fact + "; "
712 | if len(serialized_facts_tmp.split()) > self.max_serialization_tokens:
713 | break
714 | else:
715 | serialized_facts = serialized_facts_tmp
716 | serialized_facts = serialized_facts.strip("; ")
717 | else:
718 | serialized_facts = ""
719 | return serialized_facts
720 |
721 | def serialize_facts_v1(self, facts):
722 | if len(facts) > 0:
723 | h_r_t = defaultdict(lambda: defaultdict(set))
724 | visited_flag = {}
725 | for fact in facts:
726 | h, r, t = fact
727 | visited_flag[tuple(fact)] = False
728 | h_r_t[h][r].add(t)
729 | facts_str = []
730 | for tri in facts:
731 | if not visited_flag[tuple(tri)]:
732 | h, r, t = tri
733 | if self.is_cvt(t) and len(h_r_t[t]) == 0:
734 | continue
735 | if self.is_cvt(h):
736 | # print("Qid:[%s] has single cvt head entities." % qid)
737 | # logger.info(triples_per_hop)
738 | continue
739 | elif self.is_cvt(t):
740 | one_hop_triples = h_r_t[t]
741 | if len(one_hop_triples) > 0:
742 | h_new = t
743 | r_new = []
744 | t_new = []
745 | for key_r, value_ts in one_hop_triples.items():
746 | for t_ in value_ts:
747 | visit_tri = (t, key_r, t_)
748 | if not visited_flag[visit_tri]:
749 | r_new.append(key_r)
750 | t_new.append(t_)
751 | visited_flag[visit_tri] = True
752 | tri_new = (h, r_new, t_new)
753 | if len(r_new) == len(t_new) > 0:
754 | str_tri_list = self.convert_hyper_facts_to_text(tri_new)
755 | if str_tri_list is not None:
756 | for st in str_tri_list:
757 | assert len(st) == 3
758 | if st not in facts_str:
759 | facts_str.append(st)
760 | else:
761 | st = self.convert_fact_to_text(tri)
762 | if st is not None:
763 | assert len(st) == 3
764 | if st not in facts_str:
765 | facts_str.append(st)
766 | facts_str = ["(" + ", ".join(fact) + ")" for fact in facts_str]
767 | serialized_facts = ""
768 | for fact in facts_str:
769 | serialized_facts_tmp = serialized_facts + fact + "; "
770 | if len(serialized_facts_tmp.split()) > self.max_serialization_tokens:
771 | break
772 | else:
773 | serialized_facts = serialized_facts_tmp
774 | # serialized_facts = "; ".join(facts_str)
775 | serialized_facts = serialized_facts.strip("; ")
776 | else:
777 | serialized_facts = ""
778 | return serialized_facts
779 |
780 | def is_cvt(self, entity):
781 | if self.cvt_flag_dict[entity]:
782 | return True
783 | else:
784 | return False
785 |
786 | def is_ent(self, ent_str):
787 | if type(ent_str) is not bool and (ent_str.startswith("m.") or ent_str.startswith("g.")):
788 | return True
789 | else:
790 | return False
791 |
792 | def filter_relation(self, rel):
793 | # same criteria as GraftNet
794 | relation = rel
795 | if relation == "common.topic.notable_types": return False
796 | if relation == "base.kwebbase.kwtopic.has_sentences": return False
797 | domain = relation.split(".")[0]
798 | if domain == "type" or domain == "common": return True
799 | return False
800 |
801 | def should_ignore(self, rel):
802 | if self.filter_relation(rel):
803 | return True
804 | return False
805 |
806 | def normalize_relation(self, rel):
807 | # e.g.
808 | rel_surface = rel
809 | # replace '.' and '_' with ' '
810 | rel_surface = rel_surface.replace('.', ' ')
811 | # only keep the last two words
812 | rel_surface = ' '.join(rel_surface.split(' ')[-2:])
813 | rel_surface = rel_surface.replace('_', ' ')
814 | return rel_surface
815 |
816 | def parse_llm_selected_relations(self, llm_sel_rels_str, can_rels):
817 | # llm_sel_rels = llm_sel_rels_str.strip(" ;.|,<>`[]'")
818 | # llm_sel_rels = llm_sel_rels.split(',')
819 | # llm_sel_rels = [rel.strip(" ;.|,<>`[]'").strip(" ;.|,<>`[]'") for rel in llm_sel_rels]
820 | # llm_sel_rel_list = []
821 | # for rel in llm_sel_rels:
822 | # if rel in can_rels:
823 | # llm_sel_rel_list.append(rel)
824 | # else:
825 | # print(rel)
826 | # if len(llm_sel_rel_list) == 0:
827 | # for rel in can_rels:
828 | # if rel in llm_sel_rels_str:
829 | # llm_sel_rel_list.append(rel)
830 | # print("-----llm_ser_rels:\n%s\ndoesn't match the predefined format" % llm_sel_rels)
831 | llm_sel_rel_list = []
832 | for rel in can_rels:
833 | if rel in llm_sel_rels_str:
834 | llm_sel_rel_list.append(rel)
835 | return llm_sel_rel_list
836 |
837 | def parse_result(self, response, parse_type):
838 | response = response.lower()
839 | if parse_type == "next_question":
840 | if "the next question:" in response:
841 | next_question = response.split("the next question:")[1].strip()
842 | elif ":" in response:
843 | next_question = response.split(":")[1].strip()
844 | else:
845 | next_question = response
846 | print("Not parse the next question exactly, directly use the response: ", response)
847 | return next_question
848 | elif parse_type == "final_answer":
849 | if "the final answers:" in response:
850 | final_answer = response.split("the final answers:")[1].strip()
851 | # 暂时注释掉
852 | elif ":" in response:
853 | final_answer = response.split(":")[1].strip()
854 | # 新添加的用于解析direct query
855 | else:
856 | final_answer = response
857 | # 暂时注释掉
858 | # print("Not parse the final answer exactly, directly use the response: ", response)
859 | return final_answer
860 |
861 | def classify_triples(self, filtered_triples_per_hop):
862 | cvt_triples, mid_triples, entstr_triples = set(), set(), set()
863 | if 0 in filtered_triples_per_hop:
864 | triples_0 = filtered_triples_per_hop[0]
865 | else:
866 | triples_0 = []
867 | if 1 in filtered_triples_per_hop:
868 | triples_1 = filtered_triples_per_hop[1]
869 | else:
870 | triples_1 = []
871 |
872 | if len(triples_1) == 0:
873 | for tri in triples_0:
874 | if self.is_ent(tri[2]):
875 | mid_triples.add(tuple(tri))
876 | else:
877 | entstr_triples.add(tuple(tri))
878 | else:
879 | for tri in triples_1:
880 | cvt_triples.add(tuple(tri))
881 | return cvt_triples, mid_triples, entstr_triples
882 |
883 | def serialize_constraints(self, cvt_triples):
884 | r2t_set = defaultdict(set)
885 | for tri in cvt_triples:
886 | subj, rel, obj = tri
887 | if self.should_ignore(rel):
888 | continue
889 |
890 | if rel.endswith(".from"):
891 | rel = rel.rstrip(".from")
892 | rel = rel + ".start_time"
893 | if rel.endswith(".to"):
894 | rel = rel.rstrip(".to")
895 | rel = rel + ".end_time"
896 |
897 | rel_surface = rel
898 |
899 | # object
900 | if obj in self.mid_mapping_dict:
901 | obj_surface = self.mid_mapping_dict[obj]
902 | elif obj.startswith("m.") or obj.startswith('g.'):
903 | # print("tail entity %s doesn't have name, we skip this triple." % obj)
904 | continue
905 | else:
906 | obj_surface = obj
907 |
908 | if obj_surface == "To" or "has_no_value" in rel:
909 | continue
910 |
911 | r2t_set[rel_surface].add(obj_surface)
912 |
913 | constraints = []
914 | for r, t_set in r2t_set.items():
915 | t_set = ['"' + t + '"' for t in t_set]
916 | constraints.append('"' + r + '"' + ": [" + ", ".join(t_set) + "]")
917 | # constraints = constraints.rstrip("\n")
918 | constraints = "\n".join(constraints)
919 | return constraints
920 |
921 | def has_constraints(self, constraint_response):
922 | if "no" in constraint_response.lower():
923 | return False
924 | else:
925 | return True
926 |
927 | def filter_triples(self, filtered_triples_per_hop, cvt_triples, constraint_response):
928 | valid_cvt_nodes = set()
929 | h_r_t = defaultdict(list)
930 | for tri in cvt_triples:
931 | h, r, t = tri
932 | h_r_t[h].append((r, t))
933 | for cvt, r_ts in h_r_t.items():
934 | flag = True
935 | at_leat_one_flag = False
936 | for r_t in r_ts:
937 | rel, obj = r_t
938 |
939 | if rel.endswith(".from"):
940 | rel = rel.rstrip(".from")
941 | rel = rel + ".start_time"
942 | if rel.endswith(".to"):
943 | rel = rel.rstrip(".to")
944 | rel = rel + ".end_time"
945 | rel_surface = rel
946 |
947 | # object
948 | if obj in self.mid_mapping_dict:
949 | obj_surface = self.mid_mapping_dict[obj]
950 | elif obj.startswith("m.") or obj.startswith('g.'):
951 | # print("tail entity %s doesn't have name, we skip this triple." % obj)
952 | continue
953 | else:
954 | obj_surface = obj
955 |
956 | if rel_surface.lower() in constraint_response.lower():
957 | at_leat_one_flag = True
958 | if obj_surface.lower() not in constraint_response.lower():
959 | flag = False
960 | break
961 | if flag and at_leat_one_flag:
962 | valid_cvt_nodes.add(cvt)
963 |
964 | # 添加软约束条件,解析cvt结点的rel,若有两部分在response中则选中
965 | if len(valid_cvt_nodes) == 0:
966 | for cvt, r_ts in h_r_t.items():
967 | flag = True
968 | at_leat_one_flag = False
969 | for r_t in r_ts:
970 | rel, obj = r_t
971 |
972 | if rel.endswith(".from"):
973 | rel = rel.rstrip(".from")
974 | rel = rel + ".start_time"
975 | if rel.endswith(".to"):
976 | rel = rel.rstrip(".to")
977 | rel = rel + ".end_time"
978 | rel_surface = rel
979 |
980 | # object
981 | if obj in self.mid_mapping_dict:
982 | obj_surface = self.mid_mapping_dict[obj]
983 | elif obj.startswith("m.") or obj.startswith('g.'):
984 | # print("tail entity %s doesn't have name, we skip this triple." % obj)
985 | continue
986 | else:
987 | obj_surface = obj
988 |
989 | rel_surface_list = rel_surface.split(".")
990 | for rel in rel_surface_list:
991 | if rel.lower() in constraint_response.lower():
992 | at_leat_one_flag = True
993 | if obj_surface.lower() not in constraint_response.lower():
994 | flag = False
995 | break
996 | else:
997 | flag = True
998 | if flag and at_leat_one_flag:
999 | valid_cvt_nodes.add(cvt)
1000 | break
1001 |
1002 | new_tris_per_hop = defaultdict(set)
1003 | for hop in [0, 1]:
1004 | triples = filtered_triples_per_hop[hop]
1005 | for tri in triples:
1006 | h, r, t = tri
1007 | if hop == 0:
1008 | if t in valid_cvt_nodes:
1009 | new_tris_per_hop[hop].add(tuple(tri))
1010 | elif hop == 1:
1011 | if h in valid_cvt_nodes:
1012 | new_tris_per_hop[hop].add(tuple(tri))
1013 | return new_tris_per_hop
1014 |
1015 | def serialize_facts_one_hop(self, facts):
1016 | if len(facts) > 0:
1017 | h_r_t = defaultdict(lambda: defaultdict(set))
1018 | visited_flag = {}
1019 | for fact in facts:
1020 | h, r, t = fact
1021 | visited_flag[tuple(fact)] = False
1022 | h_r_t[h][r].add(t)
1023 | facts_str = []
1024 | for tri in facts:
1025 | if not visited_flag[tuple(tri)]:
1026 | h, r, t = tri
1027 | if self.is_cvt(t) and len(h_r_t[t]) == 0:
1028 | continue
1029 | if self.is_cvt(h):
1030 | # print("Qid:[%s] has single cvt head entities." % qid)
1031 | # logger.info(triples_per_hop)
1032 | continue
1033 | elif self.is_cvt(t):
1034 | one_hop_triples = h_r_t[t]
1035 | if len(one_hop_triples) > 0:
1036 | h_new = t
1037 | r_new = []
1038 | t_new = []
1039 | for key_r, value_ts in one_hop_triples.items():
1040 | for t_ in value_ts:
1041 | visit_tri = (t, key_r, t_)
1042 | if not visited_flag[visit_tri]:
1043 | r_new.append(key_r)
1044 | t_new.append(t_)
1045 | visited_flag[visit_tri] = True
1046 | tri_new = (h, r_new, t_new)
1047 | if len(r_new) == len(t_new) > 0:
1048 | str_tri_list = self.convert_hyper_facts_to_text(tri_new)
1049 | if str_tri_list is not None:
1050 | for st in str_tri_list:
1051 | assert len(st) == 3
1052 | if st not in facts_str:
1053 | facts_str.append(st)
1054 | else:
1055 | st = self.convert_fact_to_text(tri)
1056 | if st is not None:
1057 | assert len(st) == 3
1058 | if st not in facts_str:
1059 | facts_str.append(st)
1060 | facts_str = ["(" + ", ".join(fact) + ")" for fact in facts_str]
1061 | serialized_facts = ""
1062 | for fact in facts_str:
1063 | serialized_facts_tmp = serialized_facts + fact + "; "
1064 | if len(serialized_facts_tmp.split()) > self.max_serialization_tokens:
1065 | break
1066 | else:
1067 | serialized_facts = serialized_facts_tmp
1068 | # serialized_facts = "; ".join(facts_str)
1069 | serialized_facts = serialized_facts.strip("; ")
1070 | else:
1071 | serialized_facts = ""
1072 | return serialized_facts
1073 |
1074 | def is_end_v2(self, response, iterative_step):
1075 | if "final" in response.lower() or iterative_step > 3:
1076 | return True
1077 | else:
1078 | return False
1079 |
1080 | def reset_history(self):
1081 | self.log = []
1082 | self.selected_relations = []
1083 | self.selected_sub_questions = []
1084 |
1085 | def get_tails_list(self, cur_ents):
1086 | tails = self.SLM.get_tails_list(cur_ents)
1087 | return tails
1088 |
1089 | def main(args, all_data, idx, api_key):
1090 | import openai
1091 | openai.api_key = api_key
1092 | if idx == -1:
1093 | output_path = args.output_path
1094 | chat_log_path = args.chat_log_path
1095 | else:
1096 | idx = "0" + str(idx) if idx < 10 else str(idx) # 00 01 02 ... 29
1097 | output_path = args.output_path + "_" + idx
1098 | chat_log_path = args.chat_log_path + "_" + idx
1099 |
1100 | print("Start PID %d and save to %s" % (os.getpid(), output_path))
1101 | solver = Solver(args)
1102 |
1103 | count = 0
1104 | valid_count = 0
1105 | with open(output_path, "w") as f:
1106 | with open(chat_log_path, "w") as fclog:
1107 | for sample in tqdm(all_data, total=len(all_data)):
1108 | # if sample["ID"] not in ["test_10943"]:
1109 | # continue
1110 | try:
1111 | question = sample["Question"]
1112 | tpe_name = sample["TopicEntityName"]
1113 | tpe_id = sample['TopicEntityID']
1114 |
1115 | prediction, chat_history, record = solver.forward_v2(question, tpe_name, tpe_id)
1116 | valid_count += 1
1117 | except openai.error.InvalidRequestError as e:
1118 | print(e)
1119 | continue
1120 | except Exception as e:
1121 | logging.exception(e)
1122 | continue
1123 |
1124 | chat = sample["ID"] + "\n" + "\n******\n".join(chat_history) + "\nAnswers: " + str(
1125 | sample['Answers']) + "\n------------------------------------------\n"
1126 | fclog.write(chat)
1127 |
1128 | count += 1
1129 | if count < 5:
1130 | print(sample['Answers'])
1131 | print(prediction)
1132 | print("---------------------")
1133 | sample["Prediction"] = prediction
1134 | f.write(json.dumps(sample) + "\n")
1135 |
1136 | print("---------------PID %d end with %d/%d samples--------------" % (os.getpid(), valid_count, count))
1137 |
1138 |
1139 |
1140 | def parse_args():
1141 | parser = argparse.ArgumentParser()
1142 | parser.add_argument('--input_path', default=None)
1143 | parser.add_argument('--output_path', default=None)
1144 | parser.add_argument('--chat_log_path', default=None)
1145 | parser.add_argument('--log_path', default=None)
1146 | parser.add_argument('--model_path', default=None)
1147 | parser.add_argument('--debug', action="store_true")
1148 | parser.add_argument('--prompt_path')
1149 | parser.add_argument('--prompt_name', default="chat", )
1150 | parser.add_argument('--bagging_type', default="llm", )
1151 | parser.add_argument('--overwrite', action="store_true")
1152 | parser.add_argument('--device', default=0, help='the gpu device')
1153 | parser.add_argument('--topk', default=10, type=int, help='retrieve the topk score paths')
1154 | parser.add_argument('--max_tokens', default=10, type=int, help='retrieve the topk score paths')
1155 | parser.add_argument('--api_key', default="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", type=str)
1156 | parser.add_argument('--filter_score', default=0.0, type=float, help='the minimal cosine similarity')
1157 | parser.add_argument('--kg_source_path', default=None, help='the sparse triples file')
1158 | parser.add_argument('--ent_type_path', default=None, help='the file of entities type of sparse triples')
1159 | parser.add_argument('--ent2id_path', default=None, help='the sparse ent2id file')
1160 | parser.add_argument('--rel2id_path', default=None, help='the sparse rel2id file')
1161 | parser.add_argument('--ent2name_path', default=None, help='the sparse rel2id file')
1162 | parser.add_argument('--max_triples_per_relation', default=40, type=int)
1163 | parser.add_argument('--max_llm_input_tokens', default=3400, type=int)
1164 | parser.add_argument('--num_process', default=1, type=int, help='the number of multi-process')
1165 |
1166 |
1167 | args = parser.parse_args()
1168 |
1169 | print("Start querying the LLM.")
1170 | return args
1171 |
1172 |
1173 | if __name__ == '__main__':
1174 | args = parse_args()
1175 | if not args.api_key.startswith("sk-"):
1176 | with open(args.api_key, "r") as f:
1177 | all_keys = f.readlines()
1178 | all_keys = [line.strip('\n') for line in all_keys]
1179 | assert len(all_keys) == args.num_process, (len(all_keys), args.num_process)
1180 |
1181 | if args.input_path.endswith('jsonl'):
1182 | with open(args.input_path, "r") as f:
1183 | all_lines = f.readlines()
1184 | all_data = [json.loads(line) for line in all_lines]
1185 | print("Totally %d test examples." % len(all_data))
1186 | elif args.input_path.endswith('json'):
1187 | with open(args.input_path, "r") as f:
1188 | all_data = json.load(f)
1189 | print("Totally %d test examples." % len(all_data))
1190 |
1191 | # used for interrupted scenario
1192 | # with open(args.output_path, "r") as f:
1193 | # all_lines = f.readlines()
1194 | # all_lines = [json.loads(line.strip("\n")) for line in all_lines]
1195 | # already_id = [line['ID'] for line in all_lines]
1196 | # all_data = [data for data in all_data if data['ID'] not in already_id]
1197 | # print("There are %d test examples need to be processed." % len(all_data))
1198 |
1199 | if args.num_process == 1:
1200 | main(args, all_data, idx=-1, api_key=args.api_key)
1201 | else:
1202 | num_each_split = int(len(all_data) / args.num_process)
1203 | p = mp.Pool(args.num_process)
1204 | for idx in range(args.num_process):
1205 | start = idx * num_each_split
1206 | if idx == args.num_process - 1:
1207 | end = max((idx + 1) * num_each_split, len(all_data))
1208 | else:
1209 | end = (idx + 1) * num_each_split
1210 | split_data = all_data[start:end]
1211 | try:
1212 | p.apply_async(main, args=(args, split_data, idx, all_keys[idx]))
1213 | except Exception as e:
1214 | logging.exception(e)
1215 |
1216 | p.close()
1217 | p.join()
1218 | print("All of the child processes over!")
1219 |
--------------------------------------------------------------------------------