├── .gitignore ├── README.md ├── code ├── LICENSE ├── README.md ├── data │ └── test_ids.json ├── prompts.py ├── requirements.txt ├── run_generate_system.py ├── run_generate_tomato.py ├── run_gpt.py ├── run_inner_speech.py ├── run_local_llm.py └── src │ ├── nn.py │ └── utils.py ├── dataset ├── LICENSE ├── README.md ├── tomato.json ├── tomato_fb.json ├── tomato_first.json └── tomato_second.json └── overview.png /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .python-version 3 | .env 4 | .venv 5 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🍅 ToMATO 2 | This is the official repository of our paper: **["ToMATO: Verbalizing the Mental States of Role-Playing LLMs for Benchmarking Theory of Mind"](https://arxiv.org/abs/2501.08838)** 3 | 4 | ![overview](overview.png) 5 | 6 | ToMATO is a new benchmark for evaluating Theory of Mind in LLMs. 7 | ToMATO comprehensively evaluates Theory of Mind in a setting that better aligns with real-world scenarios compared to existing datasets. 8 | ToMATO was generated through newly designed LLM-LLM conversations with information asymmetry, as illustrated above. 9 | 10 | Please cite our work if you find the dataset or codes in this repository useful. 11 | 12 | ``` 13 | @inproceedings{shinoda2025tomato, 14 | title={ToMATO: Verbalizing the Mental States of Role-Playing LLMs for Benchmarking Theory of Mind}, 15 | author={Kazutoshi Shinoda and Nobukatsu Hojo and Kyosuke Nishida and Saki Mizuno and Keita Suzuki and Ryo Masumura and Hiroaki Sugiyama and Kuniko Saito}, 16 | booktitle={AAAI}, 17 | year={2025} 18 | } 19 | ``` 20 | 21 | ## Repository Contents and Licenses 22 | - `dataset/`: ToMATO benchmark presented in our paper 23 | - License: META LLAMA 3 COMMUNITY LICENSE 24 | - `code/`: Codes necessary for reproducing our work 25 | - License: NTT License 26 | 27 | ## Intended Use of Data 28 | Please use the ToMATO benchmark only for evaluation purposes. To avoid test set contamination, please do not use ToMATO for fine-tuning any models. 29 | 30 | ## Disclaimer 31 | The major part of ToMATO was generated using LLMs. Though the quality of the dataset was carefully validated by human annotators, the outputs of LLMs may contain biases. The dataset do not necessarily reflect the views and opinions of the authors and their associated affiliations. 32 | 33 | ## Contact 34 | For any question about our work, please email [kazutoshi.shinoda@ntt.com](mailto:kazutoshi.shinoda@ntt.com) or open an issue. 35 | -------------------------------------------------------------------------------- /code/LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE LICENSE AGREEMENT FOR EVALUATION (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation except OSS listed in Exhibit A to this Agreement. 9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the non-commercial purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy or technical contest, etc. ("academy"). User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT. 15 | 4. Proprietary Rights 16 | (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software except OSS shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 17 | (b) NTT shall not be subject to the obligation of licensing the copyright, patent rights, etc. of author when user hope commercial / noncommercial use of the published / provided software, etc. 18 | (c) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE. 19 | (d) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 20 | 5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 21 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 22 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 23 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 24 | 9. OSS. The OSS included in the Software is shown on the "OSS List" in Exhibit A. User shall be subject to the license term of each OSS when using the OSS portion of the Software, and shall be subject to the terms of this document when using the non-OSS portion. 25 | 10. General 26 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 27 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 28 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 29 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 30 | (e) This Agreement shall be governed by and interpreted under construed in accordance with the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 31 | (f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control. 32 | 33 | EXHIBIT A 34 | Software 35 | 36 | The software and related data include the following files, 37 | - data 38 | - src 39 | - prompts.py 40 | - README.md 41 | - requirements.txt 42 | - run_generate_system.py 43 | - run_generate_tomato.py 44 | - run_gpt.py 45 | - run_inner_speech.py 46 | - run_local_llm.py 47 | 48 | OSS List 49 | Not Applicable -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # Codes for Reproduction 2 | ## Setup 3 | ### Software 4 | The codes are tested with the following libraries. 5 | Please update them if needed. 6 | 7 | * transformers==4.43.2 8 | * torch==2.3.0 9 | * bitsandbytes==0.43.1 10 | * openai==1.3.6 11 | 12 | Other required libraries are listed in `requirements.txt`. 13 | To install them, plese prepare a virtual environment if needed and then execute 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Generate ToMATO 19 | ### 0. Download SOTOPIA dataset 20 | First, set up a virtual environment for loading SOTOPIA following [this URL](https://docs.sotopia.world/#installation). 21 | Second, follow the instructions in [this URL](https://github.com/sotopia-lab/sotopia/issues/7#issuecomment-1806365778) to get the raw SOTOPIA dataset and launch a docker container. 22 | Then, convert the raw data into json formats as follows, while launching the docker container. 23 | 24 | ``` 25 | from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile 26 | from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage 27 | 28 | def save_json(data, file): 29 | with open(file, 'w') as f: 30 | json.dump(data, f, indent=4) 31 | 32 | agent_pks = AgentProfile.all_pks() 33 | agent_pks = list(agent_pks) 34 | agents = [] 35 | for pk in agent_pks: 36 | agents.append(AgentProfile.get(pk=pk)) 37 | output_agents = [a.__dict__ for a in agents] 38 | save_json(output_agents, './data/sotopia/agent.json') 39 | 40 | environment_pks = EnvironmentProfile.all_pks() 41 | environment_pks = list(environment_pks) 42 | print(len(environment_pks)) 43 | environments = [] 44 | for pk in environment_pks: 45 | environments.append(EnvironmentProfile.get(pk=pk)) 46 | output_environments = [e.__dict__ for e in environments] 47 | save_json(output_environments, './data/sotopia/environment.json') 48 | 49 | all_combos = EnvAgentComboStorage().all_pks() 50 | all_combos = list(all_combos) 51 | combos = [] 52 | for combo in all_combos: 53 | combos.append(EnvAgentComboStorage().get(combo)) 54 | output_combos = [c.__dict__ for c in combos] 55 | save_json(output_combos, './data/sotopia/combo.json') 56 | ``` 57 | 58 | ### 1. Generate System Prompts 59 | ``` 60 | python run_generate_system.py 61 | ``` 62 | This command will generate `data/scenarios/{scenario_name}.json` 63 | 64 | ### 2. LLM-LLM Conversations 65 | ``` 66 | GPU_ID="0" 67 | N_TURN="7" 68 | OUTPUT_DIR="data/conversations/" 69 | SCENARIO_DIR="scenarios/" 70 | SCENARIO_LIST="assignments/gpu0.txt" 71 | MODEL_PATH="meta-llama/Meta-Llama-3-70B-Instruct" 72 | IFS="/"; MODEL_NAME=($MODEL_PATH) 73 | EXP_ID="Conv_${MODEL_NAME[1]}_gpu${GPU_ID}" 74 | IFS=" " 75 | 76 | CUDA_VISIBLE_DEVICES=$GPU_ID python -u run_inner_speech.py --exp_id $EXP_ID \ 77 | --scenario_dir $SCENARIO_DIR --scenario_list $SCENARIO_LIST --output_dir $OUTPUT_DIR --n_turn $N_TURN \ 78 | --model_path $MODEL_PATH --conv_mode llama_3 --do_sample --load_4bit 79 | ``` 80 | This command will generate `data/conversations/{scenario_name}.json` 81 | 82 | ### 3. Fromat ToMATO 83 | ``` 84 | OUTPUT_DIR="data/tomato" 85 | MODEL_PATH="meta-llama/Meta-Llama-3-70B-Instruct" 86 | 87 | python run_generate_tomato.py --output_dir $OUTPUT_DIR \ 88 | --model_path $MODEL_PATH --max_utterances 14 --overwrite 89 | ``` 90 | This command will generate `data/tomato/tomato.json` 91 | 92 | ## Evaluate LLMs on ToMATO 93 | The following commands are for evaluating local and proprietary LLMs on ToMATO. 94 | 95 | ### Local LLM 96 | ``` 97 | GPU_IDS="0" 98 | MODEL_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 99 | EVAL_BATCH_SIZE="2" 100 | TEST_FILES="data/tomato/tomato.json" 101 | SEED="42" 102 | IFS="/"; MODEL_NAME=($MODEL_PATH); METHOD=${MODEL_NAME[-1]}; IFS=" " 103 | EXP_ID="${METHOD}_${SEED}" 104 | 105 | CUDA_VISIBLE_DEVICES=$GPU_IDS python \ 106 | run_local_llm.py --do_eval --output_dir results/${EXP_ID} --exp_id $EXP_ID \ 107 | --model_path $MODEL_PATH --load_4bit --eval_batch_size $EVAL_BATCH_SIZE \ 108 | --test_files $TEST_FILES \ 109 | --data_type tomato --do_gen_qa --do_sample --seed $SEED 110 | ``` 111 | 112 | ### GPT 113 | ``` 114 | FILE="data/tomato/tomato.json" 115 | MODEL="gpt-4o-mini-2024-07-18" 116 | OUTPUT_DIR="retults/${MODEL}" 117 | SEED="42" 118 | IFS="/"; F=($FILE); VERSION=${F[2]}; IFS=" " 119 | EXP_ID="${MODEL}_${SEED}" 120 | 121 | python run_gpt.py --data_type tomato \ 122 | --test_file $FILE --output_dir $OUTPUT_DIR --log_steps 200 \ 123 | --model $MODEL --exp_id $EXP_ID --do_mc_qa --seed $SEED 124 | ``` 125 | 126 | ## License 127 | The codes are released under NTT Licence, which allows them to be used for research purposes only. 128 | -------------------------------------------------------------------------------- /code/data/test_ids.json: -------------------------------------------------------------------------------- 1 | [ 2 | "09V9Q", 3 | "2CM9M", 4 | "K1AMA", 5 | "TYHSF", 6 | "JG5CA", 7 | "R5FD7", 8 | "VFY0W", 9 | "2NPXT", 10 | "QG808", 11 | "2D7Q3", 12 | "517ZC", 13 | "HHJJT", 14 | "GYC9F", 15 | "142GD", 16 | "ZBX4Z", 17 | "1C7SQ", 18 | "7029K", 19 | "K2FN1", 20 | "TWFNT", 21 | "48NQ6", 22 | "GXTTH", 23 | "7B14X", 24 | "ER913", 25 | "8NXDQ", 26 | "RDWJG", 27 | "GNDNX", 28 | "RDYDR", 29 | "FY4QK", 30 | "17ZKF", 31 | "PTZWM", 32 | "73FS8", 33 | "GJ7MG", 34 | "AMCR1", 35 | "DX6YE", 36 | "WXKA9", 37 | "SFD9A", 38 | "7K99H", 39 | "KDGD9", 40 | "DK59V", 41 | "5GV4A", 42 | "4P9V4", 43 | "K1WK9", 44 | "J3JZJ", 45 | "XHZ91", 46 | "YXZDB", 47 | "7X3A7", 48 | "VJZ4R", 49 | "QFAEB", 50 | "8814W", 51 | "E4E49", 52 | "BC9HH", 53 | "2GV2C", 54 | "45XNV", 55 | "NQYBQ", 56 | "54XDF", 57 | "2XBMC", 58 | "1AP1R", 59 | "PR2QR", 60 | "PWK6G", 61 | "737CD", 62 | "5D0ZQ", 63 | "DG0J4", 64 | "6RST4", 65 | "TEP3Z", 66 | "P4WNP", 67 | "C2XWC", 68 | "R41AQ", 69 | "3QDGZ", 70 | "E41TK", 71 | "F7XJZ", 72 | "CVHBX", 73 | "ETBH3", 74 | "T42HZ", 75 | "V3GSG", 76 | "49BMG", 77 | "607ZP", 78 | "Z9MWM", 79 | "V8886", 80 | "X13Z9", 81 | "D84HQ", 82 | "33J4K", 83 | "V9DRX", 84 | "19T4J", 85 | "65V71", 86 | "CA21K", 87 | "EF851", 88 | "8ZDQX", 89 | "YCXFD", 90 | "P4DWS", 91 | "6KX5S", 92 | "48GQF", 93 | "WMD22", 94 | "2QZFT", 95 | "0DDFT", 96 | "70H4A", 97 | "Y52FZ", 98 | "QRP53", 99 | "T7KSB", 100 | "PBWTV", 101 | "C0F56", 102 | "8NE39", 103 | "ACWBS", 104 | "T6HBP", 105 | "YTBFX", 106 | "Z9BFN", 107 | "WG6QE", 108 | "94HEG", 109 | "TNRSZ", 110 | "VKH9W", 111 | "1TAAX", 112 | "SANSS", 113 | "N2026", 114 | "QECN5", 115 | "RWS8C", 116 | "B4H98", 117 | "DFTMK", 118 | "PK0RG", 119 | "WH9QG", 120 | "J8Z3P", 121 | "PXJQ6", 122 | "CRTFR", 123 | "3GPRK", 124 | "NGFAS", 125 | "XSZCV", 126 | "77D3D", 127 | "9HZ6K", 128 | "G1KQD", 129 | "WSCM8", 130 | "Q3GA6", 131 | "YAY8H", 132 | "00HRE", 133 | "1MEZV", 134 | "1X8Z7", 135 | "GC15T", 136 | "F1T24", 137 | "CS35V", 138 | "2R2Z0", 139 | "51E92", 140 | "0SZ49", 141 | "0RVFV", 142 | "TJE8W", 143 | "21DGC", 144 | "SQ2A4", 145 | "3ZXK7", 146 | "4X6J3", 147 | "P3DHR", 148 | "33RB3", 149 | "K61BA", 150 | "VR3JB", 151 | "HEZ5S", 152 | "EXZM9", 153 | "Q0X52", 154 | "HEEVC", 155 | "ZPZCK", 156 | "NP7ND", 157 | "073E7", 158 | "AAYVS", 159 | "KKXNF", 160 | "S5AN1", 161 | "EBB0Y" 162 | ] -------------------------------------------------------------------------------- /code/prompts.py: -------------------------------------------------------------------------------- 1 | cot_prompts = { 2 | "cot": "Let\'s think step by step.", 3 | } 4 | -------------------------------------------------------------------------------- /code/requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes==0.43.1 2 | datasets==2.19.2 3 | flash-attn==2.5.9.post1 4 | numpy==1.26.2 5 | openai==1.3.6 6 | peft==0.4.0 7 | spacy==3.7.2 8 | torch==2.3.0 9 | tqdm==4.66.2 10 | transformers==4.43.3 11 | -------------------------------------------------------------------------------- /code/run_generate_system.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from src.utils import get_tom_prompt, mental_verb, mental_states, save_json, load_json 3 | 4 | sotopia_combo = load_json('sotopia/combo.json') 5 | sotopia_agent = {a['pk']: a for a in load_json('sotopia/agent.json')} 6 | sotopia_environment = {a['pk']: a for a in load_json('sotopia/environment.json')} 7 | 8 | p_naive_prompt = { 9 | 'Openness to Experience': 'an open', 10 | 'Conscientiousness': 'conscientious', 11 | 'Extraversion': 'extraversive', 12 | 'Agreeableness': 'agreeable', 13 | 'Neuroticism': 'neurotic', 14 | } 15 | p_naive_prompt_reversed = { 16 | 'Openness to Experience': 'a closed', 17 | 'Conscientiousness': 'unconscientious', 18 | 'Extraversion': 'introversive', 19 | 'Agreeableness': 'disagreeable', 20 | 'Neuroticism': 'stable', 21 | } 22 | 23 | 24 | def big_five_to_description(big_five, p_prompt_type='naive'): 25 | if p_prompt_type == 'naive': 26 | p_description = 'You are ' 27 | for i, p in enumerate(big_five.split(';')): 28 | key, value = p.strip().split(' - ') 29 | des_i = p_naive_prompt[key] if value == 'High' else p_naive_prompt_reversed[key] 30 | if i < 4: 31 | p_description += des_i + ', ' 32 | else: 33 | p_description += 'and ' + des_i + ' person.' 34 | else: 35 | raise NotImplementedError(p_prompt_type) 36 | return p_description 37 | 38 | def system_prompt_1st(agent, other, env, mental_state, i): 39 | pronoun = other["gender_pronoun"].lower().split("/") 40 | system = f'Your name is {agent["first_name"]} {agent["last_name"]}, a {agent["age"]}-year-old {agent["occupation"]}.\n' \ 41 | f'You are talking with {other["first_name"]} {other["last_name"]}, a {other["age"]}-year-old {other["occupation"]}.\n' \ 42 | f'The scenario of this conversation: {env["scenario"]}.\nYour goal: {env["agent_goals"][i]}\n' \ 43 | f'Your personality: {big_five_to_description(agent["big_five"], p_prompt_type="naive")}\n' \ 44 | f'Please have a conversation with {pronoun[1]} while thinking about your {mental_state} from ( to ) in one sentence.\n' \ 45 | 'Please generate different thoughts and utterances in different turns.\n' \ 46 | f'After thinking about your {mental_state} briefly, please finish your thought with ) and speak to {pronoun[1]} briefly in one or two sentences based on your thought.\n' \ 47 | 'Output your thought and utterance by strictly following this format: (your thought) "your utterance".' 48 | return system 49 | 50 | def system_prompt_2nd(agent, other, env, mental_state, i): 51 | pronoun = other["gender_pronoun"].lower().split("/") 52 | system = f'Your name is {agent["first_name"]} {agent["last_name"]}, a {agent["age"]}-year-old {agent["occupation"]}.\n' \ 53 | f'You are talking with {other["first_name"]} {other["last_name"]}, a {other["age"]}-year-old {other["occupation"]}.\n' \ 54 | f'The scenario of this conversation: {env["scenario"]}.\nYour goal: {env["agent_goals"][i]}\n' \ 55 | f'Your personality: {big_five_to_description(agent["big_five"], p_prompt_type="naive")}\n' \ 56 | f'Please have a conversation with {pronoun[1]} while thinking about {other["first_name"]}\'s {mental_state} from ( to ) in one sentence.\n' \ 57 | 'Please generate different thoughts and utterances in different turns.\n' \ 58 | f'After thinking about {other["first_name"]}\'s {mental_state} briefly, please finish your thought with ) and speak to {pronoun[1]} briefly in one or two sentences based on your thought.\n' \ 59 | 'Output your thought and utterance by strictly following this format: (your thought) "your utterance".' 60 | return system 61 | 62 | def main(): 63 | Path('data/scenarios/').mkdir(exist_ok=True, parents=True) 64 | test_ids = load_json('data/test_ids.json') 65 | 66 | for i, combo in enumerate(sotopia_combo): 67 | combo_id = combo['pk'] 68 | if not combo_id[-5:] in test_ids: 69 | continue 70 | for mental_state in mental_states: 71 | scenario = {} 72 | scenario['sotopia'] = {} 73 | scenario['sotopia']['combo_id'] = combo_id 74 | scenario['sotopia']['env_id'] = combo['env_id'] 75 | scenario['sotopia']['agent_ids'] = combo['agent_ids'] 76 | 77 | env = sotopia_environment[combo['env_id']] 78 | agent1 = sotopia_agent[combo['agent_ids'][0]] 79 | agent2 = sotopia_agent[combo['agent_ids'][1]] 80 | pronoun1 = agent1["gender_pronoun"].lower().split("/") 81 | pronoun2 = agent2["gender_pronoun"].lower().split("/") 82 | 83 | system1 = system_prompt_1st(agent1, agent2, env, mental_state, 0) 84 | tom_order1 = 1 85 | 86 | system2 = system_prompt_2nd(agent2, agent1, env, mental_state, 1) 87 | tom_order2 = 2 88 | 89 | scenario['init_inst1'] = '() "Hi, how are you?"' 90 | scenario['init_inst2'] = '() "Hi!"' 91 | 92 | tom_prompt1 = get_tom_prompt(mental_state, tom_order1, pronoun=pronoun2[0]) 93 | tom_prompt2 = get_tom_prompt(mental_state, tom_order2, pronoun=pronoun1[0]) 94 | 95 | scenario['system1'] = system1 96 | scenario['system2'] = system2 97 | scenario['tom_prompt1'] = tom_prompt1 98 | scenario['tom_prompt2'] = tom_prompt2 99 | scenario_name = f'{mental_state}_{combo_id[-5:]}' 100 | save_json(scenario, f'data/scenarios/{scenario_name}.json') 101 | 102 | Path('assignments/').mkdir(parents=True, exist_ok=True) 103 | d = Path('assignments/') 104 | order = 1.5 105 | 106 | gpu_id = 0 107 | fs = {} 108 | fs[0] = open(d / 'gpu0.txt', 'w') 109 | # fs[1] = open(d / 'gpu1.txt', 'w') 110 | # fs[2] = open(d / 'gpu2.txt', 'w') 111 | # fs[3] = open(d / 'gpu3.txt', 'w') 112 | 113 | n = 0 114 | 115 | for i, combo in enumerate(sotopia_combo): 116 | for mental_state in mental_states: 117 | combo_id = combo['pk'] 118 | scenario_name = f'{mental_state}_{combo_id[-5:]}' 119 | fs[gpu_id].write(scenario_name + "\n") 120 | n += 1 121 | gpu_id = (gpu_id + 1) % len(fs) 122 | for i in fs: 123 | fs[i].close() 124 | print(n, 'scenarios in total') 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /code/run_generate_tomato.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import json 5 | import random 6 | import datetime 7 | import base64 8 | import time 9 | import requests 10 | import traceback 11 | import argparse 12 | from collections import defaultdict 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | from tqdm import tqdm 17 | from transformers import AutoTokenizer 18 | 19 | from src.utils import ( 20 | load_json, save_json, remove_tom, get_formatted_conv, is_match_inner_speech_format, separate_tom, 21 | mental_states, mental_verb, change_pronoun 22 | ) 23 | 24 | 25 | def ArgParser(): 26 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | parser.add_argument("--model_path", default="meta-llama/Meta-Llama-3-70B-Instruct", type=str) 28 | parser.add_argument("--output_dir", default="data/tomato", type=str, help="") 29 | parser.add_argument("--overwrite", default=False, action="store_true", help="") 30 | parser.add_argument("--max_utterances", default=20, type=int) 31 | args = parser.parse_args() 32 | print(args) 33 | return args 34 | 35 | 36 | def convert_messages_to_examples(messages, scenario, scenario_id, sotopia_agent, tokenizer): 37 | agent_ids = scenario['sotopia']['agent_ids'] 38 | agents = [sotopia_agent[i] for i in agent_ids] 39 | agents = [f'{a["first_name"]} {a["last_name"]}' for a in agents] 40 | 41 | is_valid_format = True 42 | n_error = 0 43 | for message in messages['messages']: 44 | res = message[1] 45 | res = res.replace('\n', ' ') 46 | res = re.sub(r'[ ]+', ' ', res).strip() 47 | if not is_match_inner_speech_format(res): 48 | n_error += 1 49 | if n_error % 100 == 0: 50 | print(datetime.datetime.now(), f'Does not match Inner Speech format in {scenario_id}:', res) 51 | is_valid_format = False 52 | if not is_valid_format: 53 | return [] 54 | 55 | all_messages = [] 56 | all_tom = defaultdict(list) 57 | for message in messages['messages']: 58 | agent = message[0] 59 | m = message[1] 60 | m = m.replace('\n', ' ') 61 | m = re.sub(r'[ ]+', ' ', m).strip() 62 | tom, res = separate_tom(m) 63 | agent_id = agent_ids[0] if agent == 'Agent A' else agent_ids[1] 64 | all_messages.append([agent_id, tom, res]) 65 | all_tom[agent_id].append(tom) 66 | 67 | valid_agent_ids = [] 68 | for agent_id in all_tom: 69 | all_tom[agent_id] = list(set(all_tom[agent_id])) 70 | if len(all_tom[agent_id]) >= 4: 71 | valid_agent_ids.append(agent_id) 72 | else: 73 | n_error += 1 74 | if n_error % 100 == 0: 75 | print(datetime.datetime.now(), f'Num. of unique options is less than 4 in {scenario_id}:', len(all_tom[agent_id])) 76 | 77 | current_utterances = '' 78 | current_utterances += f'{agents[0]}: "Hi!"\n' 79 | current_utterances += f'{agents[1]}: "Hi, how are you?"\n' 80 | for j, (agent_id, tom, res) in enumerate(all_messages): 81 | agent = sotopia_agent[agent_id] 82 | agent_name = f'{agent["first_name"]} {agent["last_name"]}' 83 | current_utterances += f'{agent_name}: ' + res + '\n' 84 | 85 | mental_state = scenario_id.split('_')[1] 86 | verb = mental_verb[mental_state][0] 87 | third = mental_verb[mental_state][1] 88 | 89 | examples = [] 90 | for j, (agent_id, tom, res) in enumerate(all_messages): 91 | ex = {} 92 | agent = sotopia_agent[agent_id] 93 | agent_name = f'{agent["first_name"]} {agent["last_name"]}' 94 | other_id = agent_ids[(agent_ids.index(agent_id) + 1) % 2] 95 | other = sotopia_agent[other_id] 96 | other_name = f'{other["first_name"]} {other["last_name"]}' 97 | 98 | if not agent_id in valid_agent_ids: 99 | continue 100 | 101 | ## Incorrect option selection 102 | toms = all_tom[agent_id] 103 | incorr = random.sample(list(filter(lambda t: t != tom, toms)), 3) 104 | if not incorr: 105 | print('[IncorrectOptionError]', 'Correct:', tom, ', Candidates:', toms) 106 | continue 107 | corr_idx = random.sample([0, 1, 2, 3], 1)[0] 108 | options = incorr[:corr_idx] + [tom] + incorr[corr_idx:] 109 | options = [o[1:-1] for o in options] 110 | 111 | pronoun = sotopia_agent[agent_id]['gender_pronoun'].split('/')[0] 112 | pronoun_other = sotopia_agent[other_id]['gender_pronoun'].split('/')[0] 113 | if j % 2 == 0: 114 | l = len(f'I {verb}') 115 | options = [f'{pronoun} {third}' + o[l:] for o in options] 116 | q_order = 1 117 | elif j % 2 == 1: 118 | l = len(f'I think that {pronoun_other.lower()} {third}') 119 | options = [f'{pronoun} thinks that {pronoun_other.lower()} {third}' + o[l:] for o in options] 120 | q_order = 2 121 | options = [change_pronoun(o, pronoun.lower(), pronoun_other.lower()) for o in options] 122 | 123 | ex['a_idx'] = corr_idx 124 | ex['a_str'] = options[corr_idx] 125 | for i in range(4): 126 | ex[f'a{i}'] = options[i] 127 | q = generate_question(mental_state, verb, third, q_order, agent["first_name"], other["first_name"], 128 | res, pronoun) 129 | ex['q'] = q 130 | ex['q_id'] = f'{scenario_id}-{j}' 131 | ex['transcript'] = current_utterances 132 | ex['scenario_id'] = scenario_id 133 | examples.append(ex) 134 | return examples 135 | 136 | 137 | def generate_question(mental_state, verb, third, order, name1, name2, 138 | res, pronoun): 139 | assert order in [1, 2] 140 | pronoun = pronoun.lower() 141 | if mental_state == 'emotion': 142 | if order == 1: 143 | q = random.choice([ 144 | f'How does {name1} {verb} when {pronoun} says {res}?', 145 | f'When {name1} says {res}, how does {pronoun} {verb}?' 146 | ]) 147 | elif order == 2: 148 | q = random.choice([ 149 | f'How does {name1} think that {name2} {third} when {pronoun} says {res}?', 150 | f'When {name1} says {res}, how does {pronoun} think that {name2} {third}?' 151 | ]) 152 | elif mental_state == 'intention': 153 | if order == 1: 154 | q = random.choice([ 155 | f'What will {name1} do when {pronoun} says {res}?', 156 | f'When {name1} says {res}, what will {name1} do?' 157 | ]) 158 | elif order == 2: 159 | q = random.choice([ 160 | f'What does {name1} think that {name2} will do when {pronoun} says {res}?', 161 | f'When {name1} says {res}, what does {name1} think that {name2} will do?' 162 | ]) 163 | else: 164 | if order == 1: 165 | q = random.choice([ 166 | f'What does {name1} {verb} when {pronoun} says {res}?', 167 | f'When {name1} says {res}, what does {name1} {verb}?' 168 | ]) 169 | elif order == 2: 170 | q = random.choice([ 171 | f'What does {name1} think that {name2} {third} when {pronoun} says {res}?', 172 | f'When {name1} says {res}, what does {name1} think that {name2} {third}?' 173 | ]) 174 | return q 175 | 176 | 177 | def main(): 178 | args = ArgParser() 179 | 180 | args.model_name = args.model_path.split('/')[1] 181 | output_dir = Path(args.output_dir) 182 | output_dir.mkdir(parents=True, exist_ok=True) 183 | output_file = output_dir / 'tomato.json' 184 | if not args.overwrite: 185 | if output_file.exists(): 186 | print(datetime.datetime.now(), str(output_file), 'already exists. Please use `--overwrite` option to overwrite the file.') 187 | sys.exit() 188 | 189 | data_dir = Path(f'data/conversations/{args.model_name}') 190 | data = list(data_dir.glob('*.json')) 191 | 192 | sotopia_agent = {a['pk']: a for a in load_json('sotopia/agent.json')} 193 | 194 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, padding_side='left') 195 | 196 | random.seed(42) 197 | 198 | all_examples = [] 199 | 200 | for i in tqdm(range(len(data))): 201 | scenario_id = data[i].stem 202 | messages = load_json(data[i]) 203 | scenario = load_json(Path('data/scenarios/') / (scenario_id + '.json')) 204 | messages['messages'] = messages['messages'][:args.max_utterances] 205 | examples = convert_messages_to_examples(messages, scenario, scenario_id, sotopia_agent, tokenizer) 206 | all_examples.extend(examples) 207 | 208 | print(datetime.datetime.now(), 'Num. of Examples:', len(all_examples)) 209 | save_json(all_examples, output_file) 210 | print('Saved:', str(output_file)) 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /code/run_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import json 5 | import datetime 6 | import base64 7 | import time 8 | import requests 9 | import traceback 10 | import argparse 11 | from collections import defaultdict 12 | 13 | import numpy as np 14 | from pathlib import Path 15 | import cv2 16 | import webvtt 17 | from tqdm import tqdm 18 | import openai 19 | from openai import OpenAI 20 | 21 | client = OpenAI( 22 | api_key=os.getenv("OPENAI_API_KEY"), 23 | organization=os.getenv("OPENAI_ORG_ID")) 24 | 25 | from src.utils import load_json, save_json 26 | from prompts import cot_prompts 27 | 28 | idx_to_option = { 29 | 0: "A", 30 | 1: "B", 31 | 2: "C", 32 | 3: "D", 33 | } 34 | 35 | 36 | def ArgParser(): 37 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 38 | parser.add_argument("--data_type", default="siq", type=str, help="") 39 | parser.add_argument("--output_dir", default="output", type=str, help="") 40 | parser.add_argument("--do_gen_qa", default=False, action="store_true", help="") 41 | parser.add_argument("--do_mc_qa", default=False, action="store_true", help="") 42 | parser.add_argument("--model", default=None, type=str) 43 | parser.add_argument("--exp_id", default="default", type=str) 44 | parser.add_argument("--test_file", default=None, type=str) 45 | parser.add_argument("--cot_prompt", default=None, type=str) 46 | parser.add_argument("--log_steps", default=10, type=int, help="") 47 | parser.add_argument("--seed", default=10, type=int, help="") 48 | parser.add_argument("--overwrite", default=False, action="store_true", help="") 49 | parser.add_argument("--debug", default=False, action="store_true", help="") 50 | args = parser.parse_args() 51 | print(args) 52 | return args 53 | 54 | def main(): 55 | args = ArgParser() 56 | 57 | # Load Dataset 58 | if args.data_type == 'tomato': 59 | output_dir = Path(args.output_dir) / args.exp_id 60 | output_dir.mkdir(parents=True, exist_ok=True) 61 | 62 | data_name = args.data_type + '_' + Path(args.test_file).stem 63 | output_name = f'predictions_on_{data_name}_.json' 64 | output_file = output_dir / output_name 65 | if output_file.exists(): 66 | if not args.overwrite: 67 | print(f'[WARNING] {str(output_file)} already exists.') 68 | sys.exit() 69 | results = load_json(output_file) 70 | save_json(results, output_dir / ('old_' + output_name)) 71 | else: 72 | results = {} 73 | elif args.data_type == 'siqa': 74 | output_dir = Path(args.output_dir) / args.exp_id 75 | output_dir.mkdir(parents=True, exist_ok=True) 76 | 77 | data_name = args.data_type + '_' + Path(args.test_file).stem 78 | output_name = f'predictions_on_{data_name}_.json' 79 | output_file = output_dir / output_name 80 | if output_file.exists(): 81 | if not args.overwrite: 82 | print(f'[WARNING] {str(output_file)} already exists.') 83 | sys.exit() 84 | results = load_json(output_file) 85 | save_json(results, output_dir / ('old_' + output_name)) 86 | else: 87 | results = {} 88 | 89 | data_id = args.data_type + '/' + Path(args.test_file).stem 90 | eval_examples = load_json(args.test_file) 91 | if args.debug: 92 | eval_examples = eval_examples[:20] 93 | num_options = 3 94 | 95 | # Multiple-choice QA 96 | if args.do_mc_qa: 97 | true_or_false = [] 98 | errors = [] 99 | 100 | step = 0 101 | for qa in tqdm(eval_examples): 102 | q_id = qa['q_id'] 103 | if q_id in results: 104 | result = results[q_id] 105 | else: 106 | result = {} 107 | 108 | q = qa['q'] 109 | a_str = qa['a_str'] 110 | a_idx = qa['a_idx'] 111 | 112 | result['a_str'] = a_str 113 | result['a_true'] = idx_to_option[a_idx] 114 | if 'scenario_id' in qa: 115 | result['scenario_id'] = qa['scenario_id'] 116 | 117 | transcripts = qa['transcript'] 118 | system_prompt = \ 119 | "You are an expert at understanding human communication. " \ 120 | "Please leverage the information provided and choose the most probable answer to the following question from the options. " \ 121 | "Output your final verdict by strictly following this format: [A], [B], [C], or [D]" 122 | instruction = "# Transcripts\n" + transcripts + "\n" \ 123 | "# Question\n" + q + "\n\n" 124 | 125 | if num_options >= 1: 126 | instruction += "# Options\n[A] " + qa['a0'] 127 | if num_options >= 2: 128 | instruction += "\n" + "[B] " + qa['a1'] 129 | if num_options >= 3: 130 | instruction += "\n" + "[C] " + qa['a2'] 131 | if num_options >= 4: 132 | instruction += "\n" + "[D] " + qa['a3'] 133 | if args.cot_prompt is not None: 134 | instruction += "\n\n" + cot_prompts[args.cot_prompt] 135 | PROMPT_MESSAGES = [ 136 | {"role": "system", "content": system_prompt}, 137 | {"role": "user", "content": instruction}, 138 | ] 139 | if step % args.log_steps == 0: 140 | print(PROMPT_MESSAGES) 141 | 142 | params = { 143 | "model": args.model, 144 | "messages": PROMPT_MESSAGES, 145 | "temperature": 0.6, 146 | "top_p": 0.9, 147 | "max_tokens": 500, 148 | "seed": args.seed, 149 | } 150 | if 'a_pred_gen' in result: 151 | print(datetime.datetime.now(), f'q_id: {q_id} has been chosen before, so it is skipped.') 152 | else: 153 | n_trial = 0 154 | max_trial = 10 155 | rest = 10 156 | while n_trial < max_trial: 157 | try: 158 | response = client.chat.completions.create(**params) 159 | a_mc = response.choices[0].message.content 160 | result['a_gen'] = a_mc 161 | if step % args.log_steps == 0: 162 | print(datetime.datetime.now(), f'Question {q_id} - Correct Option:') 163 | print(idx_to_option[a_idx]) 164 | print(datetime.datetime.now(), f'Question {q_id} - Multiple-choice QA:') 165 | print(a_mc) 166 | print(datetime.datetime.now(), f'Question {q_id} - Multiple-choice QA: Suceeded!') 167 | time.sleep(1) 168 | break 169 | except: 170 | traceback.print_exc() 171 | n_trial += 1 172 | print(datetime.datetime.now(), f'Question {q_id} - Multiple-choice QA: ... failed {n_trial} times. Let me try again after {rest} seconds ...') 173 | time.sleep(rest) 174 | if 'a_gen' in result: 175 | a_pred = extract_answer_from_response(result['a_gen']) 176 | result['a_pred_gen'] = a_pred 177 | true_or_false.append(result['a_true'] == a_pred) 178 | else: 179 | errors.append(q_id) 180 | 181 | step += 1 182 | results[q_id] = result 183 | save_json(results, output_file) 184 | 185 | save_json(results, output_file) 186 | print('Saved:', output_file) 187 | if args.do_mc_qa: 188 | print(datetime.datetime.now(), f'Num. of Errors: {len(errors)}') 189 | 190 | 191 | def extract_answer_from_response(response): 192 | a_pred = re.search(r'\[[A,B,C,D]\]', response) 193 | if a_pred is not None: 194 | s, e = a_pred.span() 195 | a_pred = response[s:e][1] 196 | return a_pred 197 | else: 198 | return '' 199 | 200 | 201 | if __name__ == '__main__': 202 | main() 203 | -------------------------------------------------------------------------------- /code/run_inner_speech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import datetime 4 | import base64 5 | import time 6 | import requests 7 | import traceback 8 | import argparse 9 | from collections import defaultdict 10 | from pathlib import Path 11 | import random 12 | 13 | import numpy as np 14 | import torch 15 | from tqdm import tqdm 16 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 17 | 18 | from src.nn import load_model_hf, step, multi_turn_conversation 19 | from src.utils import load_json, save_json, remove_tom, get_formatted_conv 20 | 21 | 22 | def ArgParser(): 23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument("--exp_id", default="default", type=str) 25 | parser.add_argument("--scenario_dir", default="data/scenarios", type=str) 26 | parser.add_argument("--scenario_list", default=None, type=str) 27 | parser.add_argument("--data_dir", default="", type=str, help="") 28 | parser.add_argument("--output_dir", default="output", type=str, help="") 29 | parser.add_argument("--output_suffix", default="", type=str, help="") 30 | parser.add_argument("--n_turn", default=5, type=int, help="") 31 | parser.add_argument("--conv_mode", default="", type=str, help="") 32 | parser.add_argument("--log_steps", default=10, type=int, help="") 33 | parser.add_argument("--debug", default=False, action="store_true", help="") 34 | parser.add_argument("--eval_batch_size", default=8, type=int, help="") 35 | parser.add_argument("--model_path", default="meta-llama/Meta-Llama-3-70B-Instruct", type=str) 36 | parser.add_argument("--disable_inner_speech", default=False, action="store_true", help="") 37 | parser.add_argument("--keep_inner_speech", default=False, action="store_true", help="") 38 | parser.add_argument("--overwrite", default=False, action="store_true", help="") 39 | parser.add_argument("--load_8bit", default=False, action="store_true", help="") 40 | parser.add_argument("--load_4bit", default=False, action="store_true", help="") 41 | parser.add_argument("--do_sample", default=False, action="store_true", help="") 42 | parser.add_argument("--top_p", default=0.9, type=float, help="") 43 | parser.add_argument("--temperature", default=0.6, type=float, help="") 44 | parser.add_argument("--seed", default=42, type=int, help="") 45 | args = parser.parse_args() 46 | print(args) 47 | return args 48 | 49 | 50 | def set_seed(seed): 51 | torch.manual_seed(seed) 52 | 53 | 54 | def main(): 55 | args = ArgParser() 56 | 57 | print(datetime.datetime.now(), f'Started: exp_id = {args.exp_id}') 58 | 59 | set_seed(args.seed) 60 | 61 | if torch.cuda.is_available(): 62 | device = 'cuda' 63 | elif torch.backends.mps.is_available(): 64 | device = "mps" 65 | else: 66 | device = 'cpu' 67 | 68 | model_name = args.model_path.split('/')[1] 69 | output_dir = Path(args.output_dir) 70 | output_dir.mkdir(parents=True, exist_ok=True) 71 | 72 | model, tokenizer = load_model_hf(args.model_path, args.load_8bit, args.load_4bit, device=device) 73 | 74 | if args.scenario_list is None: 75 | files = list(Path(args.scenario_dir).glob('*.json')) 76 | ids = [f.stem for f in files] 77 | else: 78 | with open(args.scenario_list, 'r') as f: 79 | ids = f.readlines() 80 | ids = [i.strip() for i in ids] 81 | 82 | random.shuffle(ids) 83 | for scenario_id in tqdm(ids): 84 | output_file = output_dir / f'{scenario_id}.json' 85 | if output_file.exists(): 86 | conv_config = load_json(output_file) 87 | if 'this_will_be_generated_by' in conv_config: 88 | if conv_config['this_will_be_generated_by'] != args.exp_id: 89 | print(output_file.name, 'will be generated by others, so it\'s skipped.') 90 | continue 91 | else: 92 | save_json({'this_will_be_generated_by': args.exp_id}, output_dir / f'{scenario_id}.json') 93 | else: 94 | save_json({'this_will_be_generated_by': args.exp_id}, output_dir / f'{scenario_id}.json') 95 | print(output_file.name, 'will be generated.') 96 | scenario = load_json(Path(args.scenario_dir) / f"{scenario_id}.json") 97 | 98 | conv_config = { 99 | "n_turn": args.n_turn, 100 | "agent1": args.model_path, 101 | "agent2": args.model_path, 102 | "init_inst1": scenario["init_inst1"], 103 | "init_inst2": scenario["init_inst2"], 104 | "system1": scenario["system1"], 105 | "system2": scenario["system2"], 106 | "tom_prompt1": scenario["tom_prompt1"], 107 | "tom_prompt2": scenario["tom_prompt2"], 108 | "inner_speech": not args.disable_inner_speech, 109 | "do_sample": args.do_sample, 110 | "top_p": args.top_p, 111 | "temperature": args.temperature, 112 | "keep_inner_speech": args.keep_inner_speech 113 | } 114 | 115 | conv1, conv2, messages = multi_turn_conversation( 116 | model, 117 | tokenizer, 118 | args.conv_mode, 119 | **conv_config) 120 | conv_config["agent1"] = conv1.__dict__ 121 | conv_config["agent2"] = conv2.__dict__ 122 | conv_config["agent1"]["sep_style"] = conv_config["agent1"]["sep_style"].value 123 | conv_config["agent2"]["sep_style"] = conv_config["agent2"]["sep_style"].value 124 | conv_config.update(args.__dict__) 125 | conv_config["messages"] = messages 126 | if "sotopia" in scenario: 127 | conv_config["sotopia"] = scenario["sotopia"] 128 | 129 | save_json(conv_config, output_dir / f'{scenario_id}.json') 130 | 131 | print('Finished!') 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /code/run_local_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import random 5 | import datetime 6 | import base64 7 | import time 8 | import requests 9 | import traceback 10 | import argparse 11 | from collections import defaultdict 12 | 13 | import numpy as np 14 | import torch 15 | from torch.nn import CrossEntropyLoss 16 | from transformers import ( 17 | AutoModelForCausalLM, 18 | AutoTokenizer, 19 | BitsAndBytesConfig, 20 | HfArgumentParser, 21 | TrainingArguments, 22 | pipeline, 23 | logging, 24 | ) 25 | from peft import ( 26 | prepare_model_for_kbit_training, 27 | LoraConfig, 28 | PeftModel, 29 | PeftConfig, 30 | get_peft_model 31 | ) 32 | from datasets import Dataset, load_dataset 33 | 34 | from pathlib import Path 35 | from tqdm import tqdm 36 | 37 | from src.utils import load_json, save_json, mental_verb 38 | from src.nn import load_model_hf 39 | from prompts import cot_prompts 40 | 41 | 42 | idx_to_option = { 43 | 0: "A", 44 | 1: "B", 45 | 2: "C", 46 | 3: "D", 47 | } 48 | 49 | data_to_num_options = { 50 | 'tomato': 4, 51 | 'siqa': 3 52 | } 53 | 54 | sotopia_agent = {a['pk']: a for a in load_json('data/sotopia/agent.json')} 55 | sotopia_combo = {a['pk'][-5:]: a for a in load_json('data/sotopia/combo.json')} 56 | sotopia_environment = {a['pk']: a for a in load_json('data/sotopia/environment.json')} 57 | 58 | 59 | def ArgParser(): 60 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 61 | parser.add_argument("--method", default="", type=str) 62 | parser.add_argument("--data_type", default="siq", type=str, help="") 63 | parser.add_argument("--test_files", default=[], nargs='*', type=str, help="") 64 | parser.add_argument("--exclude_options_from_prompt", default=False, action="store_true", help="") 65 | parser.add_argument("--cot_prompt", default=None, type=str) 66 | parser.add_argument("--output_dir", default="output", type=str, help="") 67 | parser.add_argument("--output_suffix", default="", type=str, help="") 68 | parser.add_argument("--do_gen_qa", default=False, action="store_true", help="") 69 | parser.add_argument("--do_mc_qa", default=False, action="store_true", help="") 70 | parser.add_argument("--do_sample", default=False, action="store_true", help="") 71 | parser.add_argument("--top_p", default=0.9, type=float, help="") 72 | parser.add_argument("--temperature", default=0.6, type=float, help="") 73 | parser.add_argument("--exp_id", default="default", type=str) 74 | parser.add_argument("--save_steps", default=500, type=int, help="") 75 | parser.add_argument("--log_steps", default=100, type=int, help="") 76 | parser.add_argument("--debug", default=False, action="store_true", help="") 77 | parser.add_argument("--eval_batch_size", default=8, type=int, help="") 78 | parser.add_argument("--model_path", default="", type=str) 79 | parser.add_argument("--overwrite", default=False, action="store_true", help="") 80 | parser.add_argument("--seed", default=42, type=int, help="") 81 | parser.add_argument("--load_8bit", default=False, action="store_true", help="") 82 | parser.add_argument("--load_4bit", default=False, action="store_true", help="") 83 | parser.add_argument("--load_bf16", default=False, action="store_true", help="") 84 | args = parser.parse_args() 85 | print(args) 86 | return args 87 | 88 | 89 | def set_seed(seed): 90 | rnd = random.Random() 91 | rnd.seed(seed) 92 | torch.manual_seed(seed) 93 | np.random.seed(seed) 94 | return rnd 95 | 96 | 97 | def get_model(model_path, load_8bit, load_4bit, load_bf16): 98 | device = 'cuda' 99 | if (Path(model_path) / 'adapter_config.json').exists(): 100 | load_adapter = True 101 | else: 102 | load_adapter = False 103 | 104 | if load_adapter: 105 | print(datetime.datetime.now(), f'Loading Peft Model from {model_path} ...') 106 | peft_config = PeftConfig.from_pretrained(model_path) 107 | model, tokenizer = load_model_hf(peft_config.base_model_name_or_path, load_8bit, load_4bit, load_bf16, device=device) 108 | model = PeftModel.from_pretrained(model, model_path, device_map=device) 109 | processor = context_len = None 110 | else: 111 | print(datetime.datetime.now(), f'Loading Huggingface Model from {model_path} ...') 112 | model, tokenizer = load_model_hf(model_path, load_8bit, load_4bit, load_bf16, device=device) 113 | processor = context_len = None 114 | return tokenizer, model, processor, context_len 115 | 116 | 117 | def compute_clm_loss(logits, labels, reduction, pad_id=0): 118 | shift_logits = logits[..., :-1, :].contiguous() 119 | shift_labels = labels[..., 1:].contiguous() 120 | # Flatten the tokens 121 | vocab_size = logits.size(-1) 122 | loss_fct = CrossEntropyLoss(reduction=reduction) 123 | shift_logits = shift_logits.view(-1, vocab_size) 124 | shift_labels = shift_labels.view(-1) 125 | # Enable model/pipeline parallelism 126 | shift_labels = shift_labels.to(shift_logits.device) 127 | loss = loss_fct(shift_logits, shift_labels) 128 | loss = loss * (shift_labels != pad_id) 129 | return loss 130 | 131 | 132 | def generate_prompt_qa( 133 | args, 134 | b, 135 | num_options, 136 | tokenizer, 137 | qa_mode='mcqa', 138 | append_option=False, 139 | option_j=None): 140 | assert qa_mode in ['mcqa', 'genqa'] 141 | context_name = 'Transcript' if args.data_type != 'siqa' else 'Context' 142 | inp = '' 143 | inp += f'# {context_name} \n' + b['transcript'] + '\n\n' 144 | inp += '# Question \n' + b['q'] 145 | if args.exclude_options_from_prompt: 146 | if append_option: 147 | opt = b[f'a{option_j}'] 148 | system_prompt = 'You are an expert at understanding human communication. ' \ 149 | 'Please leverage the information provided and generate an answer in one sentence to the question.' 150 | else: 151 | inp += '\n\n' + '# Options \n' 152 | inp += '[A] ' + b['a0'] + '\n' 153 | if num_options >= 2: 154 | inp += '[B] ' + b['a1'] + '\n' 155 | if num_options >= 3: 156 | inp += '[C] ' + b['a2'] + '\n' 157 | if num_options >= 4: 158 | inp += '[D] ' + b['a3'] + '\n' 159 | if num_options >= 5: 160 | raise NotImplementedError(f'`num_options = {num_options}` is not supported yet') 161 | if append_option: 162 | opt = '[' + idx_to_option[option_j] + ']' 163 | system_prompt = 'You are an expert at understanding human communication. ' \ 164 | 'Please leverage the information provided and choose the most probable answer to the question from the options. ' \ 165 | 'Output your final answer by strictly following this format: [A], [B], [C], or [D]' 166 | if args.do_disable_system_prompt: 167 | # For llms that do not support system prompts, such as Mistral 168 | chat = [ 169 | {"role": "user", "content": system_prompt + '\n\n' + inp} 170 | ] 171 | else: 172 | chat = [ 173 | {"role": "system", "content": system_prompt}, 174 | {"role": "user", "content": inp} 175 | ] 176 | if qa_mode == 'mcqa' and append_option: # for evaluation 177 | chat.append({"role": "assistant", "content": opt}) 178 | if qa_mode == 'genqa' and args.cot_prompt is not None: 179 | res_prefix = cot_prompts[args.cot_prompt] 180 | chat.append({"role": "assistant", "content": res_prefix}) 181 | if not 'google/flan-t5' in args.model_path: 182 | if 'gemma' in args.model_path: 183 | prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=args.cot_prompt is None) 184 | else: 185 | prompt = tokenizer.apply_chat_template(chat, tokenize=False) 186 | else: 187 | # For LLMs Without Multi-Turn Conversation Ability 188 | prompt = "\n\n".join([c["content"] for c in chat]) 189 | if qa_mode == 'genqa' and args.cot_prompt is not None: 190 | if 'gemma' in args.model_path: 191 | if prompt.endswith("\n"): 192 | prompt = prompt[:-len("\n")].strip() 193 | else: 194 | if prompt.endswith(tokenizer.eos_token): 195 | prompt = prompt[:-len(tokenizer.eos_token)].strip() 196 | return prompt 197 | 198 | 199 | def evaluate(args, model, processor, tokenizer, eval_examples, data_name, output_file): 200 | 201 | batch_size = args.eval_batch_size 202 | tokenizer.padding_side = "right" 203 | tokenizer.truncation_side = "right" 204 | if tokenizer.pad_token is None: 205 | tokenizer.pad_token = tokenizer.eos_token 206 | if hasattr(model.config, "max_position_embeddings"): 207 | tokenizer.max_length = model.config.max_position_embeddings 208 | else: 209 | tokenizer.max_length = 8192 210 | 211 | if 'Llama-3' in args.model_path: 212 | conv_mode = "llama_3" 213 | else: 214 | conv_mode = "default" 215 | print(datetime.datetime.now(), f'`conv_mode` is set to {conv_mode}') 216 | 217 | n_examples = len(eval_examples) 218 | n_batches = n_examples // batch_size + int(n_examples % batch_size != 0) 219 | all_outputs = {} 220 | 221 | num_options = data_to_num_options[args.data_type] 222 | 223 | for i in tqdm(range(n_batches)): 224 | batch = eval_examples[i*batch_size:(i+1)*batch_size] 225 | 226 | if args.do_mc_qa: 227 | lls = [] 228 | for j in range(num_options): # iterate for num of options 229 | inputs_text = [] 230 | for b in batch: 231 | prompt = generate_prompt_qa( 232 | args, b, num_options, tokenizer, qa_mode='mcqa', append_option=True, option_j=j) 233 | inputs_text.append(prompt) 234 | 235 | inputs = tokenizer(inputs_text, add_special_tokens=False, padding=True, truncation=True, return_tensors='pt') 236 | inputs = {k: inputs[k].cuda() for k in inputs} 237 | with torch.inference_mode(): 238 | outputs = model(**inputs) 239 | lls.append(outputs.loss.view(len(batch), -1).sum(1, keepdim=True).cpu()) 240 | lls = torch.cat(lls, dim=1) 241 | if i % args.log_steps == 0: 242 | print(datetime.datetime.now()) 243 | print('Input', json.dumps(batch[0], indent=4)) 244 | print('Decoded input', tokenizer.decode(inputs['input_ids'][0])) 245 | print('Input IDs', inputs['input_ids'][0]) 246 | print('Output (NLL)', lls[0]) 247 | for b, ll in zip(batch, lls): 248 | ll = ll.cpu().numpy().tolist() 249 | all_outputs[b['q_id']] = { 250 | 'll': ll, 251 | 'pred': idx_to_option[int(np.argmin(ll))] 252 | } 253 | 254 | if args.do_gen_qa: 255 | if 'Llama-2' in args.model_path: 256 | response_template = "[/INST]" 257 | tokenizer.padding_side = "left" 258 | elif 'Llama-3' in args.model_path: 259 | response_template = "<|start_header_id|>assistant<|end_header_id|>" 260 | tokenizer.padding_side = "left" 261 | elif 'mistralai' in args.model_path: 262 | response_template = "[/INST]" 263 | tokenizer.padding_side = "left" 264 | elif 'flan-t5' in args.model_path: 265 | response_template = "" 266 | tokenizer.padding_side = "left" 267 | elif 'gemma' in args.model_path: 268 | response_template = "model" 269 | tokenizer.padding_side = "left" 270 | else: 271 | raise NotImplementedError(args.model_path) 272 | 273 | inputs_text = [] 274 | for b in batch: 275 | prompt = generate_prompt_qa( 276 | args, b, num_options, tokenizer, qa_mode='genqa') 277 | inputs_text.append(prompt) 278 | inputs = tokenizer(inputs_text, add_special_tokens=False, padding=True, truncation=True, return_tensors='pt') 279 | inputs = {k: inputs[k].cuda() for k in inputs} 280 | inputs["max_new_tokens"] = 256 281 | if args.do_sample: 282 | inputs["do_sample"] = True 283 | inputs["top_p"] = args.top_p 284 | inputs["temperature"] = args.temperature 285 | with torch.inference_mode(): 286 | if conv_mode == 'llama_3': 287 | outputs = model.generate( 288 | **inputs, 289 | eos_token_id=[ 290 | tokenizer.eos_token_id, 291 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 292 | ], 293 | pad_token_id=tokenizer.eos_token_id 294 | ) 295 | else: 296 | outputs = model.generate(**inputs) 297 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=False) 298 | if i % args.log_steps == 0: 299 | print(datetime.datetime.now()) 300 | print('Input', json.dumps(batch[0], indent=4)) 301 | print('Decoded input', tokenizer.decode(inputs['input_ids'][0])) 302 | print('Input IDs', inputs['input_ids'][0]) 303 | print('Decoded output', decoded_outputs[0]) 304 | for b, o in zip(batch, decoded_outputs): 305 | if not b['q_id'] in all_outputs: 306 | all_outputs[b['q_id']] = {} 307 | if response_template != "": 308 | o = o.split(response_template)[-1].strip() 309 | all_outputs[b['q_id']]['gen'] = o 310 | 311 | if output_file.exists(): 312 | results = load_json(output_file) 313 | print(datetime.datetime.now(), f'{str(output_file)} is loaded, and will be overwritten.') 314 | else: 315 | results = {} 316 | true_or_false = [] 317 | true_or_false_gen = [] 318 | for e in eval_examples: 319 | q_id = e['q_id'] 320 | a_corr = e['a_str'] 321 | a_idx = e['a_idx'] 322 | if q_id in results: 323 | result = results[q_id] 324 | else: 325 | result = {} 326 | result['q_id'] = q_id 327 | result['a_str'] = a_corr 328 | result['a_true'] = idx_to_option[a_idx] 329 | if 'pred' in all_outputs[q_id]: 330 | result['a_pred'] = all_outputs[q_id]['pred'] 331 | true_or_false.append(idx_to_option[a_idx] == all_outputs[q_id]['pred']) 332 | if 'll' in all_outputs[q_id]: 333 | result['a_score'] = all_outputs[q_id]['ll'] 334 | if 'gen' in all_outputs[q_id]: 335 | result['a_gen'] = all_outputs[q_id]['gen'] 336 | a_pred_gen = re.search(r'\[[A,B,C,D]\]', all_outputs[q_id]['gen']) 337 | if a_pred_gen is not None: 338 | start, end = a_pred_gen.span() 339 | a_pred_gen = all_outputs[q_id]['gen'][start:end][1] 340 | else: 341 | a_pred_gen = '' 342 | result['a_pred_gen'] = a_pred_gen 343 | true_or_false_gen.append(idx_to_option[a_idx] == a_pred_gen) 344 | if 'scenario_id' in e: 345 | result['scenario_id'] = e['scenario_id'] 346 | results[q_id] = result 347 | if args.do_mc_qa: 348 | acc = float(np.mean(true_or_false)) 349 | print('Accuracy:', acc) 350 | if args.do_gen_qa: 351 | acc = float(np.mean(true_or_false_gen)) 352 | print('Accuracy:', acc) 353 | save_json(results, output_file) 354 | print(datetime.datetime.now(), f'Saved {str(output_file)}.') 355 | return results 356 | 357 | 358 | def main(): 359 | args = ArgParser() 360 | 361 | rnd = set_seed(args.seed) 362 | args.rnd = rnd 363 | 364 | # Disable system prompt for some llms 365 | args.do_disable_system_prompt = ('mistralai' in args.model_path or 'gemma' in args.model_path) 366 | 367 | output_dir = Path(args.output_dir) / args.exp_id 368 | output_dir.mkdir(parents=True, exist_ok=True) 369 | 370 | for test_file in args.test_files: 371 | data_name = args.data_type + '_' + Path(test_file).stem # used for output file name 372 | output_name = f'predictions_on_{data_name}_{args.output_suffix}.json' 373 | output_file = output_dir / output_name 374 | if output_file.exists(): 375 | print(datetime.datetime.now(), f'[WARNING] {str(output_file)} already exists.') 376 | if not args.overwrite: 377 | return None 378 | else: 379 | save_json({'CreatedAt': str(datetime.datetime.now())}, output_file) 380 | eval_examples = load_json(test_file) 381 | 382 | ## Load Model 383 | tokenizer, model, processor, context_len = get_model( 384 | args.model_path, args.load_8bit, args.load_4bit, args.load_bf16) 385 | model.config.use_cache = True 386 | 387 | if args.debug: 388 | eval_examples = eval_examples[:8] 389 | print(datetime.datetime.now(), f'{len(eval_examples)} examples will be evaluated.') 390 | results = evaluate(args, model, processor, tokenizer, eval_examples, data_name, output_file) 391 | 392 | 393 | if __name__ == '__main__': 394 | main() 395 | -------------------------------------------------------------------------------- /code/src/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig, BitsAndBytesConfig 3 | from src.utils import remove_tom 4 | 5 | 6 | def load_model_hf(model_path, load_8bit, load_4bit, load_bf16, device='auto'): 7 | kwargs = {"device_map": "auto"} 8 | if load_8bit: 9 | kwargs['load_in_8bit'] = True 10 | elif load_4bit: 11 | kwargs['quantization_config'] = BitsAndBytesConfig( 12 | load_in_4bit=True, 13 | bnb_4bit_compute_dtype=torch.float16, 14 | bnb_4bit_use_double_quant=True, 15 | bnb_4bit_quant_type='nf4' 16 | ) 17 | elif load_bf16: 18 | kwargs['torch_dtype'] = torch.bfloat16 19 | else: 20 | kwargs['torch_dtype'] = torch.float16 21 | 22 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side='left') 23 | tokenizer.pad_token = tokenizer.eos_token 24 | if 'google/flan-t5' in model_path: 25 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 26 | else: 27 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 28 | return model, tokenizer 29 | 30 | 31 | def step(model, tokenizer, conv_mode, prompts, do_sample=False, top_p=0.9, temperature=0.6): 32 | inputs = tokenizer(prompts, add_special_tokens=True, return_tensors='pt', padding=True) 33 | for k in inputs: 34 | inputs[k] = inputs[k].to(model.device) 35 | inputs["max_new_tokens"] = 256 36 | if do_sample: 37 | inputs["do_sample"] = True 38 | inputs["top_p"] = top_p 39 | inputs["temperature"] = temperature 40 | with torch.inference_mode(): 41 | if conv_mode == 'llama_3': 42 | outputs = model.generate( 43 | **inputs, 44 | eos_token_id=[ 45 | tokenizer.eos_token_id, 46 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 47 | ], 48 | pad_token_id=tokenizer.eos_token_id 49 | ) 50 | else: 51 | outputs = model.generate(**inputs) 52 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=False) 53 | responces = [] 54 | for out in decoded_outputs: 55 | if conv_mode == 'llama_2': 56 | res = out.split('[/INST]')[-1].strip().replace("", "") 57 | elif conv_mode == 'llama_3': 58 | res = out.split('<|start_header_id|>assistant<|end_header_id|>\n\n')[-1].strip().replace("<|eot_id|>", "") 59 | responces.append(res) 60 | return responces 61 | 62 | 63 | def multi_turn_conversation( 64 | model, 65 | tokenizer, 66 | conv_mode, 67 | n_turn=0, 68 | agent1="", 69 | agent2="", 70 | system1="", 71 | system2="", 72 | init_inst1="", 73 | init_inst2="", 74 | tom_prompt1="", 75 | tom_prompt2="", 76 | inner_speech=False, 77 | do_sample=False, 78 | top_p=0.9, 79 | temperature=0.6, 80 | keep_inner_speech=False): 81 | conv1 = [] 82 | conv2 = [] 83 | 84 | conv1.append({'role': 'system', 'content': system1}) 85 | conv2.append({'role': 'system', 'content': system2}) 86 | 87 | if inner_speech and (not keep_inner_speech): 88 | init_inst2 = remove_tom(init_inst2) 89 | conv2.append({'role': 'user', 'content': init_inst2}) 90 | conv2.append({'role': 'assistant', 'content': init_inst1}) 91 | if inner_speech and (not keep_inner_speech): 92 | init_inst1 = remove_tom(init_inst1) 93 | conv1.append({'role': 'user', 'content': init_inst1}) 94 | messages = [] 95 | 96 | for i in range(n_turn): 97 | prompt1 = tokenizer.apply_chat_template(conv1, tokenize=False) 98 | if inner_speech: 99 | prompt1 += tom_prompt1 100 | res1 = step(model, tokenizer, conv_mode, [prompt1], do_sample=do_sample, top_p=top_p, temperature=temperature)[0] 101 | conv1.append({'role': 'assistant', 'content': res1}) 102 | messages.append(['Agent A', res1]) 103 | if inner_speech and (not keep_inner_speech): 104 | res1 = remove_tom(res1) 105 | conv2.append({'role': 'user', 'content': res1}) 106 | prompt2 = tokenizer.apply_chat_template(conv2, tokenize=False) 107 | if inner_speech: 108 | prompt2 += tom_prompt2 109 | res2 = step(model, tokenizer, conv_mode, [prompt2], do_sample=do_sample, top_p=top_p, temperature=temperature)[0] 110 | conv2.append({'role': 'assistant', 'content': res2}) 111 | messages.append(['Agent B', res2]) 112 | if inner_speech and (not keep_inner_speech): 113 | res2 = remove_tom(res2) 114 | conv1.append({'role': 'user', 'content': res2}) 115 | return conv1, conv2, messages 116 | -------------------------------------------------------------------------------- /code/src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import spacy 3 | nlp = spacy.load('en_core_web_sm') 4 | import json 5 | 6 | def load_json(file): 7 | with open(file, 'r') as f: 8 | return json.load(f) 9 | 10 | def save_json(data, file): 11 | with open(file, 'w') as f: 12 | json.dump(data, f, indent=4) 13 | 14 | def remove_tom(res): 15 | match = re.findall(r'\(.*\)', res) 16 | for m in match: 17 | res = res.replace(m, '') 18 | return res.strip() 19 | 20 | def get_formatted_conv(messages): 21 | out = '' 22 | for m in messages: 23 | out += '> ' + m[0] + '\n' 24 | out += m[1] + '\n\n' 25 | return out 26 | 27 | def is_match_inner_speech_format(res): 28 | match = re.fullmatch(r'\([^\(\)]*\) "[^"\(\)]*"', res) 29 | if match is not None: 30 | return True 31 | else: 32 | return False 33 | 34 | def separate_tom(res): 35 | match = re.search(r'\(.*\)', res) 36 | s, e = match.span() 37 | tom = res[s:e].strip() 38 | res = res.replace(tom, '').strip() 39 | return tom, res 40 | 41 | mental_states = ["emotion", "belief", "intention", "desire", "knowledge"] 42 | mental_verb = { 43 | "emotion": ["feel", "feels", "felt"], 44 | "belief": ["think", "thinks", "thought"], 45 | "intention": ["will", "will", "would"], 46 | "desire": ["want", "wants", "wanted"], 47 | "knowledge": ["know", "knows", "knew"], 48 | } 49 | 50 | def get_tom_prompt(mental_state, order, pronoun=None): 51 | if order == 1: 52 | return f'(I {mental_verb[mental_state][0]}' 53 | elif order == 2: 54 | assert pronoun is not None 55 | return f'(I think that {pronoun} {mental_verb[mental_state][1]}' 56 | elif order == 0: 57 | return '' 58 | else: 59 | raise NotImplementedError(order) 60 | 61 | def change_pronoun(input, pronoun, pronoun_other): 62 | assert pronoun in ['he', 'she', 'they'] 63 | assert pronoun_other in ['he', 'she', 'they'] 64 | if pronoun == 'he': 65 | dict = {'I': 'he', 'me': 'him', 'my': 'his', 'mine': 'his', 'myself': 'himself', 66 | 'we': 'they', 'our': 'their', 'us': 'them', 'ours': 'theirs'} 67 | replace_dict = {'I\'m': 'he\'s', 'I\'ve': 'he has', 'I\'ll': 'he will', 'I\'d': 'he would'} 68 | elif pronoun == 'she': 69 | dict = {'I': 'she', 'me': 'her', 'my': 'her', 'mine': 'hers', 'myself': 'herself', 70 | 'we': 'they', 'our': 'their', 'us': 'them', 'ours': 'theirs'} 71 | replace_dict = {'I\'m': 'she\'s', 'I\'ve': 'she has', 'I\'ll': 'she will', 'I\'d': 'she would'} 72 | elif pronoun == 'they': 73 | dict = {'I': 'they', 'me': 'them', 'my': 'their', 'mine': 'theirs', 'myself': 'themselves', 'I\'m': 'they\'re', 74 | 'we': 'they', 'our': 'their', 'us': 'them', 'ours': 'theirs'} 75 | replace_dict = {'I\'m': 'they\'re', 'I\'ve': 'they have', 'I\'ll': 'they will', 'I\'d': 'they would'} 76 | 77 | if pronoun_other == 'he': 78 | dict.update({'you': 'he', 'you': 'him', 'your': 'his', 'yours': 'his', 'yourself': 'himself'}) 79 | replace_dict.update({'you\'re': 'he\'s', 'you\'ve': 'he has', 'you\'ll': 'he will'}) 80 | elif pronoun_other == 'she': 81 | dict.update({'you': 'she', 'you': 'her', 'your': 'her', 'yours': 'hers', 'yourself': 'herself'}) 82 | replace_dict.update({'you\'re': 'she\'s', 'you\'ve': 'she has', 'you\'ll': 'she will'}) 83 | elif pronoun_other == 'they': 84 | dict.update({'you': 'they', 'you': 'them', 'your': 'their', 'yours': 'theirs', 'yourself': 'themselves'}) 85 | replace_dict.update({'you\'re': 'they\'re', 'you\'ve': 'they have', 'you\'ll': 'they will'}) 86 | 87 | for k, v in replace_dict.items(): 88 | input = input.replace(k, v) 89 | out = [] 90 | doc = nlp(input) 91 | for sent in doc.sents: 92 | for tok in sent: 93 | out.append(dict.get(tok.text, tok.text)) 94 | out.append(tok.whitespace_) 95 | break 96 | return "".join(out) 97 | -------------------------------------------------------------------------------- /dataset/LICENSE: -------------------------------------------------------------------------------- 1 | META LLAMA 3 COMMUNITY LICENSE AGREEMENT 2 | Meta Llama 3 Version Release Date: April 18, 2024 3 | 4 | “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Llama Materials set forth herein. 5 | “Documentation” means the specifications, manuals and documentation accompanying Meta Llama 3 distributed by Meta at https://llama.meta.com/get-started/. 6 | “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. 7 | “MetaLlama 3” means the foundational large language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Meta at https://llama.meta.com/llama-downloads. 8 | “Llama Materials” means, collectively, Meta’s proprietary Meta Llama 3 and Documentation (and any portion thereof) made available under this Agreement. 9 | “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). 10 | By clicking “I Accept” below or by using or distributing any portion or element of the Llama Materials, you agree to be bound by this Agreement. 11 | 12 | 1. License Rights and Redistribution. 13 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Llama Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Llama Materials. 14 | b. Redistribution and Use. 15 | i. If you distribute or make available the Llama Materials (or any derivative works thereof), or a product or service that uses any of them, including another AI model, you shall (A) provide a copy of this Agreement with any such Llama Materials; and (B) prominently display “Built with Meta Llama 3” on a related website, user interface, blogpost, about page, or product documentation. If you use the Llama Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “Llama 3” at the beginning of any such AI model name. 16 | 17 | ii. If you receive Llama Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you. 18 | 19 | iii. You must retain in all copies of the Llama Materials that you distribute the following attribution notice within a “Notice” text file distributed as a part of such copies: “Meta Llama 3 is licensed under the Meta Llama 3 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.” 20 | 21 | iv. Your use of the Llama Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Llama Materials (available at https://llama.meta.com/llama3/use-policy), which is hereby incorporated by reference into this Agreement. 22 | v. You will not use the Llama Materials or any output or results of the Llama Materials to improve any other large language model (excluding Meta Llama 3 or derivative works thereof). 23 | 24 | 2. Additional Commercial Terms. If, on the Meta Llama 3 version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee’s affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Meta, which Meta may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Meta otherwise expressly grants you such rights. 25 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. 26 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 27 | 5. Intellectual Property. 28 | a. No trademark licenses are granted under this Agreement, and in connection with the Llama Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Llama Materials or as set forth in this Section 5(a). Meta hereby grants you a license to use “Llama 3” (the “Mark”) solely as required to comply with the last sentence of Section 1.b.i. You will comply with Meta’s brand guidelines (currently accessible at https://about.meta.com/brand/resources/meta/company-brand/). All goodwill arising out of your use of the Mark will inure to the benefit of Meta. 29 | b. Subject to Meta’s ownership of Llama Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Llama Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. 30 | 31 | c. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama Materials or Meta Llama 3 outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Llama Materials. 32 | 33 | 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Llama Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement. 34 | 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # ToMATO Benchmark 2 | This directory contains the ToMATO benchmark, which we presented in our paper.
3 | 4 | ## Contents 5 | - tomato.json 6 | - contains all the examples in ToMATO 7 | - tomato_first.json 8 | - contains first-order ToM questions in ToMATO 9 | - tomato_second.json 10 | - contains second-order ToM questions in ToMATO 11 | - tomato_fb.json 12 | - contains false-belief tasks in ToMATO (ToMATO-FB) 13 | 14 | ## Data Format 15 | ``` 16 | { 17 | "a0":