├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data.zip
├── operations
├── __init__.py
├── add_column.py
├── final_query.py
├── group_by.py
├── select_column.py
├── select_row.py
└── sort_by.py
├── requirements.txt
├── run_demo.ipynb
├── run_tabfact.py
├── third_party
└── select_column_row_prompts
│ ├── LICENSE
│ └── select_column_row_prompts.py
└── utils
├── chain.py
├── evaluate.py
├── helper.py
├── llm.py
└── load_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .DS_Store
3 | data/
4 | results/
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We would love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our Community Guidelines
22 |
23 | This project follows
24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code Reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)
32 | for this purpose.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chain of Table
2 |
3 | Code for paper [Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding](https://arxiv.org/abs/2401.04398)
4 |
5 | *This is not an officially supported Google product.*
6 |
7 | ## Environment
8 |
9 | ```shell
10 | conda create --name cotable python=3.10 -y
11 | conda activate cotable
12 | pip install -r requirements.txt
13 | ```
14 |
15 | ## Data
16 |
17 | ```shell
18 | unzip data.zip
19 | ```
20 |
21 | ## Command Usages
22 |
23 | ### Arguments
24 |
25 | - `--dataset_path`: path to the dataset, default: `./data/tabfact/test.jsonl`
26 | - `--raw2clean_path`: path to the preprocessed raw2clean file, default: `./data/tabfact/raw2clean.json` (cleaned by [Dater](https://arxiv.org/pdf/2301.13808.pdf))
27 | - `--model_name`: name of the OpenAI API, default: `gpt-3.5-turbo-16k-0613`
28 | - `--result_dir`: path to the result directory, default: `./results/tabfact`
29 | - `--openai_key`: key of the OpenAI API
30 | - `--first_n`: number of the first n samples to evaluate, default: `-1` means whole dataset
31 | - `--n_proc`: number of processes to use in multiprocessing, default: `1`
32 | - `--chunk_size`: chunk size used in multiprocessing, default: `1`
33 |
34 | ### Example usages
35 |
36 | 1. Run tests on the first 10 cases
37 |
38 | ```shell
39 | python run_tabfact.py \
40 | --result_dir 'results/tabfact_first10' \
41 | --first_n 10 \
42 | --n_proc 10 \
43 | --chunk_size 1 \
44 | --openai_api_key
45 | ```
46 |
47 | 2. Run the experiment on the whole dataset
48 |
49 | ```shell
50 | python run_tabfact.py \
51 | --result_dir 'results/tabfact' \
52 | --n_proc 20 \
53 | --chunk_size 10 \
54 | --openai_api_key
55 | ```
56 |
57 | ## Cite
58 |
59 | If you find this repository useful, please consider citing:
60 |
61 | ```bibtex
62 | @article{wang2024chain,
63 | title={Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding},
64 | author={Wang, Zilong and Zhang, Hao and Li, Chun-Liang and Eisenschlos, Julian Martin and Perot, Vincent and Wang, Zifeng and Miculicich, Lesly and Fujii, Yasuhisa and Shang, Jingbo and Lee, Chen-Yu and Pfister, Tomas},
65 | journal={ICLR},
66 | year={2024}
67 | }
68 | ```
69 |
70 | ## Acknowledgement
71 |
72 | We thank [Dater](https://arxiv.org/pdf/2301.13808.pdf) for providing the cleaned TabFact dataset and releasing the [code](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/dater). We include the cleaned raw2clean file in the `data.zip` and the prompts for row/column selection in the `third_party/select_column_row_prompts/select_column_row_prompts.py` under the MIT License.
73 |
--------------------------------------------------------------------------------
/data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/chain-of-table/1fd716c608f56e7d9d156ead432219c7ab9008af/data.zip
--------------------------------------------------------------------------------
/operations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .add_column import add_column_func, add_column_act
17 | from .group_by import group_column_func, group_column_act
18 | from .select_column import select_column_func, select_column_act
19 | from .select_row import select_row_func, select_row_act
20 | from .sort_by import sort_column_func, sort_column_act
21 |
22 | from .final_query import simple_query
--------------------------------------------------------------------------------
/operations/add_column.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import re
17 | import copy
18 | import numpy as np
19 | from utils.helper import table2string
20 |
21 |
22 | add_column_demo = """To tell the statement is true or false, we can first use f_add_column() to add more columns to the table.
23 |
24 | The added columns should have these data types:
25 | 1. Numerical: the numerical strings that can be used in sort, sum
26 | 2. Datetype: the strings that describe a date, such as year, month, day
27 | 3. String: other strings
28 |
29 | /*
30 | col : week | when | kickoff | opponent | results; final score | results; team record | game site | attendance
31 | row 1 : 1 | saturday, april 13 | 7:00 p.m. | at rhein fire | w 27–21 | 1–0 | rheinstadion | 32,092
32 | row 2 : 2 | saturday, april 20 | 7:00 p.m. | london monarchs | w 37–3 | 2–0 | waldstadion | 34,186
33 | row 3 : 3 | sunday, april 28 | 6:00 p.m. | at barcelona dragons | w 33–29 | 3–0 | estadi olímpic de montjuïc | 17,503
34 | */
35 | Statement: april 20 is the date of the competition with highest attendance.
36 | The existing columns are: "week", "when", "kickoff", "opponent", "results; final score", "results; team record", "game site", "attendance".
37 | Explanation: To tell this statement is true or false, we need to know the attendence number of each competition. We extract the value from column "attendance" and create a different column "attendance number" for each row. The datatype is numerical.
38 | Therefore, the answer is: f_add_column(attendance number). The value: 32092 | 34186 | 17503
39 |
40 | /*
41 | col : rank | lane | player | time
42 | row 1 : | 5 | olga tereshkova (kaz) | 51.86
43 | row 2 : | 6 | manjeet kaur (ind) | 52.17
44 | row 3 : | 3 | asami tanno (jpn) | 53.04
45 | */
46 | Statement: there are one athlete from japan.
47 | The existing columns are: rank, lane, player, time.
48 | Explanation: To tell this statement is true or false, we need to know the country of each athelte. We extract the value from column "player" and create a different column "country of athletes" for each row. The datatype is string.
49 | Therefore, the answer is: f_add_column(country of athletes). The value: kaz | ind | jpn
50 |
51 | /*
52 | col : year | competition | venue | position | notes
53 | row 1 : 1991 | european junior championships | thessaloniki, greece | 10th | 4.90 m
54 | row 2 : 1992 | world junior championships | seoul, south korea | 1st | 5.45 m
55 | row 3 : 1996 | european indoor championships | stockholm, sweden | 14th (q) | 5.45 m
56 | */
57 | Statement: laurens place 1st in 1991.
58 | The existing columns are: year, competition, venue, position, notes.
59 | Explanation: To tell this statement is true or false, we need to know the place of each competition. We extract the value from column "position" and create a different column "placing result" for each row. The datatype is numerical.
60 | Therefore, the answer is: f_add_column(placing result). The value: 10 | 1 | 14
61 |
62 | /*
63 | col : iso/iec standard | status | wg
64 | row 1 : iso/iec tr 19759 | published (2005) | 20
65 | row 2 : iso/iec 15288 | published (2008) | 7
66 | row 3 : iso/iec 12207 | published (2008) | 7
67 | */
68 | Statement: the standards published three times in 2008.
69 | The existing columns are: iso/iec standard, title, status, description, wg.
70 | Explanation: To tell this statement is true or false, we need to know the year of each standard. We extract the value from column "status" and create a different column "year of standard" for each row. The datatype is datetype.
71 | Therefore, the answer is: f_add_column(year of standard). The value: 2005 | 2008 | 2008
72 |
73 | /*
74 | col : match | date | ground | opponent | score1 | pos. | pts. | gd
75 | row 1 : 1 | 15 august | a | bayer uerdingen | 3 – 0 | 1 | 2 | 3
76 | row 2 : 2 | 22 july | h | 1. fc kaiserslautern | 1 – 0 | 1 | 4 | 4
77 | row 3 : 4 | 29 september | h | dynamo dresden | 3 – 1 | 1 | 6 | 6
78 | */
79 | Statement: they play 5 times in august.
80 | The existing columns are: match, date, ground, opponent, score1, pos., pts., gd.
81 | Explanation: To tell this statement is true or false, we need to know the month of each match. We extract the value from column "date" and create a different column "month" for each row. The datatype is datetype.
82 | Therefore, the answer is: f_add_column(month). The value: august | july | september
83 |
84 | /*
85 | table caption : 1984 u.s. open (golf)
86 | col : place | player | country | score | to par
87 | row 1 : 1 | hale irwin | united states | 68 + 68 = 136 | - 4
88 | row 2 : 2 | fuzzy zoeller | united states | 71 + 66 = 137 | -- 3
89 | row 3 : t3 | david canipe | united states | 69 + 69 = 138 | - 2
90 | */
91 | Statement: david canipe of united states has 138 score
92 | The existing columns are: place, player, country, score, to par.
93 | Explanation: To tell this statement is true or false, we need to know the score values of each player. We extract the value from column "score" and create a different column "score value" for each row. The datatype is numerical.
94 | Therefore, the answer is: f_add_column(score value). The value: 136 | 137 | 138
95 |
96 | /*
97 | col : code | county | former province | area (km2) | population; census 2009 | capital
98 | row 1 : 1 | mombasa | coast | 212.5 | 939,370 | mombasa (city)
99 | row 2 : 2 | kwale | coast | 8,270.3 | 649,931 | kwale
100 | row 3 : 3 | kilifi | coast | 12,245.9 | 1,109,735 | kilifi
101 | */
102 | Statement: kwale has a population in 2009 higher than 500,000.
103 | The existing columns are: code, county, former province, area (km2), population; census 2009, capital.
104 | Explanation: To tell this statement is true or false, we need to know the population of each county. We extract the value from column "population; census 2009" and create a different column "population" for each row. The datatype is numerical.
105 | Therefore, the answer is: f_add_column(population). The value: 939370 | 649311 | 1109735"""
106 |
107 |
108 | def add_column_build_prompt(table_text, statement, table_caption=None, num_rows=100):
109 | table_str = table2string(table_text, caption=table_caption, num_rows=num_rows)
110 | prompt = "/*\n" + table_str + "\n*/\n"
111 | prompt += "Statement: " + statement + "\n"
112 | prompt += "The existing columns are: "
113 | prompt += ", ".join(table_text[0]) + ".\n"
114 | prompt += "Explanation:"
115 | return prompt
116 |
117 |
118 | def add_column_func(
119 | sample, table_info, llm, llm_options=None, debug=False, skip_op=[], strategy="top"
120 | ):
121 | operation = {
122 | "operation_name": "add_column",
123 | "parameter_and_conf": [],
124 | }
125 | failure_sample_copy = copy.deepcopy(sample)
126 | failure_sample_copy["chain"].append(operation)
127 |
128 | # table_info = get_table_info(sample, skip_op=skip_op)
129 | table_text = table_info["table_text"]
130 |
131 | table_caption = sample["table_caption"]
132 | cleaned_statement = sample["cleaned_statement"]
133 | cleaned_statement = re.sub(r"\d+", "_", cleaned_statement)
134 |
135 | prompt = "" + add_column_demo.rstrip() + "\n\n"
136 | prompt += add_column_build_prompt(
137 | table_text, cleaned_statement, table_caption=table_caption, num_rows=3
138 | )
139 | if llm_options is None:
140 | llm_options = llm.get_model_options()
141 | llm_options["n"] = 1
142 | responses = llm.generate_plus_with_score(
143 | prompt,
144 | options=llm_options,
145 | )
146 |
147 | add_column_and_conf = {}
148 | for res, score in responses:
149 | try:
150 | f_add_func = re.findall(r"f_add_column\(.*\)", res, re.S)[0].strip()
151 | left = f_add_func.index("(") + 1
152 | right = f_add_func.index(")")
153 | add_column = f_add_func[left:right].strip()
154 | first_3_values = res.split("The value:")[-1].strip().split("|")
155 | first_3_values = [v.strip() for v in first_3_values]
156 | assert len(first_3_values) == 3
157 | except:
158 | continue
159 |
160 | add_column_key = str((add_column, first_3_values, res))
161 | if add_column_key not in add_column_and_conf:
162 | add_column_and_conf[add_column_key] = 0
163 | add_column_and_conf[add_column_key] += np.exp(score)
164 |
165 | if len(add_column_and_conf) == 0:
166 | return failure_sample_copy
167 |
168 | add_column_and_conf_list = sorted(
169 | add_column_and_conf.items(), key=lambda x: x[1], reverse=True
170 | )
171 | if strategy == "top":
172 | selected_add_column_key = add_column_and_conf_list[0][0]
173 | selected_add_column_conf = add_column_and_conf_list[0][1]
174 | else:
175 | raise NotImplementedError()
176 |
177 | add_column, first_3_values, llm_response = eval(selected_add_column_key)
178 |
179 | existing_columns = table_text[0]
180 | if add_column in existing_columns:
181 | return failure_sample_copy
182 |
183 | add_column_contents = [] + first_3_values
184 |
185 | # get following contents
186 | try:
187 | left_index = llm_response.index("We extract the value from")
188 | right_index = llm_response.index("The value:")
189 | explanaiton_beginning = llm_response[left_index:right_index] + "The value:"
190 | except:
191 | return failure_sample_copy
192 |
193 | def _sample_to_simple_prompt_header(table_text, num_rows=3):
194 | x = ""
195 | x += "/*\n"
196 | x += table2string(table_text, caption=table_caption, num_rows=num_rows) + "\n"
197 | x += "*/\n"
198 | x += "Explanation: "
199 | return x
200 |
201 | new_prompt = ""
202 | new_prompt += (
203 | _sample_to_simple_prompt_header(table_text, num_rows=3)
204 | + llm_response[left_index:]
205 | )
206 |
207 | headers = table_text[0]
208 | rows = table_text[1:]
209 | for i in range(3, len(rows)):
210 | partial_table_text = [headers] + rows[i : i + 1]
211 | cur_prompt = (
212 | new_prompt
213 | + "\n\n"
214 | + _sample_to_simple_prompt_header(partial_table_text)
215 | + explanaiton_beginning
216 | )
217 | cur_response = llm.generate(
218 | cur_prompt,
219 | options=llm.get_model_options(
220 | per_example_max_decode_steps=150, per_example_top_p=1.0
221 | ),
222 | ).strip()
223 | if debug:
224 | print(cur_prompt)
225 | print(cur_response)
226 | print("---")
227 | print()
228 |
229 | contents = cur_response
230 | if "|" in contents:
231 | contents = contents.split("|")[0].strip()
232 |
233 | add_column_contents.append(contents)
234 |
235 | if debug:
236 | print("New col contents: ", add_column_contents)
237 |
238 | add_column_info = [
239 | (str((add_column, add_column_contents)), selected_add_column_conf)
240 | ]
241 |
242 | operation = {
243 | "operation_name": "add_column",
244 | "parameter_and_conf": add_column_info,
245 | }
246 |
247 | sample_copy = copy.deepcopy(sample)
248 | sample_copy["chain"].append(operation)
249 |
250 | return sample_copy
251 |
252 |
253 | def add_column_act(table_info, operation, skip_op=[], debug=False):
254 | table_info = copy.deepcopy(table_info)
255 |
256 | failure_table_info = copy.deepcopy(table_info)
257 | failure_table_info["act_chain"].append("skip f_add_column()")
258 | if "add_column" in skip_op:
259 | return failure_table_info
260 | if len(operation["parameter_and_conf"]) == 0:
261 | return failure_table_info
262 |
263 | add_column_key, _ = operation["parameter_and_conf"][0]
264 | add_column, add_column_contents = eval(add_column_key)
265 |
266 | table_text = table_info["table_text"]
267 | headers = table_text[0]
268 | rows = table_text[1:]
269 |
270 | header2contents = {}
271 | for i, header in enumerate(headers):
272 | header2contents[header] = []
273 | for row in rows:
274 | header2contents[header].append(row[i])
275 |
276 | if add_column.startswith("number of"):
277 | # remove 'number of'
278 | if debug:
279 | print("remove number of")
280 | return failure_table_info
281 |
282 | if len(set(add_column_contents)) == 1:
283 | # all same
284 | if debug:
285 | print("all same")
286 | return failure_table_info
287 |
288 | for x in add_column_contents:
289 | if x.strip() == "":
290 | # empty cell
291 | if debug:
292 | print("empty cell")
293 | return failure_table_info
294 |
295 | if add_column in headers:
296 | # same column header
297 | if debug:
298 | print("same column header")
299 | return failure_table_info
300 |
301 | for header in header2contents:
302 | if add_column_contents == header2contents[header]:
303 | # different header, same content
304 | if debug:
305 | print("different header, same content")
306 | return failure_table_info
307 |
308 | exist_flag = False
309 |
310 | for header, contents in header2contents.items():
311 | current_column_exist_flag = True
312 |
313 | for i in range(len(contents)):
314 | if add_column_contents[i] not in contents[i]:
315 | current_column_exist_flag = False
316 | break
317 |
318 | if current_column_exist_flag:
319 | exist_flag = True
320 | break
321 | if not exist_flag:
322 | if debug:
323 | print(add_column, add_column_contents)
324 | print("not substring of a column")
325 | return failure_table_info
326 |
327 | if debug:
328 | print("default")
329 | new_headers = headers + [add_column]
330 | new_rows = []
331 | for i, row in enumerate(rows):
332 | row.append(add_column_contents[i])
333 | new_rows.append(row)
334 |
335 | new_table_text = [new_headers] + new_rows
336 | table_info["table_text"] = new_table_text
337 | table_info["act_chain"].append(f"f_add_column({add_column})")
338 | return table_info
339 |
--------------------------------------------------------------------------------
/operations/final_query.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import copy
17 | import numpy as np
18 | from utils.helper import table2string
19 |
20 |
21 | general_demo = """/*
22 | table caption : 2008 sidecarcross world championship.
23 | col : position | driver / passenger | equipment | bike no | points
24 | row 1 : 1 | daniël willemsen / reto grütter | ktm - ayr | 1 | 531
25 | row 2 : 2 | kristers sergis / kaspars stupelis | ktm - ayr | 3 | 434
26 | row 3 : 3 | jan hendrickx / tim smeuninx | zabel - vmc | 2 | 421
27 | row 4 : 4 | joris hendrickx / kaspars liepins | zabel - vmc | 8 | 394
28 | row 5 : 5 | marco happich / meinrad schelbert | zabel - mefo | 7 | 317
29 | */
30 | Statement: bike number 3 is the only one to use equipment ktm - ayr.
31 | The anwser is: NO
32 |
33 | /*
34 | table caption : 1957 vfl season.
35 | col : home team | home team score | away team | away team score | venue | crowd | date
36 | row 1 : footscray | 6.6 (42) | north melbourne | 8.13 (61) | western oval | 13325 | 10 august 1957
37 | row 2 : essendon | 10.15 (75) | south melbourne | 7.13 (55) | windy hill | 16000 | 10 august 1957
38 | row 3 : st kilda | 1.5 (11) | melbourne | 6.13 (49) | junction oval | 17100 | 10 august 1957
39 | row 4 : hawthorn | 14.19 (103) | geelong | 8.7 (55) | brunswick street oval | 12000 | 10 august 1957
40 | row 5 : fitzroy | 8.14 (62) | collingwood | 8.13 (61) | glenferrie oval | 22000 | 10 august 1957
41 | */
42 | Statement: collingwood was the away team playing at the brunswick street oval venue.
43 | The anwser is: NO
44 |
45 | /*
46 | table caption : co - operative commonwealth federation (ontario section).
47 | col : year of election | candidates elected | of seats available | of votes | % of popular vote
48 | row 1 : 1934 | 1 | 90 | na | 7.0%
49 | row 2 : 1937 | 0 | 90 | na | 5.6%
50 | row 3 : 1943 | 34 | 90 | na | 31.7%
51 | row 4 : 1945 | 8 | 90 | na | 22.4%
52 | row 5 : 1948 | 21 | 90 | na | 27.0%
53 | */
54 | Statement: the 1937 election had a % of popular vote that was 1.4% lower than that of the 1959 election.
55 | The anwser is: NO
56 |
57 | /*
58 | table caption : 2003 pga championship.
59 | col : place | player | country | score | to par
60 | row 1 : 1 | shaun micheel | united states | 69 + 68 = 137 | - 3
61 | row 2 : t2 | billy andrade | united states | 67 + 72 = 139 | - 1
62 | row 3 : t2 | mike weir | canada | 68 + 71 = 139 | - 1
63 | row 4 : 4 | rod pampling | australia | 66 + 74 = 140 | e
64 | row 5 : t5 | chad campbell | united states | 69 + 72 = 141 | + 1
65 | */
66 | Statement: phil mickelson was one of five players with + 1 to par , all of which had placed t5.
67 | The anwser is: YES"""
68 |
69 |
70 | def simple_query(sample, table_info, llm, debug=False, use_demo=False, llm_options=None):
71 | table_text = table_info["table_text"]
72 |
73 | caption = sample["table_caption"]
74 | statement = sample["statement"]
75 |
76 | prompt = ""
77 | prompt += "Here are the statement about the table and the task is to tell whether the statement is True or False.\n"
78 | prompt += "If the statement is true, answer YES, and otherwise answer NO.\n"
79 |
80 | if use_demo:
81 | prompt += "\n"
82 | prompt += general_demo + "\n\n"
83 | prompt += "Here are the statement about the table and the task is to tell whether the statement is True or False.\n"
84 | prompt += "If the statement is true, answer YES, and otherwise answer NO.\n"
85 | prompt += "\n"
86 |
87 | prompt += "/*\n"
88 | prompt += table2string(table_text, caption=caption) + "\n"
89 | prompt += "*/\n"
90 |
91 | if "group_sub_table" in table_info:
92 | group_column, group_info = table_info["group_sub_table"]
93 | prompt += "/*\n"
94 | prompt += "Group the rows according to column: {}.\n".format(group_column)
95 | group_headers = ["Group ID", group_column, "Count"]
96 | group_rows = []
97 | for i, (v, count) in enumerate(group_info):
98 | if v.strip() == "":
99 | v = "[Empty Cell]"
100 | group_rows.append([f"Group {i+1}", v, str(count)])
101 | prompt += " | ".join(group_headers) + "\n"
102 | for row in group_rows:
103 | prompt += " | ".join(row) + "\n"
104 | prompt += "*/\n"
105 |
106 | prompt += "Statement: " + statement + "\n"
107 |
108 | prompt += "The answer is:"
109 | responses = llm.generate_plus_with_score(prompt, options=llm_options)
110 | responses = [(res.strip(), np.exp(score)) for res, score in responses]
111 |
112 | if debug:
113 | print(prompt)
114 | print(responses)
115 |
116 | operation = {
117 | "operation_name": "simple_query",
118 | "parameter_and_conf": responses,
119 | }
120 | sample_copy = copy.deepcopy(sample)
121 | sample_copy["chain"].append(operation)
122 |
123 | return sample_copy
124 |
125 |
--------------------------------------------------------------------------------
/operations/group_by.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import re
17 | import numpy as np
18 | import copy
19 | from utils.helper import table2string
20 |
21 |
22 | group_column_demo = """To tell the statement is true or false, we can first use f_group() to group the values in a column.
23 |
24 | /*
25 | col : rank | lane | athlete | time | country
26 | row 1 : 1 | 6 | manjeet kaur (ind) | 52.17 | ind
27 | row 2 : 2 | 5 | olga tereshkova (kaz) | 51.86 | kaz
28 | row 3 : 3 | 4 | pinki pramanik (ind) | 53.06 | ind
29 | row 4 : 4 | 1 | tang xiaoyin (chn) | 53.66 | chn
30 | row 5 : 5 | 8 | marina maslyonko (kaz) | 53.99 | kaz
31 | */
32 | Statement: there are one athlete from japan.
33 | The existing columns are: rank, lane, athlete, time, country.
34 | Explanation: the statement says the number of athletes from japan is one. Each row is about an athlete. We can group column "country" to group the athletes from the same country.
35 | Therefore, the answer is: f_group(country).
36 |
37 | /*
38 | col : district | name | party | residence | first served
39 | row 1 : district 1 | nelson albano | dem | vineland | 2006
40 | row 2 : district 1 | robert andrzejczak | dem | middle twp. | 2013†
41 | row 3 : district 2 | john f. amodeo | rep | margate | 2008
42 | row 4 : district 2 | chris a. brown | rep | ventnor | 2012
43 | row 5 : district 3 | john j. burzichelli | dem | paulsboro | 2002
44 | */
45 | Statement: the number of districts that are democratic is 5.
46 | The existing columns are: district, name, party, residence, first served.
47 | Explanation: the statement says the number of districts that are democratic is 5. Each row is about a district. We can group the column "party" to group the districts from the same party.
48 | Therefore, the answer is: f_group(party)."""
49 |
50 |
51 | def group_column_build_prompt(table_text, statement, table_caption=None, num_rows=100):
52 | table_str = table2string(
53 | table_text, caption=table_caption, num_rows=num_rows
54 | ).strip()
55 | prompt = "/*\n" + table_str + "\n*/\n"
56 | prompt += "Statement: " + statement + "\n"
57 | prompt += "The existing columns are: "
58 | prompt += ", ".join(table_text[0]) + ".\n"
59 | prompt += "Explanation:"
60 | return prompt
61 |
62 |
63 | def group_column_func(
64 | sample, table_info, llm, llm_options=None, debug=False, skip_op=[]
65 | ):
66 | table_text = table_info["table_text"]
67 |
68 | table_caption = sample["table_caption"]
69 | statement = sample["statement"]
70 | prompt = "" + group_column_demo.rstrip() + "\n\n"
71 | prompt += group_column_build_prompt(
72 | table_text, statement, table_caption=table_caption, num_rows=5
73 | )
74 | responses = llm.generate_plus_with_score(
75 | prompt,
76 | options=llm_options,
77 | )
78 |
79 | if debug:
80 | print(prompt)
81 | print(responses)
82 |
83 | group_param_and_conf = {}
84 | group_column_and_conf = {}
85 |
86 | headers = table_text[0]
87 | rows = table_text[1:]
88 | for res, score in responses:
89 | re_result = re.findall(r"f_group\(([^\)]*)\)", res, re.S)
90 |
91 | if debug:
92 | print("Re result: ", re_result)
93 |
94 | try:
95 | group_column = re_result[0].strip()
96 | assert group_column in headers
97 | except:
98 | continue
99 |
100 | if group_column not in group_column_and_conf:
101 | group_column_and_conf[group_column] = 0
102 | group_column_and_conf[group_column] += np.exp(score)
103 |
104 | for group_column, conf in group_column_and_conf.items():
105 | group_column_contents = []
106 | index = headers.index(group_column)
107 | for row in rows:
108 | group_column_contents.append(row[index])
109 |
110 | def check_if_group(vs):
111 | vs_without_empty = [v for v in vs if v.strip()]
112 | return len(set(vs_without_empty)) / len(vs_without_empty) <= 0.8
113 |
114 | if not check_if_group(group_column_contents):
115 | continue
116 |
117 | vs_to_group = []
118 | for i in range(len(group_column_contents)):
119 | vs_to_group.append((group_column_contents[i], i))
120 |
121 | group_info = []
122 | for v in sorted(set(group_column_contents)):
123 | group_info.append((v, group_column_contents.count(v)))
124 | group_info = sorted(group_info, key=lambda x: x[1], reverse=True)
125 |
126 | group_key = str((group_column, group_info))
127 | group_param_and_conf[group_key] = conf
128 |
129 | group_param_and_conf_list = sorted(
130 | group_param_and_conf.items(), key=lambda x: x[1], reverse=True
131 | )
132 |
133 | operation = {
134 | "operation_name": "group_column",
135 | "parameter_and_conf": group_param_and_conf_list,
136 | }
137 |
138 | sample_copy = copy.deepcopy(sample)
139 | sample_copy["chain"].append(operation)
140 |
141 | return sample_copy
142 |
143 |
144 | def group_column_act(table_info, operation, strategy="top", skip_op=[]):
145 | table_info = copy.deepcopy(table_info)
146 |
147 | failure_table_info = copy.deepcopy(table_info)
148 | failure_table_info["act_chain"].append("skip f_group_column()")
149 |
150 | if "group_column" in skip_op:
151 | return failure_table_info
152 | if len(operation["parameter_and_conf"]) == 0:
153 | return failure_table_info
154 | if strategy == "top":
155 | group_column, group_info = eval(operation["parameter_and_conf"][0][0])
156 | else:
157 | raise NotImplementedError()
158 |
159 | table_info["group_sub_table"] = (group_column, group_info)
160 | table_info["act_chain"].append(f"f_group_column({group_column})")
161 |
162 | return table_info
163 |
--------------------------------------------------------------------------------
/operations/select_column.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import json
17 | import copy
18 | import re
19 | import numpy as np
20 | from utils.helper import table2df, NoIndent, MyEncoder
21 |
22 | from third_party.select_column_row_prompts.select_column_row_prompts import select_column_demo
23 |
24 |
25 | def twoD_list_transpose(arr, keep_num_rows=3):
26 | arr = arr[: keep_num_rows + 1] if keep_num_rows + 1 <= len(arr) else arr
27 | return [[arr[i][j] for i in range(len(arr))] for j in range(len(arr[0]))]
28 |
29 |
30 | def select_column_build_prompt(table_text, statement, table_caption=None, num_rows=100):
31 | df = table2df(table_text, num_rows=num_rows)
32 | tmp = df.values.tolist()
33 | list_table = [list(df.columns)] + tmp
34 | list_table = twoD_list_transpose(list_table, len(list_table))
35 | if table_caption is not None:
36 | dic = {
37 | "table_caption": table_caption,
38 | "columns": NoIndent(list(df.columns)),
39 | "table_column_priority": [NoIndent(i) for i in list_table],
40 | }
41 | else:
42 | dic = {
43 | "columns": NoIndent(list(df.columns)),
44 | "table_column_priority": [NoIndent(i) for i in list_table],
45 | }
46 | linear_dic = json.dumps(
47 | dic, cls=MyEncoder, ensure_ascii=False, sort_keys=False, indent=2
48 | )
49 | prompt = "/*\ntable = " + linear_dic + "\n*/\n"
50 | prompt += "statement : " + statement + ".\n"
51 | prompt += "similar words link to columns :\n"
52 | return prompt
53 |
54 |
55 | def select_column_func(sample, table_info, llm, llm_options, debug=False, num_rows=100):
56 | # table_info = get_table_info(sample)
57 | table_text = table_info["table_text"]
58 |
59 | table_caption = sample["table_caption"]
60 | statement = sample["statement"]
61 |
62 | prompt = "" + select_column_demo.rstrip() + "\n\n"
63 | prompt += select_column_build_prompt(
64 | table_text, statement, table_caption, num_rows=num_rows
65 | )
66 |
67 | responses = llm.generate_plus_with_score(prompt, options=llm_options)
68 |
69 | if debug:
70 | print(prompt)
71 | print(responses)
72 |
73 | pattern_col = r"f_col\(\[(.*?)\]\)"
74 |
75 | pred_conf_dict = {}
76 | for res, score in responses:
77 | try:
78 | pred = re.findall(pattern_col, res, re.S)[0].strip()
79 | except Exception:
80 | continue
81 | pred = pred.split(", ")
82 | pred = [i.strip() for i in pred]
83 | pred = sorted(pred)
84 | pred = str(pred)
85 | if pred not in pred_conf_dict:
86 | pred_conf_dict[pred] = 0
87 | pred_conf_dict[pred] += np.exp(score)
88 |
89 | select_col_rank = sorted(pred_conf_dict.items(), key=lambda x: x[1], reverse=True)
90 |
91 | operation = {
92 | "operation_name": "select_column",
93 | "parameter_and_conf": select_col_rank,
94 | }
95 |
96 | sample_copy = copy.deepcopy(sample)
97 | sample_copy["chain"].append(operation)
98 |
99 | return sample_copy
100 |
101 |
102 | def select_column_act(table_info, operation, union_num=2, skip_op=[]):
103 | table_info = copy.deepcopy(table_info)
104 |
105 | failure_table_info = copy.deepcopy(table_info)
106 | failure_table_info["act_chain"].append("skip f_select_column()")
107 |
108 | if "select_column" in skip_op:
109 | return failure_table_info
110 |
111 | def union_lists(to_union):
112 | return list(set().union(*to_union))
113 |
114 | def twoD_list_transpose(arr):
115 | return [[arr[i][j] for i in range(len(arr))] for j in range(len(arr[0]))]
116 |
117 | selected_columns_info = operation["parameter_and_conf"]
118 | selected_columns_info = sorted(
119 | selected_columns_info, key=lambda x: x[1], reverse=True
120 | )
121 | selected_columns_info = selected_columns_info[:union_num]
122 | selected_columns = [x[0] for x in selected_columns_info]
123 | selected_columns = [eval(x) for x in selected_columns]
124 | selected_columns = union_lists(selected_columns)
125 |
126 | real_selected_columns = []
127 |
128 | table_text = table_info["table_text"]
129 | table = twoD_list_transpose(table_text)
130 | new_table = []
131 | for cols in table:
132 | if cols[0].lower() in selected_columns:
133 | real_selected_columns.append(cols[0])
134 | new_table.append(copy.deepcopy(cols))
135 | if len(new_table) == 0:
136 | new_table = table
137 | real_selected_columns = ["*"]
138 | new_table = twoD_list_transpose(new_table)
139 |
140 | table_info["table_text"] = new_table
141 | table_info["act_chain"].append(
142 | f"f_select_column({', '.join(real_selected_columns)})"
143 | )
144 |
145 | return table_info
146 |
--------------------------------------------------------------------------------
/operations/select_row.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import copy
17 | import re
18 | import numpy as np
19 | from utils.helper import table2string
20 |
21 | from third_party.select_column_row_prompts.select_column_row_prompts import select_row_demo
22 |
23 |
24 | def select_row_build_prompt(table_text, statement, table_caption=None, num_rows=100):
25 | table_str = table2string(table_text, caption=table_caption).strip()
26 | prompt = "/*\n" + table_str + "\n*/\n"
27 | question = statement
28 | prompt += "statement : " + question + "\n"
29 | prompt += "explain :"
30 | return prompt
31 |
32 |
33 | def select_row_func(sample, table_info, llm, llm_options=None, debug=False):
34 | table_text = table_info["table_text"]
35 |
36 | table_caption = sample["table_caption"]
37 | statement = sample["statement"]
38 |
39 | prompt = "" + select_row_demo.rstrip() + "\n\n"
40 | prompt += select_row_build_prompt(table_text, statement, table_caption)
41 |
42 | responses = llm.generate_plus_with_score(prompt, options=llm_options)
43 |
44 | if debug:
45 | print(responses)
46 |
47 | pattern_row = r"f_row\(\[(.*?)\]\)"
48 |
49 | pred_conf_dict = {}
50 | for res, score in responses:
51 | try:
52 | pred = re.findall(pattern_row, res, re.S)[0].strip()
53 | except Exception:
54 | continue
55 | pred = pred.split(", ")
56 | pred = [i.strip() for i in pred]
57 | pred = [i.split(" ")[-1] for i in pred]
58 | pred = sorted(pred)
59 | pred = str(pred)
60 | if pred not in pred_conf_dict:
61 | pred_conf_dict[pred] = 0
62 | pred_conf_dict[pred] += np.exp(score)
63 |
64 | select_row_rank = sorted(pred_conf_dict.items(), key=lambda x: x[1], reverse=True)
65 |
66 | operation = {
67 | "operation_name": "select_row",
68 | "parameter_and_conf": select_row_rank,
69 | }
70 |
71 | sample_copy = copy.deepcopy(sample)
72 | sample_copy["chain"].append(operation)
73 |
74 | return sample_copy
75 |
76 |
77 | def select_row_act(table_info, operation, union_num=2, skip_op=[]):
78 | table_info = copy.deepcopy(table_info)
79 |
80 | if "select_row" in skip_op:
81 | failure_table_info = copy.deepcopy(table_info)
82 | failure_table_info["act_chain"].append("skip f_select_row()")
83 | return failure_table_info
84 |
85 | def union_lists(to_union):
86 | return list(set().union(*to_union))
87 |
88 | selected_rows_info = operation["parameter_and_conf"]
89 | selected_rows_info = sorted(selected_rows_info, key=lambda x: x[1], reverse=True)
90 | selected_rows_info = selected_rows_info[:union_num]
91 | selected_rows = [x[0] for x in selected_rows_info]
92 | selected_rows = [eval(x) for x in selected_rows]
93 | selected_rows = union_lists(selected_rows)
94 |
95 | if "*" in selected_rows:
96 | failure_table_info = copy.deepcopy(table_info)
97 | failure_table_info["act_chain"].append("f_select_row(*)")
98 | return failure_table_info
99 |
100 | real_selected_rows = []
101 |
102 | table_text = table_info["table_text"]
103 | new_table = [copy.deepcopy(table_text[0])]
104 | for row_id, row in enumerate(table_text):
105 | row_id = str(row_id)
106 | if row_id in selected_rows:
107 | new_table.append(copy.deepcopy(row))
108 | real_selected_rows.append(row_id)
109 |
110 | if len(new_table) == 1:
111 | failure_table_info = copy.deepcopy(table_info)
112 | failure_table_info["act_chain"].append("f_select_row(*)")
113 | return failure_table_info
114 |
115 | table_info["table_text"] = new_table
116 | selected_row_names = [f"row {x+1}" for x in range(len(real_selected_rows))]
117 | table_info["act_chain"].append(f"f_select_row({', '.join(selected_row_names)})")
118 |
119 | _real_selected_row_names = [f"row {x-1}" for x in map(int, real_selected_rows)]
120 | table_info['_real_select_rows'] = f"f_select_row({', '.join(_real_selected_row_names)})"
121 |
122 | return table_info
123 |
--------------------------------------------------------------------------------
/operations/sort_by.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import copy
17 | import re
18 | import numpy as np
19 | from utils.helper import table2string
20 |
21 |
22 | sort_column_demo = """To tell the statement is true or false, we can first use f_sort() to sort the values in a column to get the order of the items. The order can be "large to small" or "small to large".
23 |
24 | The column to sort should have these data types:
25 | 1. Numerical: the numerical strings that can be used in sort
26 | 2. DateType: the strings that describe a date, such as year, month, day
27 | 3. String: other strings
28 |
29 | /*
30 | col : position | club | played | points | wins | draws | losses | goals for | goals against | goal difference
31 | row 1 : 1 | malaga cf | 42 | 79 | 22 | 13 | 7 | 72 | 47 | +25
32 | row 10 : 10 | cp merida | 42 | 59 | 15 | 14 | 13 | 48 | 41 | +7
33 | row 3 : 3 | cd numancia | 42 | 73 | 21 | 10 | 11 | 68 | 40 | +28
34 | */
35 | Statement: cd numancia placed in the last position
36 | The existing columns are: position, club, played, points, wins, draws, losses, goals for, goals against, goal difference.
37 | Explanation: the statement wants to check cd numanica is in the last position. Each row is about a club. We need to know the order of position from last to front. There is a column for position and the column name is position. The datatype is Numerical.
38 | Therefore, the answer is: f_sort(position), the order is "large to small".
39 |
40 | /*
41 | col : year | team | games | combined tackles | tackles | assisted tackles |
42 | row 1 : 2004 | hou | 16 | 63 | 51 | 12 |
43 | row 2 : 2005 | hou | 12 | 35 | 24 | 11 |
44 | row 3 : 2006 | hou | 15 | 26 | 19 | 7 |
45 | */
46 | Statement: in 2006 babin had the least amount of tackles
47 | The existing columns are: year, team, games, combined tackles, tackles, assisted tackles.
48 | Explanation: the statement wants to check babin had the least amount of tackles in 2006. Each row is about a year. We need to know the order of tackles from the least to the most. There is a column for tackles and the column name is tackles. The datatype is Numerical.
49 | Therefore, the answer is: f_sort(tackles), the order is "small to large"."""
50 |
51 |
52 | def only_keep_num_and_first_dot(s):
53 | if s.strip() and s.strip()[0] == "-":
54 | minus = True
55 | else:
56 | minus = False
57 | ns = ""
58 | dot = False
59 | for c in s:
60 | if c in "0123456789":
61 | ns += c
62 | if c == ".":
63 | if dot == False:
64 | ns += c
65 | dot = True
66 | if ns == ".":
67 | return ""
68 | if ns == "":
69 | return ""
70 | if minus:
71 | ns = "-" + ns
72 | return ns
73 |
74 |
75 | def sort_column_build_prompt(table_text, statement, table_caption=None, num_rows=100):
76 | table_str = table2string(
77 | table_text, caption=table_caption, num_rows=num_rows
78 | ).strip()
79 | prompt = "/*\n" + table_str + "\n*/\n"
80 | prompt += "Statement: " + statement + "\n"
81 | prompt += "The existing columns are: "
82 | prompt += ", ".join(table_text[0]) + ".\n"
83 | prompt += "Explanation:"
84 | return prompt
85 |
86 |
87 | def sort_column_func(
88 | sample, table_info, llm, llm_options=None, debug=False, skip_op=[]
89 | ):
90 | # table_info = get_table_info(sample, skip_op=skip_op)
91 | table_text = table_info["table_text"]
92 |
93 | statement = sample["statement"]
94 | prompt = "" + sort_column_demo.rstrip() + "\n\n"
95 | prompt += sort_column_build_prompt(table_text, statement, num_rows=3)
96 | responses = llm.generate_plus_with_score(
97 | prompt,
98 | options=llm_options,
99 | )
100 |
101 | if debug:
102 | print(prompt)
103 | print(responses)
104 |
105 | sort_info_and_conf = {}
106 |
107 | headers = table_text[0]
108 | rows = table_text[1:]
109 | for res, score in responses:
110 | try:
111 | datatype = re.findall(r"The datatype is (\w*).", res, re.S)[0].strip()
112 | sort_order = re.findall(r'the order is "(.*)"\.', res, re.S)[0].strip()
113 | sort_column = re.findall(r"f_sort\((.*)\)", res, re.S)[0].strip()
114 | except:
115 | continue
116 |
117 | if sort_order not in ["small to large", "large to small"]:
118 | continue
119 | if sort_column not in headers:
120 | continue
121 | sort_key = (sort_column, sort_order, datatype)
122 | if sort_key not in sort_info_and_conf:
123 | sort_info_and_conf[sort_key] = 0
124 | sort_info_and_conf[sort_key] += np.exp(score)
125 |
126 | sort_param_and_conf_list = []
127 | for (sort_column, sort_order, datatype), conf in sort_info_and_conf.items():
128 | sort_column_contents = []
129 | index = headers.index(sort_column)
130 | for row in rows:
131 | sort_column_contents.append(row[index])
132 |
133 | vs_to_sort = []
134 | vs_not_to_sort = []
135 | if datatype == "Numerical":
136 | for i in range(len(sort_column_contents)):
137 | v_str = sort_column_contents[i]
138 | v_str = only_keep_num_and_first_dot(v_str)
139 | if v_str == "" or v_str == ".":
140 | vs_not_to_sort.append((sort_column_contents[i], i))
141 | else:
142 | vs_to_sort.append((float(v_str), i))
143 | else:
144 | for i in range(len(sort_column_contents)):
145 | v_str = sort_column_contents[i]
146 | v_str = v_str.strip()
147 | if v_str == "":
148 | vs_not_to_sort.append((sort_column_contents[i], i))
149 | else:
150 | vs_to_sort.append((v_str, i))
151 |
152 | # check if already sorted
153 | pure_vs_to_sort = [x[0] for x in vs_to_sort]
154 | if (
155 | sorted(pure_vs_to_sort) == pure_vs_to_sort
156 | or sorted(pure_vs_to_sort, reverse=True) == pure_vs_to_sort
157 | ):
158 | continue
159 |
160 | # get sorted index
161 | if sort_order == "small to large":
162 | vs_to_sort = sorted(vs_to_sort, key=lambda x: x[0])
163 | else:
164 | vs_to_sort = sorted(vs_to_sort, reverse=True, key=lambda x: x[0])
165 | index_order = [x[1] for x in vs_to_sort] + [x[1] for x in vs_not_to_sort]
166 |
167 | sort_param_and_conf_list.append(
168 | (
169 | sort_column,
170 | sort_order,
171 | datatype,
172 | index_order,
173 | max([x[0] for x in vs_to_sort]),
174 | min([x[0] for x in vs_to_sort]),
175 | conf,
176 | )
177 | )
178 |
179 | sort_param_and_conf_list = sorted(sort_param_and_conf_list, key=lambda x: x[-1])
180 |
181 | operation = {
182 | "operation_name": "sort_column",
183 | "parameter_and_conf": sort_param_and_conf_list,
184 | }
185 |
186 | sample_copy = copy.deepcopy(sample)
187 | sample_copy["chain"].append(operation)
188 |
189 | if debug:
190 | print(sort_param_and_conf_list)
191 |
192 | return sample_copy
193 |
194 |
195 | def sort_column_act(
196 | table_info, operation, strategy="top", filter="Only Numerical", skip_op=[]
197 | ):
198 | table_info = copy.deepcopy(table_info)
199 |
200 | failure_table_info = copy.deepcopy(table_info)
201 | failure_table_info["act_chain"].append("skip f_sort_column()")
202 |
203 | if "sort_column" in skip_op:
204 | return failure_table_info
205 | if len(operation["parameter_and_conf"]) == 0:
206 | return failure_table_info
207 |
208 | if strategy == "top":
209 | sort_column, sort_order, datatype, index_order, max_v, min_v = operation[
210 | "parameter_and_conf"
211 | ][0][:-1]
212 | else:
213 | raise NotImplementedError()
214 |
215 | if filter == "Only Numerical":
216 | if datatype != "Numerical":
217 | return failure_table_info
218 | else:
219 | raise NotImplementedError()
220 |
221 | table_text = table_info["table_text"]
222 | headers = table_text[0]
223 | rows = table_text[1:]
224 | new_rows = [rows[i] for i in index_order]
225 | new_table_text = [headers] + new_rows
226 |
227 | table_info["table_text"] = new_table_text
228 | table_info["sort_sub_table"] = (sort_column, max_v, min_v)
229 | table_info["act_chain"].append(f"f_sort_column({sort_column})")
230 |
231 | return table_info
232 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | numpy
3 | pandas
4 | tqdm
5 | openai==0.28.1
--------------------------------------------------------------------------------
/run_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Copyright 2024 The Chain-of-Table authors\n",
8 | "\n",
9 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
10 | "you may not use this file except in compliance with the License.\n",
11 | "You may obtain a copy of the License at\n",
12 | "\n",
13 | " https://www.apache.org/licenses/LICENSE-2.0\n",
14 | "\n",
15 | "Unless required by applicable law or agreed to in writing, software\n",
16 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
17 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
18 | "See the License for the specific language governing permissions and\n",
19 | "limitations under the License."
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "# Demo of Chain of Tables\n",
27 | "\n",
28 | "Paper: https://arxiv.org/abs/2401.04398"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 1,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "import pandas as pd\n",
38 | "\n",
39 | "from utils.load_data import wrap_input_for_demo\n",
40 | "from utils.llm import ChatGPT\n",
41 | "from utils.helper import *\n",
42 | "from utils.evaluate import *\n",
43 | "from utils.chain import *\n",
44 | "from operations import *"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "# User parameters\n",
54 | "model_name: str = \"gpt-3.5-turbo-0613\"\n",
55 | "openai_api_key: str = None\n",
56 | "\n",
57 | "# Default parameters\n",
58 | "dataset_path: str = \"data/tabfact/test.jsonl\"\n",
59 | "raw2clean_path: str = \"data/tabfact/raw2clean.jsonl\""
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 3,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "gpt_llm = ChatGPT(\n",
69 | " model_name=model_name,\n",
70 | " key=openai_api_key,\n",
71 | ")"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "statement = \"the wildcats kept the opposing team scoreless in four games\"\n",
81 | "table_caption = \"1947 kentucky wildcats football team\"\n",
82 | "table_text = [\n",
83 | " ['game', 'date', 'opponent', 'result', 'wildcats points', 'opponents', 'record'],\n",
84 | " ['1', 'sept 20', 'ole miss', 'loss', '7', '14', '0 - 1'],\n",
85 | " ['2', 'sept 27', 'cincinnati', 'win', '20', '0', '1 - 1'],\n",
86 | " ['3', 'oct 4', 'xavier', 'win', '20', '7', '2 - 1'],\n",
87 | " ['4', 'oct 11', '9 georgia', 'win', '26', '0', '3 - 1 , 20'],\n",
88 | " ['5', 'oct 18', '10 vanderbilt', 'win', '14', '0', '4 - 1 , 14'],\n",
89 | " ['6', 'oct 25', 'michigan state', 'win', '7', '6', '5 - 1 , 13'],\n",
90 | " ['7', 'nov 1', '18 alabama', 'loss', '0', '13', '5 - 2'],\n",
91 | " ['8', 'nov 8', 'west virginia', 'win', '15', '6', '6 - 2'],\n",
92 | " ['9', 'nov 15', 'evansville', 'win', '36', '0', '7 - 2'],\n",
93 | " ['10', 'nov 22', 'tennessee', 'loss', '6', '13', '7 - 3']\n",
94 | "]\n",
95 | "answer = \"True\""
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 5,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "demo_sample = wrap_input_for_demo(\n",
105 | " statement=statement, table_caption=table_caption, table_text=table_text\n",
106 | ")\n",
107 | "proc_sample, dynamic_chain_log = dynamic_chain_exec_one_sample(\n",
108 | " sample=demo_sample, llm=gpt_llm\n",
109 | ")\n",
110 | "output_sample = simple_query(\n",
111 | " sample=proc_sample,\n",
112 | " table_info=get_table_info(proc_sample),\n",
113 | " llm=gpt_llm,\n",
114 | " use_demo=True,\n",
115 | " llm_options=gpt_llm.get_model_options(\n",
116 | " temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0\n",
117 | " ),\n",
118 | ")\n",
119 | "cotable_log = get_table_log(output_sample)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 9,
125 | "metadata": {},
126 | "outputs": [
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "Statements: the wildcats kept the opposing team scoreless in four games\n",
132 | "\n",
133 | "Table: 1947 kentucky wildcats football team\n",
134 | " game result wildcats points opponents\n",
135 | "0 2 win 20 0\n",
136 | "1 4 win 26 0\n",
137 | "2 5 win 14 0\n",
138 | "3 9 win 36 0\n",
139 | "4 3 win 20 7\n",
140 | "\n",
141 | "-> f_select_row(row 1, row 2, row 3, row 4, row 8)\n",
142 | " game date opponent result wildcats points opponents record\n",
143 | "0 2 sept 27 cincinnati win 20 0 1 - 1\n",
144 | "1 3 oct 4 xavier win 20 7 2 - 1\n",
145 | "2 4 oct 11 9 georgia win 26 0 3 - 1 , 20\n",
146 | "3 5 oct 18 10 vanderbilt win 14 0 4 - 1 , 14\n",
147 | "4 9 nov 15 evansville win 36 0 7 - 2\n",
148 | "\n",
149 | "-> f_select_column(game, result, wildcats points, opponents)\n",
150 | " game result wildcats points opponents\n",
151 | "0 2 win 20 0\n",
152 | "1 3 win 20 7\n",
153 | "2 4 win 26 0\n",
154 | "3 5 win 14 0\n",
155 | "4 9 win 36 0\n",
156 | "\n",
157 | "-> f_group_column(opponents)\n",
158 | " game result wildcats points opponents\n",
159 | "0 2 win 20 0\n",
160 | "1 3 win 20 7\n",
161 | "2 4 win 26 0\n",
162 | "3 5 win 14 0\n",
163 | "4 9 win 36 0\n",
164 | " Group ID opponents Count\n",
165 | "0 Group 1 0 4\n",
166 | "1 Group 2 7 1\n",
167 | "\n",
168 | "-> f_sort_column(opponents)\n",
169 | " game result wildcats points opponents\n",
170 | "0 2 win 20 0\n",
171 | "1 4 win 26 0\n",
172 | "2 5 win 14 0\n",
173 | "3 9 win 36 0\n",
174 | "4 3 win 20 7\n",
175 | " Group ID opponents Count\n",
176 | "0 Group 1 0 4\n",
177 | "1 Group 2 7 1\n",
178 | "\n",
179 | "-> simple_query()\n",
180 | "The statement is True\n",
181 | "\n",
182 | "Groundtruth: The statement is True\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "print(f'Statements: {output_sample[\"statement\"]}\\n')\n",
188 | "print(f'Table: {output_sample[\"table_caption\"]}')\n",
189 | "print(f\"{pd.DataFrame(table_text[1:], columns=table_text[0])}\\n\")\n",
190 | "for table_info in cotable_log:\n",
191 | " if table_info[\"act_chain\"]:\n",
192 | " table_text = table_info[\"table_text\"]\n",
193 | " table_action = table_info[\"act_chain\"][-1]\n",
194 | " if \"skip\" in table_action:\n",
195 | " continue\n",
196 | " if \"query\" in table_action:\n",
197 | " result = table_info[\"cotable_result\"]\n",
198 | " if result == \"YES\":\n",
199 | " print(f\"-> {table_action}\\nThe statement is True\\n\")\n",
200 | " else:\n",
201 | " print(f\"-> {table_action}\\nThe statement is False\\n\")\n",
202 | " else:\n",
203 | " print(f\"-> {table_action}\\n{pd.DataFrame(table_text[1:], columns=table_text[0])}\")\n",
204 | " if 'group_sub_table' in table_info:\n",
205 | " group_column, group_info = table_info[\"group_sub_table\"]\n",
206 | " group_headers = [\"Group ID\", group_column, \"Count\"]\n",
207 | " group_rows = []\n",
208 | " for i, (v, count) in enumerate(group_info):\n",
209 | " if v.strip() == \"\":\n",
210 | " v = \"[Empty Cell]\"\n",
211 | " group_rows.append([f\"Group {i+1}\", v, str(count)])\n",
212 | " print(f\"{pd.DataFrame(group_rows, columns=group_headers)}\")\n",
213 | " print()\n",
214 | "\n",
215 | "print(f\"Groundtruth: The statement is {answer}\")"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": []
224 | }
225 | ],
226 | "metadata": {
227 | "kernelspec": {
228 | "display_name": "cotable",
229 | "language": "python",
230 | "name": "python3"
231 | },
232 | "language_info": {
233 | "codemirror_mode": {
234 | "name": "ipython",
235 | "version": 3
236 | },
237 | "file_extension": ".py",
238 | "mimetype": "text/x-python",
239 | "name": "python",
240 | "nbconvert_exporter": "python",
241 | "pygments_lexer": "ipython3",
242 | "version": "3.10.13"
243 | }
244 | },
245 | "nbformat": 4,
246 | "nbformat_minor": 2
247 | }
248 |
--------------------------------------------------------------------------------
/run_tabfact.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import fire
17 | import os
18 |
19 | from utils.load_data import load_tabfact_dataset
20 | from utils.llm import ChatGPT
21 | from utils.helper import *
22 | from utils.evaluate import *
23 | from utils.chain import *
24 | from operations import *
25 |
26 |
27 | def main(
28 | dataset_path: str = "data/tabfact/test.jsonl",
29 | raw2clean_path: str ="data/tabfact/raw2clean.jsonl",
30 | model_name: str = "gpt-3.5-turbo-16k-0613",
31 | result_dir: str = "results/tabfact",
32 | openai_api_key: str = None,
33 | first_n=-1,
34 | n_proc=1,
35 | chunk_size=1,
36 | ):
37 | dataset = load_tabfact_dataset(dataset_path, raw2clean_path, first_n=first_n)
38 | gpt_llm = ChatGPT(
39 | model_name=model_name,
40 | key=os.environ["OPENAI_API_KEY"] if openai_api_key is None else openai_api_key,
41 | )
42 | os.makedirs(result_dir, exist_ok=True)
43 |
44 | proc_samples, dynamic_chain_log_list = dynamic_chain_exec_with_cache_mp(
45 | dataset,
46 | llm=gpt_llm,
47 | llm_options=gpt_llm.get_model_options(
48 | temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0
49 | ),
50 | strategy="top",
51 | cache_dir=os.path.join(result_dir, "cache"),
52 | n_proc=n_proc,
53 | chunk_size=chunk_size,
54 | )
55 | fixed_chain = [
56 | (
57 | "simpleQuery_fewshot",
58 | simple_query,
59 | dict(use_demo=True),
60 | dict(
61 | temperature=0, per_example_max_decode_steps=200, per_example_top_p=1.0
62 | ),
63 | ),
64 | ]
65 | final_result, _ = fixed_chain_exec_mp(gpt_llm, proc_samples, fixed_chain)
66 | acc = tabfact_match_func_for_samples(final_result)
67 | print("Accuracy:", acc)
68 |
69 | print(
70 | f'Accuracy: {acc}',
71 | file=open(os.path.join(result_dir, "result.txt"), "w")
72 | )
73 | pickle.dump(
74 | final_result, open(os.path.join(result_dir, "final_result.pkl"), "wb")
75 | )
76 | pickle.dump(
77 | dynamic_chain_log_list,
78 | open(os.path.join(result_dir, "dynamic_chain_log_list.pkl"), "wb")
79 | )
80 |
81 |
82 | if __name__ == "__main__":
83 | fire.Fire(main)
84 |
--------------------------------------------------------------------------------
/third_party/select_column_row_prompts/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Alibaba Research
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
--------------------------------------------------------------------------------
/third_party/select_column_row_prompts/select_column_row_prompts.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022 Alibaba Research
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 |
16 | select_column_demo = """Use f_col() api to filter out useless columns in the table according to informations in the statement and the table.
17 |
18 | /*
19 | {
20 | "table_caption": "south wales derby",
21 | "columns": ["competition", "total matches", "cardiff win", "draw", "swansea win"],
22 | "table_column_priority": [
23 | ["competition", "league", "fa cup", "league cup"],
24 | ["total matches", "55", "2", "5"],
25 | ["cardiff win", "19", "0", "2"],
26 | ["draw", "16", "27", "0"],
27 | ["swansea win", "20", "2", "3"]
28 | ]
29 | }
30 | */
31 | statement : there are no cardiff wins that have a draw greater than 27.
32 | similar words link to columns :
33 | no cardiff wins -> cardiff win
34 | a draw -> draw
35 | column value link to columns :
36 | 27 -> draw
37 | semantic sentence link to columns :
38 | None
39 | The answer is : f_col([cardiff win, draw])
40 |
41 | /*
42 | {
43 | "table_caption": "gambrinus liga",
44 | "columns": ["season", "champions", "runner - up", "third place", "top goalscorer", "club"],
45 | "table_column_priority": [
46 | ["season", "1993 - 94", "1994 - 95", "1995 - 96"],
47 | ["champions", "sparta prague (1)", "sparta prague (2)", "slavia prague (1)"],
48 | ["runner - up", "slavia prague", "slavia prague", "sigma olomouc"],
49 | ["third place", "ban\u00edk ostrava", "fc brno", "baumit jablonec"],
50 | ["top goalscorer", "horst siegl (20)", "radek drulák (15)", "radek drulák (22)"],
51 | ["club", "sparta prague", "drnovice", "drnovice"]
52 | ]
53 | }
54 | */
55 | statement : the top goal scorer for the season 2010 - 2011 was david lafata.
56 | similar words link to columns :
57 | season 2010 - 2011 -> season
58 | the top goal scorer -> top goalscorer
59 | column value link to columns :
60 | 2010 - 2011 -> season
61 | semantic sentence link to columns :
62 | the top goal scorer for ... was david lafata -> top goalscorer
63 | The answer is : f_col([season, top goalscorer])
64 |
65 | /*
66 | {
67 | "table_caption": "head of the river (queensland)",
68 | "columns": ["crew", "open 1st viii", "senior 2nd viii", "senior 3rd viii", "senior iv", "year 12 single scull", "year 11 single scull"],
69 | "table_column_priority": [
70 | ["crew", "2009", "2010", "2011"],
71 | ["open 1st viii", "stm", "splc", "stm"],
72 | ["senior 2nd viii", "sta", "som", "stu"],
73 | ["senior 3rd viii", "sta", "som", "stu"],
74 | ["senior iv", "som", "sth", "sta"],
75 | ["year 12 single scull", "stm", "splc", "stm"],
76 | ["year 11 single scull", "splc", "splc", "splc"]
77 | ]
78 | }
79 | */
80 | statement : the crew that had a senior 2nd viii of som and senior iv of stm was that of 2013.
81 | similar words link to columns :
82 | the crew -> crew
83 | a senior 2nd viii of som -> senior 2nd viii
84 | senior iv of stm -> senior iv
85 | column value link to columns :
86 | som -> senior 2nd viii
87 | stm -> senior iv
88 | semantic sentence link to columns :
89 | None
90 | The answer is : f_col([crew, senior 2nd viii, senior iv])
91 |
92 | /*
93 | {
94 | "table_caption": "2007 - 08 boston celtics season",
95 | "columns": ["game", "date", "team", "score", "high points", "high rebounds", "high assists", "location attendance", "record"],
96 | "table_column_priority": [
97 | ["game", "74", "75", "76"],
98 | ["date", "april 1", "april 2", "april 5"],
99 | ["team", "chicago", "indiana", "charlotte"],
100 | ["score", "106 - 92", "92 - 77", "101 - 78"],
101 | ["high points", "allen (22)", "garnett (20)", "powe (22)"],
102 | ["high rebounds", "perkins (9)", "garnett (11)", "powe (9)"],
103 | ["high assists", "rondo (10)", "rondo (6)", "rondo (5)"],
104 | ["location attendance", "united center 22225", "td banknorth garden 18624", "charlotte bobcats arena 19403"],
105 | ["record", "59 - 15", "60 - 15", "61 - 15"]
106 | ]
107 | }
108 | */
109 | statement : in game 74 against chicago , perkins had the most rebounds (9) and allen had the most points (22).
110 | similar words link to columns :
111 | the most rebounds -> high rebounds
112 | the most points -> high points
113 | in game 74 -> game
114 | column value link to columns :
115 | 74 -> game
116 | semantic sentence link to columns :
117 | 2007 - 08 boston celtics season in game 74 against chicago -> team
118 | perkins had the most rebounds -> high rebounds
119 | allen had the most points -> high points
120 | The answer is : f_col([game, team, high points, high rebounds])
121 |
122 | /*
123 | {
124 | "table_caption": "dan hardy",
125 | "columns": ["res", "record", "opponent", "method", "event", "round", "time", "location"],
126 | "table_column_priority": [
127 | ["res", "win", "win", "loss"],
128 | ["record", "25 - 10 (1)", "24 - 10 (1)", "23 - 10 (1)"],
129 | ["opponent", "amir sadollah", "duane ludwig", "chris lytle"],
130 | ["method", "decision (unanimous)", "ko (punch and elbows)", "submission (guillotine choke)"],
131 | ["event", "ufc on fuel tv : struve vs miocic", "ufc 146", "ufc live : hardy vs lytle"],
132 | ["round", "3", "1", "5"],
133 | ["time", "5:00", "3:51", "4:16"],
134 | ["location", "nottingham , england", "las vegas , nevada , united states", "milwaukee , wisconsin , united states"]
135 | ]
136 | }
137 | */
138 | statement : the record of the match was a 10 - 3 (1) score , resulting in a win in round 5 with a time of 5:00 minutes.
139 | similar words link to columns :
140 | the record of the match was a 10 - 3 (1) score -> record
141 | the record -> record
142 | in round -> round
143 | a time -> time
144 | column value link to columns :
145 | 10 - 3 (1) -> record
146 | 5 -> round
147 | 5:00 minutes -> time
148 | semantic sentence link to columns :
149 | resulting in a win -> res
150 | The answer is : f_col([res, record, round, time])
151 |
152 | /*
153 | {
154 | "table_caption": "list of largest airlines in central america & the caribbean",
155 | "columns": ["rank", "airline", "country", "fleet size", "remarks"],
156 | "table_column_priority": [
157 | ["rank", "1", "2", "3"],
158 | ["airline", "caribbean airlines", "liat", "cubana de aviaci\u00e3 cubicn"],
159 | ["country", "trinidad and tobago", "antigua and barbuda", "cuba"],
160 | ["fleet size", "22", "17", "14"],
161 | ["remarks", "largest airline in the caribbean", "second largest airline in the caribbean", "operational since 1929"]
162 | ]
163 | }
164 | */
165 | statement : the remark on airline of dutch antilles express with fleet size over 4 is curacao second national carrier.
166 | similar words link to columns :
167 | the remark -> remarks
168 | on airline -> airline
169 | fleet size -> fleet size
170 | column value link to columns :
171 | dutch antilles -> country
172 | 4 -> fleet size
173 | curacao second national carrier -> remarks
174 | semantic sentence link to columns :
175 | None
176 | The answer is : f_col([airline, fleet size, remarks])
177 |
178 | /*
179 | {
180 | "table_caption": "cnbc prime 's the profit 200",
181 | "columns": ["year", "date", "driver", "team", "manufacturer", "laps", "-", "race time", "average speed (mph)"],
182 | "table_column_priority": [
183 | ["year", "1990", "1990", "1991"],
184 | ["date", "july 15", "october 14", "july 14"],
185 | ["driver", "tommy ellis", "rick mast", "kenny wallace"],
186 | ["team", "john jackson", "ag dillard motorsports", "rusty wallace racing"],
187 | ["manufacturer", "buick", "buick", "pontiac"],
188 | ["laps", "300", "250", "300"],
189 | ["-", "317.4 (510.805)", "264.5 (425.671)", "317.4 (510.805)"],
190 | ["race time", "3:41:58", "2:44:37", "2:54:38"],
191 | ["average speed (mph)", "85.797", "94.405", "109.093"]
192 | ]
193 | }
194 | */
195 | statemnet : on june 26th , 2010 kyle busch drove a total of 211.6 miles at an average speed of 110.673 miles per hour.
196 | similar words link to columns :
197 | drove -> driver
198 | column value link to columns :
199 | june 26th , 2010 -> date, year
200 | a total of 211.6 miles -> -
201 | semantic sentence link to columns :
202 | kyle busch drove -> driver
203 | an average speed of 110.673 miles per hour -> average speed (mph)
204 | The answer is : f_col([year, date, driver, -, average speed (mph)])
205 |
206 | /*
207 | {
208 | "table_caption": "2000 ansett australia cup",
209 | "columns": ["home team", "home team score", "away team", "away team score", "ground", "crowd", "date"],
210 | "table_column_priority": [
211 | ["home team", "brisbane lions", "kangaroos", "richmond"],
212 | ["home team score", "13.6 (84)", "10.16 (76)", "11.16 (82)"],
213 | ["away team", "sydney", "richmond", "brisbane lions"],
214 | ["away team score", "17.10 (112)", "9.11 (65)", "15.9 (99)"],
215 | ["ground", "bundaberg rum stadium", "waverley park", "north hobart oval"],
216 | ["crowd", "8818", "16512", "4908"],
217 | ["date", "friday , 28 january", "friday , 28 january", "saturday , 5 february"]
218 | ]
219 | }
220 | */
221 | statement : sydney scored the same amount of points in the first game of the 2000 afl ansett australia cup as their opponent did in their second.
222 | similar words link to columns :
223 | scored -> away team score, home team score
224 | column value link to columns :
225 | sydney -> away team, home team
226 | semantic sentence link to columns :
227 | their opponent -> home team, away team
228 | scored the same amount of points -> away team score, home team score
229 | first game -> date
230 | their second -> date
231 | sydney scored -> home team, away team, home team score, away team score
232 | The answer is : f_col([away team, home team, away team score, home team score, date])"""
233 |
234 |
235 | select_row_demo = """Using f_row() api to select relevant rows in the given table that support or oppose the statement.
236 | Please use f_row([*]) to select all rows in the table.
237 |
238 | /*
239 | table caption : 1972 vfl season.
240 | col : home team | home team score | away team | away team score | venue | crowd | date
241 | row 1 : st kilda | 13.12 (90) | melbourne | 13.11 (89) | moorabbin oval | 18836 | 19 august 1972
242 | row 2 : south melbourne | 9.12 (66) | footscray | 11.13 (79) | lake oval | 9154 | 19 august 1972
243 | row 3 : richmond | 20.17 (137) | fitzroy | 13.22 (100) | mcg | 27651 | 19 august 1972
244 | row 4 : geelong | 17.10 (112) | collingwood | 17.9 (111) | kardinia park | 23108 | 19 august 1972
245 | row 5 : north melbourne | 8.12 (60) | carlton | 23.11 (149) | arden street oval | 11271 | 19 august 1972
246 | row 6 : hawthorn | 15.16 (106) | essendon | 12.15 (87) | vfl park | 36749 | 19 august 1972
247 | */
248 | statement : the away team with the highest score is fitzroy.
249 | explain : the statement want to check the highest away team score. we need to compare score of away team fitzroy with all others, so we need all rows. use * to represent all rows in the table.
250 | The answer is : f_row([*])
251 |
252 | /*
253 | table caption : list of largest airlines in central america & the caribbean.
254 | col : rank | airline | country | fleet size | remarks
255 | row 1 : 1 | caribbean airlines | trinidad and tobago | 22 | largest airline in the caribbean
256 | row 2 : 2 | liat | antigua and barbuda | 17 | second largest airline in the caribbean
257 | row 3 : 3 | cubana de aviaciã cubicn | cuba | 14 | operational since 1929
258 | row 4 : 4 | inselair | curacao | 12 | operational since 2006
259 | row 5 : 5 | dutch antilles express | curacao | 4 | curacao second national carrier
260 | row 6 : 6 | air jamaica | trinidad and tobago | 5 | parent company is caribbean airlines
261 | row 7 : 7 | tiara air | aruba | 3 | aruba 's national airline
262 | */
263 | statement : the remark on airline of dutch antilles express with fleet size over 4 is curacao second national carrier.
264 | explain : the statement want to check a record in the table. we cannot find a record perfectly satisfied the statement, the most relevant row is row 5, which describes dutch antilles express airline, remarks is uracao second national carrier and fleet size is 4 not over 4.
265 | The answer is : f_row([row 5])
266 |
267 | /*
268 | table caption : list of longest - serving soap opera actors.
269 | col : actor | character | soap opera | years | duration
270 | row 1 : tom jordon | charlie kelly | fair city | 1989- | 25 years
271 | row 2 : tony tormey | paul brennan | fair city | 1989- | 25 years
272 | row 3 : jim bartley | bela doyle | fair city | 1989- | 25 years
273 | row 4 : sarah flood | suzanne halpin | fair city | 1989 - 2013 | 24 years
274 | row 5 : pat nolan | barry o'hanlon | fair city | 1989 - 2011 | 22 years
275 | row 6 : martina stanley | dolores molloy | fair city | 1992- | 22 years
276 | row 7 : joan brosnan walsh | mags kelly | fair city | 1989 - 2009 | 20 years
277 | row 8 : jean costello | rita doyle | fair city | 1989 - 2008 , 2010 | 19 years
278 | row 9 : ciara o'callaghan | yvonne gleeson | fair city | 1991 - 2004 , 2008- | 19 years
279 | row 10 : celia murphy | niamh cassidy | fair city | 1995- | 19 years
280 | row 39 : tommy o'neill | john deegan | fair city | 2001- | 13 years
281 | row 40 : seamus moran | mike gleeson | fair city | 1996 - 2008 | 12 years
282 | row 41 : rebecca smith | annette daly | fair city | 1997 - 2009 | 12 years
283 | row 42 : grace barry | mary - ann byrne | glenroe | 1990 - 2001 | 11 years
284 | row 43 : gemma doorly | sarah o'leary | fair city | 2001 - 2011 | 10 years
285 | */
286 | statement : seamus moran and rebecca smith were in soap operas for a duration of 12 years.
287 | explain : the statement want to check seamus moran and rebecca smith in the table. row 40 describes seamus moran were in soap operas for a duration of 12 years. row 41 describes rebecca smith were in soap operas for a duration of 12 years
288 | The answer is : f_row([row 40, row 41])
289 |
290 | /*
291 | table caption : jeep grand cherokee.
292 | col : years | displacement | engine | power | torque
293 | row 1 : 1999 - 2004 | 4.0l (242cid) | power tech i6 | - | 3000 rpm
294 | row 2 : 1999 - 2004 | 4.7l (287cid) | powertech v8 | - | 3200 rpm
295 | row 3 : 2002 - 2004 | 4.7l (287cid) | high output powertech v8 | - | -
296 | row 4 : 1999 - 2001 | 3.1l diesel | 531 ohv diesel i5 | - | -
297 | row 5 : 2002 - 2004 | 2.7l diesel | om647 diesel i5 | - | -
298 | */
299 | statement : the jeep grand cherokee with the om647 diesel i5 had the third lowest numbered displacement.
300 | explain : the statement want to check the om647 diesel i5 had third lowest numbered displacement. so we need first three low numbered displacement and all rows that power is om647 diesel i5.
301 | The answer is : f_row([row 5, row 4, row 1])"""
302 |
303 |
--------------------------------------------------------------------------------
/utils/chain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import copy
17 | import re
18 | from tqdm import tqdm
19 | import numpy as np
20 | from utils.helper import table2string
21 | from collections import defaultdict
22 | import pickle
23 | import os
24 |
25 | import multiprocessing as mp
26 |
27 | from operations import *
28 |
29 |
30 | def fixed_chain_exec_mp(llm, init_samples, fixed_op_list, n_proc=10, chunk_size=50):
31 | history = {}
32 | final_result = None
33 |
34 | chain_header = copy.deepcopy(init_samples)
35 | chain_key = ""
36 |
37 | for i, (op_name, solver_func, kargs, llm_kargs) in enumerate(fixed_op_list):
38 | chain_key += f"->{op_name}" if i > 0 else op_name
39 | chain_header = conduct_single_solver_mp(
40 | llm=llm,
41 | all_samples=chain_header,
42 | solver_func=solver_func,
43 | tqdm_tag=op_name,
44 | n_proc=n_proc,
45 | chunk_size=chunk_size,
46 | llm_options=llm.get_model_options(
47 | **llm_kargs,
48 | ),
49 | **kargs,
50 | )
51 |
52 | history[f"({i}) {chain_key}"] = chain_header
53 | if i == len(fixed_op_list) - 1:
54 | final_result = chain_header
55 |
56 | return final_result, history
57 |
58 |
59 | def conduct_single_solver(llm, all_samples, solver_func, tqdm_tag=None, **kwargs):
60 | result_samples = [None for _ in range(len(all_samples))]
61 |
62 | for idx in tqdm(range(len(all_samples)), desc=tqdm_tag):
63 | try:
64 | sample = all_samples[idx]
65 | table_info = get_table_info(
66 | sample,
67 | skip_op=kwargs.get("skip_op", []),
68 | first_n_op=kwargs.get("first_n_op", None),
69 | )
70 | proc_sample = solver_func(sample, table_info, llm, **kwargs)
71 | result_samples[idx] = proc_sample
72 | except Exception as e:
73 | print(f"Error in {idx}th sample: {e}")
74 | continue
75 | return result_samples
76 |
77 |
78 | def _conduct_single_solver_mp_core(arg):
79 | idx, sample, llm, solver_func, kwargs = arg
80 | try:
81 | table_info = get_table_info(
82 | sample,
83 | skip_op=kwargs.get("skip_op", []),
84 | first_n_op=kwargs.get("first_n_op", None),
85 | )
86 | proc_sample = solver_func(sample, table_info, llm, **kwargs)
87 | return idx, proc_sample
88 | except Exception as e:
89 | print(f"Error in {idx}-th sample: {e}")
90 | return idx, None
91 |
92 |
93 | def conduct_single_solver_mp(
94 | llm, all_samples, solver_func, tqdm_tag=None, n_proc=10, chunk_size=50, **kwargs
95 | ):
96 | result_samples = [None for _ in range(len(all_samples))]
97 |
98 | args = [
99 | (idx, sample, llm, solver_func, kwargs)
100 | for idx, sample in enumerate(all_samples)
101 | ]
102 |
103 | with mp.Pool(n_proc) as p:
104 | for idx, proc_sample in tqdm(
105 | p.imap_unordered(_conduct_single_solver_mp_core, args, chunksize=chunk_size),
106 | total=len(all_samples),
107 | desc=tqdm_tag,
108 | ):
109 | result_samples[idx] = proc_sample
110 |
111 | return result_samples
112 |
113 |
114 | def get_act_func(name):
115 | try:
116 | return eval(f"{name}_act")
117 | except:
118 |
119 | def _default_act(table_text, *args, **kwargs):
120 | return copy.deepcopy(table_text)
121 |
122 | if "query" not in name:
123 | print("Unknown operation: ", name)
124 | return _default_act
125 |
126 |
127 | def get_table_info(sample, skip_op=[], first_n_op=None):
128 | table_text = sample["table_text"]
129 | chain = sample["chain"]
130 |
131 | if first_n_op is not None:
132 | chain = chain[:first_n_op]
133 |
134 | table_info = {
135 | "table_text": table_text,
136 | "act_chain": [],
137 | }
138 |
139 | for operation in chain:
140 | operation_name = operation["operation_name"]
141 | act_func = get_act_func(operation_name)
142 | table_info = act_func(table_info, operation, skip_op=skip_op)
143 |
144 | return table_info
145 |
146 |
147 | def get_table_log(sample, skip_op=[], first_n_op=None):
148 | table_text = sample["table_text"]
149 | chain = sample["chain"]
150 |
151 | if first_n_op is not None:
152 | chain = chain[:first_n_op]
153 |
154 | table_log = []
155 |
156 | table_info = {
157 | "table_text": table_text,
158 | "act_chain": [],
159 | }
160 | table_log.append(table_info)
161 |
162 | for operation in chain:
163 | operation_name = operation["operation_name"]
164 | act_func = get_act_func(operation_name)
165 | table_info = act_func(table_info, operation, skip_op=skip_op)
166 | if 'row' in operation_name:
167 | table_info['act_chain'][-1] = table_info['_real_select_rows']
168 | if 'query' in operation_name:
169 | table_info['act_chain'].append(f'{operation_name}()')
170 | table_info['cotable_result'] = operation['parameter_and_conf'][0][0]
171 | table_log.append(table_info)
172 |
173 | return table_log
174 |
175 |
176 | # Dynmiac Chain Func
177 |
178 |
179 | plan_add_column_demo = """If the table does not have the needed column to tell whether the statement is True or False, we use f_add_column() to add a new column for it. For example,
180 | /*
181 | col : rank | lane | player | time
182 | row 1 : | 5 | olga tereshkova (kaz) | 51.86
183 | row 2 : | 6 | manjeet kaur (ind) | 52.17
184 | row 3 : | 3 | asami tanno (jpn) | 53.04
185 | */
186 | Statement: there are one athlete from japan.
187 | Function: f_add_column(country of athlete)
188 | Explanation: The statement is about the number of athletes from japan. We need to known the country of each athlete. There is no column of the country of athletes. We add a column "country of athlete"."""
189 |
190 | plan_select_column_demo = """If the table only needs a few columns to tell whether the statement is True or False, we use f_select_column() to select these columns for it. For example,
191 | /*
192 | col : code | county | former province | area (km2) | population | capital
193 | row 1 : 1 | mombasa | coast | 212.5 | 939,370 | mombasa (city)
194 | row 2 : 2 | kwale | coast | 8,270.3 | 649,931 | kwale
195 | row 3 : 3 | kilifi | coast | 12,245.9 | 1,109,735 | kilifi
196 | */
197 | Statement: momasa is a county with population higher than 500000.
198 | Function: f_select_column(county, population)
199 | Explanation: The statement wants to check momasa county with population higher than 500000. We need to know the county and its population. We select the column "county" and column "population"."""
200 |
201 | plan_select_row_demo = """If the table only needs a few rows to tell whether the statement is True or False, we use f_select_row() to select these rows for it. For example,
202 | /*
203 | table caption : jeep grand cherokee.
204 | col : years | displacement | engine | power | torque
205 | row 1 : 1999 - 2004 | 4.0l (242cid) | power tech i6 | - | 3000 rpm
206 | row 2 : 1999 - 2004 | 4.7l (287cid) | powertech v8 | - | 3200 rpm
207 | row 3 : 2002 - 2004 | 4.7l (287cid) | high output powertech v8 | - | -
208 | row 4 : 1999 - 2001 | 3.1l diesel | 531 ohv diesel i5 | - | -
209 | row 5 : 2002 - 2004 | 2.7l diesel | om647 diesel i5 | - | -
210 | */
211 | Statement: the jeep grand cherokee with the om647 diesel i5 had the third lowest numbered displacement.
212 | Function: f_select_row(row 1, row 4, row 5)
213 | Explanation: The statement wants to check the om647 diesel i5 had third lowest numbered displacement. We need to know the first three low numbered displacement and all rows that power is om647 diesel i5. We select the row 1, row 4, row 5."""
214 |
215 | plan_group_column_demo = """If the statement is about items with the same value and the number of these items, we use f_group_column() to group the items. For example,
216 | /*
217 | col : district | name | party | residence | first served
218 | row 1 : district 1 | nelson albano | dem | vineland | 2006
219 | row 2 : district 1 | robert andrzejczak | dem | middle twp. | 2013†
220 | row 3 : district 2 | john f. amodeo | rep | margate | 2008
221 | */
222 | Statement: there are 5 districts are democratic
223 | Function: f_group_column(party)
224 | Explanation: The statement wants to check 5 districts are democratic. We need to know the number of dem in the table. We group the rows according to column "party"."""
225 |
226 | plan_sort_column_demo = """If the statement is about the order of items in a column, we use f_sort_column() to sort the items. For example,
227 | /*
228 | col : position | club | played | points
229 | row 1 : 1 | malaga cf | 42 | 79
230 | row 10 : 10 | cp merida | 42 | 59
231 | row 3 : 3 | cd numancia | 42 | 73
232 | */
233 | Statement: cd numancia placed in the last position.
234 | Function: f_sort_column(position)
235 | Explanation: The statement wants to check about cd numancia in the last position. We need to know the order of position from last to front. We sort the rows according to column "position"."""
236 |
237 | plan_full_demo_simple = """Here are examples of using the operations to tell whether the statement is True or False.
238 |
239 | /*
240 | col : date | division | league | regular season | playoffs | open cup | avg. attendance
241 | row 1 : 2001/01/02 | 2 | usl a-league | 4th, western | quarterfinals | did not qualify | 7,169
242 | row 2 : 2002/08/06 | 2 | usl a-league | 2nd, pacific | 1st round | did not qualify | 6,260
243 | row 5 : 2005/03/24 | 2 | usl first division | 5th | quarterfinals | 4th round | 6,028
244 | */
245 | Statement: 2005 is the last year where this team was a part of the usl a-league?
246 | Function Chain: f_add_column(year) -> f_select_row(row 1, row 2) -> f_select_column(year, league) -> f_sort_column(year) ->
247 |
248 | */
249 | col : rank | lane | athlete | time
250 | row 1 : 1 | 6 | manjeet kaur (ind) | 52.17
251 | row 2 : 2 | 5 | olga tereshkova (kaz) | 51.86
252 | row 3 : 3 | 4 | pinki pramanik (ind) | 53.06
253 | */
254 | Statement: There are 10 athletes from India.
255 | Function Chain: f_add_column(country of athletes) -> f_select_row(row 1, row 3) -> f_select_column(athlete, country of athletes) -> f_group_column(country of athletes) ->
256 |
257 | /*
258 | col : week | when | kickoff | opponent | results; final score | results; team record | game site | attendance
259 | row 1 : 1 | saturday, april 13 | 7:00 p.m. | at rhein fire | w 27–21 | 1–0 | rheinstadion | 32,092
260 | row 2 : 2 | saturday, april 20 | 7:00 p.m. | london monarchs | w 37–3 | 2–0 | waldstadion | 34,186
261 | row 3 : 3 | sunday, april 28 | 6:00 p.m. | at barcelona dragons | w 33–29 | 3–0 | estadi olímpic de montjuïc | 17,503
262 | */
263 | Statement: the competition with highest points scored is played on April 20.
264 | Function Chain: f_add_column(points scored) -> f_select_row(*) -> f_select_column(when, points scored) -> f_sort_column(points scored) ->
265 |
266 | /*
267 | col : iso/iec standard | status | wg
268 | row 1 : iso/iec tr 19759 | published (2005) | 20
269 | row 2 : iso/iec 15288 | published (2008) | 7
270 | row 3 : iso/iec 12207 | published (2011) | 7
271 | */
272 | Statement: 2 standards are published in 2011
273 | Function Chain: f_add_column(year) -> f_select_row(row 3) -> f_select_column(year) -> f_group_column(year) ->
274 |
275 | Here are examples of using the operations to tell whether the statement is True or False."""
276 |
277 | possible_next_operation_dict = {
278 | "": [
279 | "add_column",
280 | "select_row",
281 | "select_column",
282 | "group_column",
283 | "sort_column",
284 | ],
285 | "add_column": [
286 | "select_row",
287 | "select_column",
288 | "group_column",
289 | "sort_column",
290 | "",
291 | ],
292 | "select_row": [
293 | "select_column",
294 | "group_column",
295 | "sort_column",
296 | "",
297 | ],
298 | "select_column": [
299 | "group_column",
300 | "sort_column",
301 | "",
302 | ],
303 | "group_column": [
304 | "sort_column",
305 | "",
306 | ],
307 | "sort_column": [
308 | "",
309 | ],
310 | }
311 |
312 |
313 | def get_operation_name(string):
314 | # f_xxxx(...)
315 | res = re.findall(r"f_(.*?)\(.*\)", string)[0]
316 | return res
317 |
318 |
319 | def get_all_operation_names(string):
320 | operation_names = []
321 | parts = string.split("->")
322 | for part in parts:
323 | part = part.strip()
324 | if part == "":
325 | operation_names.append("")
326 | else:
327 | res = re.findall(r"f_(.*?)\(.*\)", part)
328 | if res:
329 | operation_names.append(res[0])
330 | return operation_names
331 |
332 |
333 | def generate_prompt_for_next_step(
334 | sample,
335 | debug=False,
336 | llm=None,
337 | llm_options=None,
338 | strategy="top",
339 | ):
340 | table_info = get_table_info(sample)
341 | act_chain = table_info["act_chain"]
342 |
343 | if debug:
344 | print("Act Chain: ", act_chain, flush=True)
345 |
346 | kept_act_chain = [x for x in act_chain if not x.startswith("skip")]
347 | kept_act_chain_str = " -> ".join(kept_act_chain)
348 | if kept_act_chain_str:
349 | kept_act_chain_str += " ->"
350 |
351 | skip_act_chain = [x for x in act_chain if x.startswith("skip")]
352 | skip_act_chain_op_names = []
353 | for op in skip_act_chain:
354 | op = op[len("skip ") :]
355 | op_name = get_operation_name(op)
356 | skip_act_chain_op_names.append(op_name)
357 |
358 | if debug:
359 | print("Kept Act Chain: ", kept_act_chain, flush=True)
360 | print("Skip Act Chain: ", skip_act_chain, flush=True)
361 |
362 | last_operation = (
363 | "" if not kept_act_chain else get_operation_name(kept_act_chain[-1])
364 | )
365 | possible_next_operations = possible_next_operation_dict[last_operation]
366 | possible_next_operations = [
367 | x for x in possible_next_operations if x not in skip_act_chain_op_names
368 | ]
369 |
370 | if debug:
371 | print("Last Operation: ", last_operation, flush=True)
372 | print("Possible Next Operations: ", possible_next_operations, flush=True)
373 |
374 | if len(possible_next_operations) == 1:
375 | log = {
376 | "act_chain": act_chain,
377 | "last_operation": last_operation,
378 | "possible_next_operations": possible_next_operations,
379 | "prompt": None,
380 | "response": None,
381 | "generate_operations": None,
382 | "next_operation": possible_next_operations[0],
383 | }
384 | return possible_next_operations[0], log
385 |
386 | prompt = ""
387 | for operation in possible_next_operations:
388 | if operation == "":
389 | continue
390 | prompt += eval(f"plan_{operation}_demo") + "\n\n"
391 |
392 | prompt += plan_full_demo_simple + "\n\n"
393 |
394 | prompt += "/*\n" + table2string(table_info["table_text"]) + "\n*/\n"
395 | prompt += "Statement: " + sample["statement"] + "\n"
396 |
397 | _possible_next_operations_str = " or ".join(
398 | [f"f_{op}()" if op != "" else op for op in possible_next_operations]
399 | )
400 |
401 | if len(possible_next_operations) > 1:
402 | prompt += (
403 | f"The next operation must be one of {_possible_next_operations_str}.\n"
404 | )
405 | else:
406 | prompt += f"The next operation must be {_possible_next_operations_str}.\n"
407 |
408 | prompt += "Function Chain: " + kept_act_chain_str
409 |
410 | responses = llm.generate_plus_with_score(
411 | prompt, options=llm_options, end_str="\n\n"
412 | )
413 |
414 | if strategy == "top":
415 | response = responses[0][0]
416 | generate_operations = get_all_operation_names(response)
417 | if debug:
418 | print('Prompt:', prompt.split("\n\n")[-1])
419 | print('Response:', response)
420 | print("Generated Operations: ", generate_operations)
421 | next_operation = ""
422 | for operation in generate_operations:
423 | if operation in possible_next_operations:
424 | next_operation = operation
425 | break
426 | elif strategy == "voting":
427 | next_operation_conf_dict = defaultdict(float)
428 | for response, score in responses:
429 | generate_operations = get_all_operation_names(response)
430 | next_operation = None
431 | for operation in generate_operations:
432 | if operation in possible_next_operations:
433 | next_operation = operation
434 | break
435 | if next_operation:
436 | next_operation_conf_dict[next_operation] += np.exp(score)
437 | if len(next_operation_conf_dict) != 0:
438 | next_operation_conf_pairs = sorted(
439 | next_operation_conf_dict.items(), key=lambda x: x[1], reverse=True
440 | )
441 | next_operation = next_operation_conf_pairs[0][0]
442 | else:
443 | next_operation = ""
444 |
445 | log = {
446 | "act_chain": act_chain,
447 | "last_operation": last_operation,
448 | "possible_next_operations": possible_next_operations,
449 | "prompt": prompt,
450 | "response": response,
451 | "generate_operations": generate_operations,
452 | "next_operation": next_operation,
453 | }
454 |
455 | return next_operation, log
456 |
457 |
458 | def dynamic_chain_exec_one_sample(
459 | sample,
460 | llm,
461 | llm_options=None,
462 | strategy="top",
463 | debug=False,
464 | operation_parameter_dict=None,
465 | ):
466 | if operation_parameter_dict is None:
467 | operation_parameter_dict = {
468 | "add_column": (
469 | "addColumn",
470 | add_column_func,
471 | {},
472 | llm.get_model_options(
473 | temperature=0.0,
474 | per_example_max_decode_steps=150,
475 | per_example_top_p=1.0,
476 | ),
477 | ),
478 | "select_row": (
479 | "selectRow",
480 | select_row_func,
481 | {},
482 | llm.get_model_options(
483 | temperature=0.5,
484 | per_example_max_decode_steps=150,
485 | per_example_top_p=1.0,
486 | n_sample=8,
487 | ),
488 | ),
489 | "select_column": (
490 | "selectColumn",
491 | select_column_func,
492 | {},
493 | llm.get_model_options(
494 | temperature=0.5,
495 | per_example_max_decode_steps=150,
496 | per_example_top_p=1.0,
497 | n_sample=8,
498 | ),
499 | ),
500 | "group_column": (
501 | "groupColumn",
502 | group_column_func,
503 | dict(skip_op=[]),
504 | llm.get_model_options(
505 | temperature=0.0,
506 | per_example_max_decode_steps=150,
507 | per_example_top_p=1.0,
508 | ),
509 | ),
510 | "sort_column": (
511 | "sortColumn",
512 | sort_column_func,
513 | dict(skip_op=[]),
514 | llm.get_model_options(
515 | temperature=0.0,
516 | per_example_max_decode_steps=150,
517 | per_example_top_p=1.0,
518 | ),
519 | ),
520 | }
521 |
522 | dynamic_chain_log = []
523 |
524 | current_sample = copy.deepcopy(sample)
525 | while True:
526 | # generate next operation
527 | next_operation, log = generate_prompt_for_next_step(
528 | current_sample,
529 | llm=llm,
530 | llm_options=llm_options,
531 | strategy=strategy,
532 | debug=debug,
533 | )
534 | dynamic_chain_log.append(log)
535 |
536 | if debug:
537 | print(next_operation)
538 |
539 | if next_operation == "":
540 | break
541 |
542 | param = operation_parameter_dict[next_operation]
543 | op_name, solver_func, kargs, op_llm_options = param
544 |
545 | table_info = get_table_info(current_sample)
546 |
547 | current_sample = solver_func(
548 | current_sample, table_info, llm=llm, llm_options=op_llm_options, **kargs
549 | )
550 | return current_sample, dynamic_chain_log
551 |
552 |
553 | def dynamic_chain_exec_with_cache_for_loop(
554 | all_samples,
555 | llm,
556 | llm_options=None,
557 | strategy="voting",
558 | cache_dir="./cache/debug",
559 | ):
560 | os.makedirs(cache_dir, exist_ok=True)
561 | result_samples = [None for _ in range(len(all_samples))]
562 | dynamic_chain_log_list = [None for _ in range(len(all_samples))]
563 |
564 | cache_filename = "case-{}.pkl"
565 |
566 | def _func(idx):
567 | sample = all_samples[idx]
568 | sample_id = sample["id"]
569 | cache_path = os.path.join(cache_dir, cache_filename.format(sample_id))
570 | if os.path.exists(cache_path):
571 | _, proc_sample, log = pickle.load(open(cache_path, "rb"))
572 | else:
573 | proc_sample, log = dynamic_chain_exec_one_sample(
574 | sample, llm=llm, llm_options=llm_options, strategy=strategy
575 | )
576 | pickle.dump((sample, proc_sample, log), open(cache_path, "wb"))
577 | result_samples[idx] = proc_sample
578 | dynamic_chain_log_list[idx] = log
579 |
580 | for idx in tqdm(range(len(all_samples)), total=len(all_samples)):
581 | try:
582 | _func(idx)
583 | except Exception as e:
584 | print(f"IDX={idx}: {e}", flush=True)
585 |
586 | return result_samples, dynamic_chain_log_list
587 |
588 |
589 | def _dynamic_chain_exec_with_cache_mp_core(arg):
590 | idx, sample, llm, llm_options, strategy, cache_dir = arg
591 |
592 | cache_filename = "case-{}.pkl"
593 | try:
594 | sample_id = sample["id"]
595 | cache_path = os.path.join(cache_dir, cache_filename.format(idx))
596 | if os.path.exists(cache_path):
597 | _, proc_sample, log = pickle.load(open(cache_path, "rb"))
598 | else:
599 | proc_sample, log = dynamic_chain_exec_one_sample(
600 | sample, llm=llm, llm_options=llm_options, strategy=strategy
601 | )
602 | pickle.dump((sample, proc_sample, log), open(cache_path, "wb"))
603 | return idx, proc_sample, log
604 | except Exception as e:
605 | print(f"Error in {sample_id}: {e}", flush=True)
606 | return idx, None, None
607 |
608 |
609 | def dynamic_chain_exec_with_cache_mp(
610 | all_samples,
611 | llm,
612 | llm_options=None,
613 | strategy="voting",
614 | cache_dir="./results/debug",
615 | n_proc=10,
616 | chunk_size=50,
617 | ):
618 | os.makedirs(cache_dir, exist_ok=True)
619 | result_samples = [None for _ in range(len(all_samples))]
620 | dynamic_chain_log_list = [None for _ in range(len(all_samples))]
621 |
622 | args = [
623 | (idx, sample, llm, llm_options, strategy, cache_dir)
624 | for idx, sample in enumerate(all_samples)
625 | ]
626 |
627 | with mp.Pool(n_proc) as p:
628 | for idx, proc_sample, log in tqdm(
629 | p.imap_unordered(
630 | _dynamic_chain_exec_with_cache_mp_core, args, chunksize=chunk_size
631 | ),
632 | total=len(all_samples),
633 | ):
634 | result_samples[idx] = proc_sample
635 | dynamic_chain_log_list[idx] = log
636 |
637 | return result_samples, dynamic_chain_log_list
638 |
--------------------------------------------------------------------------------
/utils/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def tabfact_match_func(sample, strategy="top"):
17 | results = sample["chain"][-1]["parameter_and_conf"]
18 |
19 | if strategy == "top":
20 | res = results[0][0]
21 | elif strategy == "weighted":
22 | res_conf_dict = {}
23 | for res, conf in results:
24 | if res not in res_conf_dict:
25 | res_conf_dict[res] = 0
26 | res_conf_dict[res] += conf
27 | res_conf_rank = sorted(res_conf_dict.items(), key=lambda x: x[1], reverse=True)
28 | res = res_conf_rank[0][0]
29 | else:
30 | raise NotImplementedError
31 |
32 | res = res.lower()
33 | if res == "true":
34 | res = "yes"
35 | if res == "false":
36 | res = "no"
37 | if res == "yes" and sample["label"] == 1:
38 | return True
39 | elif res == "no" and sample["label"] == 0:
40 | return True
41 | else:
42 | return False
43 |
44 |
45 | def tabfact_match_func_for_samples(all_samples, strategy="top"):
46 | correct_list = []
47 | for sample in all_samples:
48 | try:
49 | if tabfact_match_func(sample, strategy):
50 | correct_list.append(1)
51 | else:
52 | correct_list.append(0)
53 | except:
54 | print(f"Error")
55 | continue
56 | return sum(correct_list) / len(correct_list)
57 |
--------------------------------------------------------------------------------
/utils/helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import pandas as pd
17 | import json
18 | import re
19 | from _ctypes import PyObj_FromPtr
20 |
21 |
22 | def table2df(table_text, num_rows=100):
23 | header, rows = table_text[0], table_text[1:]
24 | rows = rows[:num_rows]
25 | df = pd.DataFrame(data=rows, columns=header)
26 | return df
27 |
28 |
29 | def table2string(
30 | table_text,
31 | num_rows=100,
32 | caption=None,
33 | ):
34 | df = table2df(table_text, num_rows)
35 | linear_table = ""
36 | if caption is not None:
37 | linear_table += "table caption : " + caption + "\n"
38 |
39 | header = "col : " + " | ".join(df.columns) + "\n"
40 | linear_table += header
41 | rows = df.values.tolist()
42 | for row_idx, row in enumerate(rows):
43 | row = [str(x) for x in row]
44 | line = "row {} : ".format(row_idx + 1) + " | ".join(row)
45 | if row_idx != len(rows) - 1:
46 | line += "\n"
47 | linear_table += line
48 | return linear_table
49 |
50 |
51 | class NoIndent(object):
52 | """Value wrapper."""
53 |
54 | def __init__(self, value):
55 | self.value = value
56 |
57 |
58 | class MyEncoder(json.JSONEncoder):
59 | FORMAT_SPEC = "@@{}@@"
60 | regex = re.compile(FORMAT_SPEC.format(r"(\d+)"))
61 |
62 | def __init__(self, **kwargs):
63 | # Save copy of any keyword argument values needed for use here.
64 | self.__sort_keys = kwargs.get("sort_keys", None)
65 | super(MyEncoder, self).__init__(**kwargs)
66 |
67 | def default(self, obj):
68 | return (
69 | self.FORMAT_SPEC.format(id(obj))
70 | if isinstance(obj, NoIndent)
71 | else super(MyEncoder, self).default(obj)
72 | )
73 |
74 | def encode(self, obj):
75 | format_spec = self.FORMAT_SPEC # Local var to expedite access.
76 | json_repr = super(MyEncoder, self).encode(obj) # Default JSON.
77 |
78 | # Replace any marked-up object ids in the JSON repr with the
79 | # value returned from the json.dumps() of the corresponding
80 | # wrapped Python object.
81 | for match in self.regex.finditer(json_repr):
82 | # see https://stackoverflow.com/a/15012814/355230
83 | id = int(match.group(1))
84 | no_indent = PyObj_FromPtr(id)
85 | json_obj_repr = json.dumps(no_indent.value, sort_keys=self.__sort_keys)
86 |
87 | # Replace the matched id string with json formatted representation
88 | # of the corresponding Python object.
89 | json_repr = json_repr.replace(
90 | '"{}"'.format(format_spec.format(id)), json_obj_repr
91 | )
92 |
93 | return json_repr
94 |
--------------------------------------------------------------------------------
/utils/llm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import openai
17 | import time
18 | import numpy as np
19 |
20 |
21 | class ChatGPT:
22 | def __init__(self, model_name, key):
23 | self.model_name = model_name
24 | self.key = key
25 |
26 | def get_model_options(
27 | self,
28 | temperature=0,
29 | per_example_max_decode_steps=150,
30 | per_example_top_p=1,
31 | n_sample=1,
32 | ):
33 | return dict(
34 | temperature=temperature,
35 | n=n_sample,
36 | top_p=per_example_top_p,
37 | max_tokens=per_example_max_decode_steps,
38 | )
39 |
40 | def generate_plus_with_score(self, prompt, options=None, end_str=None):
41 | if options is None:
42 | options = self.get_model_options()
43 | messages = [
44 | {
45 | "role": "system",
46 | "content": "I will give you some examples, you need to follow the examples and complete the text, and no other content.",
47 | },
48 | {"role": "user", "content": prompt},
49 | ]
50 | gpt_responses = None
51 | retry_num = 0
52 | retry_limit = 2
53 | error = None
54 | while gpt_responses is None:
55 | try:
56 | gpt_responses = openai.ChatCompletion.create(
57 | model=self.model_name,
58 | messages=messages,
59 | stop=end_str,
60 | api_key=self.key,
61 | **options
62 | )
63 | error = None
64 | except Exception as e:
65 | print(str(e), flush=True)
66 | error = str(e)
67 | if "This model's maximum context length is" in str(e):
68 | print(e, flush=True)
69 | gpt_responses = {
70 | "choices": [{"message": {"content": "PLACEHOLDER"}}]
71 | }
72 | elif retry_num > retry_limit:
73 | error = "too many retry times"
74 | gpt_responses = {
75 | "choices": [{"message": {"content": "PLACEHOLDER"}}]
76 | }
77 | else:
78 | time.sleep(60)
79 | retry_num += 1
80 | if error:
81 | raise Exception(error)
82 | results = []
83 | for i, res in enumerate(gpt_responses["choices"]):
84 | text = res["message"]["content"]
85 | fake_conf = (len(gpt_responses["choices"]) - i) / len(
86 | gpt_responses["choices"]
87 | )
88 | results.append((text, np.log(fake_conf)))
89 |
90 | return results
91 |
92 | def generate(self, prompt, options=None, end_str=None):
93 | if options is None:
94 | options = self.get_model_options()
95 | options["n"] = 1
96 | result = self.generate_plus_with_score(prompt, options, end_str)[0][0]
97 | return result
98 |
--------------------------------------------------------------------------------
/utils/load_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Chain-of-Table authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import json
17 | from tqdm import tqdm
18 |
19 | def load_tabfact_dataset(
20 | dataset_path,
21 | raw2clean_path,
22 | tag="test",
23 | first_n=-1,
24 | ):
25 | tabfact_statement_raw2clean_dict = {}
26 | with open(raw2clean_path, "r") as f:
27 | lines = f.readlines()
28 | for line in lines:
29 | info = json.loads(line)
30 | tabfact_statement_raw2clean_dict[info["statement"]] = info["cleaned_statement"]
31 |
32 | dataset = []
33 | if first_n != -1:
34 | all_lines = []
35 | for line in open(dataset_path):
36 | all_lines.append(line)
37 | if len(all_lines) >= first_n: break
38 | else:
39 | all_lines = open(dataset_path).readlines()
40 | for i, line in tqdm(enumerate(all_lines), total=len(all_lines), desc=f"Loading tabfact-{tag} dataset"):
41 | info = json.loads(line)
42 | info["id"] = f"{tag}-{i}"
43 | info["chain"] = []
44 | if info["statement"] in tabfact_statement_raw2clean_dict:
45 | info["cleaned_statement"] = tabfact_statement_raw2clean_dict[
46 | info["statement"]
47 | ]
48 | else:
49 | info["cleaned_statement"] = info["statement"]
50 | dataset.append(info)
51 | return dataset
52 |
53 |
54 | def wrap_input_for_demo(statement, table_caption, table_text, cleaned_statement=None):
55 | return {
56 | "statement": statement,
57 | "table_caption": table_caption,
58 | "table_text": table_text,
59 | "cleaned_statement": cleaned_statement if cleaned_statement is not None else statement,
60 | "chain": [],
61 | }
62 |
63 |
--------------------------------------------------------------------------------