├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data.zip ├── operations ├── __init__.py ├── add_column.py ├── final_query.py ├── group_by.py ├── select_column.py ├── select_row.py └── sort_by.py ├── requirements.txt ├── run_demo.ipynb ├── run_tabfact.py ├── third_party └── select_column_row_prompts │ ├── LICENSE │ └── select_column_row_prompts.py └── utils ├── chain.py ├── evaluate.py ├── helper.py ├── llm.py └── load_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | data/ 4 | results/ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 32 | for this purpose. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chain of Table 2 | 3 | Code for paper [Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding](https://arxiv.org/abs/2401.04398) 4 | 5 | *This is not an officially supported Google product.* 6 | 7 | ## Environment 8 | 9 | ```shell 10 | conda create --name cotable python=3.10 -y 11 | conda activate cotable 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Data 16 | 17 | ```shell 18 | unzip data.zip 19 | ``` 20 | 21 | ## Command Usages 22 | 23 | ### Arguments 24 | 25 | - `--dataset_path`: path to the dataset, default: `./data/tabfact/test.jsonl` 26 | - `--raw2clean_path`: path to the preprocessed raw2clean file, default: `./data/tabfact/raw2clean.json` (cleaned by [Dater](https://arxiv.org/pdf/2301.13808.pdf)) 27 | - `--model_name`: name of the OpenAI API, default: `gpt-3.5-turbo-16k-0613` 28 | - `--result_dir`: path to the result directory, default: `./results/tabfact` 29 | - `--openai_key`: key of the OpenAI API 30 | - `--first_n`: number of the first n samples to evaluate, default: `-1` means whole dataset 31 | - `--n_proc`: number of processes to use in multiprocessing, default: `1` 32 | - `--chunk_size`: chunk size used in multiprocessing, default: `1` 33 | 34 | ### Example usages 35 | 36 | 1. Run tests on the first 10 cases 37 | 38 | ```shell 39 | python run_tabfact.py \ 40 | --result_dir 'results/tabfact_first10' \ 41 | --first_n 10 \ 42 | --n_proc 10 \ 43 | --chunk_size 1 \ 44 | --openai_api_key 45 | ``` 46 | 47 | 2. Run the experiment on the whole dataset 48 | 49 | ```shell 50 | python run_tabfact.py \ 51 | --result_dir 'results/tabfact' \ 52 | --n_proc 20 \ 53 | --chunk_size 10 \ 54 | --openai_api_key 55 | ``` 56 | 57 | ## Cite 58 | 59 | If you find this repository useful, please consider citing: 60 | 61 | ```bibtex 62 | @article{wang2024chain, 63 | title={Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding}, 64 | author={Wang, Zilong and Zhang, Hao and Li, Chun-Liang and Eisenschlos, Julian Martin and Perot, Vincent and Wang, Zifeng and Miculicich, Lesly and Fujii, Yasuhisa and Shang, Jingbo and Lee, Chen-Yu and Pfister, Tomas}, 65 | journal={ICLR}, 66 | year={2024} 67 | } 68 | ``` 69 | 70 | ## Acknowledgement 71 | 72 | We thank [Dater](https://arxiv.org/pdf/2301.13808.pdf) for providing the cleaned TabFact dataset and releasing the [code](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/dater). We include the cleaned raw2clean file in the `data.zip` and the prompts for row/column selection in the `third_party/select_column_row_prompts/select_column_row_prompts.py` under the MIT License. 73 | -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/chain-of-table/1fd716c608f56e7d9d156ead432219c7ab9008af/data.zip -------------------------------------------------------------------------------- /operations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .add_column import add_column_func, add_column_act 17 | from .group_by import group_column_func, group_column_act 18 | from .select_column import select_column_func, select_column_act 19 | from .select_row import select_row_func, select_row_act 20 | from .sort_by import sort_column_func, sort_column_act 21 | 22 | from .final_query import simple_query -------------------------------------------------------------------------------- /operations/add_column.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import re 17 | import copy 18 | import numpy as np 19 | from utils.helper import table2string 20 | 21 | 22 | add_column_demo = """To tell the statement is true or false, we can first use f_add_column() to add more columns to the table. 23 | 24 | The added columns should have these data types: 25 | 1. Numerical: the numerical strings that can be used in sort, sum 26 | 2. Datetype: the strings that describe a date, such as year, month, day 27 | 3. String: other strings 28 | 29 | /* 30 | col : week | when | kickoff | opponent | results; final score | results; team record | game site | attendance 31 | row 1 : 1 | saturday, april 13 | 7:00 p.m. | at rhein fire | w 27–21 | 1–0 | rheinstadion | 32,092 32 | row 2 : 2 | saturday, april 20 | 7:00 p.m. | london monarchs | w 37–3 | 2–0 | waldstadion | 34,186 33 | row 3 : 3 | sunday, april 28 | 6:00 p.m. | at barcelona dragons | w 33–29 | 3–0 | estadi olímpic de montjuïc | 17,503 34 | */ 35 | Statement: april 20 is the date of the competition with highest attendance. 36 | The existing columns are: "week", "when", "kickoff", "opponent", "results; final score", "results; team record", "game site", "attendance". 37 | Explanation: To tell this statement is true or false, we need to know the attendence number of each competition. We extract the value from column "attendance" and create a different column "attendance number" for each row. The datatype is numerical. 38 | Therefore, the answer is: f_add_column(attendance number). The value: 32092 | 34186 | 17503 39 | 40 | /* 41 | col : rank | lane | player | time 42 | row 1 : | 5 | olga tereshkova (kaz) | 51.86 43 | row 2 : | 6 | manjeet kaur (ind) | 52.17 44 | row 3 : | 3 | asami tanno (jpn) | 53.04 45 | */ 46 | Statement: there are one athlete from japan. 47 | The existing columns are: rank, lane, player, time. 48 | Explanation: To tell this statement is true or false, we need to know the country of each athelte. We extract the value from column "player" and create a different column "country of athletes" for each row. The datatype is string. 49 | Therefore, the answer is: f_add_column(country of athletes). The value: kaz | ind | jpn 50 | 51 | /* 52 | col : year | competition | venue | position | notes 53 | row 1 : 1991 | european junior championships | thessaloniki, greece | 10th | 4.90 m 54 | row 2 : 1992 | world junior championships | seoul, south korea | 1st | 5.45 m 55 | row 3 : 1996 | european indoor championships | stockholm, sweden | 14th (q) | 5.45 m 56 | */ 57 | Statement: laurens place 1st in 1991. 58 | The existing columns are: year, competition, venue, position, notes. 59 | Explanation: To tell this statement is true or false, we need to know the place of each competition. We extract the value from column "position" and create a different column "placing result" for each row. The datatype is numerical. 60 | Therefore, the answer is: f_add_column(placing result). The value: 10 | 1 | 14 61 | 62 | /* 63 | col : iso/iec standard | status | wg 64 | row 1 : iso/iec tr 19759 | published (2005) | 20 65 | row 2 : iso/iec 15288 | published (2008) | 7 66 | row 3 : iso/iec 12207 | published (2008) | 7 67 | */ 68 | Statement: the standards published three times in 2008. 69 | The existing columns are: iso/iec standard, title, status, description, wg. 70 | Explanation: To tell this statement is true or false, we need to know the year of each standard. We extract the value from column "status" and create a different column "year of standard" for each row. The datatype is datetype. 71 | Therefore, the answer is: f_add_column(year of standard). The value: 2005 | 2008 | 2008 72 | 73 | /* 74 | col : match | date | ground | opponent | score1 | pos. | pts. | gd 75 | row 1 : 1 | 15 august | a | bayer uerdingen | 3 – 0 | 1 | 2 | 3 76 | row 2 : 2 | 22 july | h | 1. fc kaiserslautern | 1 – 0 | 1 | 4 | 4 77 | row 3 : 4 | 29 september | h | dynamo dresden | 3 – 1 | 1 | 6 | 6 78 | */ 79 | Statement: they play 5 times in august. 80 | The existing columns are: match, date, ground, opponent, score1, pos., pts., gd. 81 | Explanation: To tell this statement is true or false, we need to know the month of each match. We extract the value from column "date" and create a different column "month" for each row. The datatype is datetype. 82 | Therefore, the answer is: f_add_column(month). The value: august | july | september 83 | 84 | /* 85 | table caption : 1984 u.s. open (golf) 86 | col : place | player | country | score | to par 87 | row 1 : 1 | hale irwin | united states | 68 + 68 = 136 | - 4 88 | row 2 : 2 | fuzzy zoeller | united states | 71 + 66 = 137 | -- 3 89 | row 3 : t3 | david canipe | united states | 69 + 69 = 138 | - 2 90 | */ 91 | Statement: david canipe of united states has 138 score 92 | The existing columns are: place, player, country, score, to par. 93 | Explanation: To tell this statement is true or false, we need to know the score values of each player. We extract the value from column "score" and create a different column "score value" for each row. The datatype is numerical. 94 | Therefore, the answer is: f_add_column(score value). The value: 136 | 137 | 138 95 | 96 | /* 97 | col : code | county | former province | area (km2) | population; census 2009 | capital 98 | row 1 : 1 | mombasa | coast | 212.5 | 939,370 | mombasa (city) 99 | row 2 : 2 | kwale | coast | 8,270.3 | 649,931 | kwale 100 | row 3 : 3 | kilifi | coast | 12,245.9 | 1,109,735 | kilifi 101 | */ 102 | Statement: kwale has a population in 2009 higher than 500,000. 103 | The existing columns are: code, county, former province, area (km2), population; census 2009, capital. 104 | Explanation: To tell this statement is true or false, we need to know the population of each county. We extract the value from column "population; census 2009" and create a different column "population" for each row. The datatype is numerical. 105 | Therefore, the answer is: f_add_column(population). The value: 939370 | 649311 | 1109735""" 106 | 107 | 108 | def add_column_build_prompt(table_text, statement, table_caption=None, num_rows=100): 109 | table_str = table2string(table_text, caption=table_caption, num_rows=num_rows) 110 | prompt = "/*\n" + table_str + "\n*/\n" 111 | prompt += "Statement: " + statement + "\n" 112 | prompt += "The existing columns are: " 113 | prompt += ", ".join(table_text[0]) + ".\n" 114 | prompt += "Explanation:" 115 | return prompt 116 | 117 | 118 | def add_column_func( 119 | sample, table_info, llm, llm_options=None, debug=False, skip_op=[], strategy="top" 120 | ): 121 | operation = { 122 | "operation_name": "add_column", 123 | "parameter_and_conf": [], 124 | } 125 | failure_sample_copy = copy.deepcopy(sample) 126 | failure_sample_copy["chain"].append(operation) 127 | 128 | # table_info = get_table_info(sample, skip_op=skip_op) 129 | table_text = table_info["table_text"] 130 | 131 | table_caption = sample["table_caption"] 132 | cleaned_statement = sample["cleaned_statement"] 133 | cleaned_statement = re.sub(r"\d+", "_", cleaned_statement) 134 | 135 | prompt = "" + add_column_demo.rstrip() + "\n\n" 136 | prompt += add_column_build_prompt( 137 | table_text, cleaned_statement, table_caption=table_caption, num_rows=3 138 | ) 139 | if llm_options is None: 140 | llm_options = llm.get_model_options() 141 | llm_options["n"] = 1 142 | responses = llm.generate_plus_with_score( 143 | prompt, 144 | options=llm_options, 145 | ) 146 | 147 | add_column_and_conf = {} 148 | for res, score in responses: 149 | try: 150 | f_add_func = re.findall(r"f_add_column\(.*\)", res, re.S)[0].strip() 151 | left = f_add_func.index("(") + 1 152 | right = f_add_func.index(")") 153 | add_column = f_add_func[left:right].strip() 154 | first_3_values = res.split("The value:")[-1].strip().split("|") 155 | first_3_values = [v.strip() for v in first_3_values] 156 | assert len(first_3_values) == 3 157 | except: 158 | continue 159 | 160 | add_column_key = str((add_column, first_3_values, res)) 161 | if add_column_key not in add_column_and_conf: 162 | add_column_and_conf[add_column_key] = 0 163 | add_column_and_conf[add_column_key] += np.exp(score) 164 | 165 | if len(add_column_and_conf) == 0: 166 | return failure_sample_copy 167 | 168 | add_column_and_conf_list = sorted( 169 | add_column_and_conf.items(), key=lambda x: x[1], reverse=True 170 | ) 171 | if strategy == "top": 172 | selected_add_column_key = add_column_and_conf_list[0][0] 173 | selected_add_column_conf = add_column_and_conf_list[0][1] 174 | else: 175 | raise NotImplementedError() 176 | 177 | add_column, first_3_values, llm_response = eval(selected_add_column_key) 178 | 179 | existing_columns = table_text[0] 180 | if add_column in existing_columns: 181 | return failure_sample_copy 182 | 183 | add_column_contents = [] + first_3_values 184 | 185 | # get following contents 186 | try: 187 | left_index = llm_response.index("We extract the value from") 188 | right_index = llm_response.index("The value:") 189 | explanaiton_beginning = llm_response[left_index:right_index] + "The value:" 190 | except: 191 | return failure_sample_copy 192 | 193 | def _sample_to_simple_prompt_header(table_text, num_rows=3): 194 | x = "" 195 | x += "/*\n" 196 | x += table2string(table_text, caption=table_caption, num_rows=num_rows) + "\n" 197 | x += "*/\n" 198 | x += "Explanation: " 199 | return x 200 | 201 | new_prompt = "" 202 | new_prompt += ( 203 | _sample_to_simple_prompt_header(table_text, num_rows=3) 204 | + llm_response[left_index:] 205 | ) 206 | 207 | headers = table_text[0] 208 | rows = table_text[1:] 209 | for i in range(3, len(rows)): 210 | partial_table_text = [headers] + rows[i : i + 1] 211 | cur_prompt = ( 212 | new_prompt 213 | + "\n\n" 214 | + _sample_to_simple_prompt_header(partial_table_text) 215 | + explanaiton_beginning 216 | ) 217 | cur_response = llm.generate( 218 | cur_prompt, 219 | options=llm.get_model_options( 220 | per_example_max_decode_steps=150, per_example_top_p=1.0 221 | ), 222 | ).strip() 223 | if debug: 224 | print(cur_prompt) 225 | print(cur_response) 226 | print("---") 227 | print() 228 | 229 | contents = cur_response 230 | if "|" in contents: 231 | contents = contents.split("|")[0].strip() 232 | 233 | add_column_contents.append(contents) 234 | 235 | if debug: 236 | print("New col contents: ", add_column_contents) 237 | 238 | add_column_info = [ 239 | (str((add_column, add_column_contents)), selected_add_column_conf) 240 | ] 241 | 242 | operation = { 243 | "operation_name": "add_column", 244 | "parameter_and_conf": add_column_info, 245 | } 246 | 247 | sample_copy = copy.deepcopy(sample) 248 | sample_copy["chain"].append(operation) 249 | 250 | return sample_copy 251 | 252 | 253 | def add_column_act(table_info, operation, skip_op=[], debug=False): 254 | table_info = copy.deepcopy(table_info) 255 | 256 | failure_table_info = copy.deepcopy(table_info) 257 | failure_table_info["act_chain"].append("skip f_add_column()") 258 | if "add_column" in skip_op: 259 | return failure_table_info 260 | if len(operation["parameter_and_conf"]) == 0: 261 | return failure_table_info 262 | 263 | add_column_key, _ = operation["parameter_and_conf"][0] 264 | add_column, add_column_contents = eval(add_column_key) 265 | 266 | table_text = table_info["table_text"] 267 | headers = table_text[0] 268 | rows = table_text[1:] 269 | 270 | header2contents = {} 271 | for i, header in enumerate(headers): 272 | header2contents[header] = [] 273 | for row in rows: 274 | header2contents[header].append(row[i]) 275 | 276 | if add_column.startswith("number of"): 277 | # remove 'number of' 278 | if debug: 279 | print("remove number of") 280 | return failure_table_info 281 | 282 | if len(set(add_column_contents)) == 1: 283 | # all same 284 | if debug: 285 | print("all same") 286 | return failure_table_info 287 | 288 | for x in add_column_contents: 289 | if x.strip() == "": 290 | # empty cell 291 | if debug: 292 | print("empty cell") 293 | return failure_table_info 294 | 295 | if add_column in headers: 296 | # same column header 297 | if debug: 298 | print("same column header") 299 | return failure_table_info 300 | 301 | for header in header2contents: 302 | if add_column_contents == header2contents[header]: 303 | # different header, same content 304 | if debug: 305 | print("different header, same content") 306 | return failure_table_info 307 | 308 | exist_flag = False 309 | 310 | for header, contents in header2contents.items(): 311 | current_column_exist_flag = True 312 | 313 | for i in range(len(contents)): 314 | if add_column_contents[i] not in contents[i]: 315 | current_column_exist_flag = False 316 | break 317 | 318 | if current_column_exist_flag: 319 | exist_flag = True 320 | break 321 | if not exist_flag: 322 | if debug: 323 | print(add_column, add_column_contents) 324 | print("not substring of a column") 325 | return failure_table_info 326 | 327 | if debug: 328 | print("default") 329 | new_headers = headers + [add_column] 330 | new_rows = [] 331 | for i, row in enumerate(rows): 332 | row.append(add_column_contents[i]) 333 | new_rows.append(row) 334 | 335 | new_table_text = [new_headers] + new_rows 336 | table_info["table_text"] = new_table_text 337 | table_info["act_chain"].append(f"f_add_column({add_column})") 338 | return table_info 339 | -------------------------------------------------------------------------------- /operations/final_query.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import numpy as np 18 | from utils.helper import table2string 19 | 20 | 21 | general_demo = """/* 22 | table caption : 2008 sidecarcross world championship. 23 | col : position | driver / passenger | equipment | bike no | points 24 | row 1 : 1 | daniël willemsen / reto grütter | ktm - ayr | 1 | 531 25 | row 2 : 2 | kristers sergis / kaspars stupelis | ktm - ayr | 3 | 434 26 | row 3 : 3 | jan hendrickx / tim smeuninx | zabel - vmc | 2 | 421 27 | row 4 : 4 | joris hendrickx / kaspars liepins | zabel - vmc | 8 | 394 28 | row 5 : 5 | marco happich / meinrad schelbert | zabel - mefo | 7 | 317 29 | */ 30 | Statement: bike number 3 is the only one to use equipment ktm - ayr. 31 | The anwser is: NO 32 | 33 | /* 34 | table caption : 1957 vfl season. 35 | col : home team | home team score | away team | away team score | venue | crowd | date 36 | row 1 : footscray | 6.6 (42) | north melbourne | 8.13 (61) | western oval | 13325 | 10 august 1957 37 | row 2 : essendon | 10.15 (75) | south melbourne | 7.13 (55) | windy hill | 16000 | 10 august 1957 38 | row 3 : st kilda | 1.5 (11) | melbourne | 6.13 (49) | junction oval | 17100 | 10 august 1957 39 | row 4 : hawthorn | 14.19 (103) | geelong | 8.7 (55) | brunswick street oval | 12000 | 10 august 1957 40 | row 5 : fitzroy | 8.14 (62) | collingwood | 8.13 (61) | glenferrie oval | 22000 | 10 august 1957 41 | */ 42 | Statement: collingwood was the away team playing at the brunswick street oval venue. 43 | The anwser is: NO 44 | 45 | /* 46 | table caption : co - operative commonwealth federation (ontario section). 47 | col : year of election | candidates elected | of seats available | of votes | % of popular vote 48 | row 1 : 1934 | 1 | 90 | na | 7.0% 49 | row 2 : 1937 | 0 | 90 | na | 5.6% 50 | row 3 : 1943 | 34 | 90 | na | 31.7% 51 | row 4 : 1945 | 8 | 90 | na | 22.4% 52 | row 5 : 1948 | 21 | 90 | na | 27.0% 53 | */ 54 | Statement: the 1937 election had a % of popular vote that was 1.4% lower than that of the 1959 election. 55 | The anwser is: NO 56 | 57 | /* 58 | table caption : 2003 pga championship. 59 | col : place | player | country | score | to par 60 | row 1 : 1 | shaun micheel | united states | 69 + 68 = 137 | - 3 61 | row 2 : t2 | billy andrade | united states | 67 + 72 = 139 | - 1 62 | row 3 : t2 | mike weir | canada | 68 + 71 = 139 | - 1 63 | row 4 : 4 | rod pampling | australia | 66 + 74 = 140 | e 64 | row 5 : t5 | chad campbell | united states | 69 + 72 = 141 | + 1 65 | */ 66 | Statement: phil mickelson was one of five players with + 1 to par , all of which had placed t5. 67 | The anwser is: YES""" 68 | 69 | 70 | def simple_query(sample, table_info, llm, debug=False, use_demo=False, llm_options=None): 71 | table_text = table_info["table_text"] 72 | 73 | caption = sample["table_caption"] 74 | statement = sample["statement"] 75 | 76 | prompt = "" 77 | prompt += "Here are the statement about the table and the task is to tell whether the statement is True or False.\n" 78 | prompt += "If the statement is true, answer YES, and otherwise answer NO.\n" 79 | 80 | if use_demo: 81 | prompt += "\n" 82 | prompt += general_demo + "\n\n" 83 | prompt += "Here are the statement about the table and the task is to tell whether the statement is True or False.\n" 84 | prompt += "If the statement is true, answer YES, and otherwise answer NO.\n" 85 | prompt += "\n" 86 | 87 | prompt += "/*\n" 88 | prompt += table2string(table_text, caption=caption) + "\n" 89 | prompt += "*/\n" 90 | 91 | if "group_sub_table" in table_info: 92 | group_column, group_info = table_info["group_sub_table"] 93 | prompt += "/*\n" 94 | prompt += "Group the rows according to column: {}.\n".format(group_column) 95 | group_headers = ["Group ID", group_column, "Count"] 96 | group_rows = [] 97 | for i, (v, count) in enumerate(group_info): 98 | if v.strip() == "": 99 | v = "[Empty Cell]" 100 | group_rows.append([f"Group {i+1}", v, str(count)]) 101 | prompt += " | ".join(group_headers) + "\n" 102 | for row in group_rows: 103 | prompt += " | ".join(row) + "\n" 104 | prompt += "*/\n" 105 | 106 | prompt += "Statement: " + statement + "\n" 107 | 108 | prompt += "The answer is:" 109 | responses = llm.generate_plus_with_score(prompt, options=llm_options) 110 | responses = [(res.strip(), np.exp(score)) for res, score in responses] 111 | 112 | if debug: 113 | print(prompt) 114 | print(responses) 115 | 116 | operation = { 117 | "operation_name": "simple_query", 118 | "parameter_and_conf": responses, 119 | } 120 | sample_copy = copy.deepcopy(sample) 121 | sample_copy["chain"].append(operation) 122 | 123 | return sample_copy 124 | 125 | -------------------------------------------------------------------------------- /operations/group_by.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import re 17 | import numpy as np 18 | import copy 19 | from utils.helper import table2string 20 | 21 | 22 | group_column_demo = """To tell the statement is true or false, we can first use f_group() to group the values in a column. 23 | 24 | /* 25 | col : rank | lane | athlete | time | country 26 | row 1 : 1 | 6 | manjeet kaur (ind) | 52.17 | ind 27 | row 2 : 2 | 5 | olga tereshkova (kaz) | 51.86 | kaz 28 | row 3 : 3 | 4 | pinki pramanik (ind) | 53.06 | ind 29 | row 4 : 4 | 1 | tang xiaoyin (chn) | 53.66 | chn 30 | row 5 : 5 | 8 | marina maslyonko (kaz) | 53.99 | kaz 31 | */ 32 | Statement: there are one athlete from japan. 33 | The existing columns are: rank, lane, athlete, time, country. 34 | Explanation: the statement says the number of athletes from japan is one. Each row is about an athlete. We can group column "country" to group the athletes from the same country. 35 | Therefore, the answer is: f_group(country). 36 | 37 | /* 38 | col : district | name | party | residence | first served 39 | row 1 : district 1 | nelson albano | dem | vineland | 2006 40 | row 2 : district 1 | robert andrzejczak | dem | middle twp. | 2013† 41 | row 3 : district 2 | john f. amodeo | rep | margate | 2008 42 | row 4 : district 2 | chris a. brown | rep | ventnor | 2012 43 | row 5 : district 3 | john j. burzichelli | dem | paulsboro | 2002 44 | */ 45 | Statement: the number of districts that are democratic is 5. 46 | The existing columns are: district, name, party, residence, first served. 47 | Explanation: the statement says the number of districts that are democratic is 5. Each row is about a district. We can group the column "party" to group the districts from the same party. 48 | Therefore, the answer is: f_group(party).""" 49 | 50 | 51 | def group_column_build_prompt(table_text, statement, table_caption=None, num_rows=100): 52 | table_str = table2string( 53 | table_text, caption=table_caption, num_rows=num_rows 54 | ).strip() 55 | prompt = "/*\n" + table_str + "\n*/\n" 56 | prompt += "Statement: " + statement + "\n" 57 | prompt += "The existing columns are: " 58 | prompt += ", ".join(table_text[0]) + ".\n" 59 | prompt += "Explanation:" 60 | return prompt 61 | 62 | 63 | def group_column_func( 64 | sample, table_info, llm, llm_options=None, debug=False, skip_op=[] 65 | ): 66 | table_text = table_info["table_text"] 67 | 68 | table_caption = sample["table_caption"] 69 | statement = sample["statement"] 70 | prompt = "" + group_column_demo.rstrip() + "\n\n" 71 | prompt += group_column_build_prompt( 72 | table_text, statement, table_caption=table_caption, num_rows=5 73 | ) 74 | responses = llm.generate_plus_with_score( 75 | prompt, 76 | options=llm_options, 77 | ) 78 | 79 | if debug: 80 | print(prompt) 81 | print(responses) 82 | 83 | group_param_and_conf = {} 84 | group_column_and_conf = {} 85 | 86 | headers = table_text[0] 87 | rows = table_text[1:] 88 | for res, score in responses: 89 | re_result = re.findall(r"f_group\(([^\)]*)\)", res, re.S) 90 | 91 | if debug: 92 | print("Re result: ", re_result) 93 | 94 | try: 95 | group_column = re_result[0].strip() 96 | assert group_column in headers 97 | except: 98 | continue 99 | 100 | if group_column not in group_column_and_conf: 101 | group_column_and_conf[group_column] = 0 102 | group_column_and_conf[group_column] += np.exp(score) 103 | 104 | for group_column, conf in group_column_and_conf.items(): 105 | group_column_contents = [] 106 | index = headers.index(group_column) 107 | for row in rows: 108 | group_column_contents.append(row[index]) 109 | 110 | def check_if_group(vs): 111 | vs_without_empty = [v for v in vs if v.strip()] 112 | return len(set(vs_without_empty)) / len(vs_without_empty) <= 0.8 113 | 114 | if not check_if_group(group_column_contents): 115 | continue 116 | 117 | vs_to_group = [] 118 | for i in range(len(group_column_contents)): 119 | vs_to_group.append((group_column_contents[i], i)) 120 | 121 | group_info = [] 122 | for v in sorted(set(group_column_contents)): 123 | group_info.append((v, group_column_contents.count(v))) 124 | group_info = sorted(group_info, key=lambda x: x[1], reverse=True) 125 | 126 | group_key = str((group_column, group_info)) 127 | group_param_and_conf[group_key] = conf 128 | 129 | group_param_and_conf_list = sorted( 130 | group_param_and_conf.items(), key=lambda x: x[1], reverse=True 131 | ) 132 | 133 | operation = { 134 | "operation_name": "group_column", 135 | "parameter_and_conf": group_param_and_conf_list, 136 | } 137 | 138 | sample_copy = copy.deepcopy(sample) 139 | sample_copy["chain"].append(operation) 140 | 141 | return sample_copy 142 | 143 | 144 | def group_column_act(table_info, operation, strategy="top", skip_op=[]): 145 | table_info = copy.deepcopy(table_info) 146 | 147 | failure_table_info = copy.deepcopy(table_info) 148 | failure_table_info["act_chain"].append("skip f_group_column()") 149 | 150 | if "group_column" in skip_op: 151 | return failure_table_info 152 | if len(operation["parameter_and_conf"]) == 0: 153 | return failure_table_info 154 | if strategy == "top": 155 | group_column, group_info = eval(operation["parameter_and_conf"][0][0]) 156 | else: 157 | raise NotImplementedError() 158 | 159 | table_info["group_sub_table"] = (group_column, group_info) 160 | table_info["act_chain"].append(f"f_group_column({group_column})") 161 | 162 | return table_info 163 | -------------------------------------------------------------------------------- /operations/select_column.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import copy 18 | import re 19 | import numpy as np 20 | from utils.helper import table2df, NoIndent, MyEncoder 21 | 22 | from third_party.select_column_row_prompts.select_column_row_prompts import select_column_demo 23 | 24 | 25 | def twoD_list_transpose(arr, keep_num_rows=3): 26 | arr = arr[: keep_num_rows + 1] if keep_num_rows + 1 <= len(arr) else arr 27 | return [[arr[i][j] for i in range(len(arr))] for j in range(len(arr[0]))] 28 | 29 | 30 | def select_column_build_prompt(table_text, statement, table_caption=None, num_rows=100): 31 | df = table2df(table_text, num_rows=num_rows) 32 | tmp = df.values.tolist() 33 | list_table = [list(df.columns)] + tmp 34 | list_table = twoD_list_transpose(list_table, len(list_table)) 35 | if table_caption is not None: 36 | dic = { 37 | "table_caption": table_caption, 38 | "columns": NoIndent(list(df.columns)), 39 | "table_column_priority": [NoIndent(i) for i in list_table], 40 | } 41 | else: 42 | dic = { 43 | "columns": NoIndent(list(df.columns)), 44 | "table_column_priority": [NoIndent(i) for i in list_table], 45 | } 46 | linear_dic = json.dumps( 47 | dic, cls=MyEncoder, ensure_ascii=False, sort_keys=False, indent=2 48 | ) 49 | prompt = "/*\ntable = " + linear_dic + "\n*/\n" 50 | prompt += "statement : " + statement + ".\n" 51 | prompt += "similar words link to columns :\n" 52 | return prompt 53 | 54 | 55 | def select_column_func(sample, table_info, llm, llm_options, debug=False, num_rows=100): 56 | # table_info = get_table_info(sample) 57 | table_text = table_info["table_text"] 58 | 59 | table_caption = sample["table_caption"] 60 | statement = sample["statement"] 61 | 62 | prompt = "" + select_column_demo.rstrip() + "\n\n" 63 | prompt += select_column_build_prompt( 64 | table_text, statement, table_caption, num_rows=num_rows 65 | ) 66 | 67 | responses = llm.generate_plus_with_score(prompt, options=llm_options) 68 | 69 | if debug: 70 | print(prompt) 71 | print(responses) 72 | 73 | pattern_col = r"f_col\(\[(.*?)\]\)" 74 | 75 | pred_conf_dict = {} 76 | for res, score in responses: 77 | try: 78 | pred = re.findall(pattern_col, res, re.S)[0].strip() 79 | except Exception: 80 | continue 81 | pred = pred.split(", ") 82 | pred = [i.strip() for i in pred] 83 | pred = sorted(pred) 84 | pred = str(pred) 85 | if pred not in pred_conf_dict: 86 | pred_conf_dict[pred] = 0 87 | pred_conf_dict[pred] += np.exp(score) 88 | 89 | select_col_rank = sorted(pred_conf_dict.items(), key=lambda x: x[1], reverse=True) 90 | 91 | operation = { 92 | "operation_name": "select_column", 93 | "parameter_and_conf": select_col_rank, 94 | } 95 | 96 | sample_copy = copy.deepcopy(sample) 97 | sample_copy["chain"].append(operation) 98 | 99 | return sample_copy 100 | 101 | 102 | def select_column_act(table_info, operation, union_num=2, skip_op=[]): 103 | table_info = copy.deepcopy(table_info) 104 | 105 | failure_table_info = copy.deepcopy(table_info) 106 | failure_table_info["act_chain"].append("skip f_select_column()") 107 | 108 | if "select_column" in skip_op: 109 | return failure_table_info 110 | 111 | def union_lists(to_union): 112 | return list(set().union(*to_union)) 113 | 114 | def twoD_list_transpose(arr): 115 | return [[arr[i][j] for i in range(len(arr))] for j in range(len(arr[0]))] 116 | 117 | selected_columns_info = operation["parameter_and_conf"] 118 | selected_columns_info = sorted( 119 | selected_columns_info, key=lambda x: x[1], reverse=True 120 | ) 121 | selected_columns_info = selected_columns_info[:union_num] 122 | selected_columns = [x[0] for x in selected_columns_info] 123 | selected_columns = [eval(x) for x in selected_columns] 124 | selected_columns = union_lists(selected_columns) 125 | 126 | real_selected_columns = [] 127 | 128 | table_text = table_info["table_text"] 129 | table = twoD_list_transpose(table_text) 130 | new_table = [] 131 | for cols in table: 132 | if cols[0].lower() in selected_columns: 133 | real_selected_columns.append(cols[0]) 134 | new_table.append(copy.deepcopy(cols)) 135 | if len(new_table) == 0: 136 | new_table = table 137 | real_selected_columns = ["*"] 138 | new_table = twoD_list_transpose(new_table) 139 | 140 | table_info["table_text"] = new_table 141 | table_info["act_chain"].append( 142 | f"f_select_column({', '.join(real_selected_columns)})" 143 | ) 144 | 145 | return table_info 146 | -------------------------------------------------------------------------------- /operations/select_row.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import re 18 | import numpy as np 19 | from utils.helper import table2string 20 | 21 | from third_party.select_column_row_prompts.select_column_row_prompts import select_row_demo 22 | 23 | 24 | def select_row_build_prompt(table_text, statement, table_caption=None, num_rows=100): 25 | table_str = table2string(table_text, caption=table_caption).strip() 26 | prompt = "/*\n" + table_str + "\n*/\n" 27 | question = statement 28 | prompt += "statement : " + question + "\n" 29 | prompt += "explain :" 30 | return prompt 31 | 32 | 33 | def select_row_func(sample, table_info, llm, llm_options=None, debug=False): 34 | table_text = table_info["table_text"] 35 | 36 | table_caption = sample["table_caption"] 37 | statement = sample["statement"] 38 | 39 | prompt = "" + select_row_demo.rstrip() + "\n\n" 40 | prompt += select_row_build_prompt(table_text, statement, table_caption) 41 | 42 | responses = llm.generate_plus_with_score(prompt, options=llm_options) 43 | 44 | if debug: 45 | print(responses) 46 | 47 | pattern_row = r"f_row\(\[(.*?)\]\)" 48 | 49 | pred_conf_dict = {} 50 | for res, score in responses: 51 | try: 52 | pred = re.findall(pattern_row, res, re.S)[0].strip() 53 | except Exception: 54 | continue 55 | pred = pred.split(", ") 56 | pred = [i.strip() for i in pred] 57 | pred = [i.split(" ")[-1] for i in pred] 58 | pred = sorted(pred) 59 | pred = str(pred) 60 | if pred not in pred_conf_dict: 61 | pred_conf_dict[pred] = 0 62 | pred_conf_dict[pred] += np.exp(score) 63 | 64 | select_row_rank = sorted(pred_conf_dict.items(), key=lambda x: x[1], reverse=True) 65 | 66 | operation = { 67 | "operation_name": "select_row", 68 | "parameter_and_conf": select_row_rank, 69 | } 70 | 71 | sample_copy = copy.deepcopy(sample) 72 | sample_copy["chain"].append(operation) 73 | 74 | return sample_copy 75 | 76 | 77 | def select_row_act(table_info, operation, union_num=2, skip_op=[]): 78 | table_info = copy.deepcopy(table_info) 79 | 80 | if "select_row" in skip_op: 81 | failure_table_info = copy.deepcopy(table_info) 82 | failure_table_info["act_chain"].append("skip f_select_row()") 83 | return failure_table_info 84 | 85 | def union_lists(to_union): 86 | return list(set().union(*to_union)) 87 | 88 | selected_rows_info = operation["parameter_and_conf"] 89 | selected_rows_info = sorted(selected_rows_info, key=lambda x: x[1], reverse=True) 90 | selected_rows_info = selected_rows_info[:union_num] 91 | selected_rows = [x[0] for x in selected_rows_info] 92 | selected_rows = [eval(x) for x in selected_rows] 93 | selected_rows = union_lists(selected_rows) 94 | 95 | if "*" in selected_rows: 96 | failure_table_info = copy.deepcopy(table_info) 97 | failure_table_info["act_chain"].append("f_select_row(*)") 98 | return failure_table_info 99 | 100 | real_selected_rows = [] 101 | 102 | table_text = table_info["table_text"] 103 | new_table = [copy.deepcopy(table_text[0])] 104 | for row_id, row in enumerate(table_text): 105 | row_id = str(row_id) 106 | if row_id in selected_rows: 107 | new_table.append(copy.deepcopy(row)) 108 | real_selected_rows.append(row_id) 109 | 110 | if len(new_table) == 1: 111 | failure_table_info = copy.deepcopy(table_info) 112 | failure_table_info["act_chain"].append("f_select_row(*)") 113 | return failure_table_info 114 | 115 | table_info["table_text"] = new_table 116 | selected_row_names = [f"row {x+1}" for x in range(len(real_selected_rows))] 117 | table_info["act_chain"].append(f"f_select_row({', '.join(selected_row_names)})") 118 | 119 | _real_selected_row_names = [f"row {x-1}" for x in map(int, real_selected_rows)] 120 | table_info['_real_select_rows'] = f"f_select_row({', '.join(_real_selected_row_names)})" 121 | 122 | return table_info 123 | -------------------------------------------------------------------------------- /operations/sort_by.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import re 18 | import numpy as np 19 | from utils.helper import table2string 20 | 21 | 22 | sort_column_demo = """To tell the statement is true or false, we can first use f_sort() to sort the values in a column to get the order of the items. The order can be "large to small" or "small to large". 23 | 24 | The column to sort should have these data types: 25 | 1. Numerical: the numerical strings that can be used in sort 26 | 2. DateType: the strings that describe a date, such as year, month, day 27 | 3. String: other strings 28 | 29 | /* 30 | col : position | club | played | points | wins | draws | losses | goals for | goals against | goal difference 31 | row 1 : 1 | malaga cf | 42 | 79 | 22 | 13 | 7 | 72 | 47 | +25 32 | row 10 : 10 | cp merida | 42 | 59 | 15 | 14 | 13 | 48 | 41 | +7 33 | row 3 : 3 | cd numancia | 42 | 73 | 21 | 10 | 11 | 68 | 40 | +28 34 | */ 35 | Statement: cd numancia placed in the last position 36 | The existing columns are: position, club, played, points, wins, draws, losses, goals for, goals against, goal difference. 37 | Explanation: the statement wants to check cd numanica is in the last position. Each row is about a club. We need to know the order of position from last to front. There is a column for position and the column name is position. The datatype is Numerical. 38 | Therefore, the answer is: f_sort(position), the order is "large to small". 39 | 40 | /* 41 | col : year | team | games | combined tackles | tackles | assisted tackles | 42 | row 1 : 2004 | hou | 16 | 63 | 51 | 12 | 43 | row 2 : 2005 | hou | 12 | 35 | 24 | 11 | 44 | row 3 : 2006 | hou | 15 | 26 | 19 | 7 | 45 | */ 46 | Statement: in 2006 babin had the least amount of tackles 47 | The existing columns are: year, team, games, combined tackles, tackles, assisted tackles. 48 | Explanation: the statement wants to check babin had the least amount of tackles in 2006. Each row is about a year. We need to know the order of tackles from the least to the most. There is a column for tackles and the column name is tackles. The datatype is Numerical. 49 | Therefore, the answer is: f_sort(tackles), the order is "small to large".""" 50 | 51 | 52 | def only_keep_num_and_first_dot(s): 53 | if s.strip() and s.strip()[0] == "-": 54 | minus = True 55 | else: 56 | minus = False 57 | ns = "" 58 | dot = False 59 | for c in s: 60 | if c in "0123456789": 61 | ns += c 62 | if c == ".": 63 | if dot == False: 64 | ns += c 65 | dot = True 66 | if ns == ".": 67 | return "" 68 | if ns == "": 69 | return "" 70 | if minus: 71 | ns = "-" + ns 72 | return ns 73 | 74 | 75 | def sort_column_build_prompt(table_text, statement, table_caption=None, num_rows=100): 76 | table_str = table2string( 77 | table_text, caption=table_caption, num_rows=num_rows 78 | ).strip() 79 | prompt = "/*\n" + table_str + "\n*/\n" 80 | prompt += "Statement: " + statement + "\n" 81 | prompt += "The existing columns are: " 82 | prompt += ", ".join(table_text[0]) + ".\n" 83 | prompt += "Explanation:" 84 | return prompt 85 | 86 | 87 | def sort_column_func( 88 | sample, table_info, llm, llm_options=None, debug=False, skip_op=[] 89 | ): 90 | # table_info = get_table_info(sample, skip_op=skip_op) 91 | table_text = table_info["table_text"] 92 | 93 | statement = sample["statement"] 94 | prompt = "" + sort_column_demo.rstrip() + "\n\n" 95 | prompt += sort_column_build_prompt(table_text, statement, num_rows=3) 96 | responses = llm.generate_plus_with_score( 97 | prompt, 98 | options=llm_options, 99 | ) 100 | 101 | if debug: 102 | print(prompt) 103 | print(responses) 104 | 105 | sort_info_and_conf = {} 106 | 107 | headers = table_text[0] 108 | rows = table_text[1:] 109 | for res, score in responses: 110 | try: 111 | datatype = re.findall(r"The datatype is (\w*).", res, re.S)[0].strip() 112 | sort_order = re.findall(r'the order is "(.*)"\.', res, re.S)[0].strip() 113 | sort_column = re.findall(r"f_sort\((.*)\)", res, re.S)[0].strip() 114 | except: 115 | continue 116 | 117 | if sort_order not in ["small to large", "large to small"]: 118 | continue 119 | if sort_column not in headers: 120 | continue 121 | sort_key = (sort_column, sort_order, datatype) 122 | if sort_key not in sort_info_and_conf: 123 | sort_info_and_conf[sort_key] = 0 124 | sort_info_and_conf[sort_key] += np.exp(score) 125 | 126 | sort_param_and_conf_list = [] 127 | for (sort_column, sort_order, datatype), conf in sort_info_and_conf.items(): 128 | sort_column_contents = [] 129 | index = headers.index(sort_column) 130 | for row in rows: 131 | sort_column_contents.append(row[index]) 132 | 133 | vs_to_sort = [] 134 | vs_not_to_sort = [] 135 | if datatype == "Numerical": 136 | for i in range(len(sort_column_contents)): 137 | v_str = sort_column_contents[i] 138 | v_str = only_keep_num_and_first_dot(v_str) 139 | if v_str == "" or v_str == ".": 140 | vs_not_to_sort.append((sort_column_contents[i], i)) 141 | else: 142 | vs_to_sort.append((float(v_str), i)) 143 | else: 144 | for i in range(len(sort_column_contents)): 145 | v_str = sort_column_contents[i] 146 | v_str = v_str.strip() 147 | if v_str == "": 148 | vs_not_to_sort.append((sort_column_contents[i], i)) 149 | else: 150 | vs_to_sort.append((v_str, i)) 151 | 152 | # check if already sorted 153 | pure_vs_to_sort = [x[0] for x in vs_to_sort] 154 | if ( 155 | sorted(pure_vs_to_sort) == pure_vs_to_sort 156 | or sorted(pure_vs_to_sort, reverse=True) == pure_vs_to_sort 157 | ): 158 | continue 159 | 160 | # get sorted index 161 | if sort_order == "small to large": 162 | vs_to_sort = sorted(vs_to_sort, key=lambda x: x[0]) 163 | else: 164 | vs_to_sort = sorted(vs_to_sort, reverse=True, key=lambda x: x[0]) 165 | index_order = [x[1] for x in vs_to_sort] + [x[1] for x in vs_not_to_sort] 166 | 167 | sort_param_and_conf_list.append( 168 | ( 169 | sort_column, 170 | sort_order, 171 | datatype, 172 | index_order, 173 | max([x[0] for x in vs_to_sort]), 174 | min([x[0] for x in vs_to_sort]), 175 | conf, 176 | ) 177 | ) 178 | 179 | sort_param_and_conf_list = sorted(sort_param_and_conf_list, key=lambda x: x[-1]) 180 | 181 | operation = { 182 | "operation_name": "sort_column", 183 | "parameter_and_conf": sort_param_and_conf_list, 184 | } 185 | 186 | sample_copy = copy.deepcopy(sample) 187 | sample_copy["chain"].append(operation) 188 | 189 | if debug: 190 | print(sort_param_and_conf_list) 191 | 192 | return sample_copy 193 | 194 | 195 | def sort_column_act( 196 | table_info, operation, strategy="top", filter="Only Numerical", skip_op=[] 197 | ): 198 | table_info = copy.deepcopy(table_info) 199 | 200 | failure_table_info = copy.deepcopy(table_info) 201 | failure_table_info["act_chain"].append("skip f_sort_column()") 202 | 203 | if "sort_column" in skip_op: 204 | return failure_table_info 205 | if len(operation["parameter_and_conf"]) == 0: 206 | return failure_table_info 207 | 208 | if strategy == "top": 209 | sort_column, sort_order, datatype, index_order, max_v, min_v = operation[ 210 | "parameter_and_conf" 211 | ][0][:-1] 212 | else: 213 | raise NotImplementedError() 214 | 215 | if filter == "Only Numerical": 216 | if datatype != "Numerical": 217 | return failure_table_info 218 | else: 219 | raise NotImplementedError() 220 | 221 | table_text = table_info["table_text"] 222 | headers = table_text[0] 223 | rows = table_text[1:] 224 | new_rows = [rows[i] for i in index_order] 225 | new_table_text = [headers] + new_rows 226 | 227 | table_info["table_text"] = new_table_text 228 | table_info["sort_sub_table"] = (sort_column, max_v, min_v) 229 | table_info["act_chain"].append(f"f_sort_column({sort_column})") 230 | 231 | return table_info 232 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | numpy 3 | pandas 4 | tqdm 5 | openai==0.28.1 -------------------------------------------------------------------------------- /run_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copyright 2024 The Chain-of-Table authors\n", 8 | "\n", 9 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 10 | "you may not use this file except in compliance with the License.\n", 11 | "You may obtain a copy of the License at\n", 12 | "\n", 13 | " https://www.apache.org/licenses/LICENSE-2.0\n", 14 | "\n", 15 | "Unless required by applicable law or agreed to in writing, software\n", 16 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 17 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 18 | "See the License for the specific language governing permissions and\n", 19 | "limitations under the License." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Demo of Chain of Tables\n", 27 | "\n", 28 | "Paper: https://arxiv.org/abs/2401.04398" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import pandas as pd\n", 38 | "\n", 39 | "from utils.load_data import wrap_input_for_demo\n", 40 | "from utils.llm import ChatGPT\n", 41 | "from utils.helper import *\n", 42 | "from utils.evaluate import *\n", 43 | "from utils.chain import *\n", 44 | "from operations import *" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# User parameters\n", 54 | "model_name: str = \"gpt-3.5-turbo-0613\"\n", 55 | "openai_api_key: str = None\n", 56 | "\n", 57 | "# Default parameters\n", 58 | "dataset_path: str = \"data/tabfact/test.jsonl\"\n", 59 | "raw2clean_path: str = \"data/tabfact/raw2clean.jsonl\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "gpt_llm = ChatGPT(\n", 69 | " model_name=model_name,\n", 70 | " key=openai_api_key,\n", 71 | ")" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "statement = \"the wildcats kept the opposing team scoreless in four games\"\n", 81 | "table_caption = \"1947 kentucky wildcats football team\"\n", 82 | "table_text = [\n", 83 | " ['game', 'date', 'opponent', 'result', 'wildcats points', 'opponents', 'record'],\n", 84 | " ['1', 'sept 20', 'ole miss', 'loss', '7', '14', '0 - 1'],\n", 85 | " ['2', 'sept 27', 'cincinnati', 'win', '20', '0', '1 - 1'],\n", 86 | " ['3', 'oct 4', 'xavier', 'win', '20', '7', '2 - 1'],\n", 87 | " ['4', 'oct 11', '9 georgia', 'win', '26', '0', '3 - 1 , 20'],\n", 88 | " ['5', 'oct 18', '10 vanderbilt', 'win', '14', '0', '4 - 1 , 14'],\n", 89 | " ['6', 'oct 25', 'michigan state', 'win', '7', '6', '5 - 1 , 13'],\n", 90 | " ['7', 'nov 1', '18 alabama', 'loss', '0', '13', '5 - 2'],\n", 91 | " ['8', 'nov 8', 'west virginia', 'win', '15', '6', '6 - 2'],\n", 92 | " ['9', 'nov 15', 'evansville', 'win', '36', '0', '7 - 2'],\n", 93 | " ['10', 'nov 22', 'tennessee', 'loss', '6', '13', '7 - 3']\n", 94 | "]\n", 95 | "answer = \"True\"" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "demo_sample = wrap_input_for_demo(\n", 105 | " statement=statement, table_caption=table_caption, table_text=table_text\n", 106 | ")\n", 107 | "proc_sample, dynamic_chain_log = dynamic_chain_exec_one_sample(\n", 108 | " sample=demo_sample, llm=gpt_llm\n", 109 | ")\n", 110 | "output_sample = simple_query(\n", 111 | " sample=proc_sample,\n", 112 | " table_info=get_table_info(proc_sample),\n", 113 | " llm=gpt_llm,\n", 114 | " use_demo=True,\n", 115 | " llm_options=gpt_llm.get_model_options(\n", 116 | " temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0\n", 117 | " ),\n", 118 | ")\n", 119 | "cotable_log = get_table_log(output_sample)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 9, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Statements: the wildcats kept the opposing team scoreless in four games\n", 132 | "\n", 133 | "Table: 1947 kentucky wildcats football team\n", 134 | " game result wildcats points opponents\n", 135 | "0 2 win 20 0\n", 136 | "1 4 win 26 0\n", 137 | "2 5 win 14 0\n", 138 | "3 9 win 36 0\n", 139 | "4 3 win 20 7\n", 140 | "\n", 141 | "-> f_select_row(row 1, row 2, row 3, row 4, row 8)\n", 142 | " game date opponent result wildcats points opponents record\n", 143 | "0 2 sept 27 cincinnati win 20 0 1 - 1\n", 144 | "1 3 oct 4 xavier win 20 7 2 - 1\n", 145 | "2 4 oct 11 9 georgia win 26 0 3 - 1 , 20\n", 146 | "3 5 oct 18 10 vanderbilt win 14 0 4 - 1 , 14\n", 147 | "4 9 nov 15 evansville win 36 0 7 - 2\n", 148 | "\n", 149 | "-> f_select_column(game, result, wildcats points, opponents)\n", 150 | " game result wildcats points opponents\n", 151 | "0 2 win 20 0\n", 152 | "1 3 win 20 7\n", 153 | "2 4 win 26 0\n", 154 | "3 5 win 14 0\n", 155 | "4 9 win 36 0\n", 156 | "\n", 157 | "-> f_group_column(opponents)\n", 158 | " game result wildcats points opponents\n", 159 | "0 2 win 20 0\n", 160 | "1 3 win 20 7\n", 161 | "2 4 win 26 0\n", 162 | "3 5 win 14 0\n", 163 | "4 9 win 36 0\n", 164 | " Group ID opponents Count\n", 165 | "0 Group 1 0 4\n", 166 | "1 Group 2 7 1\n", 167 | "\n", 168 | "-> f_sort_column(opponents)\n", 169 | " game result wildcats points opponents\n", 170 | "0 2 win 20 0\n", 171 | "1 4 win 26 0\n", 172 | "2 5 win 14 0\n", 173 | "3 9 win 36 0\n", 174 | "4 3 win 20 7\n", 175 | " Group ID opponents Count\n", 176 | "0 Group 1 0 4\n", 177 | "1 Group 2 7 1\n", 178 | "\n", 179 | "-> simple_query()\n", 180 | "The statement is True\n", 181 | "\n", 182 | "Groundtruth: The statement is True\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "print(f'Statements: {output_sample[\"statement\"]}\\n')\n", 188 | "print(f'Table: {output_sample[\"table_caption\"]}')\n", 189 | "print(f\"{pd.DataFrame(table_text[1:], columns=table_text[0])}\\n\")\n", 190 | "for table_info in cotable_log:\n", 191 | " if table_info[\"act_chain\"]:\n", 192 | " table_text = table_info[\"table_text\"]\n", 193 | " table_action = table_info[\"act_chain\"][-1]\n", 194 | " if \"skip\" in table_action:\n", 195 | " continue\n", 196 | " if \"query\" in table_action:\n", 197 | " result = table_info[\"cotable_result\"]\n", 198 | " if result == \"YES\":\n", 199 | " print(f\"-> {table_action}\\nThe statement is True\\n\")\n", 200 | " else:\n", 201 | " print(f\"-> {table_action}\\nThe statement is False\\n\")\n", 202 | " else:\n", 203 | " print(f\"-> {table_action}\\n{pd.DataFrame(table_text[1:], columns=table_text[0])}\")\n", 204 | " if 'group_sub_table' in table_info:\n", 205 | " group_column, group_info = table_info[\"group_sub_table\"]\n", 206 | " group_headers = [\"Group ID\", group_column, \"Count\"]\n", 207 | " group_rows = []\n", 208 | " for i, (v, count) in enumerate(group_info):\n", 209 | " if v.strip() == \"\":\n", 210 | " v = \"[Empty Cell]\"\n", 211 | " group_rows.append([f\"Group {i+1}\", v, str(count)])\n", 212 | " print(f\"{pd.DataFrame(group_rows, columns=group_headers)}\")\n", 213 | " print()\n", 214 | "\n", 215 | "print(f\"Groundtruth: The statement is {answer}\")" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "cotable", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.10.13" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 2 247 | } 248 | -------------------------------------------------------------------------------- /run_tabfact.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import fire 17 | import os 18 | 19 | from utils.load_data import load_tabfact_dataset 20 | from utils.llm import ChatGPT 21 | from utils.helper import * 22 | from utils.evaluate import * 23 | from utils.chain import * 24 | from operations import * 25 | 26 | 27 | def main( 28 | dataset_path: str = "data/tabfact/test.jsonl", 29 | raw2clean_path: str ="data/tabfact/raw2clean.jsonl", 30 | model_name: str = "gpt-3.5-turbo-16k-0613", 31 | result_dir: str = "results/tabfact", 32 | openai_api_key: str = None, 33 | first_n=-1, 34 | n_proc=1, 35 | chunk_size=1, 36 | ): 37 | dataset = load_tabfact_dataset(dataset_path, raw2clean_path, first_n=first_n) 38 | gpt_llm = ChatGPT( 39 | model_name=model_name, 40 | key=os.environ["OPENAI_API_KEY"] if openai_api_key is None else openai_api_key, 41 | ) 42 | os.makedirs(result_dir, exist_ok=True) 43 | 44 | proc_samples, dynamic_chain_log_list = dynamic_chain_exec_with_cache_mp( 45 | dataset, 46 | llm=gpt_llm, 47 | llm_options=gpt_llm.get_model_options( 48 | temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0 49 | ), 50 | strategy="top", 51 | cache_dir=os.path.join(result_dir, "cache"), 52 | n_proc=n_proc, 53 | chunk_size=chunk_size, 54 | ) 55 | fixed_chain = [ 56 | ( 57 | "simpleQuery_fewshot", 58 | simple_query, 59 | dict(use_demo=True), 60 | dict( 61 | temperature=0, per_example_max_decode_steps=200, per_example_top_p=1.0 62 | ), 63 | ), 64 | ] 65 | final_result, _ = fixed_chain_exec_mp(gpt_llm, proc_samples, fixed_chain) 66 | acc = tabfact_match_func_for_samples(final_result) 67 | print("Accuracy:", acc) 68 | 69 | print( 70 | f'Accuracy: {acc}', 71 | file=open(os.path.join(result_dir, "result.txt"), "w") 72 | ) 73 | pickle.dump( 74 | final_result, open(os.path.join(result_dir, "final_result.pkl"), "wb") 75 | ) 76 | pickle.dump( 77 | dynamic_chain_log_list, 78 | open(os.path.join(result_dir, "dynamic_chain_log_list.pkl"), "wb") 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | fire.Fire(main) 84 | -------------------------------------------------------------------------------- /third_party/select_column_row_prompts/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alibaba Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. -------------------------------------------------------------------------------- /third_party/select_column_row_prompts/select_column_row_prompts.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022 Alibaba Research 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | 16 | select_column_demo = """Use f_col() api to filter out useless columns in the table according to informations in the statement and the table. 17 | 18 | /* 19 | { 20 | "table_caption": "south wales derby", 21 | "columns": ["competition", "total matches", "cardiff win", "draw", "swansea win"], 22 | "table_column_priority": [ 23 | ["competition", "league", "fa cup", "league cup"], 24 | ["total matches", "55", "2", "5"], 25 | ["cardiff win", "19", "0", "2"], 26 | ["draw", "16", "27", "0"], 27 | ["swansea win", "20", "2", "3"] 28 | ] 29 | } 30 | */ 31 | statement : there are no cardiff wins that have a draw greater than 27. 32 | similar words link to columns : 33 | no cardiff wins -> cardiff win 34 | a draw -> draw 35 | column value link to columns : 36 | 27 -> draw 37 | semantic sentence link to columns : 38 | None 39 | The answer is : f_col([cardiff win, draw]) 40 | 41 | /* 42 | { 43 | "table_caption": "gambrinus liga", 44 | "columns": ["season", "champions", "runner - up", "third place", "top goalscorer", "club"], 45 | "table_column_priority": [ 46 | ["season", "1993 - 94", "1994 - 95", "1995 - 96"], 47 | ["champions", "sparta prague (1)", "sparta prague (2)", "slavia prague (1)"], 48 | ["runner - up", "slavia prague", "slavia prague", "sigma olomouc"], 49 | ["third place", "ban\u00edk ostrava", "fc brno", "baumit jablonec"], 50 | ["top goalscorer", "horst siegl (20)", "radek drulák (15)", "radek drulák (22)"], 51 | ["club", "sparta prague", "drnovice", "drnovice"] 52 | ] 53 | } 54 | */ 55 | statement : the top goal scorer for the season 2010 - 2011 was david lafata. 56 | similar words link to columns : 57 | season 2010 - 2011 -> season 58 | the top goal scorer -> top goalscorer 59 | column value link to columns : 60 | 2010 - 2011 -> season 61 | semantic sentence link to columns : 62 | the top goal scorer for ... was david lafata -> top goalscorer 63 | The answer is : f_col([season, top goalscorer]) 64 | 65 | /* 66 | { 67 | "table_caption": "head of the river (queensland)", 68 | "columns": ["crew", "open 1st viii", "senior 2nd viii", "senior 3rd viii", "senior iv", "year 12 single scull", "year 11 single scull"], 69 | "table_column_priority": [ 70 | ["crew", "2009", "2010", "2011"], 71 | ["open 1st viii", "stm", "splc", "stm"], 72 | ["senior 2nd viii", "sta", "som", "stu"], 73 | ["senior 3rd viii", "sta", "som", "stu"], 74 | ["senior iv", "som", "sth", "sta"], 75 | ["year 12 single scull", "stm", "splc", "stm"], 76 | ["year 11 single scull", "splc", "splc", "splc"] 77 | ] 78 | } 79 | */ 80 | statement : the crew that had a senior 2nd viii of som and senior iv of stm was that of 2013. 81 | similar words link to columns : 82 | the crew -> crew 83 | a senior 2nd viii of som -> senior 2nd viii 84 | senior iv of stm -> senior iv 85 | column value link to columns : 86 | som -> senior 2nd viii 87 | stm -> senior iv 88 | semantic sentence link to columns : 89 | None 90 | The answer is : f_col([crew, senior 2nd viii, senior iv]) 91 | 92 | /* 93 | { 94 | "table_caption": "2007 - 08 boston celtics season", 95 | "columns": ["game", "date", "team", "score", "high points", "high rebounds", "high assists", "location attendance", "record"], 96 | "table_column_priority": [ 97 | ["game", "74", "75", "76"], 98 | ["date", "april 1", "april 2", "april 5"], 99 | ["team", "chicago", "indiana", "charlotte"], 100 | ["score", "106 - 92", "92 - 77", "101 - 78"], 101 | ["high points", "allen (22)", "garnett (20)", "powe (22)"], 102 | ["high rebounds", "perkins (9)", "garnett (11)", "powe (9)"], 103 | ["high assists", "rondo (10)", "rondo (6)", "rondo (5)"], 104 | ["location attendance", "united center 22225", "td banknorth garden 18624", "charlotte bobcats arena 19403"], 105 | ["record", "59 - 15", "60 - 15", "61 - 15"] 106 | ] 107 | } 108 | */ 109 | statement : in game 74 against chicago , perkins had the most rebounds (9) and allen had the most points (22). 110 | similar words link to columns : 111 | the most rebounds -> high rebounds 112 | the most points -> high points 113 | in game 74 -> game 114 | column value link to columns : 115 | 74 -> game 116 | semantic sentence link to columns : 117 | 2007 - 08 boston celtics season in game 74 against chicago -> team 118 | perkins had the most rebounds -> high rebounds 119 | allen had the most points -> high points 120 | The answer is : f_col([game, team, high points, high rebounds]) 121 | 122 | /* 123 | { 124 | "table_caption": "dan hardy", 125 | "columns": ["res", "record", "opponent", "method", "event", "round", "time", "location"], 126 | "table_column_priority": [ 127 | ["res", "win", "win", "loss"], 128 | ["record", "25 - 10 (1)", "24 - 10 (1)", "23 - 10 (1)"], 129 | ["opponent", "amir sadollah", "duane ludwig", "chris lytle"], 130 | ["method", "decision (unanimous)", "ko (punch and elbows)", "submission (guillotine choke)"], 131 | ["event", "ufc on fuel tv : struve vs miocic", "ufc 146", "ufc live : hardy vs lytle"], 132 | ["round", "3", "1", "5"], 133 | ["time", "5:00", "3:51", "4:16"], 134 | ["location", "nottingham , england", "las vegas , nevada , united states", "milwaukee , wisconsin , united states"] 135 | ] 136 | } 137 | */ 138 | statement : the record of the match was a 10 - 3 (1) score , resulting in a win in round 5 with a time of 5:00 minutes. 139 | similar words link to columns : 140 | the record of the match was a 10 - 3 (1) score -> record 141 | the record -> record 142 | in round -> round 143 | a time -> time 144 | column value link to columns : 145 | 10 - 3 (1) -> record 146 | 5 -> round 147 | 5:00 minutes -> time 148 | semantic sentence link to columns : 149 | resulting in a win -> res 150 | The answer is : f_col([res, record, round, time]) 151 | 152 | /* 153 | { 154 | "table_caption": "list of largest airlines in central america & the caribbean", 155 | "columns": ["rank", "airline", "country", "fleet size", "remarks"], 156 | "table_column_priority": [ 157 | ["rank", "1", "2", "3"], 158 | ["airline", "caribbean airlines", "liat", "cubana de aviaci\u00e3 cubicn"], 159 | ["country", "trinidad and tobago", "antigua and barbuda", "cuba"], 160 | ["fleet size", "22", "17", "14"], 161 | ["remarks", "largest airline in the caribbean", "second largest airline in the caribbean", "operational since 1929"] 162 | ] 163 | } 164 | */ 165 | statement : the remark on airline of dutch antilles express with fleet size over 4 is curacao second national carrier. 166 | similar words link to columns : 167 | the remark -> remarks 168 | on airline -> airline 169 | fleet size -> fleet size 170 | column value link to columns : 171 | dutch antilles -> country 172 | 4 -> fleet size 173 | curacao second national carrier -> remarks 174 | semantic sentence link to columns : 175 | None 176 | The answer is : f_col([airline, fleet size, remarks]) 177 | 178 | /* 179 | { 180 | "table_caption": "cnbc prime 's the profit 200", 181 | "columns": ["year", "date", "driver", "team", "manufacturer", "laps", "-", "race time", "average speed (mph)"], 182 | "table_column_priority": [ 183 | ["year", "1990", "1990", "1991"], 184 | ["date", "july 15", "october 14", "july 14"], 185 | ["driver", "tommy ellis", "rick mast", "kenny wallace"], 186 | ["team", "john jackson", "ag dillard motorsports", "rusty wallace racing"], 187 | ["manufacturer", "buick", "buick", "pontiac"], 188 | ["laps", "300", "250", "300"], 189 | ["-", "317.4 (510.805)", "264.5 (425.671)", "317.4 (510.805)"], 190 | ["race time", "3:41:58", "2:44:37", "2:54:38"], 191 | ["average speed (mph)", "85.797", "94.405", "109.093"] 192 | ] 193 | } 194 | */ 195 | statemnet : on june 26th , 2010 kyle busch drove a total of 211.6 miles at an average speed of 110.673 miles per hour. 196 | similar words link to columns : 197 | drove -> driver 198 | column value link to columns : 199 | june 26th , 2010 -> date, year 200 | a total of 211.6 miles -> - 201 | semantic sentence link to columns : 202 | kyle busch drove -> driver 203 | an average speed of 110.673 miles per hour -> average speed (mph) 204 | The answer is : f_col([year, date, driver, -, average speed (mph)]) 205 | 206 | /* 207 | { 208 | "table_caption": "2000 ansett australia cup", 209 | "columns": ["home team", "home team score", "away team", "away team score", "ground", "crowd", "date"], 210 | "table_column_priority": [ 211 | ["home team", "brisbane lions", "kangaroos", "richmond"], 212 | ["home team score", "13.6 (84)", "10.16 (76)", "11.16 (82)"], 213 | ["away team", "sydney", "richmond", "brisbane lions"], 214 | ["away team score", "17.10 (112)", "9.11 (65)", "15.9 (99)"], 215 | ["ground", "bundaberg rum stadium", "waverley park", "north hobart oval"], 216 | ["crowd", "8818", "16512", "4908"], 217 | ["date", "friday , 28 january", "friday , 28 january", "saturday , 5 february"] 218 | ] 219 | } 220 | */ 221 | statement : sydney scored the same amount of points in the first game of the 2000 afl ansett australia cup as their opponent did in their second. 222 | similar words link to columns : 223 | scored -> away team score, home team score 224 | column value link to columns : 225 | sydney -> away team, home team 226 | semantic sentence link to columns : 227 | their opponent -> home team, away team 228 | scored the same amount of points -> away team score, home team score 229 | first game -> date 230 | their second -> date 231 | sydney scored -> home team, away team, home team score, away team score 232 | The answer is : f_col([away team, home team, away team score, home team score, date])""" 233 | 234 | 235 | select_row_demo = """Using f_row() api to select relevant rows in the given table that support or oppose the statement. 236 | Please use f_row([*]) to select all rows in the table. 237 | 238 | /* 239 | table caption : 1972 vfl season. 240 | col : home team | home team score | away team | away team score | venue | crowd | date 241 | row 1 : st kilda | 13.12 (90) | melbourne | 13.11 (89) | moorabbin oval | 18836 | 19 august 1972 242 | row 2 : south melbourne | 9.12 (66) | footscray | 11.13 (79) | lake oval | 9154 | 19 august 1972 243 | row 3 : richmond | 20.17 (137) | fitzroy | 13.22 (100) | mcg | 27651 | 19 august 1972 244 | row 4 : geelong | 17.10 (112) | collingwood | 17.9 (111) | kardinia park | 23108 | 19 august 1972 245 | row 5 : north melbourne | 8.12 (60) | carlton | 23.11 (149) | arden street oval | 11271 | 19 august 1972 246 | row 6 : hawthorn | 15.16 (106) | essendon | 12.15 (87) | vfl park | 36749 | 19 august 1972 247 | */ 248 | statement : the away team with the highest score is fitzroy. 249 | explain : the statement want to check the highest away team score. we need to compare score of away team fitzroy with all others, so we need all rows. use * to represent all rows in the table. 250 | The answer is : f_row([*]) 251 | 252 | /* 253 | table caption : list of largest airlines in central america & the caribbean. 254 | col : rank | airline | country | fleet size | remarks 255 | row 1 : 1 | caribbean airlines | trinidad and tobago | 22 | largest airline in the caribbean 256 | row 2 : 2 | liat | antigua and barbuda | 17 | second largest airline in the caribbean 257 | row 3 : 3 | cubana de aviaciã cubicn | cuba | 14 | operational since 1929 258 | row 4 : 4 | inselair | curacao | 12 | operational since 2006 259 | row 5 : 5 | dutch antilles express | curacao | 4 | curacao second national carrier 260 | row 6 : 6 | air jamaica | trinidad and tobago | 5 | parent company is caribbean airlines 261 | row 7 : 7 | tiara air | aruba | 3 | aruba 's national airline 262 | */ 263 | statement : the remark on airline of dutch antilles express with fleet size over 4 is curacao second national carrier. 264 | explain : the statement want to check a record in the table. we cannot find a record perfectly satisfied the statement, the most relevant row is row 5, which describes dutch antilles express airline, remarks is uracao second national carrier and fleet size is 4 not over 4. 265 | The answer is : f_row([row 5]) 266 | 267 | /* 268 | table caption : list of longest - serving soap opera actors. 269 | col : actor | character | soap opera | years | duration 270 | row 1 : tom jordon | charlie kelly | fair city | 1989- | 25 years 271 | row 2 : tony tormey | paul brennan | fair city | 1989- | 25 years 272 | row 3 : jim bartley | bela doyle | fair city | 1989- | 25 years 273 | row 4 : sarah flood | suzanne halpin | fair city | 1989 - 2013 | 24 years 274 | row 5 : pat nolan | barry o'hanlon | fair city | 1989 - 2011 | 22 years 275 | row 6 : martina stanley | dolores molloy | fair city | 1992- | 22 years 276 | row 7 : joan brosnan walsh | mags kelly | fair city | 1989 - 2009 | 20 years 277 | row 8 : jean costello | rita doyle | fair city | 1989 - 2008 , 2010 | 19 years 278 | row 9 : ciara o'callaghan | yvonne gleeson | fair city | 1991 - 2004 , 2008- | 19 years 279 | row 10 : celia murphy | niamh cassidy | fair city | 1995- | 19 years 280 | row 39 : tommy o'neill | john deegan | fair city | 2001- | 13 years 281 | row 40 : seamus moran | mike gleeson | fair city | 1996 - 2008 | 12 years 282 | row 41 : rebecca smith | annette daly | fair city | 1997 - 2009 | 12 years 283 | row 42 : grace barry | mary - ann byrne | glenroe | 1990 - 2001 | 11 years 284 | row 43 : gemma doorly | sarah o'leary | fair city | 2001 - 2011 | 10 years 285 | */ 286 | statement : seamus moran and rebecca smith were in soap operas for a duration of 12 years. 287 | explain : the statement want to check seamus moran and rebecca smith in the table. row 40 describes seamus moran were in soap operas for a duration of 12 years. row 41 describes rebecca smith were in soap operas for a duration of 12 years 288 | The answer is : f_row([row 40, row 41]) 289 | 290 | /* 291 | table caption : jeep grand cherokee. 292 | col : years | displacement | engine | power | torque 293 | row 1 : 1999 - 2004 | 4.0l (242cid) | power tech i6 | - | 3000 rpm 294 | row 2 : 1999 - 2004 | 4.7l (287cid) | powertech v8 | - | 3200 rpm 295 | row 3 : 2002 - 2004 | 4.7l (287cid) | high output powertech v8 | - | - 296 | row 4 : 1999 - 2001 | 3.1l diesel | 531 ohv diesel i5 | - | - 297 | row 5 : 2002 - 2004 | 2.7l diesel | om647 diesel i5 | - | - 298 | */ 299 | statement : the jeep grand cherokee with the om647 diesel i5 had the third lowest numbered displacement. 300 | explain : the statement want to check the om647 diesel i5 had third lowest numbered displacement. so we need first three low numbered displacement and all rows that power is om647 diesel i5. 301 | The answer is : f_row([row 5, row 4, row 1])""" 302 | 303 | -------------------------------------------------------------------------------- /utils/chain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import re 18 | from tqdm import tqdm 19 | import numpy as np 20 | from utils.helper import table2string 21 | from collections import defaultdict 22 | import pickle 23 | import os 24 | 25 | import multiprocessing as mp 26 | 27 | from operations import * 28 | 29 | 30 | def fixed_chain_exec_mp(llm, init_samples, fixed_op_list, n_proc=10, chunk_size=50): 31 | history = {} 32 | final_result = None 33 | 34 | chain_header = copy.deepcopy(init_samples) 35 | chain_key = "" 36 | 37 | for i, (op_name, solver_func, kargs, llm_kargs) in enumerate(fixed_op_list): 38 | chain_key += f"->{op_name}" if i > 0 else op_name 39 | chain_header = conduct_single_solver_mp( 40 | llm=llm, 41 | all_samples=chain_header, 42 | solver_func=solver_func, 43 | tqdm_tag=op_name, 44 | n_proc=n_proc, 45 | chunk_size=chunk_size, 46 | llm_options=llm.get_model_options( 47 | **llm_kargs, 48 | ), 49 | **kargs, 50 | ) 51 | 52 | history[f"({i}) {chain_key}"] = chain_header 53 | if i == len(fixed_op_list) - 1: 54 | final_result = chain_header 55 | 56 | return final_result, history 57 | 58 | 59 | def conduct_single_solver(llm, all_samples, solver_func, tqdm_tag=None, **kwargs): 60 | result_samples = [None for _ in range(len(all_samples))] 61 | 62 | for idx in tqdm(range(len(all_samples)), desc=tqdm_tag): 63 | try: 64 | sample = all_samples[idx] 65 | table_info = get_table_info( 66 | sample, 67 | skip_op=kwargs.get("skip_op", []), 68 | first_n_op=kwargs.get("first_n_op", None), 69 | ) 70 | proc_sample = solver_func(sample, table_info, llm, **kwargs) 71 | result_samples[idx] = proc_sample 72 | except Exception as e: 73 | print(f"Error in {idx}th sample: {e}") 74 | continue 75 | return result_samples 76 | 77 | 78 | def _conduct_single_solver_mp_core(arg): 79 | idx, sample, llm, solver_func, kwargs = arg 80 | try: 81 | table_info = get_table_info( 82 | sample, 83 | skip_op=kwargs.get("skip_op", []), 84 | first_n_op=kwargs.get("first_n_op", None), 85 | ) 86 | proc_sample = solver_func(sample, table_info, llm, **kwargs) 87 | return idx, proc_sample 88 | except Exception as e: 89 | print(f"Error in {idx}-th sample: {e}") 90 | return idx, None 91 | 92 | 93 | def conduct_single_solver_mp( 94 | llm, all_samples, solver_func, tqdm_tag=None, n_proc=10, chunk_size=50, **kwargs 95 | ): 96 | result_samples = [None for _ in range(len(all_samples))] 97 | 98 | args = [ 99 | (idx, sample, llm, solver_func, kwargs) 100 | for idx, sample in enumerate(all_samples) 101 | ] 102 | 103 | with mp.Pool(n_proc) as p: 104 | for idx, proc_sample in tqdm( 105 | p.imap_unordered(_conduct_single_solver_mp_core, args, chunksize=chunk_size), 106 | total=len(all_samples), 107 | desc=tqdm_tag, 108 | ): 109 | result_samples[idx] = proc_sample 110 | 111 | return result_samples 112 | 113 | 114 | def get_act_func(name): 115 | try: 116 | return eval(f"{name}_act") 117 | except: 118 | 119 | def _default_act(table_text, *args, **kwargs): 120 | return copy.deepcopy(table_text) 121 | 122 | if "query" not in name: 123 | print("Unknown operation: ", name) 124 | return _default_act 125 | 126 | 127 | def get_table_info(sample, skip_op=[], first_n_op=None): 128 | table_text = sample["table_text"] 129 | chain = sample["chain"] 130 | 131 | if first_n_op is not None: 132 | chain = chain[:first_n_op] 133 | 134 | table_info = { 135 | "table_text": table_text, 136 | "act_chain": [], 137 | } 138 | 139 | for operation in chain: 140 | operation_name = operation["operation_name"] 141 | act_func = get_act_func(operation_name) 142 | table_info = act_func(table_info, operation, skip_op=skip_op) 143 | 144 | return table_info 145 | 146 | 147 | def get_table_log(sample, skip_op=[], first_n_op=None): 148 | table_text = sample["table_text"] 149 | chain = sample["chain"] 150 | 151 | if first_n_op is not None: 152 | chain = chain[:first_n_op] 153 | 154 | table_log = [] 155 | 156 | table_info = { 157 | "table_text": table_text, 158 | "act_chain": [], 159 | } 160 | table_log.append(table_info) 161 | 162 | for operation in chain: 163 | operation_name = operation["operation_name"] 164 | act_func = get_act_func(operation_name) 165 | table_info = act_func(table_info, operation, skip_op=skip_op) 166 | if 'row' in operation_name: 167 | table_info['act_chain'][-1] = table_info['_real_select_rows'] 168 | if 'query' in operation_name: 169 | table_info['act_chain'].append(f'{operation_name}()') 170 | table_info['cotable_result'] = operation['parameter_and_conf'][0][0] 171 | table_log.append(table_info) 172 | 173 | return table_log 174 | 175 | 176 | # Dynmiac Chain Func 177 | 178 | 179 | plan_add_column_demo = """If the table does not have the needed column to tell whether the statement is True or False, we use f_add_column() to add a new column for it. For example, 180 | /* 181 | col : rank | lane | player | time 182 | row 1 : | 5 | olga tereshkova (kaz) | 51.86 183 | row 2 : | 6 | manjeet kaur (ind) | 52.17 184 | row 3 : | 3 | asami tanno (jpn) | 53.04 185 | */ 186 | Statement: there are one athlete from japan. 187 | Function: f_add_column(country of athlete) 188 | Explanation: The statement is about the number of athletes from japan. We need to known the country of each athlete. There is no column of the country of athletes. We add a column "country of athlete".""" 189 | 190 | plan_select_column_demo = """If the table only needs a few columns to tell whether the statement is True or False, we use f_select_column() to select these columns for it. For example, 191 | /* 192 | col : code | county | former province | area (km2) | population | capital 193 | row 1 : 1 | mombasa | coast | 212.5 | 939,370 | mombasa (city) 194 | row 2 : 2 | kwale | coast | 8,270.3 | 649,931 | kwale 195 | row 3 : 3 | kilifi | coast | 12,245.9 | 1,109,735 | kilifi 196 | */ 197 | Statement: momasa is a county with population higher than 500000. 198 | Function: f_select_column(county, population) 199 | Explanation: The statement wants to check momasa county with population higher than 500000. We need to know the county and its population. We select the column "county" and column "population".""" 200 | 201 | plan_select_row_demo = """If the table only needs a few rows to tell whether the statement is True or False, we use f_select_row() to select these rows for it. For example, 202 | /* 203 | table caption : jeep grand cherokee. 204 | col : years | displacement | engine | power | torque 205 | row 1 : 1999 - 2004 | 4.0l (242cid) | power tech i6 | - | 3000 rpm 206 | row 2 : 1999 - 2004 | 4.7l (287cid) | powertech v8 | - | 3200 rpm 207 | row 3 : 2002 - 2004 | 4.7l (287cid) | high output powertech v8 | - | - 208 | row 4 : 1999 - 2001 | 3.1l diesel | 531 ohv diesel i5 | - | - 209 | row 5 : 2002 - 2004 | 2.7l diesel | om647 diesel i5 | - | - 210 | */ 211 | Statement: the jeep grand cherokee with the om647 diesel i5 had the third lowest numbered displacement. 212 | Function: f_select_row(row 1, row 4, row 5) 213 | Explanation: The statement wants to check the om647 diesel i5 had third lowest numbered displacement. We need to know the first three low numbered displacement and all rows that power is om647 diesel i5. We select the row 1, row 4, row 5.""" 214 | 215 | plan_group_column_demo = """If the statement is about items with the same value and the number of these items, we use f_group_column() to group the items. For example, 216 | /* 217 | col : district | name | party | residence | first served 218 | row 1 : district 1 | nelson albano | dem | vineland | 2006 219 | row 2 : district 1 | robert andrzejczak | dem | middle twp. | 2013† 220 | row 3 : district 2 | john f. amodeo | rep | margate | 2008 221 | */ 222 | Statement: there are 5 districts are democratic 223 | Function: f_group_column(party) 224 | Explanation: The statement wants to check 5 districts are democratic. We need to know the number of dem in the table. We group the rows according to column "party".""" 225 | 226 | plan_sort_column_demo = """If the statement is about the order of items in a column, we use f_sort_column() to sort the items. For example, 227 | /* 228 | col : position | club | played | points 229 | row 1 : 1 | malaga cf | 42 | 79 230 | row 10 : 10 | cp merida | 42 | 59 231 | row 3 : 3 | cd numancia | 42 | 73 232 | */ 233 | Statement: cd numancia placed in the last position. 234 | Function: f_sort_column(position) 235 | Explanation: The statement wants to check about cd numancia in the last position. We need to know the order of position from last to front. We sort the rows according to column "position".""" 236 | 237 | plan_full_demo_simple = """Here are examples of using the operations to tell whether the statement is True or False. 238 | 239 | /* 240 | col : date | division | league | regular season | playoffs | open cup | avg. attendance 241 | row 1 : 2001/01/02 | 2 | usl a-league | 4th, western | quarterfinals | did not qualify | 7,169 242 | row 2 : 2002/08/06 | 2 | usl a-league | 2nd, pacific | 1st round | did not qualify | 6,260 243 | row 5 : 2005/03/24 | 2 | usl first division | 5th | quarterfinals | 4th round | 6,028 244 | */ 245 | Statement: 2005 is the last year where this team was a part of the usl a-league? 246 | Function Chain: f_add_column(year) -> f_select_row(row 1, row 2) -> f_select_column(year, league) -> f_sort_column(year) -> 247 | 248 | */ 249 | col : rank | lane | athlete | time 250 | row 1 : 1 | 6 | manjeet kaur (ind) | 52.17 251 | row 2 : 2 | 5 | olga tereshkova (kaz) | 51.86 252 | row 3 : 3 | 4 | pinki pramanik (ind) | 53.06 253 | */ 254 | Statement: There are 10 athletes from India. 255 | Function Chain: f_add_column(country of athletes) -> f_select_row(row 1, row 3) -> f_select_column(athlete, country of athletes) -> f_group_column(country of athletes) -> 256 | 257 | /* 258 | col : week | when | kickoff | opponent | results; final score | results; team record | game site | attendance 259 | row 1 : 1 | saturday, april 13 | 7:00 p.m. | at rhein fire | w 27–21 | 1–0 | rheinstadion | 32,092 260 | row 2 : 2 | saturday, april 20 | 7:00 p.m. | london monarchs | w 37–3 | 2–0 | waldstadion | 34,186 261 | row 3 : 3 | sunday, april 28 | 6:00 p.m. | at barcelona dragons | w 33–29 | 3–0 | estadi olímpic de montjuïc | 17,503 262 | */ 263 | Statement: the competition with highest points scored is played on April 20. 264 | Function Chain: f_add_column(points scored) -> f_select_row(*) -> f_select_column(when, points scored) -> f_sort_column(points scored) -> 265 | 266 | /* 267 | col : iso/iec standard | status | wg 268 | row 1 : iso/iec tr 19759 | published (2005) | 20 269 | row 2 : iso/iec 15288 | published (2008) | 7 270 | row 3 : iso/iec 12207 | published (2011) | 7 271 | */ 272 | Statement: 2 standards are published in 2011 273 | Function Chain: f_add_column(year) -> f_select_row(row 3) -> f_select_column(year) -> f_group_column(year) -> 274 | 275 | Here are examples of using the operations to tell whether the statement is True or False.""" 276 | 277 | possible_next_operation_dict = { 278 | "": [ 279 | "add_column", 280 | "select_row", 281 | "select_column", 282 | "group_column", 283 | "sort_column", 284 | ], 285 | "add_column": [ 286 | "select_row", 287 | "select_column", 288 | "group_column", 289 | "sort_column", 290 | "", 291 | ], 292 | "select_row": [ 293 | "select_column", 294 | "group_column", 295 | "sort_column", 296 | "", 297 | ], 298 | "select_column": [ 299 | "group_column", 300 | "sort_column", 301 | "", 302 | ], 303 | "group_column": [ 304 | "sort_column", 305 | "", 306 | ], 307 | "sort_column": [ 308 | "", 309 | ], 310 | } 311 | 312 | 313 | def get_operation_name(string): 314 | # f_xxxx(...) 315 | res = re.findall(r"f_(.*?)\(.*\)", string)[0] 316 | return res 317 | 318 | 319 | def get_all_operation_names(string): 320 | operation_names = [] 321 | parts = string.split("->") 322 | for part in parts: 323 | part = part.strip() 324 | if part == "": 325 | operation_names.append("") 326 | else: 327 | res = re.findall(r"f_(.*?)\(.*\)", part) 328 | if res: 329 | operation_names.append(res[0]) 330 | return operation_names 331 | 332 | 333 | def generate_prompt_for_next_step( 334 | sample, 335 | debug=False, 336 | llm=None, 337 | llm_options=None, 338 | strategy="top", 339 | ): 340 | table_info = get_table_info(sample) 341 | act_chain = table_info["act_chain"] 342 | 343 | if debug: 344 | print("Act Chain: ", act_chain, flush=True) 345 | 346 | kept_act_chain = [x for x in act_chain if not x.startswith("skip")] 347 | kept_act_chain_str = " -> ".join(kept_act_chain) 348 | if kept_act_chain_str: 349 | kept_act_chain_str += " ->" 350 | 351 | skip_act_chain = [x for x in act_chain if x.startswith("skip")] 352 | skip_act_chain_op_names = [] 353 | for op in skip_act_chain: 354 | op = op[len("skip ") :] 355 | op_name = get_operation_name(op) 356 | skip_act_chain_op_names.append(op_name) 357 | 358 | if debug: 359 | print("Kept Act Chain: ", kept_act_chain, flush=True) 360 | print("Skip Act Chain: ", skip_act_chain, flush=True) 361 | 362 | last_operation = ( 363 | "" if not kept_act_chain else get_operation_name(kept_act_chain[-1]) 364 | ) 365 | possible_next_operations = possible_next_operation_dict[last_operation] 366 | possible_next_operations = [ 367 | x for x in possible_next_operations if x not in skip_act_chain_op_names 368 | ] 369 | 370 | if debug: 371 | print("Last Operation: ", last_operation, flush=True) 372 | print("Possible Next Operations: ", possible_next_operations, flush=True) 373 | 374 | if len(possible_next_operations) == 1: 375 | log = { 376 | "act_chain": act_chain, 377 | "last_operation": last_operation, 378 | "possible_next_operations": possible_next_operations, 379 | "prompt": None, 380 | "response": None, 381 | "generate_operations": None, 382 | "next_operation": possible_next_operations[0], 383 | } 384 | return possible_next_operations[0], log 385 | 386 | prompt = "" 387 | for operation in possible_next_operations: 388 | if operation == "": 389 | continue 390 | prompt += eval(f"plan_{operation}_demo") + "\n\n" 391 | 392 | prompt += plan_full_demo_simple + "\n\n" 393 | 394 | prompt += "/*\n" + table2string(table_info["table_text"]) + "\n*/\n" 395 | prompt += "Statement: " + sample["statement"] + "\n" 396 | 397 | _possible_next_operations_str = " or ".join( 398 | [f"f_{op}()" if op != "" else op for op in possible_next_operations] 399 | ) 400 | 401 | if len(possible_next_operations) > 1: 402 | prompt += ( 403 | f"The next operation must be one of {_possible_next_operations_str}.\n" 404 | ) 405 | else: 406 | prompt += f"The next operation must be {_possible_next_operations_str}.\n" 407 | 408 | prompt += "Function Chain: " + kept_act_chain_str 409 | 410 | responses = llm.generate_plus_with_score( 411 | prompt, options=llm_options, end_str="\n\n" 412 | ) 413 | 414 | if strategy == "top": 415 | response = responses[0][0] 416 | generate_operations = get_all_operation_names(response) 417 | if debug: 418 | print('Prompt:', prompt.split("\n\n")[-1]) 419 | print('Response:', response) 420 | print("Generated Operations: ", generate_operations) 421 | next_operation = "" 422 | for operation in generate_operations: 423 | if operation in possible_next_operations: 424 | next_operation = operation 425 | break 426 | elif strategy == "voting": 427 | next_operation_conf_dict = defaultdict(float) 428 | for response, score in responses: 429 | generate_operations = get_all_operation_names(response) 430 | next_operation = None 431 | for operation in generate_operations: 432 | if operation in possible_next_operations: 433 | next_operation = operation 434 | break 435 | if next_operation: 436 | next_operation_conf_dict[next_operation] += np.exp(score) 437 | if len(next_operation_conf_dict) != 0: 438 | next_operation_conf_pairs = sorted( 439 | next_operation_conf_dict.items(), key=lambda x: x[1], reverse=True 440 | ) 441 | next_operation = next_operation_conf_pairs[0][0] 442 | else: 443 | next_operation = "" 444 | 445 | log = { 446 | "act_chain": act_chain, 447 | "last_operation": last_operation, 448 | "possible_next_operations": possible_next_operations, 449 | "prompt": prompt, 450 | "response": response, 451 | "generate_operations": generate_operations, 452 | "next_operation": next_operation, 453 | } 454 | 455 | return next_operation, log 456 | 457 | 458 | def dynamic_chain_exec_one_sample( 459 | sample, 460 | llm, 461 | llm_options=None, 462 | strategy="top", 463 | debug=False, 464 | operation_parameter_dict=None, 465 | ): 466 | if operation_parameter_dict is None: 467 | operation_parameter_dict = { 468 | "add_column": ( 469 | "addColumn", 470 | add_column_func, 471 | {}, 472 | llm.get_model_options( 473 | temperature=0.0, 474 | per_example_max_decode_steps=150, 475 | per_example_top_p=1.0, 476 | ), 477 | ), 478 | "select_row": ( 479 | "selectRow", 480 | select_row_func, 481 | {}, 482 | llm.get_model_options( 483 | temperature=0.5, 484 | per_example_max_decode_steps=150, 485 | per_example_top_p=1.0, 486 | n_sample=8, 487 | ), 488 | ), 489 | "select_column": ( 490 | "selectColumn", 491 | select_column_func, 492 | {}, 493 | llm.get_model_options( 494 | temperature=0.5, 495 | per_example_max_decode_steps=150, 496 | per_example_top_p=1.0, 497 | n_sample=8, 498 | ), 499 | ), 500 | "group_column": ( 501 | "groupColumn", 502 | group_column_func, 503 | dict(skip_op=[]), 504 | llm.get_model_options( 505 | temperature=0.0, 506 | per_example_max_decode_steps=150, 507 | per_example_top_p=1.0, 508 | ), 509 | ), 510 | "sort_column": ( 511 | "sortColumn", 512 | sort_column_func, 513 | dict(skip_op=[]), 514 | llm.get_model_options( 515 | temperature=0.0, 516 | per_example_max_decode_steps=150, 517 | per_example_top_p=1.0, 518 | ), 519 | ), 520 | } 521 | 522 | dynamic_chain_log = [] 523 | 524 | current_sample = copy.deepcopy(sample) 525 | while True: 526 | # generate next operation 527 | next_operation, log = generate_prompt_for_next_step( 528 | current_sample, 529 | llm=llm, 530 | llm_options=llm_options, 531 | strategy=strategy, 532 | debug=debug, 533 | ) 534 | dynamic_chain_log.append(log) 535 | 536 | if debug: 537 | print(next_operation) 538 | 539 | if next_operation == "": 540 | break 541 | 542 | param = operation_parameter_dict[next_operation] 543 | op_name, solver_func, kargs, op_llm_options = param 544 | 545 | table_info = get_table_info(current_sample) 546 | 547 | current_sample = solver_func( 548 | current_sample, table_info, llm=llm, llm_options=op_llm_options, **kargs 549 | ) 550 | return current_sample, dynamic_chain_log 551 | 552 | 553 | def dynamic_chain_exec_with_cache_for_loop( 554 | all_samples, 555 | llm, 556 | llm_options=None, 557 | strategy="voting", 558 | cache_dir="./cache/debug", 559 | ): 560 | os.makedirs(cache_dir, exist_ok=True) 561 | result_samples = [None for _ in range(len(all_samples))] 562 | dynamic_chain_log_list = [None for _ in range(len(all_samples))] 563 | 564 | cache_filename = "case-{}.pkl" 565 | 566 | def _func(idx): 567 | sample = all_samples[idx] 568 | sample_id = sample["id"] 569 | cache_path = os.path.join(cache_dir, cache_filename.format(sample_id)) 570 | if os.path.exists(cache_path): 571 | _, proc_sample, log = pickle.load(open(cache_path, "rb")) 572 | else: 573 | proc_sample, log = dynamic_chain_exec_one_sample( 574 | sample, llm=llm, llm_options=llm_options, strategy=strategy 575 | ) 576 | pickle.dump((sample, proc_sample, log), open(cache_path, "wb")) 577 | result_samples[idx] = proc_sample 578 | dynamic_chain_log_list[idx] = log 579 | 580 | for idx in tqdm(range(len(all_samples)), total=len(all_samples)): 581 | try: 582 | _func(idx) 583 | except Exception as e: 584 | print(f"IDX={idx}: {e}", flush=True) 585 | 586 | return result_samples, dynamic_chain_log_list 587 | 588 | 589 | def _dynamic_chain_exec_with_cache_mp_core(arg): 590 | idx, sample, llm, llm_options, strategy, cache_dir = arg 591 | 592 | cache_filename = "case-{}.pkl" 593 | try: 594 | sample_id = sample["id"] 595 | cache_path = os.path.join(cache_dir, cache_filename.format(idx)) 596 | if os.path.exists(cache_path): 597 | _, proc_sample, log = pickle.load(open(cache_path, "rb")) 598 | else: 599 | proc_sample, log = dynamic_chain_exec_one_sample( 600 | sample, llm=llm, llm_options=llm_options, strategy=strategy 601 | ) 602 | pickle.dump((sample, proc_sample, log), open(cache_path, "wb")) 603 | return idx, proc_sample, log 604 | except Exception as e: 605 | print(f"Error in {sample_id}: {e}", flush=True) 606 | return idx, None, None 607 | 608 | 609 | def dynamic_chain_exec_with_cache_mp( 610 | all_samples, 611 | llm, 612 | llm_options=None, 613 | strategy="voting", 614 | cache_dir="./results/debug", 615 | n_proc=10, 616 | chunk_size=50, 617 | ): 618 | os.makedirs(cache_dir, exist_ok=True) 619 | result_samples = [None for _ in range(len(all_samples))] 620 | dynamic_chain_log_list = [None for _ in range(len(all_samples))] 621 | 622 | args = [ 623 | (idx, sample, llm, llm_options, strategy, cache_dir) 624 | for idx, sample in enumerate(all_samples) 625 | ] 626 | 627 | with mp.Pool(n_proc) as p: 628 | for idx, proc_sample, log in tqdm( 629 | p.imap_unordered( 630 | _dynamic_chain_exec_with_cache_mp_core, args, chunksize=chunk_size 631 | ), 632 | total=len(all_samples), 633 | ): 634 | result_samples[idx] = proc_sample 635 | dynamic_chain_log_list[idx] = log 636 | 637 | return result_samples, dynamic_chain_log_list 638 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def tabfact_match_func(sample, strategy="top"): 17 | results = sample["chain"][-1]["parameter_and_conf"] 18 | 19 | if strategy == "top": 20 | res = results[0][0] 21 | elif strategy == "weighted": 22 | res_conf_dict = {} 23 | for res, conf in results: 24 | if res not in res_conf_dict: 25 | res_conf_dict[res] = 0 26 | res_conf_dict[res] += conf 27 | res_conf_rank = sorted(res_conf_dict.items(), key=lambda x: x[1], reverse=True) 28 | res = res_conf_rank[0][0] 29 | else: 30 | raise NotImplementedError 31 | 32 | res = res.lower() 33 | if res == "true": 34 | res = "yes" 35 | if res == "false": 36 | res = "no" 37 | if res == "yes" and sample["label"] == 1: 38 | return True 39 | elif res == "no" and sample["label"] == 0: 40 | return True 41 | else: 42 | return False 43 | 44 | 45 | def tabfact_match_func_for_samples(all_samples, strategy="top"): 46 | correct_list = [] 47 | for sample in all_samples: 48 | try: 49 | if tabfact_match_func(sample, strategy): 50 | correct_list.append(1) 51 | else: 52 | correct_list.append(0) 53 | except: 54 | print(f"Error") 55 | continue 56 | return sum(correct_list) / len(correct_list) 57 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import pandas as pd 17 | import json 18 | import re 19 | from _ctypes import PyObj_FromPtr 20 | 21 | 22 | def table2df(table_text, num_rows=100): 23 | header, rows = table_text[0], table_text[1:] 24 | rows = rows[:num_rows] 25 | df = pd.DataFrame(data=rows, columns=header) 26 | return df 27 | 28 | 29 | def table2string( 30 | table_text, 31 | num_rows=100, 32 | caption=None, 33 | ): 34 | df = table2df(table_text, num_rows) 35 | linear_table = "" 36 | if caption is not None: 37 | linear_table += "table caption : " + caption + "\n" 38 | 39 | header = "col : " + " | ".join(df.columns) + "\n" 40 | linear_table += header 41 | rows = df.values.tolist() 42 | for row_idx, row in enumerate(rows): 43 | row = [str(x) for x in row] 44 | line = "row {} : ".format(row_idx + 1) + " | ".join(row) 45 | if row_idx != len(rows) - 1: 46 | line += "\n" 47 | linear_table += line 48 | return linear_table 49 | 50 | 51 | class NoIndent(object): 52 | """Value wrapper.""" 53 | 54 | def __init__(self, value): 55 | self.value = value 56 | 57 | 58 | class MyEncoder(json.JSONEncoder): 59 | FORMAT_SPEC = "@@{}@@" 60 | regex = re.compile(FORMAT_SPEC.format(r"(\d+)")) 61 | 62 | def __init__(self, **kwargs): 63 | # Save copy of any keyword argument values needed for use here. 64 | self.__sort_keys = kwargs.get("sort_keys", None) 65 | super(MyEncoder, self).__init__(**kwargs) 66 | 67 | def default(self, obj): 68 | return ( 69 | self.FORMAT_SPEC.format(id(obj)) 70 | if isinstance(obj, NoIndent) 71 | else super(MyEncoder, self).default(obj) 72 | ) 73 | 74 | def encode(self, obj): 75 | format_spec = self.FORMAT_SPEC # Local var to expedite access. 76 | json_repr = super(MyEncoder, self).encode(obj) # Default JSON. 77 | 78 | # Replace any marked-up object ids in the JSON repr with the 79 | # value returned from the json.dumps() of the corresponding 80 | # wrapped Python object. 81 | for match in self.regex.finditer(json_repr): 82 | # see https://stackoverflow.com/a/15012814/355230 83 | id = int(match.group(1)) 84 | no_indent = PyObj_FromPtr(id) 85 | json_obj_repr = json.dumps(no_indent.value, sort_keys=self.__sort_keys) 86 | 87 | # Replace the matched id string with json formatted representation 88 | # of the corresponding Python object. 89 | json_repr = json_repr.replace( 90 | '"{}"'.format(format_spec.format(id)), json_obj_repr 91 | ) 92 | 93 | return json_repr 94 | -------------------------------------------------------------------------------- /utils/llm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import openai 17 | import time 18 | import numpy as np 19 | 20 | 21 | class ChatGPT: 22 | def __init__(self, model_name, key): 23 | self.model_name = model_name 24 | self.key = key 25 | 26 | def get_model_options( 27 | self, 28 | temperature=0, 29 | per_example_max_decode_steps=150, 30 | per_example_top_p=1, 31 | n_sample=1, 32 | ): 33 | return dict( 34 | temperature=temperature, 35 | n=n_sample, 36 | top_p=per_example_top_p, 37 | max_tokens=per_example_max_decode_steps, 38 | ) 39 | 40 | def generate_plus_with_score(self, prompt, options=None, end_str=None): 41 | if options is None: 42 | options = self.get_model_options() 43 | messages = [ 44 | { 45 | "role": "system", 46 | "content": "I will give you some examples, you need to follow the examples and complete the text, and no other content.", 47 | }, 48 | {"role": "user", "content": prompt}, 49 | ] 50 | gpt_responses = None 51 | retry_num = 0 52 | retry_limit = 2 53 | error = None 54 | while gpt_responses is None: 55 | try: 56 | gpt_responses = openai.ChatCompletion.create( 57 | model=self.model_name, 58 | messages=messages, 59 | stop=end_str, 60 | api_key=self.key, 61 | **options 62 | ) 63 | error = None 64 | except Exception as e: 65 | print(str(e), flush=True) 66 | error = str(e) 67 | if "This model's maximum context length is" in str(e): 68 | print(e, flush=True) 69 | gpt_responses = { 70 | "choices": [{"message": {"content": "PLACEHOLDER"}}] 71 | } 72 | elif retry_num > retry_limit: 73 | error = "too many retry times" 74 | gpt_responses = { 75 | "choices": [{"message": {"content": "PLACEHOLDER"}}] 76 | } 77 | else: 78 | time.sleep(60) 79 | retry_num += 1 80 | if error: 81 | raise Exception(error) 82 | results = [] 83 | for i, res in enumerate(gpt_responses["choices"]): 84 | text = res["message"]["content"] 85 | fake_conf = (len(gpt_responses["choices"]) - i) / len( 86 | gpt_responses["choices"] 87 | ) 88 | results.append((text, np.log(fake_conf))) 89 | 90 | return results 91 | 92 | def generate(self, prompt, options=None, end_str=None): 93 | if options is None: 94 | options = self.get_model_options() 95 | options["n"] = 1 96 | result = self.generate_plus_with_score(prompt, options, end_str)[0][0] 97 | return result 98 | -------------------------------------------------------------------------------- /utils/load_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Chain-of-Table authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | from tqdm import tqdm 18 | 19 | def load_tabfact_dataset( 20 | dataset_path, 21 | raw2clean_path, 22 | tag="test", 23 | first_n=-1, 24 | ): 25 | tabfact_statement_raw2clean_dict = {} 26 | with open(raw2clean_path, "r") as f: 27 | lines = f.readlines() 28 | for line in lines: 29 | info = json.loads(line) 30 | tabfact_statement_raw2clean_dict[info["statement"]] = info["cleaned_statement"] 31 | 32 | dataset = [] 33 | if first_n != -1: 34 | all_lines = [] 35 | for line in open(dataset_path): 36 | all_lines.append(line) 37 | if len(all_lines) >= first_n: break 38 | else: 39 | all_lines = open(dataset_path).readlines() 40 | for i, line in tqdm(enumerate(all_lines), total=len(all_lines), desc=f"Loading tabfact-{tag} dataset"): 41 | info = json.loads(line) 42 | info["id"] = f"{tag}-{i}" 43 | info["chain"] = [] 44 | if info["statement"] in tabfact_statement_raw2clean_dict: 45 | info["cleaned_statement"] = tabfact_statement_raw2clean_dict[ 46 | info["statement"] 47 | ] 48 | else: 49 | info["cleaned_statement"] = info["statement"] 50 | dataset.append(info) 51 | return dataset 52 | 53 | 54 | def wrap_input_for_demo(statement, table_caption, table_text, cleaned_statement=None): 55 | return { 56 | "statement": statement, 57 | "table_caption": table_caption, 58 | "table_text": table_text, 59 | "cleaned_statement": cleaned_statement if cleaned_statement is not None else statement, 60 | "chain": [], 61 | } 62 | 63 | --------------------------------------------------------------------------------