├── chinatravel ├── environment │ ├── __init__.py │ └── tools │ │ ├── __init__.py │ │ ├── poi │ │ └── apis.py │ │ ├── accommodations │ │ └── apis.py │ │ ├── restaurants │ │ └── apis.py │ │ ├── attractions │ │ └── apis.py │ │ └── intercity_transport │ │ └── apis.py ├── evaluation │ ├── __init__.py │ ├── default_splits │ │ ├── preference_base50.txt │ │ ├── p0base50.txt │ │ ├── p1base50.txt │ │ ├── p2base50.txt │ │ ├── p3base50.txt │ │ ├── p4base50.txt │ │ ├── p5base50.txt │ │ ├── medium.txt │ │ ├── human.txt │ │ └── easy.txt │ ├── utils.py │ ├── schema_constraint.py │ ├── test.py │ ├── output_schema.json │ ├── commonsense_constraint.py │ ├── hard_constraint.py │ └── rank.py ├── agent │ ├── nesy_verifier │ │ ├── __init__.py │ │ ├── verifier │ │ │ └── personal_constraint_nl.py │ │ └── prompts │ │ │ └── poi_selection.py │ ├── pure_neuro_agent │ │ ├── __init__.py │ │ ├── prompts │ │ │ ├── __init__.py │ │ │ ├── query.json │ │ │ └── output_schema.json │ │ └── pure_neuro_agent.py │ ├── tpc_agent │ │ ├── tpc_llm.py │ │ └── tpc_agent.py │ ├── UrbanTrip │ │ ├── tpc_llm.py │ │ └── utils.py │ ├── nesy_agent │ │ ├── prompts │ │ │ └── __init__.py │ │ ├── plan_for_check │ │ │ ├── day1.json │ │ │ └── day2.json │ │ └── utils.py │ ├── utils.py │ ├── load_model.py │ └── base.py ├── .gitignore ├── symbol_verification │ ├── preference.py │ ├── readme.md │ └── concept_func.py └── data │ └── load_datasets.py ├── images └── overview.png ├── .gitignore ├── download_llm.sh ├── requirements.txt ├── run_tpc.py ├── TPC@AIC2025 └── readme.md ├── run_exp.py ├── eval_exp.py ├── eval_tpc.py └── README.md /chinatravel/environment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chinatravel/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chinatravel/agent/nesy_verifier/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_modulo import LLMModuloAgent -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAMDASZ-ML/ChinaTravel/HEAD/images/overview.png -------------------------------------------------------------------------------- /chinatravel/agent/pure_neuro_agent/__init__.py: -------------------------------------------------------------------------------- 1 | from agent.pure_neuro_agent.pure_neuro_agent import ActAgent, ReActAgent 2 | 3 | __all__ = ["ActAgent", "ReActAgent"] 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | run_exp.sh 3 | eval_exp.sh 4 | cache/ 5 | results/ 6 | eval_res/ 7 | debug.sh 8 | eval_1000.sh 9 | human150.sh 10 | results.md 11 | your_tpc_scores.json -------------------------------------------------------------------------------- /chinatravel/.gitignore: -------------------------------------------------------------------------------- 1 | environment/database/* 2 | local_llm/* 3 | data/* 4 | !data/load_datasets.py 5 | evaluation/default_splits/* 6 | !evaluation/default_splits/easy.txt 7 | !evaluation/default_splits/medium.txt 8 | !evaluation/default_splits/human.txt -------------------------------------------------------------------------------- /chinatravel/environment/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .attractions.apis import Attractions 2 | from .accommodations.apis import Accommodations 3 | from .restaurants.apis import Restaurants 4 | from .intercity_transport.apis import IntercityTransport 5 | from .transportation.apis import Transportation 6 | from .poi.apis import Poi 7 | 8 | __all__ = [ 9 | "Attractions", 10 | "Accommodations", 11 | "Restaurants", 12 | "IntercityTransport", 13 | "Transportation", 14 | "Poi", 15 | ] 16 | -------------------------------------------------------------------------------- /download_llm.sh: -------------------------------------------------------------------------------- 1 | modelscope download --model Qwen/Qwen3-8B --local_dir chinatravel/local_llm/Qwen3-8B 2 | modelscope download --model Qwen/Qwen3-4B --local_dir chinatravel/local_llm/Qwen3-4B 3 | modelscope download --model LLM-Research/Meta-Llama-3.1-8B-Instruct --local_dir chinatravel/local_llm/Meta-Llama-3.1-8B-Instruct 4 | modelscope download --model LLM-Research/Llama-3.2-3B-Instruct --local_dir chinatravel/local_llm/Llama-3.2-3B-Instruct 5 | modelscope download --model LLM-Research/Mistral-7B-Instruct-v0.3 --local_dir chinatravel/local_llm/Mistral-7B-Instruct-v0.3 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | geopy==2.4.1 2 | numpy==1.26.4 3 | pandas==2.2.3 4 | jsonschema==4.23.0 5 | tqdm==4.66.6 6 | openai==1.53.0 7 | json_repair==0.30.0 8 | scikit-learn==1.5.2 9 | modelscope==1.20.0 10 | packaging==24.2 11 | transformers==4.51.3 12 | vllm==0.8.5 13 | fuzzywuzzy==0.18.0 14 | func_timeout==4.3.5 15 | matplotlib==3.9.4 16 | seaborn==0.13.2 17 | datasets==3.3.2 18 | httpx==0.27.2 19 | accelerate==1.6.0 20 | z3-solver==4.14.1.0 21 | flashinfer-python==0.2.5 22 | tiktoken==0.9.0 23 | 24 | -f https://download.pytorch.org/whl/cu121 25 | torch==2.6.0 -------------------------------------------------------------------------------- /chinatravel/agent/pure_neuro_agent/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from agent.pure_neuro_agent.prompts.prompts import ( 2 | ZEROSHOT_ACT_INSTRUCTION, 3 | ZEROSHOT_REACT_INSTRUCTION, 4 | ZEROSHOT_REACT_INSTRUCTION_GLM4, 5 | ONESHOT_REACT_INSTRUCTION, 6 | ONESHOT_REACT_INSTRUCTION_GLM4, 7 | DIRECT_PROMPT, 8 | ) 9 | 10 | __all__ = [ 11 | "ZEROSHOT_ACT_INSTRUCTION", 12 | "ZEROSHOT_REACT_INSTRUCTION", 13 | "ZEROSHOT_REACT_INSTRUCTION_GLM4", 14 | "ONESHOT_REACT_INSTRUCTION", 15 | "ONESHOT_REACT_INSTRUCTION_GLM4", 16 | "DIRECT_PROMPT", 17 | ] 18 | -------------------------------------------------------------------------------- /chinatravel/agent/pure_neuro_agent/prompts/query.json: -------------------------------------------------------------------------------- 1 | { 2 | "start_city": "北京", 3 | "target_city": "南京", 4 | "hard_logic": [ 5 | "days==2", 6 | "people_number==1", 7 | "cost<=3000", 8 | "tickets==1", 9 | "rooms==1", 10 | "transport_type<={'metro'}", 11 | "{'江浙菜'} <= food_type", 12 | "{'南京博物院'}<=attraction_names" 13 | ], 14 | "preference": { 15 | "attraction": "历史景点", 16 | "transport": "地铁" 17 | }, 18 | "nature_language": "当前位置北京。我一个人想去南京玩2天,预算3000人民币,尽量多坐地铁,喜欢吃江浙菜,想去南京博物院,请给我一个旅行规划。" 19 | } -------------------------------------------------------------------------------- /chinatravel/agent/tpc_agent/tpc_llm.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | project_root_path = os.path.dirname( 6 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | ) 8 | if project_root_path not in sys.path: 9 | sys.path.append(project_root_path) 10 | if os.path.dirname(project_root_path) not in sys.path: 11 | sys.path.append(os.path.dirname(project_root_path)) 12 | 13 | from agent.llms import AbstractLLM 14 | 15 | 16 | class TPCLLM(AbstractLLM): 17 | def __init__(self): 18 | super().__init__() 19 | self.name = "EmptyLLM" 20 | 21 | def _get_response(self, messages, one_line, json_mode): 22 | return "Empty LLM response" -------------------------------------------------------------------------------- /chinatravel/agent/UrbanTrip/tpc_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | project_root_path = os.path.dirname( 5 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | ) 7 | if project_root_path not in sys.path: 8 | sys.path.append(project_root_path) 9 | if os.path.dirname(project_root_path) not in sys.path: 10 | sys.path.append(os.path.dirname(project_root_path)) 11 | 12 | from agent.llms import AbstractLLM 13 | 14 | 15 | class TPCLLM(AbstractLLM): 16 | def __init__(self): 17 | super().__init__() 18 | self.name = "EmptyLLM" 19 | 20 | def _get_response(self, messages, one_line, json_mode): 21 | return "Empty LLM response" 22 | 23 | -------------------------------------------------------------------------------- /chinatravel/agent/tpc_agent/tpc_agent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import argparse 5 | import pandas as pd 6 | import json 7 | import numpy as np 8 | 9 | sys.path.append("./../../../") 10 | project_root_path = os.path.dirname( 11 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | ) 13 | 14 | if project_root_path not in sys.path: 15 | sys.path.insert(0, project_root_path) 16 | 17 | 18 | from agent.base import AbstractAgent, BaseAgent 19 | 20 | class TPCAgent(BaseAgent): 21 | def __init__(self, **kwargs): 22 | super().__init__(name="TPC", **kwargs) 23 | 24 | def run(self, query, prob_idx, oralce_translation=False): 25 | 26 | 27 | self.reset_clock() 28 | 29 | result = { 30 | "itinerary": [], 31 | "elapsed_time(sec)": time.time() - self.start_clock, 32 | } 33 | 34 | return False, result -------------------------------------------------------------------------------- /chinatravel/agent/nesy_agent/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from chinatravel.agent.nesy_agent.prompts.prompts import ( 2 | NEXT_POI_TYPE_INSTRUCTION, 3 | INTERCITY_TRANSPORT_GO_INSTRUCTION, 4 | INTERCITY_TRANSPORT_BACK_INSTRUCTION, 5 | HOTEL_RANKING_INSTRUCTION, 6 | ATTRACTION_RANKING_INSTRUCTION, 7 | RESTAURANT_RANKING_INSTRUCTION, 8 | SELECT_POI_TIME_INSTRUCTION, 9 | ROOMS_PLANNING_INSTRUCTION, 10 | INNERCITY_TRANSPORTS_SELECTION_INSTRUCTION, 11 | BUDGETS_INSTRUCTION, 12 | NL2SL_INSTRUCTION, 13 | NL2SL_INSTRUCTION_V2, 14 | ) 15 | 16 | __all__ = [ 17 | "NEXT_POI_TYPE_INSTRUCTION", 18 | "INTERCITY_TRANSPORT_GO_INSTRUCTION", 19 | "INTERCITY_TRANSPORT_BACK_INSTRUCTION", 20 | "HOTEL_RANKING_INSTRUCTION", 21 | "ATTRACTION_RANKING_INSTRUCTION", 22 | "RESTAURANT_RANKING_INSTRUCTION", 23 | "SELECT_POI_TIME_INSTRUCTION", 24 | "ROOMS_PLANNING_INSTRUCTION", 25 | "INNERCITY_TRANSPORTS_SELECTION_INSTRUCTION", 26 | "NL2SL_INSTRUCTION", 27 | "NL2SL_INSTRUCTION_V2", 28 | "BUDGETS_INSTRUCTION", 29 | ] 30 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/preference_base50.txt: -------------------------------------------------------------------------------- 1 | e20241028160428605482 2 | m20241028165012424316 3 | t20241206211714065799 4 | m20241028164944055132 5 | m20241028164800617250 6 | m20241028164913584976 7 | m20241028164927523772 8 | m20241028164647615645 9 | m20241028164641561951 10 | e20241028160417133031 11 | m20241028164637595271 12 | m20241028164804685984 13 | m20241028164623257189 14 | m20241028164922818260 15 | t20241206172005248006 16 | e20241028160500856634 17 | m20241028164741002971 18 | m20241028164822843097 19 | m20241028164642633824 20 | t20241206211714066235 21 | e20241028160311039073 22 | e20241028160400903418 23 | t20241206211714064921 24 | m20241028164830108978 25 | e20241028160445320665 26 | m20241028164936662445 27 | e20241028160248698752 28 | m20241028164843213370 29 | m20241028164854688667 30 | t20241206211714065676 31 | t20241206211714065174 32 | t20241206211714064535 33 | t20241206211714065114 34 | e20241028160410629558 35 | e20241028160253742452 36 | t20241206211714064387 37 | m20241028164658228656 38 | m20241028164742823033 39 | m20241028164912237931 40 | e20241028160401902010 41 | e20241028160442853706 42 | m20241028164746091821 43 | t20241206211714065305 44 | t20241206172253793024 45 | m20241028164925825092 46 | t20241206211714066103 47 | m20241028164814314214 48 | m20241028164748425604 49 | m20241028164928339245 50 | m20241028164901591817 51 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/p0base50.txt: -------------------------------------------------------------------------------- 1 | p0e20241028160428605482 2 | p0m20241028165012424316 3 | p0t20241206211714065799 4 | p0m20241028164944055132 5 | p0m20241028164800617250 6 | p0m20241028164913584976 7 | p0m20241028164927523772 8 | p0m20241028164647615645 9 | p0m20241028164641561951 10 | p0e20241028160417133031 11 | p0m20241028164637595271 12 | p0m20241028164804685984 13 | p0m20241028164623257189 14 | p0m20241028164922818260 15 | p0t20241206172005248006 16 | p0e20241028160500856634 17 | p0m20241028164741002971 18 | p0m20241028164822843097 19 | p0m20241028164642633824 20 | p0t20241206211714066235 21 | p0e20241028160311039073 22 | p0e20241028160400903418 23 | p0t20241206211714064921 24 | p0m20241028164830108978 25 | p0e20241028160445320665 26 | p0m20241028164936662445 27 | p0e20241028160248698752 28 | p0m20241028164843213370 29 | p0m20241028164854688667 30 | p0t20241206211714065676 31 | p0t20241206211714065174 32 | p0t20241206211714064535 33 | p0t20241206211714065114 34 | p0e20241028160410629558 35 | p0e20241028160253742452 36 | p0t20241206211714064387 37 | p0m20241028164658228656 38 | p0m20241028164742823033 39 | p0m20241028164912237931 40 | p0e20241028160401902010 41 | p0e20241028160442853706 42 | p0m20241028164746091821 43 | p0t20241206211714065305 44 | p0t20241206172253793024 45 | p0m20241028164925825092 46 | p0t20241206211714066103 47 | p0m20241028164814314214 48 | p0m20241028164748425604 49 | p0m20241028164928339245 50 | p0m20241028164901591817 51 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/p1base50.txt: -------------------------------------------------------------------------------- 1 | p1e20241028160248698752 2 | p1e20241028160253742452 3 | p1e20241028160311039073 4 | p1e20241028160400903418 5 | p1e20241028160401902010 6 | p1e20241028160410629558 7 | p1e20241028160417133031 8 | p1e20241028160428605482 9 | p1e20241028160442853706 10 | p1e20241028160445320665 11 | p1e20241028160500856634 12 | p1m20241028164623257189 13 | p1m20241028164637595271 14 | p1m20241028164641561951 15 | p1m20241028164642633824 16 | p1m20241028164647615645 17 | p1m20241028164658228656 18 | p1m20241028164741002971 19 | p1m20241028164742823033 20 | p1m20241028164746091821 21 | p1m20241028164748425604 22 | p1m20241028164800617250 23 | p1m20241028164804685984 24 | p1m20241028164814314214 25 | p1m20241028164822843097 26 | p1m20241028164830108978 27 | p1m20241028164843213370 28 | p1m20241028164854688667 29 | p1m20241028164901591817 30 | p1m20241028164912237931 31 | p1m20241028164913584976 32 | p1m20241028164922818260 33 | p1m20241028164925825092 34 | p1m20241028164927523772 35 | p1m20241028164928339245 36 | p1m20241028164936662445 37 | p1m20241028164944055132 38 | p1m20241028165012424316 39 | p1t20241206172005248006 40 | p1t20241206172253793024 41 | p1t20241206211714064387 42 | p1t20241206211714064535 43 | p1t20241206211714064921 44 | p1t20241206211714065114 45 | p1t20241206211714065174 46 | p1t20241206211714065305 47 | p1t20241206211714065676 48 | p1t20241206211714065799 49 | p1t20241206211714066103 50 | p1t20241206211714066235 51 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/p2base50.txt: -------------------------------------------------------------------------------- 1 | p2e20241028160248698752 2 | p2e20241028160253742452 3 | p2e20241028160311039073 4 | p2e20241028160400903418 5 | p2e20241028160401902010 6 | p2e20241028160410629558 7 | p2e20241028160417133031 8 | p2e20241028160428605482 9 | p2e20241028160442853706 10 | p2e20241028160445320665 11 | p2e20241028160500856634 12 | p2m20241028164623257189 13 | p2m20241028164637595271 14 | p2m20241028164641561951 15 | p2m20241028164642633824 16 | p2m20241028164647615645 17 | p2m20241028164658228656 18 | p2m20241028164741002971 19 | p2m20241028164742823033 20 | p2m20241028164746091821 21 | p2m20241028164748425604 22 | p2m20241028164800617250 23 | p2m20241028164804685984 24 | p2m20241028164814314214 25 | p2m20241028164822843097 26 | p2m20241028164830108978 27 | p2m20241028164843213370 28 | p2m20241028164854688667 29 | p2m20241028164901591817 30 | p2m20241028164912237931 31 | p2m20241028164913584976 32 | p2m20241028164922818260 33 | p2m20241028164925825092 34 | p2m20241028164927523772 35 | p2m20241028164928339245 36 | p2m20241028164936662445 37 | p2m20241028164944055132 38 | p2m20241028165012424316 39 | p2t20241206172005248006 40 | p2t20241206172253793024 41 | p2t20241206211714064387 42 | p2t20241206211714064535 43 | p2t20241206211714064921 44 | p2t20241206211714065114 45 | p2t20241206211714065174 46 | p2t20241206211714065305 47 | p2t20241206211714065676 48 | p2t20241206211714065799 49 | p2t20241206211714066103 50 | p2t20241206211714066235 51 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/p3base50.txt: -------------------------------------------------------------------------------- 1 | p3e20241028160248698752 2 | p3e20241028160253742452 3 | p3e20241028160311039073 4 | p3e20241028160400903418 5 | p3e20241028160401902010 6 | p3e20241028160410629558 7 | p3e20241028160417133031 8 | p3e20241028160428605482 9 | p3e20241028160442853706 10 | p3e20241028160445320665 11 | p3e20241028160500856634 12 | p3m20241028164623257189 13 | p3m20241028164637595271 14 | p3m20241028164641561951 15 | p3m20241028164642633824 16 | p3m20241028164647615645 17 | p3m20241028164658228656 18 | p3m20241028164741002971 19 | p3m20241028164742823033 20 | p3m20241028164746091821 21 | p3m20241028164748425604 22 | p3m20241028164800617250 23 | p3m20241028164804685984 24 | p3m20241028164814314214 25 | p3m20241028164822843097 26 | p3m20241028164830108978 27 | p3m20241028164843213370 28 | p3m20241028164854688667 29 | p3m20241028164901591817 30 | p3m20241028164912237931 31 | p3m20241028164913584976 32 | p3m20241028164922818260 33 | p3m20241028164925825092 34 | p3m20241028164927523772 35 | p3m20241028164928339245 36 | p3m20241028164936662445 37 | p3m20241028164944055132 38 | p3m20241028165012424316 39 | p3t20241206172005248006 40 | p3t20241206172253793024 41 | p3t20241206211714064387 42 | p3t20241206211714064535 43 | p3t20241206211714064921 44 | p3t20241206211714065114 45 | p3t20241206211714065174 46 | p3t20241206211714065305 47 | p3t20241206211714065676 48 | p3t20241206211714065799 49 | p3t20241206211714066103 50 | p3t20241206211714066235 51 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/p4base50.txt: -------------------------------------------------------------------------------- 1 | p4e20241028160248698752 2 | p4e20241028160253742452 3 | p4e20241028160311039073 4 | p4e20241028160400903418 5 | p4e20241028160401902010 6 | p4e20241028160410629558 7 | p4e20241028160417133031 8 | p4e20241028160428605482 9 | p4e20241028160442853706 10 | p4e20241028160445320665 11 | p4e20241028160500856634 12 | p4m20241028164623257189 13 | p4m20241028164637595271 14 | p4m20241028164641561951 15 | p4m20241028164642633824 16 | p4m20241028164647615645 17 | p4m20241028164658228656 18 | p4m20241028164741002971 19 | p4m20241028164742823033 20 | p4m20241028164746091821 21 | p4m20241028164748425604 22 | p4m20241028164800617250 23 | p4m20241028164804685984 24 | p4m20241028164814314214 25 | p4m20241028164822843097 26 | p4m20241028164830108978 27 | p4m20241028164843213370 28 | p4m20241028164854688667 29 | p4m20241028164901591817 30 | p4m20241028164912237931 31 | p4m20241028164913584976 32 | p4m20241028164922818260 33 | p4m20241028164925825092 34 | p4m20241028164927523772 35 | p4m20241028164928339245 36 | p4m20241028164936662445 37 | p4m20241028164944055132 38 | p4m20241028165012424316 39 | p4t20241206172005248006 40 | p4t20241206172253793024 41 | p4t20241206211714064387 42 | p4t20241206211714064535 43 | p4t20241206211714064921 44 | p4t20241206211714065114 45 | p4t20241206211714065174 46 | p4t20241206211714065305 47 | p4t20241206211714065676 48 | p4t20241206211714065799 49 | p4t20241206211714066103 50 | p4t20241206211714066235 51 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/p5base50.txt: -------------------------------------------------------------------------------- 1 | p5e20241028160248698752 2 | p5e20241028160253742452 3 | p5e20241028160311039073 4 | p5e20241028160400903418 5 | p5e20241028160401902010 6 | p5e20241028160410629558 7 | p5e20241028160417133031 8 | p5e20241028160428605482 9 | p5e20241028160442853706 10 | p5e20241028160445320665 11 | p5e20241028160500856634 12 | p5m20241028164623257189 13 | p5m20241028164637595271 14 | p5m20241028164641561951 15 | p5m20241028164642633824 16 | p5m20241028164647615645 17 | p5m20241028164658228656 18 | p5m20241028164741002971 19 | p5m20241028164742823033 20 | p5m20241028164746091821 21 | p5m20241028164748425604 22 | p5m20241028164800617250 23 | p5m20241028164804685984 24 | p5m20241028164814314214 25 | p5m20241028164822843097 26 | p5m20241028164830108978 27 | p5m20241028164843213370 28 | p5m20241028164854688667 29 | p5m20241028164901591817 30 | p5m20241028164912237931 31 | p5m20241028164913584976 32 | p5m20241028164922818260 33 | p5m20241028164925825092 34 | p5m20241028164927523772 35 | p5m20241028164928339245 36 | p5m20241028164936662445 37 | p5m20241028164944055132 38 | p5m20241028165012424316 39 | p5t20241206172005248006 40 | p5t20241206172253793024 41 | p5t20241206211714064387 42 | p5t20241206211714064535 43 | p5t20241206211714064921 44 | p5t20241206211714065114 45 | p5t20241206211714065174 46 | p5t20241206211714065305 47 | p5t20241206211714065676 48 | p5t20241206211714065799 49 | p5t20241206211714066103 50 | p5t20241206211714066235 51 | -------------------------------------------------------------------------------- /chinatravel/agent/nesy_agent/plan_for_check/day1.json: -------------------------------------------------------------------------------- 1 | { 2 | "people_number": 1, 3 | "start_city": "上海", 4 | "target_city": "杭州", 5 | "itinerary": [ 6 | { 7 | "day": 1, 8 | "activities": [ 9 | { 10 | "start_time": "05:26", 11 | "end_time": "07:27", 12 | "start": "上海南站", 13 | "end": "杭州南站", 14 | "price": 65.95, 15 | "cost": 65.95, 16 | "tickets": 1, 17 | "transports": [], 18 | "TrainID": "K335", 19 | "type": "train" 20 | }, 21 | { 22 | "start_time": "07:48", 23 | "end_time": "09:59", 24 | "start": "杭州南站", 25 | "end": "上海南站", 26 | "price": 82.44, 27 | "cost": 82.44, 28 | "tickets": 1, 29 | "transports": [ 30 | { 31 | "start": "杭州南站", 32 | "end": "杭州南站", 33 | "mode": "walk", 34 | "start_time": "07:27", 35 | "end_time": "07:27", 36 | "cost": 0.0, 37 | "distance": 0.0, 38 | "price": 0.0 39 | } 40 | ], 41 | "TrainID": "T78", 42 | "type": "train" 43 | } 44 | ] 45 | } 46 | ], 47 | "search_time_sec": 23.123600244522095, 48 | "llm_inference_time_sec": 22.524855375289917 49 | } -------------------------------------------------------------------------------- /chinatravel/agent/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from numpy import ndarray, integer, floating 3 | import numpy as np 4 | import json 5 | 6 | 7 | def decode_numpy_dict(d): 8 | if isinstance(d, dict): 9 | return {decode_numpy_dict(k): decode_numpy_dict(v) for k, v in d.items()} 10 | elif isinstance(d, list): 11 | return [decode_numpy_dict(i) for i in d] 12 | elif isinstance(d, integer): 13 | return int(d) 14 | elif isinstance(d, floating): 15 | return float(d) 16 | elif isinstance(d, ndarray): 17 | return decode_numpy_dict(d.tolist()) 18 | else: 19 | return d 20 | 21 | 22 | class Logger(object): 23 | def __init__(self, filename="default.log", stream=sys.stdout, debug_mode=False): 24 | 25 | self.debug_mode = debug_mode 26 | self.log = open(filename, "a", encoding="utf-8") 27 | 28 | if self.debug_mode: 29 | self.terminal = stream 30 | 31 | def write(self, message): 32 | self.log.write(message) 33 | 34 | if self.debug_mode: 35 | self.terminal.write(message) 36 | 37 | def flush(self): 38 | pass 39 | 40 | def __del__(self): 41 | self.log.close() 42 | 43 | 44 | class NpEncoder(json.JSONEncoder): 45 | def default(self, obj): 46 | if isinstance(obj, np.integer): 47 | return int(obj) 48 | if isinstance(obj, np.floating): 49 | return float(obj) 50 | if isinstance(obj, np.ndarray): 51 | return obj.tolist() 52 | return super(NpEncoder, self).default(obj) 53 | 54 | 55 | def load_json_file(file_path): 56 | with open(file_path, "r", encoding="utf-8") as f: 57 | return json.load(f) 58 | 59 | 60 | def save_json_file(json_data, file_path): 61 | with open(file_path, "w", encoding="utf8") as dump_f: 62 | json.dump(json_data, dump_f, ensure_ascii=False, indent=4, cls=NpEncoder) 63 | -------------------------------------------------------------------------------- /chinatravel/agent/nesy_verifier/verifier/personal_constraint_nl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | from chinatravel.environment.tools.accommodations.apis import Accommodations 5 | from chinatravel.environment.tools.restaurants.apis import Restaurants 6 | from chinatravel.environment.tools.attractions.apis import Attractions 7 | from chinatravel.environment.tools.intercity_transport.apis import IntercityTransport 8 | from chinatravel.environment.tools.transportation.apis import Transportation 9 | 10 | from chinatravel.symbol_verification.concept_func import func_dict 11 | from chinatravel.evaluation.utils import load_json_file 12 | 13 | import pandas as pd 14 | 15 | from copy import deepcopy 16 | 17 | accommodation = Accommodations() 18 | restaurants = Restaurants() 19 | attractions = Attractions() 20 | 21 | def collect_personal_error(problem, plan, verbose=False): 22 | 23 | if not 'hard_logic_nl' in problem: 24 | print(f"Data id {problem['uid']}, no hard_logic_nl information.") 25 | return [] 26 | if len(problem["hard_logic_py"]) != len(problem["hard_logic_nl"]): 27 | print(f"Data id {problem['uid']}, hard_logic_py and hard_logic_nl are not consistent.") 28 | return [] 29 | 30 | error_info = [] 31 | for idx, constraint in enumerate(problem["hard_logic_py"]): 32 | vars_dict = deepcopy(func_dict) 33 | vars_dict["plan"] = plan 34 | # exec(constraint, {"__builtins__": {"set": set, "print": print}}, vars_dict) 35 | # results.append(vars_dict.get("result", False)) 36 | try: 37 | # Evaluate the constraint in a safe manner 38 | exec( 39 | constraint, 40 | { 41 | "__builtins__": { 42 | "set": set, 43 | } 44 | }, 45 | vars_dict, 46 | ) 47 | res_i = vars_dict.get("result", False) 48 | # results.append(bool(res_i)) 49 | if not res_i: 50 | error_info.append(f"用户要求未被满足:{problem['hard_logic_nl'][idx]}") 51 | except Exception as e: 52 | if verbose: 53 | print(f"Error evaluating constraint '{constraint}': {e}") 54 | error_info.append(f"Raise Error when evaluating constraint {problem['hard_logic_nl'][idx]}") 55 | # print(results) 56 | return error_info -------------------------------------------------------------------------------- /chinatravel/environment/tools/poi/apis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | class Poi: 6 | def __init__(self, base_path: str = "../../database/poi/", en_version=False): 7 | 8 | city_list = [ 9 | "beijing", 10 | "shanghai", 11 | "nanjing", 12 | "suzhou", 13 | "hangzhou", 14 | "shenzhen", 15 | "chengdu", 16 | "wuhan", 17 | "guangzhou", 18 | "chongqing", 19 | ] 20 | curdir = os.path.dirname(os.path.realpath(__file__)) 21 | data_path_list = [ 22 | os.path.join(curdir, f"{base_path}/{city}/poi.json") for city in city_list 23 | ] 24 | self.data = {} 25 | for i, city in enumerate(city_list): 26 | self.data[city] = json.load(open(data_path_list[i], "r", encoding="utf-8")) 27 | city_data = {} 28 | for name_pos in self.data[city]: 29 | name = name_pos["name"] 30 | pos = name_pos["position"] 31 | city_data[name] = tuple(pos) 32 | self.data[city] = city_data 33 | # self.data[city] = [ 34 | # (x["name"], tuple(x["position"])) for x in self.data[city] 35 | # ] 36 | city_cn_list = [ 37 | "北京", 38 | "上海", 39 | "南京", 40 | "苏州", 41 | "杭州", 42 | "深圳", 43 | "成都", 44 | "武汉", 45 | "广州", 46 | "重庆", 47 | ] 48 | for i, city in enumerate(city_list): 49 | self.data[city_cn_list[i]] = self.data.pop(city) 50 | self.city_cn_list = city_cn_list 51 | self.city_list = city_list 52 | 53 | def search(self, city: str, name: str): 54 | if city in self.city_list: 55 | city = self.city_cn_list[self.city_list.index(city)] 56 | city_data = self.data[city] 57 | try: 58 | return city_data[name] 59 | except KeyError: 60 | return f"No such point in the city. Check the point name: {name}." 61 | 62 | 63 | def test(): 64 | poi = Poi() 65 | while True: 66 | query = input("请输入查询的poi名称:") 67 | if query == "exit": 68 | return 69 | print(poi.search("南京", query)) 70 | 71 | 72 | if __name__ == "__main__": 73 | test() 74 | -------------------------------------------------------------------------------- /chinatravel/agent/load_model.py: -------------------------------------------------------------------------------- 1 | def init_agent(kwargs): 2 | from .nesy_agent.rule_driven_rec import RuleDrivenAgent 3 | from .nesy_agent.llm_driven_rec import LLMDrivenAgent 4 | from .pure_neuro_agent.pure_neuro_agent import ActAgent, ReActAgent 5 | from .pure_neuro_agent.prompts import ( 6 | ZEROSHOT_ACT_INSTRUCTION, 7 | ZEROSHOT_REACT_INSTRUCTION, 8 | ZEROSHOT_REACT_INSTRUCTION_GLM4, 9 | ONESHOT_REACT_INSTRUCTION, 10 | ONESHOT_REACT_INSTRUCTION_GLM4, 11 | ) 12 | 13 | from .nesy_verifier import LLMModuloAgent 14 | 15 | from .tpc_agent.tpc_agent import TPCAgent 16 | 17 | if kwargs["method"] == "RuleNeSy": 18 | agent = RuleDrivenAgent( 19 | env=kwargs["env"], 20 | backbone_llm=kwargs["backbone_llm"], 21 | cache_dir=kwargs["cache_dir"], 22 | debug=kwargs["debug"], 23 | ) 24 | elif kwargs["method"] == "LLMNeSy": 25 | agent = LLMDrivenAgent( 26 | **kwargs 27 | ) 28 | elif kwargs["method"] == "Act": 29 | agent = ActAgent( 30 | env=kwargs["env"], 31 | backbone_llm=kwargs["backbone_llm"], 32 | prompt=ZEROSHOT_ACT_INSTRUCTION, 33 | ) 34 | elif kwargs["method"] == "ReAct": 35 | agent = ReActAgent( 36 | env=kwargs["env"], 37 | backbone_llm=kwargs["backbone_llm"], 38 | prompt=( 39 | ONESHOT_REACT_INSTRUCTION 40 | if "glm4" not in kwargs["backbone_llm"].name.lower() 41 | else ONESHOT_REACT_INSTRUCTION_GLM4 42 | ), 43 | ) 44 | elif kwargs["method"] == "ReAct0": 45 | agent = ReActAgent( 46 | env=kwargs["env"], 47 | backbone_llm=kwargs["backbone_llm"], 48 | prompt=( 49 | ZEROSHOT_REACT_INSTRUCTION 50 | if "glm4" not in kwargs["backbone_llm"].name.lower() 51 | else ZEROSHOT_REACT_INSTRUCTION_GLM4 52 | ), 53 | ) 54 | elif kwargs["method"] == "LLM-modulo": 55 | kwargs["model"] = kwargs["backbone_llm"] 56 | kwargs["max_steps"] = kwargs["refine_steps"] 57 | agent = LLMModuloAgent( 58 | **kwargs 59 | ) 60 | elif kwargs["method"] == "TPCAgent": 61 | agent = TPCAgent( 62 | **kwargs 63 | ) 64 | else: 65 | raise Exception("Not Implemented") 66 | return agent 67 | 68 | 69 | def init_llm(llm_name, max_model_len=None): 70 | from .llms import Deepseek, GPT4o, GLM4Plus, Qwen, Mistral, Llama, EmptyLLM 71 | 72 | from .tpc_agent.tpc_llm import TPCLLM 73 | 74 | if llm_name == "deepseek": 75 | llm = Deepseek() 76 | elif llm_name == "gpt-4o": 77 | llm = GPT4o() 78 | elif llm_name == "glm4-plus": 79 | llm = GLM4Plus() 80 | elif "Qwen" in llm_name: 81 | llm = Qwen(llm_name, max_model_len=max_model_len) 82 | elif llm_name == "mistral": 83 | llm = Mistral(max_model_len=max_model_len) 84 | elif "Llama" in llm_name: 85 | llm = Llama(llm_name) 86 | elif llm_name == "rule": 87 | return EmptyLLM() 88 | elif llm_name == "TPCLLM": 89 | llm = TPCLLM() 90 | else: 91 | raise Exception("Not Implemented") 92 | 93 | return llm 94 | -------------------------------------------------------------------------------- /chinatravel/symbol_verification/preference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | from chinatravel.environment.tools.accommodations.apis import Accommodations 5 | from chinatravel.environment.tools.restaurants.apis import Restaurants 6 | from chinatravel.environment.tools.attractions.apis import Attractions 7 | from chinatravel.environment.tools.intercity_transport.apis import IntercityTransport 8 | from chinatravel.environment.tools.transportation.apis import Transportation 9 | 10 | from chinatravel.symbol_verification.concept_func import func_dict 11 | from chinatravel.evaluation.utils import load_json_file 12 | 13 | import pandas as pd 14 | 15 | from copy import deepcopy 16 | 17 | accommodation = Accommodations() 18 | restaurants = Restaurants() 19 | attractions = Attractions() 20 | 21 | from .concept_func import * 22 | 23 | 24 | def evaluate_preference_py(preference_list, plan, verbose=False): 25 | 26 | 27 | # time_cost = 0 28 | # transport_count = 0 29 | # for activity in allactivities(plan): 30 | # transports = activity_transports(activity) 31 | # if transports!=[]: 32 | # transport_count += 1 33 | # time_cost += innercity_transport_time(transports) 34 | # average_time_cost = time_cost / transport_count if transport_count > 0 else -1 35 | 36 | # print(average_time_cost) 37 | 38 | 39 | # target_poi = '大足石刻' 40 | # poi_list = list() 41 | # total_distance = 0 42 | # poi_count = 0 43 | # city = target_city(plan) 44 | # for activity in allactivities(plan): 45 | # if activity_type(activity) in ['breakfast', 'lunch', 'dinner', 'accommodation', 'attraction']: 46 | # poi_list.append(activity_position(activity)) 47 | # for poi in poi_list: 48 | # total_distance += poi_distance(city, target_poi, poi) 49 | # poi_count += 1 50 | # average_dist_cost = total_distance / poi_count if poi_count > 0 else -1 51 | # print(average_dist_cost) 52 | 53 | results = [] 54 | # hard_logic_py.append(debug_logic_py) 55 | for _, preference_concept, preference_code in preference_list: 56 | vars_dict = deepcopy(func_dict) 57 | vars_dict["plan"] = plan 58 | # exec(constraint, {"__builtins__": {"set": set, "print": print}}, vars_dict) 59 | # results.append(vars_dict.get("result", False)) 60 | try: 61 | # Evaluate the constraint in a safe manner 62 | exec( 63 | preference_code, 64 | { 65 | "__builtins__": { 66 | "set": set, 67 | "list": list, 68 | } 69 | }, 70 | vars_dict, 71 | ) 72 | res_i = vars_dict.get(preference_concept, None) 73 | # if type(res_i) != float: 74 | # raise Exception("The result of the constraint must be a float value.") 75 | # if res_i == -1: 76 | # raise Exception("return -1") 77 | 78 | results.append(float(res_i)) 79 | # results.append(result) 80 | except Exception as e: 81 | if verbose: 82 | print(f"Error evaluating preference '{preference_code}': {e}") 83 | results.append(None) 84 | # print(results) 85 | return results 86 | -------------------------------------------------------------------------------- /chinatravel/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonschema 3 | from jsonschema import validate 4 | import os 5 | import pandas as pd 6 | import sys 7 | 8 | project_root_path = os.path.dirname( 9 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | ) 11 | 12 | if project_root_path not in sys.path: 13 | sys.path.insert(0, project_root_path) 14 | 15 | from chinatravel.agent.utils import Logger, NpEncoder 16 | from chinatravel.environment.tools import Attractions 17 | 18 | 19 | class AttractionsOODTag(Attractions): 20 | def __init__( 21 | self, base_path: str = os.path.dirname(__file__) + "/eval_annotation/attractions/", en_version=False 22 | ): 23 | super().__init__(en_version=en_version) 24 | city_list = [ 25 | "beijing", 26 | "shanghai", 27 | "nanjing", 28 | "suzhou", 29 | "hangzhou", 30 | "shenzhen", 31 | "chengdu", 32 | "wuhan", 33 | "guangzhou", 34 | "chongqing", 35 | ] 36 | curdir = os.path.dirname(os.path.realpath(__file__)) 37 | self.ood_tag = {} 38 | for city in city_list: 39 | self.ood_tag[city] = pd.read_csv( 40 | os.path.join(curdir, f"{base_path}/{city}/attractions_tag.csv") 41 | ) 42 | city_cn_list = [ 43 | "北京", 44 | "上海", 45 | "南京", 46 | "苏州", 47 | "杭州", 48 | "深圳", 49 | "成都", 50 | "武汉", 51 | "广州", 52 | "重庆", 53 | ] 54 | for i, city in enumerate(city_list): 55 | self.ood_tag[city_cn_list[i]] = self.ood_tag.pop(city) 56 | for city in city_cn_list: 57 | self.data[city] = pd.merge( 58 | self.data[city], self.ood_tag[city], on=["id", "name"], how="left" 59 | ) 60 | # print(self.data[city]) 61 | del self.ood_tag 62 | 63 | def load_json_file(file_path): 64 | with open(file_path, "r", encoding="utf-8") as f: 65 | return json.load(f) 66 | 67 | 68 | def validate_json(json_data, schema): 69 | try: 70 | validate(instance=json_data, schema=schema) 71 | return True 72 | except jsonschema.exceptions.ValidationError as e: 73 | return False 74 | 75 | def save_json_file(json_data, file_path): 76 | with open(file_path, "w", encoding="utf8") as dump_f: 77 | json.dump(json_data, dump_f, ensure_ascii=False, indent=4, cls=NpEncoder) 78 | 79 | 80 | 81 | if __name__ == "__main__": 82 | schema_file_path = "./output_schema.json" 83 | json_file_path_template = "../results/test_20240909091404/query_{}_result.json" 84 | 85 | schema = load_json_file(schema_file_path) 86 | acc = 0 87 | for i in range(10): 88 | try: 89 | json_data = load_json_file(json_file_path_template.format(i)) 90 | if validate_json(json_data, schema): 91 | acc += 1 92 | else: 93 | print("Error {}".format(i)) 94 | except: 95 | print("Error {}".format(i)) 96 | continue 97 | print(acc / 10) 98 | a = AttractionsOODTag() 99 | print(a.select("北京", "id", lambda x: x == 1)) 100 | print(a.select("北京", "name", lambda x: x == "故宫博物院")) 101 | -------------------------------------------------------------------------------- /chinatravel/environment/tools/accommodations/apis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas import DataFrame 3 | from typing import Callable 4 | from geopy.distance import geodesic 5 | import os 6 | 7 | import sys 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | from poi.apis import Poi 11 | 12 | 13 | class Accommodations: 14 | 15 | def __init__( 16 | self, base_path: str = "../../database/accommodations/", en_version=False 17 | ): 18 | curdir = os.path.dirname(os.path.realpath(__file__)) 19 | city_list = [ 20 | "beijing", 21 | "shanghai", 22 | "nanjing", 23 | "suzhou", 24 | "hangzhou", 25 | "shenzhen", 26 | "chengdu", 27 | "wuhan", 28 | "guangzhou", 29 | "chongqing", 30 | ] 31 | data_path_list = [ 32 | os.path.join(curdir, f"{base_path}/{city}/accommodations.csv") 33 | for city in city_list 34 | ] 35 | self.data = {} 36 | for i, city in enumerate(city_list): 37 | self.data[city] = pd.read_csv(data_path_list[i]).dropna() 38 | self.key_type_tuple_list = {} 39 | for city in city_list: 40 | self.key_type_tuple_list[city] = [] 41 | for key in self.data[city].keys(): 42 | self.key_type_tuple_list[city].append( 43 | (key, type(self.data[city].iloc[0][key])) 44 | ) 45 | city_cn_list = [ 46 | "北京", 47 | "上海", 48 | "南京", 49 | "苏州", 50 | "杭州", 51 | "深圳", 52 | "成都", 53 | "武汉", 54 | "广州", 55 | "重庆", 56 | ] 57 | 58 | for i, city in enumerate(city_list): 59 | self.data[city_cn_list[i]] = self.data.pop(city) 60 | self.key_type_tuple_list[city_cn_list[i]] = self.key_type_tuple_list.pop( 61 | city 62 | ) 63 | 64 | self.poi = Poi(en_version=en_version) 65 | 66 | def keys(self, city): 67 | return self.key_type_tuple_list[city] 68 | 69 | def select(self, city, key, func: Callable) -> DataFrame: 70 | if key not in self.data[city].keys(): 71 | return "Key not found." 72 | bool_list = [func(x) for x in self.data[city][key]] 73 | return self.data[city][bool_list] 74 | 75 | def nearby(self, city, point: str, topk: int = None, dist: float = 5) -> DataFrame: 76 | lat_lon = self.poi.search(city, point) 77 | if isinstance(lat_lon, str): 78 | return lat_lon 79 | lat, lon = lat_lon 80 | distance = [ 81 | geodesic((lat, lon), (x, y)).km 82 | for x, y in zip(self.data[city]["lat"], self.data[city]["lon"]) 83 | ] 84 | tmp = self.data[city].copy() 85 | tmp["distance"] = distance 86 | if dist is not None: 87 | tmp = tmp[tmp["distance"] < dist] 88 | tmp = tmp.sort_values(by=["distance"]) 89 | if topk is not None: 90 | return tmp.head(topk) 91 | return tmp 92 | 93 | 94 | if __name__ == "__main__": 95 | 96 | AccommodationsAPI = Accommodations() 97 | print(AccommodationsAPI.keys("南京")) 98 | 99 | # def query_key(key): 100 | # print("query key {}".format(key)) 101 | # print(AccommodationsAPI.get_info(key)) 102 | 103 | # for key in ["Price", "numBed", "hotelName"]: 104 | # query_key(key) 105 | 106 | # def query_nearby(lat=32.040158, lon=118.823291): 107 | 108 | # print("query nearby ({}, {}): ".format(lat, lon)) 109 | # print(AccommodationsAPI.nearby(lat=lat, lon=lon, topk=None, dist=2)) 110 | 111 | # query_nearby() 112 | 113 | # print(AccommodationsAPI.select("numBed", 2)) 114 | 115 | # print(AccommodationsAPI.data['featureHotelType'].unique()) 116 | -------------------------------------------------------------------------------- /chinatravel/evaluation/schema_constraint.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | 4 | # from chinatravel.environment.tools.accommodations.apis import Accommodations 5 | # from chinatravel.environment.tools.restaurants.apis import Restaurants 6 | # from chinatravel.environment.tools.attractions.apis import Attractions 7 | # from chinatravel.environment.tools.intercity_transport.apis import IntercityTransport 8 | # from chinatravel.environment.tools.transportation.apis import Transportation 9 | # from env.tools.transportation.apis import GoTo 10 | # from envs import goto 11 | import json 12 | import os 13 | import sys 14 | from tqdm import tqdm 15 | 16 | 17 | import pandas as pd 18 | 19 | import json 20 | import jsonschema 21 | from jsonschema import validate 22 | 23 | def validate_json(json_data, schema): 24 | try: 25 | validate(instance=json_data, schema=schema) 26 | return True 27 | except jsonschema.exceptions.ValidationError as e: 28 | # print(e) 29 | return False 30 | 31 | def evaluate_schema_constraints(data_index, plan_json_dict, schema): 32 | # assert len(symbolic_input_list)==len(plan_json_list) 33 | 34 | total_correct = 0 35 | result_agg = pd.DataFrame(columns=['data_id', "schema"]) 36 | result_agg['data_id'] = data_index 37 | 38 | pass_id = [] 39 | 40 | for ii, idx in tqdm(enumerate(data_index), total=len(data_index)): 41 | 42 | plan_json = plan_json_dict[idx] 43 | 44 | 45 | succ_flag = 0 46 | try: 47 | if validate_json(plan_json, schema): 48 | succ_flag = 1 49 | pass_id.append(idx) 50 | except: 51 | pass 52 | 53 | 54 | result_agg.loc[ii, "schema"] = succ_flag 55 | total_correct += succ_flag 56 | 57 | total_count=len(data_index) 58 | 59 | return 1. * total_correct / total_count*100, result_agg, pass_id 60 | 61 | if __name__ == "__main__": 62 | 63 | 64 | 65 | from evaluation.utils import load_json_file 66 | # test_example=load_json_file("./example/query_53.json") 67 | # test_plan=load_json_file("./example/plan_53.json") 68 | # evaluate_commonsense_constraints([test_example], [test_plan]) 69 | 70 | # exit(0) 71 | 72 | symbolic_input_list=[] 73 | plan_json_list=[] 74 | 75 | for i in range(1): 76 | test_plan_path='./example/a_result.json'.format(i+1) 77 | test_example_path='./example/a_query.json'.format(i+1) 78 | test_example=load_json_file(test_example_path) 79 | test_plan=load_json_file(test_plan_path) 80 | symbolic_input_list.append(test_example) 81 | plan_json_list.append(test_plan) 82 | macro_accuracy, micro_accuracy, _ =evaluate_commonsense_constraints(symbolic_input_list,plan_json_list) 83 | print('macro: {}%, micro: {}%'.format(macro_accuracy,micro_accuracy)) 84 | 85 | # test_plan_path='./example/plan_4.json' 86 | # test_example_path='./example/query_4.json' 87 | # test_example=load_json_file(test_example_path) 88 | # test_plan=load_json_file(test_plan_path) 89 | 90 | # print(Is_intercity_transport_correct(test_example,test_plan)) 91 | # print(Is_attractions_correct(test_example,test_plan)) 92 | # print(Is_hotels_correct(test_example,test_plan)) 93 | # print(Is_restaurants_correct(test_example,test_plan)) 94 | # print(Is_transport_correct(test_example,test_plan)) 95 | # print(Is_time_correct(test_example,test_plan)) 96 | # print(Is_space_correct(test_example,test_plan)) 97 | 98 | 99 | # pass_flag = True 100 | 101 | 102 | 103 | # info_list = [] 104 | # for func_i in func_list: 105 | # flag, info = func_i(test_example,test_plan) 106 | 107 | # print(info) 108 | 109 | # pass_flag = pass_flag and flag 110 | # info_list.append(info) 111 | 112 | # print("final result: ", pass_flag) 113 | 114 | # for item in info_list: 115 | # print(item) 116 | # print(info_list) 117 | 118 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/medium.txt: -------------------------------------------------------------------------------- 1 | e20241028160842228543 2 | e20241028160845920703 3 | e20241028160848776495 4 | e20241028160851829629 5 | e20241028160855113315 6 | e20241028160857444675 7 | e20241028160902013119 8 | e20241028160904677400 9 | e20241028160906602324 10 | e20241028160908981620 11 | e20241028160911184742 12 | e20241028160913165493 13 | e20241028160914767827 14 | e20241028160916850792 15 | e20241028160918445815 16 | e20241028160920077469 17 | e20241028160930368370 18 | e20241028160932644544 19 | e20241028160934978800 20 | e20241028160936417511 21 | e20241028160937967706 22 | e20241028160939613971 23 | e20241028160941454573 24 | e20241028160942866027 25 | e20241028160944911300 26 | e20241028160946447653 27 | e20241028160950031404 28 | e20241028160951725307 29 | e20241028160953818897 30 | e20241028160956304302 31 | e20241028160958608314 32 | e20241028161009825606 33 | e20241028161012012769 34 | e20241028161015707983 35 | e20241028161018716295 36 | e20241028161022415715 37 | e20241028161026917449 38 | e20241028161029502862 39 | e20241028161030936732 40 | e20241028161032890887 41 | e20241028161034674798 42 | e20241028161036243215 43 | e20241028161037866126 44 | e20241028161039320663 45 | e20241028161040968006 46 | e20241028161042873410 47 | e20241028161046455300 48 | e20241028161048026142 49 | e20241028161049552536 50 | e20241028161052579774 51 | e20241028161054093646 52 | e20241028161055792780 53 | e20241028161058062780 54 | e20241028161059430567 55 | e20241028161100624021 56 | e20241028161103332044 57 | e20241028161104817595 58 | e20241028161107650861 59 | e20241028161109476697 60 | e20241028161111138617 61 | e20241028161113744118 62 | e20241028161115178119 63 | e20241028161117649877 64 | e20241028161119161993 65 | e20241028161120613448 66 | e20241028161122171567 67 | e20241028161123919808 68 | e20241028161125262867 69 | e20241028161126940854 70 | e20241028161128372981 71 | e20241028161130188821 72 | e20241028161131864997 73 | e20241028161133868794 74 | e20241028161135356352 75 | e20241028161138879144 76 | e20241028161142239258 77 | e20241028161143787919 78 | e20241028161145339341 79 | e20241028161147059952 80 | e20241028161148219472 81 | e20241028161150400837 82 | e20241028161152287295 83 | e20241028161153932651 84 | e20241028161156053328 85 | e20241028161158132216 86 | e20241028161200307413 87 | e20241028161202594158 88 | e20241028161205343103 89 | e20241028161207342223 90 | e20241028161208995433 91 | e20241028161210773442 92 | e20241028161212423674 93 | e20241028161214096539 94 | e20241028161215791938 95 | e20241028161217125222 96 | e20241028161218972541 97 | e20241028161220874335 98 | e20241028161223030634 99 | e20241028161225372884 100 | e20241028161228110024 101 | e20241028161230900633 102 | e20241028161232523372 103 | e20241028161234891857 104 | e20241028161236945933 105 | e20241028161238698866 106 | e20241028161241981481 107 | e20241028161244805289 108 | e20241028161247080504 109 | e20241028161249891005 110 | e20241028161251275160 111 | e20241028161252903955 112 | e20241028161254604670 113 | e20241028161256713544 114 | e20241028161258161531 115 | e20241028161302837074 116 | e20241028161304176860 117 | e20241028161305426441 118 | e20241028161307197168 119 | e20241028161311009105 120 | e20241028161312783307 121 | e20241028161314922311 122 | e20241028161318465643 123 | e20241028161320378455 124 | e20241028161322437990 125 | e20241028161323947938 126 | e20241028161327496043 127 | e20241028161330699097 128 | e20241028161333580458 129 | e20241028161334777418 130 | e20241028161337180365 131 | e20241028161340383778 132 | e20241028161341975729 133 | e20241028161343498529 134 | e20241028161345670208 135 | e20241028161347428726 136 | e20241028161351544529 137 | e20241028161353723055 138 | e20241028161355587050 139 | e20241028161357703933 140 | e20241028161358807726 141 | e20241028161401963912 142 | e20241028161403238048 143 | e20241028161404895626 144 | e20241028161406841505 145 | e20241028161411151157 146 | e20241028161414324221 147 | e20241028161417958159 148 | e20241028161422443540 149 | e20241028161425307700 150 | e20241028161428027700 -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/human.txt: -------------------------------------------------------------------------------- 1 | h20241029143447759844 2 | h20241029143450090032 3 | h20241029143451793119 4 | h20241029143453343691 5 | h20241029143455115600 6 | h20241029143456235562 7 | h20241029143457494209 8 | h20241029143458760702 9 | h20241029143500599951 10 | h20241029143502525564 11 | h20241029143504074288 12 | h20241029143506461809 13 | h20241029143508251643 14 | h20241029143509607671 15 | h20241029143510880048 16 | h20241029143512507336 17 | h20241029143514054328 18 | h20241029143515190249 19 | h20241029143518931601 20 | h20241029143520277982 21 | h20241029143521646393 22 | h20241029143522848652 23 | h20241029143524548106 24 | h20241029143526214076 25 | h20241029143528503277 26 | h20241029143530430866 27 | h20241029143532210869 28 | h20241029143533982475 29 | h20241029143536587211 30 | h20241029143538018046 31 | h20241029143539603238 32 | h20241029143541319291 33 | h20241029143542560410 34 | h20241029143543943871 35 | h20241029143546424651 36 | h20241029143547722063 37 | h20241029143549773288 38 | h20241029143551657496 39 | h20241029143553510979 40 | h20241029143554849031 41 | h20241029143556181559 42 | h20241029143558223472 43 | h20241029143602442546 44 | h20241029143604280402 45 | h20241029143606240237 46 | h20241029143608049345 47 | h20241029143609292913 48 | h20241029143611928866 49 | h20241029143613424655 50 | h20241029143614813545 51 | h20241029143616010079 52 | h20241029143617975256 53 | h20241029143619654143 54 | h20241029143620980391 55 | h20241029143622387332 56 | h20241029143623982133 57 | h20241029143625307702 58 | h20241029143627200710 59 | h20241029143628653733 60 | h20241029143630351554 61 | h20241029143632058747 62 | h20241029143633338303 63 | h20241029143636163145 64 | h20241029143637732099 65 | h20241029143639901784 66 | h20241029143641719716 67 | h20241029143644154862 68 | h20241029143645157524 69 | h20241029143646523600 70 | h20241029143648613072 71 | h20241029143650282919 72 | h20241029143652031792 73 | h20241029143653372872 74 | h20241029143656081463 75 | h20241029143657365639 76 | h20241029143658659295 77 | h20241029143659983667 78 | h20241029143701815147 79 | h20241029143704161138 80 | h20241029143706437033 81 | h20241029143709097610 82 | h20241029143711334809 83 | h20241029143714502665 84 | h20241029143718654427 85 | h20241029143720725398 86 | h20241029143722458026 87 | h20241029143724478281 88 | h20241029143727719855 89 | h20241029143729369677 90 | h20241029143730769592 91 | h20241029143732290071 92 | h20241029143733781691 93 | h20241029143735439292 94 | h20241029143736524841 95 | h20241029143738057260 96 | h20241029143739523099 97 | h20241029143741942032 98 | h20241029143743586105 99 | h20241029143745264225 100 | h20241029143747397251 101 | h20241029143749478632 102 | h20241029143750930343 103 | h20241029143752714587 104 | h20241029143754236194 105 | h20241029143759740039 106 | h20241029143803188730 107 | h20241029143804839319 108 | h20241029143806507479 109 | h20241029143807949268 110 | h20241029143809635485 111 | h20241029143811168339 112 | h20241029143812547929 113 | h20241029143814437710 114 | h20241029143816266519 115 | h20241029143818664124 116 | h20241029143820591758 117 | h20241029143823020071 118 | h20241029143824748448 119 | h20241029143826541906 120 | h20241029143828857993 121 | h20241029143830714886 122 | h20241029143832205713 123 | h20241029143834377206 124 | h20241029143835520306 125 | h20241029143837173547 126 | h20241029143839056472 127 | h20241029143840382242 128 | h20241029143842373194 129 | h20241029143844076573 130 | h20241029143845365386 131 | h20241029143846901147 132 | h20241029143848307913 133 | h20241029143850298727 134 | h20241029143851561679 135 | h20241029143853055859 136 | h20241029143854711129 137 | h20241029143857404500 138 | h20241029143858969810 139 | h20241029143900796443 140 | h20241029143902858522 141 | h20241029143904148477 142 | h20241029143906048411 143 | h20241029143908194169 144 | h20241029143909714190 145 | h20241029143911770965 146 | h20241029143913181191 147 | h20241029143915014327 148 | h20241029143916604779 149 | h20241029143918467169 150 | h20241029143920003060 151 | h20241029143921648851 152 | h20241029143923769644 153 | h20241029143925294432 154 | h20241029143926788464 155 | -------------------------------------------------------------------------------- /chinatravel/agent/nesy_agent/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.feature_extraction.text import TfidfVectorizer 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | 5 | 6 | 7 | class TimeOutError(Exception): 8 | def __init__(self, message="Searching TIME OUT !!!"): 9 | self.message = message 10 | super().__init__(self.message) 11 | 12 | 13 | def time_compare_if_earlier_equal(time_1, time_2): 14 | 15 | time1 = float(time_1.split(":")[0])*60 + float(time_1.split(":")[1]) 16 | time2 = float(time_2.split(":")[0])*60 + float(time_2.split(":")[1]) 17 | 18 | 19 | return time1 <= time2 20 | 21 | 22 | def add_time_delta(time1, time_delta): 23 | 24 | hour, minu = int(time1.split(":")[0]), int(time1.split(":")[1]) 25 | 26 | min_new = minu + time_delta 27 | 28 | if min_new >= 60: 29 | hour_new = hour + int(min_new / 60) 30 | min_new = min_new % 60 31 | else: 32 | hour_new = hour 33 | 34 | if hour_new < 10: 35 | time_new = "0" + str(hour_new) + ":" 36 | else: 37 | time_new = str(hour_new) + ":" 38 | if min_new < 10: 39 | 40 | time_new = time_new + "0" + str(min_new) 41 | else: 42 | time_new = time_new + str(min_new) 43 | 44 | return time_new 45 | 46 | def calc_cost_from_itinerary_wo_intercity(itinerary, people_number): 47 | total_cost = 0 48 | for day in itinerary: 49 | for activity in day["activities"]: 50 | 51 | for transport in activity.get("transports", []): 52 | 53 | mode = transport["mode"] 54 | if mode=='taxi': 55 | if 'cars' in transport.keys(): 56 | total_cost += transport.get('cars',0)*transport.get("cost", 0) 57 | else: 58 | total_cost += transport.get('tickets',0)*transport.get("cost", 0) 59 | if mode=='metro': 60 | total_cost += transport.get('tickets',0)*transport.get("cost", 0) 61 | 62 | 63 | # if activity["type"] == "airplane": 64 | # total_cost += activity.get('tickets',0)*activity.get("cost", 0) 65 | 66 | # if activity["type"] == "train": 67 | # total_cost += activity.get('tickets',0)*activity.get("cost", 0) 68 | 69 | if activity["type"] == "breakfest" or activity["type"] == "lunch" or activity["type"] == "dinner": 70 | total_cost += activity.get('cost',0)*people_number 71 | 72 | # if activity["type"] == "accommodation": 73 | # total_cost += activity.get('rooms',0)*activity.get("cost", 0) 74 | 75 | if activity["type"] == "attraction": 76 | total_cost += activity.get('tickets',0)*activity.get("cost", 0) 77 | return total_cost 78 | 79 | def mmr_algorithm(name_list,score,lambda_value=0.3): 80 | selected_indices = [] 81 | remaining_indices = list(range(len(name_list))) 82 | 83 | tfidf_vectorizer = TfidfVectorizer() 84 | 85 | while len(selected_indices) < len(name_list): 86 | if len(selected_indices) == 0: 87 | mmr_scores = np.ones(len(name_list)) 88 | else: 89 | selected_names = [name.split()[0] for name in name_list[selected_indices]] 90 | remaining_names = [name.split()[0] for name in name_list[remaining_indices]] 91 | 92 | tfidf_matrix = tfidf_vectorizer.fit_transform(np.concatenate((selected_names, remaining_names))) 93 | similarity_matrix = cosine_similarity(tfidf_matrix) 94 | 95 | selected_similarities = similarity_matrix[:len(selected_names), len(selected_names):] 96 | remaining_similarities = similarity_matrix[len(selected_names):, len(selected_names):] 97 | 98 | mmr_scores = lambda_value*score[remaining_indices] - (1 - lambda_value) * np.max(selected_similarities, axis=0) 99 | 100 | max_index = np.argmax(mmr_scores) 101 | selected_indices.append(remaining_indices[max_index]) 102 | del remaining_indices[max_index] 103 | 104 | return mmr_scores 105 | -------------------------------------------------------------------------------- /chinatravel/agent/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | 6 | import json 7 | from json import JSONDecodeError 8 | 9 | from abc import ABC, abstractmethod 10 | from agent.utils import decode_numpy_dict 11 | 12 | 13 | def is_jsonable(x): 14 | try: 15 | json.dumps(x) 16 | return True 17 | except JSONDecodeError: 18 | return False 19 | 20 | 21 | class AgentError(Exception): 22 | def __init__(self, message): 23 | self.message = message 24 | super().__init__(self.message) 25 | 26 | def __str__(self): 27 | return self.message 28 | 29 | 30 | class AgentReturnError(AgentError): 31 | pass 32 | 33 | 34 | class AgentReturnInfoError(AgentError): 35 | pass 36 | 37 | 38 | class DecodeLogError(AgentError): 39 | pass 40 | 41 | 42 | class AgentReturnInfo: 43 | """ 44 | This class is used to store the return information of an agent. 45 | It contains two attributes: ans and log. 46 | ans: The answer of the agent. Should be a string. 47 | log: The log of the agent. It can be transformed to json. 48 | """ 49 | 50 | def __init__(self, ans, log: dict = {}): 51 | if not isinstance(ans, str): 52 | raise AgentReturnInfoError("ans must be a string") 53 | if not is_jsonable(log): 54 | raise AgentReturnInfoError("log must be a json object") 55 | try: 56 | log = decode_numpy_dict(log) 57 | except Exception as e: 58 | raise DecodeLogError(f"Error when decoding log: {e}") 59 | self.data = {"ans": ans, "log": log} 60 | 61 | def __getitem__(self, key): 62 | return self.data[key] 63 | 64 | 65 | class AbstractAgent(ABC): 66 | def __init__(self, env): 67 | 68 | self._env = env 69 | self._ans = None 70 | self._log = None 71 | 72 | def __call__(self, query): 73 | self._reset() # Default reset, reset env, ans and log 74 | self.reset() # Reset the agent, just those added in the subclass 75 | 76 | return_info = self.run(query) 77 | 78 | if not isinstance(return_info, AgentReturnInfo): 79 | raise AgentReturnError( 80 | "Return value must be an instance of AgentReturnInfo" 81 | ) 82 | 83 | return return_info 84 | 85 | @property 86 | def env(self): 87 | return self._env 88 | 89 | @property 90 | def ans(self): 91 | return self._ans 92 | 93 | @property 94 | def log(self): 95 | return self._log 96 | 97 | def _reset(self): 98 | self._env.reset() 99 | self._ans = None 100 | self._log = None 101 | 102 | @abstractmethod 103 | def run(self, query) -> AgentReturnInfo: 104 | """ 105 | You should implement this method to run the agent. 106 | Return the answer and log in the form of AgentReturnInfo. 107 | """ 108 | pass 109 | 110 | @abstractmethod 111 | def reset(self): 112 | pass 113 | 114 | 115 | import time 116 | 117 | from agent.load_model import init_llm 118 | 119 | class BaseAgent: 120 | def __init__(self, name, **kwargs): 121 | self.name = name 122 | 123 | self.env = kwargs.get('env', None) 124 | 125 | self.log_dir = kwargs.get('log_dir', "logs") 126 | if not os.path.exists(self.log_dir): 127 | os.makedirs(self.log_dir) 128 | print(f"Created log directory: {self.log_dir}") 129 | 130 | 131 | model = kwargs.get('backbone_llm', None) 132 | if type(model) == str: 133 | self.backbone_llm = init_llm(model) 134 | else: 135 | self.backbone_llm = model 136 | self.model_name = self.backbone_llm.name 137 | 138 | 139 | self.llm_inference_time_count = 0 140 | self.start_clock = 0 141 | 142 | def reset_clock(self): 143 | self.start_clock = time.time() 144 | 145 | 146 | def act(self, observation, reward, done, info): 147 | """Act based on the observation and reward.""" 148 | raise NotImplementedError 149 | 150 | def reset(self): 151 | """Reset the agent.""" 152 | raise NotImplementedError 153 | 154 | -------------------------------------------------------------------------------- /chinatravel/environment/tools/restaurants/apis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas import DataFrame 3 | from typing import Callable 4 | import os 5 | from geopy.distance import geodesic 6 | 7 | import sys 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | from poi.apis import Poi 11 | 12 | 13 | class Restaurants: 14 | def __init__(self, base_path: str = "../../database/restaurants"): 15 | city_list = [ 16 | "beijing", 17 | "shanghai", 18 | "nanjing", 19 | "suzhou", 20 | "hangzhou", 21 | "shenzhen", 22 | "chengdu", 23 | "wuhan", 24 | "guangzhou", 25 | "chongqing", 26 | ] 27 | self.data = {} 28 | curdir = os.path.dirname(os.path.realpath(__file__)) 29 | for city in city_list: 30 | path = os.path.join(curdir, base_path, city, "restaurants_" + city + ".csv") 31 | self.data[city] = pd.read_csv(path) 32 | 33 | self.key_type_tuple_list_map = {} 34 | for city in city_list: 35 | self.key_type_tuple_list_map[city] = [] 36 | for key in self.data[city].keys(): 37 | self.key_type_tuple_list_map[city].append( 38 | (key, type(self.data[city][key][0])) 39 | ) 40 | self.cuisine_list_map = {} 41 | for city in city_list: 42 | self.cuisine_list_map[city] = self.data[city]["cuisine"].unique() 43 | city_cn_list = [ 44 | "北京", 45 | "上海", 46 | "南京", 47 | "苏州", 48 | "杭州", 49 | "深圳", 50 | "成都", 51 | "武汉", 52 | "广州", 53 | "重庆", 54 | ] 55 | 56 | for i, city in enumerate(city_list): 57 | self.data[city_cn_list[i]] = self.data.pop(city) 58 | self.key_type_tuple_list_map[city_cn_list[i]] = ( 59 | self.key_type_tuple_list_map.pop(city) 60 | ) 61 | self.cuisine_list_map[city_cn_list[i]] = self.cuisine_list_map.pop(city) 62 | 63 | self.poi = Poi() 64 | 65 | def keys(self, city: str): 66 | return self.key_type_tuple_list_map[city] 67 | 68 | def select(self, city: str, key, func: Callable) -> DataFrame: 69 | if key not in self.data[city].keys(): 70 | return "Key not found." 71 | bool_list = [func(x) for x in self.data[city][key]] 72 | return self.data[city][bool_list] 73 | 74 | def id_is_open(self, city: str, id: int, time: str) -> bool: 75 | match = self.data[city].loc[self.data[city]["id"] == id] 76 | open_time = match["opentime"].values[0] 77 | end_time = match["endtime"].values[0] 78 | open_time = ( 79 | -1 80 | if open_time == "不营业" 81 | else float(open_time.split(":")[0]) + float(open_time.split(":")[1]) / 60 82 | ) 83 | end_time = ( 84 | -1 85 | if end_time == "不营业" 86 | else float(end_time.split(":")[0]) + float(end_time.split(":")[1]) / 60 87 | ) 88 | time = float(time.split(":")[0]) + float(time.split(":")[1]) / 60 89 | if open_time == -1 or end_time == -1: 90 | return False 91 | if open_time < end_time: 92 | return open_time <= time <= end_time 93 | else: 94 | return open_time <= time or time <= end_time 95 | 96 | def nearby(self, city: str, point: str, topk: int = None, dist=2) -> DataFrame: 97 | lat_lon = self.poi.search(city, point) 98 | if isinstance(lat_lon, str): 99 | return lat_lon 100 | lat, lon = lat_lon 101 | distance = [ 102 | geodesic((lat, lon), (x, y)).km 103 | for x, y in zip(self.data[city]["lat"], self.data[city]["lon"]) 104 | ] 105 | tmp = self.data[city].copy() 106 | tmp["distance"] = distance 107 | tmp = tmp.sort_values(by=["distance"]) 108 | if topk is None: 109 | return tmp[tmp["distance"] <= dist] 110 | return tmp[tmp["distance"] <= dist].head(topk) 111 | 112 | def restaurants_with_recommended_food(self, city: str, food: str): 113 | return self.data[city][self.data[city]["recommendedfood"].str.contains(food)] 114 | 115 | def get_cuisine_list(self, city: str): 116 | return self.cuisine_list_map[city] 117 | 118 | 119 | if __name__ == "__main__": 120 | a = Restaurants() 121 | -------------------------------------------------------------------------------- /run_tpc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import sys 6 | import os 7 | import json 8 | from func_timeout import func_timeout, FunctionTimedOut 9 | 10 | project_root_path = os.path.dirname(os.path.abspath(__file__)) 11 | if project_root_path not in sys.path: 12 | sys.path.insert(0, project_root_path) 13 | 14 | from copy import deepcopy 15 | 16 | from chinatravel.data.load_datasets import load_query, save_json_file 17 | from chinatravel.agent.load_model import init_agent, init_llm 18 | from chinatravel.environment.world_env import WorldEnv 19 | 20 | 21 | if __name__ == "__main__": 22 | 23 | parser = argparse.ArgumentParser(description="argparse testing") 24 | parser.add_argument( 25 | "--splits", 26 | "-s", 27 | type=str, 28 | default="tpc_phase1", 29 | help="query subset", 30 | ) 31 | parser.add_argument("--index", "-id", type=str, default=None, help="query index") 32 | parser.add_argument( 33 | "--skip", "-sk", type=int, default=0, help="skip if the plan exists" 34 | ) 35 | parser.add_argument( 36 | "--agent", 37 | "-a", 38 | type=str, 39 | default=None, 40 | choices=["TPCAgent"], 41 | ) 42 | parser.add_argument( 43 | "--llm", 44 | "-l", 45 | type=str, 46 | default=None 47 | ) 48 | parser.add_argument( 49 | "--timeout", 50 | "-t", 51 | type=int, 52 | default=300, 53 | help="Timeout in seconds for each query", 54 | ) 55 | 56 | parser.add_argument('--oracle_translation', action='store_true', help='Set this flag to enable oracle translation.') 57 | 58 | args = parser.parse_args() 59 | 60 | print(args) 61 | 62 | query_index, query_data = load_query(args) 63 | print(len(query_index), "samples") 64 | 65 | if args.index is not None: 66 | query_index = [args.index] 67 | 68 | cache_dir = os.path.join(project_root_path, "cache") 69 | 70 | method = args.agent + "_" + args.llm 71 | if args.agent == "LLM-modulo": 72 | method += f"_{args.refine_steps}steps" 73 | 74 | if not args.oracle_translation: 75 | raise Exception("LLM-modulo must use oracle translation") 76 | 77 | if args.oracle_translation: 78 | method = method + "_oracletranslation" 79 | 80 | res_dir = os.path.join( 81 | project_root_path, "results", method 82 | ) 83 | log_dir = os.path.join( 84 | project_root_path, "cache", method 85 | ) 86 | if not os.path.exists(res_dir): 87 | os.makedirs(res_dir) 88 | if not os.path.exists(log_dir): 89 | os.makedirs(log_dir) 90 | 91 | print("res_dir: ", res_dir) 92 | print("log_dir:", log_dir) 93 | 94 | kwargs = { 95 | "method": args.agent, 96 | "env": WorldEnv(), 97 | "backbone_llm": init_llm(args.llm), 98 | "cache_dir": cache_dir, 99 | "log_dir": log_dir, 100 | "debug": True, 101 | } 102 | agent = init_agent(kwargs) 103 | 104 | succ_count, eval_count = 0, 0 105 | 106 | for i, data_idx in enumerate(query_index): 107 | 108 | sys.stdout = sys.__stdout__ 109 | print("------------------------------") 110 | print( 111 | "Process [{}/{}], Success [{}/{}]:".format( 112 | i, len(query_index), succ_count, eval_count 113 | ) 114 | ) 115 | print("data uid: ", data_idx) 116 | 117 | if args.skip and os.path.exists(os.path.join(res_dir, f"{data_idx}.json")): 118 | continue 119 | eval_count += 1 120 | query_i = query_data[data_idx] 121 | print(query_i) 122 | try: 123 | # succ, plan = agent.run(query_i, prob_idx=data_idx, oralce_translation=args.oracle_translation) 124 | succ, plan = func_timeout( 125 | args.timeout, 126 | agent.run, 127 | args=(query_i,), 128 | kwargs=dict( 129 | prob_idx=data_idx, oralce_translation=args.oracle_translation 130 | ), 131 | ) 132 | except FunctionTimedOut: 133 | # print(f"⚠️ 任务 {data_idx} 超过 {args.timeout}s 被中断。") 134 | succ, plan = 0, {"error": f"timeout after {args.timeout}s"} 135 | 136 | except Exception as e: 137 | # print(f"❌ 执行任务 {data_idx} 出错: {e}") 138 | succ, plan = 0, {"error": str(e)} 139 | 140 | if succ: 141 | succ_count += 1 142 | 143 | save_json_file( 144 | json_data=plan, file_path=os.path.join(res_dir, f"{data_idx}.json") 145 | ) 146 | -------------------------------------------------------------------------------- /chinatravel/agent/UrbanTrip/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from sklearn.feature_extraction.text import TfidfVectorizer 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | 6 | 7 | class TimeOutError(Exception): 8 | def __init__(self, message="Searching TIME OUT !!!"): 9 | self.message = message 10 | super().__init__(self.message) 11 | 12 | def time_to_minutes(tstr: str) -> int: 13 | """'HH:MM' -> 分钟""" 14 | h, m = map(int, tstr.split(":")) 15 | return h * 60 + m 16 | 17 | def minutes_to_time(minutes: int) -> str: 18 | """分钟 -> 'HH:MM'""" 19 | h = minutes // 60 20 | m = minutes % 60 21 | return f"{h:02d}:{m:02d}" 22 | 23 | def time_compare_if_earlier_equal(time_1, time_2): 24 | time1 = float(time_1.split(":")[0]) * 60 + float(time_1.split(":")[1]) 25 | time2 = float(time_2.split(":")[0]) * 60 + float(time_2.split(":")[1]) 26 | 27 | return time1 <= time2 28 | 29 | 30 | def add_time_delta(time1, time_delta): 31 | hour, minu = int(time1.split(":")[0]), int(time1.split(":")[1]) 32 | 33 | min_new = minu + time_delta 34 | 35 | if min_new >= 60: 36 | hour_new = hour + int(min_new / 60) 37 | min_new = min_new % 60 38 | else: 39 | hour_new = hour 40 | 41 | if hour_new < 10: 42 | time_new = "0" + str(hour_new) + ":" 43 | else: 44 | time_new = str(hour_new) + ":" 45 | if min_new < 10: 46 | 47 | time_new = time_new + "0" + str(min_new) 48 | else: 49 | time_new = time_new + str(min_new) 50 | 51 | return time_new 52 | 53 | def get_time_delta(time1, time2): 54 | """ 55 | 计算 time1 和 time2 的时间差(分钟数) 56 | 假设 time1 早于 time2,格式为 HH:MM 57 | """ 58 | hour1, min1 = int(time1.split(":")[0]), int(time1.split(":")[1]) 59 | hour2, min2 = int(time2.split(":")[0]), int(time2.split(":")[1]) 60 | 61 | delta_minutes = (hour2 - hour1) * 60 + (min2 - min1) 62 | return delta_minutes 63 | 64 | 65 | def calc_cost_from_itinerary_wo_intercity(itinerary, people_number): 66 | total_cost = 0 67 | for day in itinerary: 68 | for activity in day["activities"]: 69 | 70 | for transport in activity.get("transports", []): 71 | 72 | mode = transport["mode"] 73 | if mode == 'taxi': 74 | if 'cars' in transport.keys(): 75 | total_cost += transport.get('cars', 0) * transport.get("cost", 0) 76 | else: 77 | total_cost += transport.get('tickets', 0) * transport.get("cost", 0) 78 | if mode == 'metro': 79 | total_cost += transport.get('tickets', 0) * transport.get("cost", 0) 80 | 81 | # if activity["type"] == "airplane": 82 | # total_cost += activity.get('tickets',0)*activity.get("cost", 0) 83 | 84 | # if activity["type"] == "train": 85 | # total_cost += activity.get('tickets',0)*activity.get("cost", 0) 86 | 87 | if activity["type"] == "breakfest" or activity["type"] == "lunch" or activity["type"] == "dinner": 88 | total_cost += activity.get('cost', 0) * people_number 89 | 90 | # if activity["type"] == "accommodation": 91 | # total_cost += activity.get('rooms',0)*activity.get("cost", 0) 92 | 93 | if activity["type"] == "attraction": 94 | total_cost += activity.get('tickets', 0) * activity.get("cost", 0) 95 | return total_cost 96 | 97 | 98 | def mmr_algorithm(name_list, score, lambda_value=0.3): 99 | selected_indices = [] 100 | remaining_indices = list(range(len(name_list))) 101 | 102 | tfidf_vectorizer = TfidfVectorizer() 103 | 104 | while len(selected_indices) < len(name_list): 105 | if len(selected_indices) == 0: 106 | mmr_scores = np.ones(len(name_list)) 107 | else: 108 | selected_names = [name.split()[0] for name in name_list[selected_indices]] 109 | remaining_names = [name.split()[0] for name in name_list[remaining_indices]] 110 | 111 | tfidf_matrix = tfidf_vectorizer.fit_transform(np.concatenate((selected_names, remaining_names))) 112 | similarity_matrix = cosine_similarity(tfidf_matrix) 113 | 114 | selected_similarities = similarity_matrix[:len(selected_names), len(selected_names):] 115 | remaining_similarities = similarity_matrix[len(selected_names):, len(selected_names):] 116 | 117 | mmr_scores = lambda_value * score[remaining_indices] - (1 - lambda_value) * np.max(selected_similarities, 118 | axis=0) 119 | 120 | max_index = np.argmax(mmr_scores) 121 | selected_indices.append(remaining_indices[max_index]) 122 | del remaining_indices[max_index] 123 | 124 | return mmr_scores -------------------------------------------------------------------------------- /chinatravel/environment/tools/attractions/apis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas import DataFrame 3 | from typing import Callable 4 | import os 5 | from geopy.distance import geodesic 6 | 7 | import sys 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | from poi.apis import Poi 11 | 12 | 13 | class Attractions: 14 | def __init__( 15 | self, 16 | base_path: str = "../../database/attractions", 17 | en_version=False, 18 | ): 19 | city_list = [ 20 | "beijing", 21 | "shanghai", 22 | "nanjing", 23 | "suzhou", 24 | "hangzhou", 25 | "shenzhen", 26 | "chengdu", 27 | "wuhan", 28 | "guangzhou", 29 | "chongqing", 30 | ] 31 | curdir = os.path.dirname(os.path.realpath(__file__)) 32 | data_path_list = [ 33 | os.path.join(curdir, f"{base_path}/{city}/attractions.csv") 34 | for city in city_list 35 | ] 36 | 37 | self.data = {} 38 | for i, city in enumerate(city_list): 39 | self.data[city] = pd.read_csv(data_path_list[i]) 40 | self.key_type_tuple_list_map = {} 41 | for city in city_list: 42 | self.key_type_tuple_list_map[city] = [] 43 | for key in self.data[city].keys(): 44 | self.key_type_tuple_list_map[city].append( 45 | (key, type(self.data[city][key][0])) 46 | ) 47 | self.type_list_map = {} 48 | for city in city_list: 49 | self.type_list_map[city] = self.data[city]["type"].unique() 50 | city_cn_list = [ 51 | "北京", 52 | "上海", 53 | "南京", 54 | "苏州", 55 | "杭州", 56 | "深圳", 57 | "成都", 58 | "武汉", 59 | "广州", 60 | "重庆", 61 | ] 62 | 63 | for i, city in enumerate(city_list): 64 | self.data[city_cn_list[i]] = self.data.pop(city) 65 | self.key_type_tuple_list_map[city_cn_list[i]] = ( 66 | self.key_type_tuple_list_map.pop(city) 67 | ) 68 | self.type_list_map[city_cn_list[i]] = self.type_list_map.pop(city) 69 | 70 | self.poi = Poi() 71 | 72 | def keys(self, city: str): 73 | return self.key_type_tuple_list_map[city] 74 | 75 | def select(self, city: str, key, func: Callable) -> DataFrame: 76 | if key not in self.data[city].keys(): 77 | return "Key not found." 78 | bool_list = [func(x) for x in self.data[city][key]] 79 | return self.data[city][bool_list] 80 | 81 | def id_is_open(self, city: str, id: int, time: str) -> bool: 82 | # open_time = self.data[city]["opentime"][id] 83 | # end_time = self.data[city]["endtime"][id] 84 | 85 | match = self.data[city].loc[self.data[city]["id"] == id] 86 | open_time = match["opentime"].values[0] 87 | end_time = match["endtime"].values[0] 88 | 89 | open_time = float(open_time.split(":")[0]) + float(open_time.split(":")[1]) / 60 90 | end_time = float(end_time.split(":")[0]) + float(end_time.split(":")[1]) / 60 91 | time = float(time.split(":")[0]) + float(time.split(":")[1]) / 60 92 | if open_time < end_time: 93 | return open_time <= time <= end_time 94 | else: 95 | return open_time <= time or time <= end_time 96 | 97 | def nearby(self, city: str, point: str, topk: int = None, dist=2) -> DataFrame: 98 | lat_lon = self.poi.search(city, point) 99 | if isinstance(lat_lon, str): 100 | return lat_lon 101 | lat, lon = lat_lon 102 | distance = [ 103 | geodesic((lat, lon), (x, y)).km 104 | for x, y in zip(self.data[city]["lat"], self.data[city]["lon"]) 105 | ] 106 | tmp = self.data[city].copy() 107 | tmp["distance"] = distance 108 | tmp = tmp.sort_values(by=["distance"]) 109 | if topk is None: 110 | return tmp[tmp["distance"] <= dist] 111 | return tmp[tmp["distance"] <= dist].head(topk) 112 | 113 | def get_type_list(self, city: str): 114 | return self.type_list_map[city] 115 | 116 | 117 | if __name__ == "__main__": 118 | a = Attractions() 119 | print(a.get_type_list("南京")) 120 | # print(a.data) 121 | # print(a.get_info("Name")) 122 | # info_list, _ = a.get_info("Name") 123 | # print(a.get_info_for_index(info_list, 0)) 124 | # print(a.get_info_for_index(info_list, [0, 1])) 125 | # print(a.nearby(a.data.iloc[0]['Latitude'], a.data.iloc[0]['Longitude'])) 126 | # print(a.select("Name", "夫子庙")) 127 | # print(a.id_is_open(0, "10:00")) 128 | # print(a.select('Type', lambda x: x == '公园')) 129 | # print(a.data) 130 | -------------------------------------------------------------------------------- /chinatravel/evaluation/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import json 5 | 6 | 7 | project_root_path = os.path.dirname( 8 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | ) 10 | if project_root_path not in sys.path: 11 | sys.path.insert(0, project_root_path) 12 | if os.path.join(project_root_path, "chinatravel") not in sys.path: 13 | sys.path.insert(0, os.path.join(project_root_path, "chinatravel")) 14 | 15 | from chinatravel.evaluation.utils import load_json_file, validate_json 16 | 17 | from chinatravel.evaluation.commonsense_constraint import evaluate_commonsense_constraints 18 | from chinatravel.evaluation.hard_constraint import evaluate_hard_constraints 19 | from chinatravel.evaluation.preference import evaluate_preference 20 | 21 | METHOD_LIST = [ 22 | "example" "act_Deepseek_zeroshot", 23 | "act_GPT4o_zeroshot", 24 | "react_Deepseek_zeroshot", 25 | "react_GPT4o_zeroshot", 26 | "react_GLM4Plus_zeroshot", 27 | "react_Deepseek_oneshot", 28 | "react_GPT4o_oneshot", 29 | "naive_ns_Deepseek", 30 | "naive_ns_GPT4o", 31 | "naive_ns_GLM4Plus", 32 | ] 33 | 34 | 35 | 36 | def load_result(args, query_index, verbose=False): 37 | 38 | def load_result_for_method(method): 39 | plans = {} 40 | for query_id in query_index: 41 | result_file = os.path.join( 42 | "../results/", method, "{}.json".format(query_id) 43 | ) 44 | 45 | try: 46 | if os.path.exists(result_file): 47 | result = load_json_file(result_file) 48 | plans[query_id] = result 49 | else: 50 | plans[query_id] = {} 51 | except: 52 | plans[query_id] = {} 53 | return plans 54 | 55 | result = {} 56 | if args.method == "all": 57 | method_list = [] 58 | for mi in METHOD_LIST: 59 | if mi != "example": 60 | method_list.append(mi) 61 | else: 62 | method_list = [args.method] 63 | 64 | for method in method_list: 65 | result[method] = load_result_for_method(method) 66 | 67 | if verbose: 68 | print(result) 69 | 70 | return method_list, result 71 | 72 | 73 | if __name__ == "__main__": 74 | 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--splits", "-s", type=str, default="example") 77 | parser.add_argument( 78 | "--method", "-m", type=str, default="example" 79 | ) # , choices=METHOD_LIST) 80 | parser.add_argument("--preference", "-p", action="store_true", default=False) 81 | args = parser.parse_args() 82 | 83 | # print(args.splits) 84 | 85 | query_index, query_data = load_query(args) 86 | 87 | method_list, result_data = load_result(args, query_index) 88 | 89 | # print(query_data.keys()) 90 | # print(result_data.keys()) 91 | 92 | if not os.path.exists("eval_res/splits_{}/".format(args.splits)): 93 | os.makedirs("eval_res/splits_{}/".format(args.splits)) 94 | 95 | for method in method_list: 96 | 97 | if not os.path.exists("eval_res/splits_{}/{}/".format(args.splits, method)): 98 | os.makedirs("eval_res/splits_{}/{}/".format(args.splits, method)) 99 | 100 | macro_comm, micro_comm, result_agg = evaluate_commonsense_constraints( 101 | query_index, query_data, result_data[method], verbose=False 102 | ) 103 | 104 | res_file = "eval_res/splits_{}/{}/commonsense.csv".format(args.splits, method) 105 | result_agg.to_csv(res_file, index=False) 106 | print("save to {}".format(res_file)) 107 | 108 | print("Method: {}".format(method)) 109 | print("Commonsense constraints:") 110 | print("micro accuracy: {}".format(micro_comm)) 111 | print("macro accuracy: {}".format(macro_comm)) 112 | 113 | # record the index of the queries that pass the commonsense constraints 114 | commonsense_pass_info = result_agg.iloc[:, 1:] 115 | id_list = result_agg.iloc[:, 0].tolist() 116 | commonsense_pass = [ 117 | id_list[i] 118 | for i in range(len(id_list)) 119 | if commonsense_pass_info.iloc[i].sum() == 0 120 | ] 121 | # record end 122 | 123 | print("Logical constraints:") 124 | macro_logi, micro_logi, result_agg = evaluate_hard_constraints( 125 | query_index, query_data, result_data[method], verbose=False 126 | ) 127 | 128 | print("micro accuracy: {}".format(micro_logi)) 129 | print("macro accuracy: {}".format(macro_logi)) 130 | 131 | res_file = "eval_res/splits_{}/{}/logical.csv".format(args.splits, method) 132 | result_agg.to_csv(res_file, index=False) 133 | print("save to {}".format(res_file)) 134 | if args.preference: 135 | print("Preference:") 136 | result_agg = evaluate_preference( 137 | query_index, 138 | query_data, 139 | result_data[method], 140 | commonsense_pass, 141 | ) 142 | 143 | res_file = "eval_res/splits_{}/{}/preference.csv".format( 144 | args.splits, method 145 | ) 146 | result_agg.to_csv(res_file, index=False) 147 | print("save to {}".format(res_file)) 148 | -------------------------------------------------------------------------------- /chinatravel/evaluation/output_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "type": "object", 4 | "properties": { 5 | "people_number": { 6 | "type": "integer" 7 | }, 8 | "start_city": { 9 | "type": "string" 10 | }, 11 | "target_city": { 12 | "type": "string" 13 | }, 14 | "itinerary": { 15 | "type": "array", 16 | "items": { 17 | "type": "object", 18 | "properties": { 19 | "day": { 20 | "type": "integer" 21 | }, 22 | "activities": { 23 | "type": "array", 24 | "items": { 25 | "type": "object", 26 | "properties": { 27 | "type": { 28 | "type": "string", 29 | "enum": ["airplane", "attraction", "lunch", "dinner", "breakfast", "accommodation", "train"] 30 | }, 31 | "start_time": { 32 | "type": "string", 33 | "pattern": "^\\d{2}:\\d{2}$" 34 | }, 35 | "end_time": { 36 | "type": "string", 37 | "pattern": "^\\d{2}:\\d{2}$" 38 | }, 39 | "cost": { 40 | "type": "number" 41 | }, 42 | "price": { 43 | "type": "number" 44 | }, 45 | "tickets": { 46 | "type": "integer" 47 | }, 48 | "position": { 49 | "type": "string" 50 | }, 51 | "transports": { 52 | "type": "array", 53 | "items": { 54 | "type": "object", 55 | "properties": { 56 | "start": { 57 | "type": "string" 58 | }, 59 | "end": { 60 | "type": "string" 61 | }, 62 | "mode": { 63 | "type": "string", 64 | "enum": [ 65 | "walk", 66 | "metro", 67 | "taxi" 68 | ] 69 | }, 70 | "start_time": { 71 | "type": "string", 72 | "pattern": "^\\d{2}:\\d{2}$" 73 | }, 74 | "end_time": { 75 | "type": "string", 76 | "pattern": "^\\d{2}:\\d{2}$" 77 | }, 78 | "cost": { 79 | "type": "number" 80 | }, 81 | "price": { 82 | "type": "number" 83 | }, 84 | "distance": { 85 | "type": "number" 86 | }, 87 | "tickets": { 88 | "type": "integer" 89 | } 90 | }, 91 | "required": [ 92 | "start", 93 | "end", 94 | "mode", 95 | "start_time", 96 | "end_time", 97 | "price", 98 | "cost", 99 | "distance" 100 | ] 101 | } 102 | }, 103 | "room_type": { 104 | "type": "integer" 105 | }, 106 | "rooms": { 107 | "type": "integer" 108 | }, 109 | "FlightID": { 110 | "type": "string" 111 | }, 112 | "TrainID": { 113 | "type": "string" 114 | } 115 | }, 116 | "required": [ 117 | "type", 118 | "start_time", 119 | "end_time", 120 | "cost", 121 | "price", 122 | "transports" 123 | ], 124 | "allOf": [ 125 | { 126 | "if": { 127 | "properties": { 128 | "type": { 129 | "enum": ["airplane", "train"] 130 | } 131 | } 132 | }, 133 | "then": { 134 | "required": ["start", "end"] 135 | } 136 | }, 137 | { 138 | "if": { 139 | "properties": { "type": { "const": "airplane" } } 140 | }, 141 | "then": { 142 | "required": ["FlightID"] 143 | } 144 | }, 145 | { 146 | "if": { 147 | "properties": { "type": { "const": "train" } } 148 | }, 149 | "then": { 150 | "required": ["TrainID"] 151 | } 152 | } 153 | ] 154 | } 155 | } 156 | }, 157 | "required": [ 158 | "day", 159 | "activities" 160 | ] 161 | } 162 | } 163 | }, 164 | "required": [ 165 | "people_number", 166 | "start_city", 167 | "target_city", 168 | "itinerary" 169 | ] 170 | } 171 | -------------------------------------------------------------------------------- /chinatravel/agent/pure_neuro_agent/prompts/output_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "type": "object", 4 | "properties": { 5 | "people_number": { 6 | "type": "integer" 7 | }, 8 | "start_city": { 9 | "type": "string" 10 | }, 11 | "target_city": { 12 | "type": "string" 13 | }, 14 | "itinerary": { 15 | "type": "array", 16 | "items": { 17 | "type": "object", 18 | "properties": { 19 | "day": { 20 | "type": "integer" 21 | }, 22 | "activities": { 23 | "type": "array", 24 | "items": { 25 | "type": "object", 26 | "properties": { 27 | "type": { 28 | "type": "string", 29 | "enum": ["airplane", "attraction", "lunch", "dinner", "breakfast", "accommodation", "train"] 30 | }, 31 | "start_time": { 32 | "type": "string", 33 | "pattern": "^\\d{2}:\\d{2}$" 34 | }, 35 | "end_time": { 36 | "type": "string", 37 | "pattern": "^\\d{2}:\\d{2}$" 38 | }, 39 | "cost": { 40 | "type": "number" 41 | }, 42 | "price": { 43 | "type": "number" 44 | }, 45 | "tickets": { 46 | "type": "integer" 47 | }, 48 | "position": { 49 | "type": "string" 50 | }, 51 | "transports": { 52 | "type": "array", 53 | "items": { 54 | "type": "object", 55 | "properties": { 56 | "start": { 57 | "type": "string" 58 | }, 59 | "end": { 60 | "type": "string" 61 | }, 62 | "mode": { 63 | "type": "string", 64 | "enum": [ 65 | "walk", 66 | "metro", 67 | "taxi" 68 | ] 69 | }, 70 | "start_time": { 71 | "type": "string", 72 | "pattern": "^\\d{2}:\\d{2}$" 73 | }, 74 | "end_time": { 75 | "type": "string", 76 | "pattern": "^\\d{2}:\\d{2}$" 77 | }, 78 | "cost": { 79 | "type": "number" 80 | }, 81 | "price": { 82 | "type": "number" 83 | }, 84 | "distance": { 85 | "type": "number" 86 | }, 87 | "tickets": { 88 | "type": "integer" 89 | } 90 | }, 91 | "required": [ 92 | "start", 93 | "end", 94 | "mode", 95 | "start_time", 96 | "end_time", 97 | "price", 98 | "cost", 99 | "distance" 100 | ] 101 | } 102 | }, 103 | "room_type": { 104 | "type": "integer" 105 | }, 106 | "rooms": { 107 | "type": "integer" 108 | }, 109 | "FlightID": { 110 | "type": "string" 111 | }, 112 | "TrainID": { 113 | "type": "string" 114 | } 115 | }, 116 | "required": [ 117 | "type", 118 | "start_time", 119 | "end_time", 120 | "cost", 121 | "price", 122 | "transports" 123 | ], 124 | "allOf": [ 125 | { 126 | "if": { 127 | "properties": { 128 | "type": { 129 | "enum": ["airplane", "train"] 130 | } 131 | } 132 | }, 133 | "then": { 134 | "required": ["start", "end"] 135 | } 136 | }, 137 | { 138 | "if": { 139 | "properties": { "type": { "const": "airplane" } } 140 | }, 141 | "then": { 142 | "required": ["FlightID"] 143 | } 144 | }, 145 | { 146 | "if": { 147 | "properties": { "type": { "const": "train" } } 148 | }, 149 | "then": { 150 | "required": ["TrainID"] 151 | } 152 | } 153 | ] 154 | } 155 | } 156 | }, 157 | "required": [ 158 | "day", 159 | "activities" 160 | ] 161 | } 162 | } 163 | }, 164 | "required": [ 165 | "people_number", 166 | "start_city", 167 | "target_city", 168 | "itinerary" 169 | ] 170 | } 171 | -------------------------------------------------------------------------------- /chinatravel/data/load_datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | import numpy as np 5 | from datasets import load_dataset as hg_load_dataset 6 | import ast 7 | 8 | project_root_path = os.path.dirname( 9 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | ) 11 | 12 | if project_root_path not in sys.path: 13 | sys.path.insert(0, project_root_path) 14 | 15 | 16 | class NpEncoder(json.JSONEncoder): 17 | def default(self, obj): 18 | if isinstance(obj, np.integer): 19 | return int(obj) 20 | if isinstance(obj, np.floating): 21 | return float(obj) 22 | if isinstance(obj, np.ndarray): 23 | return obj.tolist() 24 | return super(NpEncoder, self).default(obj) 25 | 26 | 27 | def load_query_local(args, version="", verbose=False): 28 | query_data = {} 29 | 30 | # split_config_file = 'default_splits/{}.txt'.format(args.splits) 31 | 32 | split_config_file = os.path.join( 33 | project_root_path, 34 | "chinatravel", 35 | "evaluation", 36 | "default_splits", 37 | "{}.txt".format(args.splits), 38 | ) 39 | 40 | print("config file for testing split: {}".format(split_config_file)) 41 | 42 | query_id_list = [] 43 | with open(split_config_file, "r") as f: 44 | for line in f.readlines(): 45 | line = line.strip() 46 | query_id_list.append(line) 47 | 48 | if verbose: 49 | print(query_id_list) 50 | 51 | data_dir = os.path.join(project_root_path, "chinatravel", "data") 52 | 53 | dir_list = os.listdir(data_dir) 54 | for dir_i in dir_list: 55 | dir_ii = os.path.join(data_dir, dir_i) 56 | if os.path.isdir(dir_ii): 57 | file_list = os.listdir(dir_ii) 58 | 59 | for file_i in file_list: 60 | query_id = file_i.split(".")[0] 61 | if query_id in query_id_list: 62 | data_i = json.load( 63 | open(os.path.join(dir_ii, file_i), encoding="utf-8") 64 | ) 65 | 66 | if hasattr(args, 'oracle_translation') and not args.oracle_translation: 67 | if "hard_logic" in data_i: 68 | del data_i["hard_logic"] 69 | if "hard_logic_py" in data_i: 70 | del data_i["hard_logic_py"] 71 | if "hard_logic_nl" in data_i: 72 | del data_i["hard_logic_nl"] 73 | 74 | query_data[query_id] = data_i 75 | 76 | # print(query_data) 77 | 78 | if verbose: 79 | for query_id in query_id_list: 80 | print(query_id, query_data[query_id]) 81 | 82 | return query_id_list, query_data 83 | 84 | 85 | def load_json_file(file_path): 86 | with open(file_path, "r", encoding="utf-8") as f: 87 | return json.load(f) 88 | 89 | 90 | def save_json_file(json_data, file_path): 91 | with open(file_path, "w", encoding="utf8") as dump_f: 92 | json.dump(json_data, dump_f, ensure_ascii=False, indent=4, cls=NpEncoder) 93 | 94 | 95 | 96 | def load_query(args): 97 | 98 | if not args.splits in ["easy", "medium", "human", "preference_base50", 99 | "preference0_base50", "preference1_base50", "preference2_base50", 100 | "preference3_base50", "preference4_base50", "preference5_base50"]: 101 | return load_query_local(args) 102 | config_name = "default" 103 | if args.splits in ["preference0_base50", "preference1_base50", "preference2_base50", 104 | "preference3_base50", "preference4_base50", "preference5_base50"]: 105 | config_name = "preference" 106 | # elif args.splits in ["human"]: 107 | # config_name = "validation" 108 | # elif args.splits in ["human1000"]: 109 | # config_name = "test" 110 | query_data = hg_load_dataset("LAMDA-NeSy/ChinaTravel", name=config_name)[args.splits].to_list() 111 | 112 | 113 | for data_i in query_data: 114 | if "hard_logic_py" in data_i: 115 | data_i["hard_logic_py"] = ast.literal_eval(data_i["hard_logic_py"]) 116 | 117 | query_id_list = [data_i["uid"] for data_i in query_data] 118 | data_dict = {} 119 | for data_i in query_data: 120 | if not args.oracle_translation: 121 | if "hard_logic" in data_i: 122 | del data_i["hard_logic"] 123 | if "hard_logic_py" in data_i: 124 | del data_i["hard_logic_py"] 125 | if "hard_logic_nl" in data_i: 126 | del data_i["hard_logic_nl"] 127 | 128 | data_dict[data_i["uid"]] = data_i 129 | 130 | return query_id_list, data_dict 131 | 132 | 133 | import argparse 134 | argparser = argparse.ArgumentParser() 135 | argparser.add_argument("--splits", type=str, default="easy") 136 | 137 | if __name__ == "__main__": 138 | 139 | 140 | # from datasets import load_dataset as hg_load_dataset 141 | 142 | # # Login using e.g. `huggingface-cli login` to access this dataset 143 | # ds = hg_load_dataset("LAMDA-NeSy/ChinaTravel") 144 | # print(ds) 145 | # print(ds["easy"].to_list()) 146 | 147 | # exit(0) 148 | args = argparser.parse_args() 149 | query_id_list, query_data = load_query(args) 150 | # print(query_id_list) 151 | # print(query_data) 152 | 153 | for uid in query_id_list: 154 | if uid in query_data: 155 | print(uid, query_data[uid]) 156 | else: 157 | raise ValueError(f"{uid} not in query_data") 158 | -------------------------------------------------------------------------------- /TPC@AIC2025/readme.md: -------------------------------------------------------------------------------- 1 | # Travel Planning Challange @ AIC 2025 2 | 3 | ## 比赛阶段 4 | 5 | ### 初赛 6 | 初赛提供数据集仅供参赛者在前期进行算法验证与调试,选手提交的算法与方案结果不计入最终总分。 7 | 8 | ### 复赛 9 | 复赛提供新的数据集(格式和初赛保持一致), 选手提交的模型由机器进行自动评分,其得分将计入决赛总分。 10 | 复赛截止前提交算法运行结果和代码,赛事方验证算法运行结果,机器评分结果将在比赛网站上展示。 11 | 12 | ### 决赛 13 | 决赛代码不可更改,参赛选手完善提交算法技术报告,进行现场答辩。 14 | 决赛综合成绩由客观评分和主观评分构成,比例为70%和30%。 15 | - 客观评分:复赛阶段提交代码在决赛私有数据集上进行评测验证的得分; 16 | - 主观评分:依据经过标准化处理后的答辩得分。答辩评价将综合考察参赛者的答辩表现,以及所提交的技术方案和代码文档。 17 | 18 | ## 评估指标 19 | 20 | ### 环境约束 21 | 环境约束评价了输出规划方案中的信息是否与提供的沙盒环境信息一致,度量了规划方案的可行性。 22 | [环境约束说明文档](../chinatravel/symbol_verification/readme.md) 23 | 24 | 25 | $$EPR-micro = \frac{\sum_{p\in P}\sum_{\in Env} 1_{passed(c,p)}}{|P|*|Env|}$$ 26 | 27 | 28 | $$EPR-macro = \frac{\sum_{p\in P}\prod_{\in Env} 1_{passed(c,p)}}{|P|}$$ 29 | 30 | ### 条件逻辑约束 31 | 条件逻辑约束评价了输出规划方案中在满足环境约束的前提下对用户个性化需求的满足程度。 32 | 33 | $$C-LPR = \frac{\sum_{p \in P} 1_{passed(Env,p)}\cdot \sum_{c\in C_p} 1_{passed(c,p)}}{\sum_{p \in P}|P|}$$ 34 | 35 | P是输出的规划方案集合,C_p 是方案 p 对应询问中的约束需求集合,passed(c,p) 表示在p中约束c是否被满足。 36 | 37 | ### 硬约束通过率 38 | 硬约束通过率表达了输出规划方案中满足所有环境约束和逻辑约束的比例。 39 | 40 | $$FPR = \frac{\sum_{p \in P} 1_{passed(Env,p)}\cdot \prod_{c\in C_p} 1_{passed(c,p)}}{\sum_{p \in P}|P|}$$ 41 | 42 | ### 偏好评估 43 | TPC 比赛中,我们提供了三个旅行中常见的偏好指标: 44 | 45 | 每天访问的景点数量尽可能多, Daily Average Attractions Visited, DAV,数值归一化到[0.4]作为分数 46 | 47 | $$DAV\text{-}score = (DAV - 0)/4 $$ 48 | 49 | 50 | 平均交通时间尽可能少, Averaged Transportation Time, ATT,数值归一到[15,120] (分钟)作为分数 51 | 52 | $$ATT\text{-}score = \max(\min((120-ATT)/(120-15),1),0) $$ 53 | 54 | 55 | 每天餐饮推荐数量尽可能多, Daily Dining Recommendations, DDR,数值归一到[0,3] 作为分数 56 | 57 | $$DDR\text{-}score = \min((DDR - 0)/(3-0),1) $$ 58 | 59 | ### 最终得分 60 | 61 | Overall Score = 10% * EPR-micro + 10% * EPR-macro + 25% * C-LPR + 40% * FPR + 5% DAV-Score + 5% ATT-Score + 5% DDR-Score 62 | 63 | 64 | ## 环境配置 65 | 请根据ChinaTravel代码库说明进行环境配置 66 | https://github.com/LAMDASZ-ML/ChinaTravel/tree/main 67 | 68 | 69 | ## 数据配置 70 | 71 | 数据集和数据索引下载: 72 | 73 | 74 | [](https://box.nju.edu.cn/d/be342d7958a44fb2ab35/) 75 | 76 | 请在官网报名注册后获取访问密码 77 | 78 | 79 | 请将数据集解压到`chinatravel/data/`目录下。例如:`chinatravel/data/tpc_aic_phase1/` 80 | 将数据索引放到`chinatravel/evaluation/default_splits`目录下。例如:`chinatravel/evaluation/default_splits/tpc_aic_phase1.txt` 81 | 82 | 83 | ## 🛠️ 算法开发 84 | 85 | ### 1. 智能体算法开发 86 | 87 | 我们在`chinatravel/agent/tpc_agent/` 提供了独立的算法开发目录,你可以把算法需要的内容都放到这里。 88 | 89 | 90 | ### 2. 语言模型训练适配 91 | 92 | 支持本地语言模型在旅行规划的适配,你可以在`chinatravel/agent/tpc_agent/tpc_llm.py` 文件中的TPCLLM实例化你的本地模型推理代码。 93 | 94 | 95 | ```python 96 | class TPCLLM(AbstractLLM): 97 | def __init__(self): 98 | super().__init__() 99 | # Initialization logic 100 | self.name = "TPCLLM" 101 | 102 | def _get_response(self, messages, one_line, json_mode): 103 | # Implement the response logic of the LLM 104 | response = "Your LLM response" 105 | if json_mode: 106 | # Handle JSON mode 107 | pass 108 | elif one_line: 109 | # Handle one - line mode 110 | response = response.split("\n")[0] 111 | return response 112 | ``` 113 | 114 | ### 3. 本地算法运行 115 | 完成智能体算法开发后,你可以使用实验脚本运行你的代码。 116 | 117 | 118 | 任务:全流程方案生成 119 | 测试流程中,用户需要实时理解用户自然语言表达的约束需求,并自动化地给出满足约束需求的旅行方案。 120 | 121 | ```bash 122 | python run_tpc.py --splits tpc_aic_phase1 --agent TPCAgent --llm TPCLLM 123 | ``` 124 | 规划结果会保存在:`results/TPCAgent_TPCLLM` 目录。 125 | 126 | 请注意算法推理时禁止使用--oracle_translation使用DSL标注信息,DSL标注信息仅供用于本地测试评估。 127 | 128 | 129 | ### 4. 本地结果获取 130 | 131 | 本地评估代码在`eval_tpc.py`文件中提供。你可以使用以下命令运行评估代码: 132 | 133 | 全流程方案生成 134 | ```bash 135 | python eval_tpc.py --splits tpc_aic_phase1 --method TPCAgent_TPCLLM 136 | ``` 137 | 138 | ### 5. 代码和结果提交 139 | 140 | 141 | 结果压缩包 XXX_code.zip:请将`chinatravel/results/TPCAgent_TPCLLM/`压缩提交。 142 | 143 | 144 | 代码压缩包 XXX_code.zip:请将`chinatravel/agent/tpc_agent/`压缩提交。 145 | 146 | #### 代码提交细则 147 | **在提交代码前,请在本地验证,确保你的算法能正确载入模型权重、正确运行、在results文件夹中能顺利生成结果plan的json文件。** 148 | - 只能打包 `chinatravel/agent/tpc_agent/` 目录内的部分,你的所有代码、模型都应放在这个目录下。官方进行代码复测时,会将你的算法文件夹直接解压到这个位置,这个文件夹外的部分都与当前 chinatravel 给出的代码一致。所以在验证你代码的可复现性时,请保证外部代码不变。 149 | - 模型问题,官方代码复测仅支持离线模型,如果你需要使用Qwen等开源模型,请将对应模型权重下载放在你的算法目录中,即`chinatravel/agent/tpc_agent/`,并检查可以被你正确调用。你可以在 `chinatravel/agent/tpc_agent/tpc_llm.py` 中指定你的模型权重目录,请确保该位置在你的算法目录,即 `chinatravel/agent/tpc_agent/`,下,例如:`path = os.path.join(project_root_path, "chinatravel", "agent", "tpc_agent", "local_llm", "Qwen3")`。 150 | - 如果你需要使用到与当前环境不一致的python包,请将相应的python包离线下载到你的算法目录(`chinatravel/agent/tpc_agent/`)中,并通过源代码载入的方式进行使用。 151 | - 在线代码复核推理命令:`python run_exp.py --splits tpc_phase_2_online_test --agent TPCAgent --llm TPCLLM` 152 | - 在线代码复核评估命令:`python eval_tpc.py --splits tpc_phase_2_online_test --method TPCAgent_TPCLLM` 153 | - 请保障你的提交的文件被正确命名和组织: 154 | - 压缩包名:队伍参赛编号_code.zip,例如 `AIC-2025-XXXXXXXX.zip`。 155 | - 压缩包中第一层为,命名为 `tpc_agent` 的文件夹 和一个命名为 `contact.txt` 的联系方式。 156 | - `tpc_agent` 文件夹中为运行你算法需要的所有内容。 157 | - `contact.txt` 中包含必要的参赛队伍联系方式,包括队伍参赛编号、团队名称、联系方式(邮箱) 158 | - 提交文件解压后大小必须在40G内 159 | 160 | **再次重申:在提交代码前,请在本地验证,确保你的算法能正确载入模型权重、正确运行、在results文件夹中能顺利生成结果plan的json文件。** 161 | 162 | ### 6. 官方评测 (复赛、决赛) 163 | 164 | - 官方验证将在离线设备上进行,该设备配置为:14核Xeon(R) Gold 6348 CPU,100GB RAM,A800-80GB GPU,50GB SSD,驱动程序:550.54.14,CUDA:12.4。 165 | - 算法需要快速响应用户请求,官方评估期间,每个查询将分配5分钟的推理时间,如果超出时间限制,系统将跳至下一个查询。请合理设计你的算法,或使用计时机制在给定的计算资源内完成规划。 166 | - 我们在 run_tpc.py 中给出了限时机制的实现。 167 | - 评估将以离线方式进行,如果你的算法需要使用大型语言模型(LLM),请使用开源模型,例如Qwen3-8B/4B、Llama 3.1-8B等。避免使用外部API,例如DeepSeek API、GPT API等。 168 | - 我们将重复评估五次,取总分的平均值作为最终结果。如果最终结果与用户提交的结果存在显著差异,我们将联系参与者进行确认。无法重现结果的参赛队伍将被取消资格。 169 | 170 | -------------------------------------------------------------------------------- /chinatravel/evaluation/commonsense_constraint.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | 4 | # from chinatravel.environment.tools.accommodations.apis import Accommodations 5 | # from chinatravel.environment.tools.restaurants.apis import Restaurants 6 | # from chinatravel.environment.tools.attractions.apis import Attractions 7 | # from chinatravel.environment.tools.intercity_transport.apis import IntercityTransport 8 | # from chinatravel.environment.tools.transportation.apis import Transportation 9 | # from env.tools.transportation.apis import GoTo 10 | # from envs import goto 11 | import json 12 | import os 13 | import sys 14 | from tqdm import tqdm 15 | 16 | 17 | import pandas as pd 18 | 19 | # accommodation = Accommodations() 20 | # restaurants = Restaurants() 21 | # attractions = Attractions() 22 | # intercity_transport=IntercityTransport() 23 | # innercity_transport=Transportation() 24 | 25 | from chinatravel.symbol_verification.commonsense_constraint import Is_intercity_transport_correct, Is_attractions_correct, Is_hotels_correct, Is_restaurants_correct, Is_transport_correct, Is_time_correct, Is_space_correct 26 | 27 | ''' 28 | Constraints: 29 | Available 30 | 1. Intercity transport information exsits and is objective: ID, time, startpos and endpos need to be correct. 31 | 2. Attractions 32 | 3. Hotels 33 | 4. Restaurants 34 | 5. transportation 35 | 6. Times 36 | 7. space 37 | ''' 38 | 39 | 40 | 41 | def evaluate_commonsense_constraints(data_index, symbolic_input_dict, plan_json_dict, verbose=False): 42 | # assert len(symbolic_input_list)==len(plan_json_list) 43 | 44 | func_list = [Is_intercity_transport_correct, Is_attractions_correct, Is_hotels_correct, Is_restaurants_correct, Is_transport_correct, Is_time_correct, Is_space_correct] 45 | total_correct = 0 46 | 47 | individual_results = [] 48 | results_per_sample=[] 49 | 50 | 51 | result_agg = pd.DataFrame(columns=['data_id']) 52 | result_agg['data_id'] = data_index 53 | 54 | individual_succ = 0 55 | pass_id = [] 56 | 57 | for ii, idx in tqdm(enumerate(data_index), total=len(data_index)): 58 | # for i,(symbolic_input,plan_json) in enumerate(zip(symbolic_input_list,plan_json_list)): 59 | 60 | 61 | 62 | symbolic_input, plan_json = symbolic_input_dict[idx], plan_json_dict[idx] 63 | 64 | if verbose: 65 | print(symbolic_input) 66 | print(plan_json) 67 | try: 68 | for func in func_list: 69 | 70 | table_res, error_info = func(symbolic_input, plan_json, verbose=verbose) 71 | 72 | if verbose: 73 | print(error_info) 74 | 75 | for colum_i in table_res.columns: 76 | if colum_i not in result_agg.columns: 77 | result_agg[colum_i] = 0 78 | 79 | result_agg.loc[ii, colum_i] = table_res[colum_i].loc[0] 80 | 81 | # print(info) 82 | if result_agg.loc[ii][1:].sum() == 0: 83 | individual_succ += 1 84 | pass_id.append(idx) 85 | except Exception as message: 86 | pass 87 | # print("Error: ", message) 88 | # print(symbolic_input) 89 | # print(plan_json) 90 | 91 | 92 | 93 | total_count=len(data_index) 94 | micro_accuracy = 1. - result_agg.drop("data_id", axis=1).sum().sum() / (total_count * (result_agg.shape[1] - 1)) 95 | 96 | macro_accuracy = individual_succ / total_count 97 | 98 | return macro_accuracy*100, micro_accuracy*100, result_agg, pass_id 99 | 100 | if __name__ == "__main__": 101 | 102 | 103 | 104 | from evaluation.utils import load_json_file 105 | # test_example=load_json_file("./example/query_53.json") 106 | # test_plan=load_json_file("./example/plan_53.json") 107 | # evaluate_commonsense_constraints([test_example], [test_plan]) 108 | 109 | # exit(0) 110 | 111 | symbolic_input_list=[] 112 | plan_json_list=[] 113 | 114 | for i in range(1): 115 | test_plan_path='./example/a_result.json'.format(i+1) 116 | test_example_path='./example/a_query.json'.format(i+1) 117 | test_example=load_json_file(test_example_path) 118 | test_plan=load_json_file(test_plan_path) 119 | symbolic_input_list.append(test_example) 120 | plan_json_list.append(test_plan) 121 | macro_accuracy, micro_accuracy, _ =evaluate_commonsense_constraints(symbolic_input_list,plan_json_list) 122 | print('macro: {}%, micro: {}%'.format(macro_accuracy,micro_accuracy)) 123 | 124 | # test_plan_path='./example/plan_4.json' 125 | # test_example_path='./example/query_4.json' 126 | # test_example=load_json_file(test_example_path) 127 | # test_plan=load_json_file(test_plan_path) 128 | 129 | # print(Is_intercity_transport_correct(test_example,test_plan)) 130 | # print(Is_attractions_correct(test_example,test_plan)) 131 | # print(Is_hotels_correct(test_example,test_plan)) 132 | # print(Is_restaurants_correct(test_example,test_plan)) 133 | # print(Is_transport_correct(test_example,test_plan)) 134 | # print(Is_time_correct(test_example,test_plan)) 135 | # print(Is_space_correct(test_example,test_plan)) 136 | 137 | 138 | # pass_flag = True 139 | 140 | 141 | 142 | # info_list = [] 143 | # for func_i in func_list: 144 | # flag, info = func_i(test_example,test_plan) 145 | 146 | # print(info) 147 | 148 | # pass_flag = pass_flag and flag 149 | # info_list.append(info) 150 | 151 | # print("final result: ", pass_flag) 152 | 153 | # for item in info_list: 154 | # print(item) 155 | # print(info_list) 156 | 157 | -------------------------------------------------------------------------------- /chinatravel/environment/tools/intercity_transport/apis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from pandas import DataFrame 4 | 5 | 6 | def time2float(time_str): 7 | h, m = time_str.split(":") 8 | return int(h) + int(m) / 60 9 | 10 | 11 | class IntercityTransport: 12 | def __init__(self, path: str = "../../database/intercity_transport/"): 13 | curdir = os.path.dirname(os.path.realpath(__file__)) 14 | self.base_path = os.path.join(curdir, path) 15 | self.airplane_path = self.base_path + "airplane.jsonl" 16 | self.airplane_df = pd.read_json( 17 | self.airplane_path, lines=True, keep_default_dates=False 18 | ) 19 | city_list = [ 20 | "上海", 21 | "北京", 22 | "深圳", 23 | "广州", 24 | "重庆", 25 | "苏州", 26 | "成都", 27 | "杭州", 28 | "武汉", 29 | "南京", 30 | ] 31 | self.train_df_dict = {} 32 | 33 | for start_city in city_list: 34 | for end_city in city_list: 35 | if start_city == end_city: 36 | continue 37 | train_path = ( 38 | self.base_path 39 | + "train/" 40 | + "from_{}_to_{}.json".format(start_city, end_city) 41 | ) 42 | train_df = pd.read_json(train_path) 43 | self.train_df_dict[(start_city, end_city)] = train_df 44 | 45 | def select( 46 | self, start_city, end_city, intercity_type, earliest_leave_time="00:00" 47 | ) -> DataFrame: 48 | if intercity_type not in ["train", "airplane"]: 49 | return "only support intercity_type in ['train','airplane']" 50 | res = self._select(start_city, end_city, intercity_type) 51 | bool_list = [False] * len(res) 52 | for i in range(len(res)): 53 | if time2float(res.loc[i, "BeginTime"]) >= time2float(earliest_leave_time): 54 | bool_list[i] = True 55 | return res[bool_list] 56 | 57 | def _select(self, start_city, end_city, intercity_type) -> DataFrame: 58 | # intercity_type=='train' | 'airplane' 59 | if intercity_type == "airplane": 60 | 61 | if len(self.airplane_df) == 0: 62 | return None 63 | 64 | filtered_flights = self.airplane_df[ 65 | (self.airplane_df["From"].str.contains(start_city)) 66 | & (self.airplane_df["To"].str.contains(end_city)) 67 | ] 68 | sorted_flights = filtered_flights.sort_values(by="BeginTime").reset_index( 69 | drop=True 70 | ) 71 | return sorted_flights 72 | if intercity_type == "train": 73 | 74 | if len(self.train_df_dict[(start_city, end_city)]) == 0: 75 | return None 76 | 77 | filtered_trains = self.train_df_dict[(start_city, end_city)] 78 | sorted_trains = filtered_trains.sort_values(by="BeginTime").reset_index( 79 | drop=True 80 | ) 81 | return sorted_trains 82 | 83 | 84 | if __name__ == "__main__": 85 | a = IntercityTransport() 86 | city_list = [ 87 | "上海", 88 | "北京", 89 | "深圳", 90 | "广州", 91 | "重庆", 92 | "苏州", 93 | "成都", 94 | "杭州", 95 | "武汉", 96 | "南京", 97 | ] 98 | city_en_list = [ 99 | "shanghai", 100 | "beijing", 101 | "shenzhen", 102 | "guangzhou", 103 | "chongqing", 104 | "suzhou", 105 | "chengdu", 106 | "hangzhou", 107 | "wuhan", 108 | "nanjing", 109 | ] 110 | str_list = [] 111 | for i in range(len(city_list)): 112 | for j in range(i + 1, len(city_list)): 113 | tmp_len = 0 114 | 115 | tmp = a.select(city_list[i], city_list[j], "train") 116 | if not isinstance(tmp, DataFrame): 117 | tmp_len = 0 118 | else: 119 | tmp_len = len(tmp) 120 | if tmp_len > 0: 121 | str_list.append( 122 | "('{}','{}','{}',{})".format( 123 | city_en_list[i], city_en_list[j], "train", tmp_len 124 | ) 125 | ) 126 | 127 | tmp = a.select(city_list[j], city_list[i], "train") 128 | if not isinstance(tmp, DataFrame): 129 | tmp_len = 0 130 | else: 131 | tmp_len = len(tmp) 132 | if tmp_len > 0: 133 | str_list.append( 134 | "('{}','{}','{}',{})".format( 135 | city_en_list[j], city_en_list[i], "train", tmp_len 136 | ) 137 | ) 138 | 139 | tmp = a.select(city_list[i], city_list[j], "flight") 140 | if not isinstance(tmp, DataFrame): 141 | tmp_len = 0 142 | else: 143 | tmp_len = len(tmp) 144 | if tmp_len > 0: 145 | str_list.append( 146 | "('{}','{}','{}',{})".format( 147 | city_en_list[i], city_en_list[j], "airplane", tmp_len 148 | ) 149 | ) 150 | 151 | tmp = a.select(city_list[j], city_list[i], "flight") 152 | if not isinstance(tmp, DataFrame): 153 | tmp_len = 0 154 | else: 155 | tmp_len = len(tmp) 156 | if tmp_len > 0: 157 | str_list.append( 158 | "('{}','{}','{}',{})".format( 159 | city_en_list[j], city_en_list[i], "airplane", tmp_len 160 | ) 161 | ) 162 | 163 | print(",\n".join(str_list)) 164 | -------------------------------------------------------------------------------- /run_exp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import sys 6 | import os 7 | import json 8 | 9 | project_root_path = os.path.dirname(os.path.abspath(__file__)) 10 | if project_root_path not in sys.path: 11 | sys.path.insert(0, project_root_path) 12 | 13 | from copy import deepcopy 14 | 15 | from chinatravel.data.load_datasets import load_query, save_json_file 16 | from chinatravel.agent.load_model import init_agent, init_llm 17 | from chinatravel.environment.world_env import WorldEnv 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = argparse.ArgumentParser(description="argparse testing") 23 | parser.add_argument( 24 | "--splits", 25 | "-s", 26 | type=str, 27 | default="easy", 28 | help="query subset", 29 | ) 30 | parser.add_argument("--index", "-id", type=str, default=None, help="query index") 31 | parser.add_argument( 32 | "--skip", "-sk", type=int, default=0, help="skip if the plan exists" 33 | ) 34 | parser.add_argument('--restart_from', type=str, default=None, help='Restart Data ID') 35 | parser.add_argument( 36 | "--agent", 37 | "-a", 38 | type=str, 39 | default=None, 40 | choices=["RuleNeSy", "LLMNeSy", "LLM-modulo", "ReAct", "ReAct0", "Act", "TPCAgent"], 41 | ) 42 | parser.add_argument( 43 | "--llm", 44 | "-l", 45 | type=str, 46 | default=None 47 | ) 48 | 49 | parser.add_argument('--oracle_translation', action='store_true', help='Set this flag to enable oracle translation.') 50 | parser.add_argument('--preference_search', action='store_true', help='Set this flag to enable preference search.') 51 | parser.add_argument('--refine_steps', type=int, default=10, help='Steps for refine-based method, such as LLM-modulo, Reflection') 52 | 53 | 54 | args = parser.parse_args() 55 | 56 | print(args) 57 | 58 | query_index, query_data = load_query(args) 59 | print(len(query_index), "samples") 60 | 61 | if args.index is not None: 62 | query_index = [args.index] 63 | 64 | cache_dir = os.path.join(project_root_path, "cache") 65 | 66 | method = args.agent + "_" + args.llm 67 | if args.agent == "LLM-modulo": 68 | method += f"_{args.refine_steps}steps" 69 | 70 | if not args.oracle_translation: 71 | raise Exception("LLM-modulo must use oracle translation") 72 | 73 | if args.oracle_translation: 74 | method = method + "_oracletranslation" 75 | if args.preference_search: 76 | method = method + "_preferencesearch" 77 | 78 | res_dir = os.path.join( 79 | project_root_path, "results", method 80 | ) 81 | log_dir = os.path.join( 82 | project_root_path, "cache", method 83 | ) 84 | if not os.path.exists(res_dir): 85 | os.makedirs(res_dir) 86 | if not os.path.exists(log_dir): 87 | os.makedirs(log_dir) 88 | 89 | print("res_dir: ", res_dir) 90 | print("log_dir:", log_dir) 91 | 92 | if args.agent in ["LLM-modulo"]: 93 | max_model_len = 65536 94 | elif args.agent in ["LLMNeSy"]: 95 | max_model_len = 8192 96 | else: 97 | max_model_len = None 98 | kwargs = { 99 | "method": args.agent, 100 | "env": WorldEnv(), 101 | "backbone_llm": init_llm(args.llm, max_model_len=max_model_len), 102 | "cache_dir": cache_dir, 103 | "log_dir": log_dir, 104 | "debug": True, 105 | "refine_steps": args.refine_steps, 106 | } 107 | agent = init_agent(kwargs) 108 | 109 | 110 | white_list = [] 111 | 112 | succ_count, eval_count = 0, 0 113 | 114 | for i, data_idx in enumerate(query_index): 115 | if (args.restart_from is not None) and (data_idx != args.restart_from): 116 | continue 117 | else: 118 | args.restart_from = None 119 | 120 | sys.stdout = sys.__stdout__ 121 | print("------------------------------") 122 | print( 123 | "Process [{}/{}], Success [{}/{}]:".format( 124 | i, len(query_index), succ_count, eval_count 125 | ) 126 | ) 127 | print("data uid: ", data_idx) 128 | 129 | if args.skip and os.path.exists(os.path.join(res_dir, f"{data_idx}.json")): 130 | continue 131 | if i in white_list: 132 | continue 133 | eval_count += 1 134 | query_i = query_data[data_idx] 135 | print(query_i) 136 | if args.agent in ["ReAct", "ReAct0", "Act"]: 137 | plan_log = agent(query_i["nature_language"]) 138 | plan = plan_log["ans"] 139 | if isinstance(plan, str): 140 | try: 141 | plan = json.loads(plan) 142 | except: 143 | plan = {"plan": plan} 144 | plan["input_token_count"] = agent.backbone_llm.input_token_count 145 | plan["output_token_count"] = agent.backbone_llm.output_token_count 146 | plan["input_token_maxx"] = agent.backbone_llm.input_token_maxx 147 | log = plan_log["log"] 148 | save_json_file( 149 | json_data=log, file_path=os.path.join(log_dir, f"{data_idx}.json") 150 | ) 151 | succ = 1 152 | elif args.agent in ["LLM-modulo"]: 153 | 154 | succ, plan = agent.solve(query_i, prob_idx=data_idx, oracle_verifier=True) 155 | 156 | elif args.agent in ["LLMNeSy", "RuleNeSy"]: 157 | succ, plan = agent.run(query_i, load_cache=True, oralce_translation=args.oracle_translation, preference_search=args.preference_search) 158 | 159 | elif args.agent == "TPCAgent": 160 | succ, plan = agent.run(query_i, prob_idx=data_idx, oralce_translation=args.oracle_translation) 161 | 162 | if succ: 163 | succ_count += 1 164 | 165 | save_json_file( 166 | json_data=plan, file_path=os.path.join(res_dir, f"{data_idx}.json") 167 | ) 168 | -------------------------------------------------------------------------------- /chinatravel/agent/nesy_agent/plan_for_check/day2.json: -------------------------------------------------------------------------------- 1 | { 2 | "people_number": 1, 3 | "start_city": "深圳", 4 | "target_city": "广州", 5 | "itinerary": [ 6 | { 7 | "day": 1, 8 | "activities": [ 9 | { 10 | "start_time": "06:25", 11 | "end_time": "06:58", 12 | "start": "深圳北站", 13 | "end": "广州南站", 14 | "price": 62.52, 15 | "cost": 62.52, 16 | "tickets": 1, 17 | "transports": [], 18 | "TrainID": "D941", 19 | "type": "train" 20 | }, 21 | { 22 | "position": "麓苑轩蒸汽火锅酒家", 23 | "type": "lunch", 24 | "price": 134, 25 | "cost": 134, 26 | "start_time": "11:00", 27 | "end_time": "12:00", 28 | "transports": [ 29 | { 30 | "start": "广州南站", 31 | "end": "广州南站-地铁站", 32 | "mode": "walk", 33 | "start_time": "06:58", 34 | "end_time": "07:00", 35 | "cost": 0, 36 | "distance": 0.19, 37 | "price": 0 38 | }, 39 | { 40 | "start": "广州南站-地铁站", 41 | "end": "淘金-地铁站", 42 | "mode": "metro", 43 | "start_time": "07:00", 44 | "end_time": "07:32", 45 | "cost": 5, 46 | "distance": 16.43, 47 | "price": 5, 48 | "tickets": 1 49 | }, 50 | { 51 | "start": "淘金-地铁站", 52 | "end": "麓苑轩蒸汽火锅酒家", 53 | "mode": "walk", 54 | "start_time": "07:32", 55 | "end_time": "07:38", 56 | "cost": 0, 57 | "distance": 0.53, 58 | "price": 0 59 | } 60 | ] 61 | }, 62 | { 63 | "position": "桔子酒店(广州白云国际机场店)", 64 | "type": "accommodation", 65 | "price": 256, 66 | "cost": 256, 67 | "start_time": "13:45", 68 | "end_time": "24:00", 69 | "transports": [ 70 | { 71 | "start": "麓苑轩蒸汽火锅酒家", 72 | "end": "淘金-地铁站", 73 | "mode": "walk", 74 | "start_time": "12:00", 75 | "end_time": "12:06", 76 | "cost": 0, 77 | "distance": 0.53, 78 | "price": 0 79 | }, 80 | { 81 | "start": "淘金-地铁站", 82 | "end": "机场北(2号航站楼)-地铁站", 83 | "mode": "metro", 84 | "start_time": "12:06", 85 | "end_time": "13:03", 86 | "cost": 7, 87 | "distance": 28.84, 88 | "price": 7, 89 | "tickets": 1 90 | }, 91 | { 92 | "start": "机场北(2号航站楼)-地铁站", 93 | "end": "桔子酒店(广州白云国际机场店)", 94 | "mode": "walk", 95 | "start_time": "13:03", 96 | "end_time": "13:45", 97 | "cost": 0, 98 | "distance": 3.52, 99 | "price": 0 100 | } 101 | ], 102 | "room_type": 1, 103 | "rooms": 1 104 | } 105 | ] 106 | }, 107 | { 108 | "day": 2, 109 | "activities": [ 110 | { 111 | "start_time": "05:03", 112 | "end_time": "06:55", 113 | "start": "广州北站", 114 | "end": "深圳东站", 115 | "price": 41.68, 116 | "cost": 41.68, 117 | "tickets": 1, 118 | "transports": [ 119 | { 120 | "start": "桔子酒店(广州白云国际机场店)", 121 | "end": "机场北(2号航站楼)-地铁站", 122 | "mode": "walk", 123 | "start_time": "00:00", 124 | "end_time": "00:42", 125 | "cost": 0, 126 | "distance": 3.52, 127 | "price": 0 128 | }, 129 | { 130 | "start": "机场北(2号航站楼)-地铁站", 131 | "end": "广州北站-地铁站", 132 | "mode": "metro", 133 | "start_time": "00:42", 134 | "end_time": "01:03", 135 | "cost": 4, 136 | "distance": 10.86, 137 | "price": 4, 138 | "tickets": 1 139 | }, 140 | { 141 | "start": "广州北站-地铁站", 142 | "end": "广州北站", 143 | "mode": "walk", 144 | "start_time": "01:03", 145 | "end_time": "01:06", 146 | "cost": 0, 147 | "distance": 0.25, 148 | "price": 0 149 | } 150 | ], 151 | "TrainID": "K9259", 152 | "type": "train" 153 | } 154 | ] 155 | } 156 | ], 157 | "search_time_sec": 66.41920399665833, 158 | "llm_inference_time_sec": 45.75778794288635 159 | } -------------------------------------------------------------------------------- /eval_exp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | import sys 8 | import os 9 | import json 10 | 11 | project_root_path = os.path.dirname(os.path.abspath(__file__)) 12 | if project_root_path not in sys.path: sys.path.insert(0, project_root_path) 13 | 14 | 15 | from chinatravel.data.load_datasets import load_query 16 | from chinatravel.evaluation.utils import load_json_file, validate_json 17 | 18 | from chinatravel.evaluation.schema_constraint import evaluate_schema_constraints 19 | from chinatravel.evaluation.commonsense_constraint import evaluate_commonsense_constraints 20 | from chinatravel.evaluation.hard_constraint import evaluate_hard_constraints, evaluate_hard_constraints_v2 21 | from chinatravel.evaluation.preference import evaluate_preference, evaluate_preference_v2 22 | 23 | 24 | METHOD_LIST = [ 25 | "example" "act_Deepseek_zeroshot", 26 | "act_GPT4o_zeroshot", 27 | "react_Deepseek_zeroshot", 28 | "react_GPT4o_zeroshot", 29 | "react_GLM4Plus_zeroshot", 30 | "react_Deepseek_oneshot", 31 | "react_GPT4o_oneshot", 32 | "naive_ns_Deepseek", 33 | "naive_ns_GPT4o", 34 | "naive_ns_GLM4Plus", 35 | ] 36 | 37 | 38 | def load_result(args, query_index, verbose=False): 39 | 40 | def load_result_for_method(method): 41 | plans = {} 42 | for query_id in query_index: 43 | result_file = os.path.join( 44 | "results/", method, "{}.json".format(query_id) 45 | ) 46 | 47 | try: 48 | if os.path.exists(result_file): 49 | result = load_json_file(result_file) 50 | plans[query_id] = result 51 | else: 52 | plans[query_id] = {} 53 | except: 54 | plans[query_id] = {} 55 | return plans 56 | 57 | result = {} 58 | if args.method == "all": 59 | method_list = [] 60 | for mi in METHOD_LIST: 61 | if mi != "example": 62 | method_list.append(mi) 63 | else: 64 | method_list = [args.method] 65 | 66 | for method in method_list: 67 | result[method] = load_result_for_method(method) 68 | 69 | if verbose: 70 | print(result) 71 | 72 | return method_list, result 73 | 74 | if __name__ == "__main__": 75 | 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--splits", "-s", type=str, default="example") 78 | parser.add_argument( 79 | "--method", "-m", type=str, default="example" 80 | ) # , choices=METHOD_LIST) 81 | parser.add_argument("--preference", "-p", action="store_true", default=False) 82 | args = parser.parse_args() 83 | 84 | # print(args.splits) 85 | 86 | query_index, query_data = load_query(args) 87 | method_list, result_data = load_result(args, query_index) 88 | 89 | # print(result_data) 90 | 91 | 92 | 93 | schema_file_path = 'chinatravel/evaluation/output_schema.json' 94 | schema = load_json_file(schema_file_path) 95 | 96 | 97 | if not os.path.exists("eval_res/"): 98 | os.makedirs("eval_res/") 99 | if not os.path.exists("eval_res/splits_{}/".format(args.splits)): 100 | os.makedirs("eval_res/splits_{}/".format(args.splits)) 101 | 102 | 103 | 104 | for method in method_list: 105 | 106 | print("method: ", method) 107 | 108 | plan_count = 0 109 | for plan in result_data[method]: 110 | if plan != {}: 111 | plan_count += 1 112 | print("There are {} results...".format(plan_count)) 113 | 114 | 115 | print("Method: {}".format(method)) 116 | 117 | if not os.path.exists("eval_res/splits_{}/{}/".format(args.splits, method)): 118 | os.makedirs("eval_res/splits_{}/{}/".format(args.splits, method)) 119 | 120 | schema_rate, schema_result_agg, schema_pass_id = evaluate_schema_constraints( 121 | query_index, result_data[method], schema=schema 122 | ) 123 | res_file = "eval_res/splits_{}/{}/schema.csv".format(args.splits, method) 124 | schema_result_agg.to_csv(res_file, index=False) 125 | print("save to {}".format(res_file)) 126 | print("Schema Pass Rate:", schema_rate) 127 | 128 | macro_comm, micro_comm, common_result_agg, commonsense_pass_id = evaluate_commonsense_constraints( 129 | query_index, query_data, result_data[method], verbose=False 130 | ) 131 | 132 | res_file = "eval_res/splits_{}/{}/commonsense.csv".format(args.splits, method) 133 | common_result_agg.to_csv(res_file, index=False) 134 | print("save to {}".format(res_file)) 135 | 136 | print("Commonsense constraints:") 137 | print("micro accuracy: {}".format(micro_comm)) 138 | print("macro accuracy: {}".format(macro_comm)) 139 | 140 | 141 | # print("Logical constraints (flat version):") 142 | # macro_logi, micro_logi, logi_result_agg, logi_pass_id_flat = evaluate_hard_constraints( 143 | # query_index, query_data, result_data[method], verbose=False 144 | # ) 145 | 146 | # print("micro accuracy: {}".format(micro_logi)) 147 | # print("macro accuracy: {}".format(macro_logi)) 148 | 149 | # res_file = "eval_res/splits_{}/{}/logical.csv".format(args.splits, method) 150 | # logi_result_agg.to_csv(res_file, index=False) 151 | # print("save to {}".format(res_file)) 152 | 153 | print("Logical constraints (python version):") 154 | macro_logi, micro_logi, conditional_macro_logi, conditional_micro_logi, logi_result_agg, logi_pass_id = evaluate_hard_constraints_v2( 155 | query_index, query_data, result_data[method], env_pass_id=commonsense_pass_id, verbose=False 156 | ) 157 | 158 | 159 | print("micro accuracy: {}".format(micro_logi)) 160 | print("macro accuracy: {}".format(macro_logi)) 161 | 162 | print("conditional micro accuracy: {}".format(conditional_micro_logi)) 163 | print("conditional macro accuracy: {}".format(conditional_macro_logi)) 164 | 165 | 166 | print("Conditional LPR: {}".format(conditional_micro_logi)) 167 | 168 | res_file = "eval_res/splits_{}/{}/logical_py.csv".format(args.splits, method) 169 | logi_result_agg.to_csv(res_file, index=False) 170 | print("save to {}".format(res_file)) 171 | 172 | # record the index of the queries that pass the logical constraints 173 | logical_pass_info = logi_result_agg.iloc[:, 1:] 174 | id_list = logi_result_agg.iloc[:, 0].tolist() 175 | 176 | all_pass_id = list(set(schema_pass_id) & set(commonsense_pass_id) & set(logi_pass_id)) 177 | 178 | 179 | 180 | print("All pass ratio: ", 1. * len(all_pass_id) / len(query_index) * 100) 181 | 182 | if args.preference: 183 | print("Preference:") 184 | result_agg = evaluate_preference_v2( 185 | query_index, 186 | query_data, 187 | result_data[method], 188 | list(set(commonsense_pass_id) & set(logi_pass_id)), 189 | ) 190 | 191 | res_file = "eval_res/splits_{}/{}/preference.csv".format( 192 | args.splits, method 193 | ) 194 | result_agg.to_csv(res_file, index=False) 195 | print("save to {}".format(res_file)) 196 | -------------------------------------------------------------------------------- /chinatravel/agent/pure_neuro_agent/pure_neuro_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | project_root_path = os.path.dirname( 5 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | ) 7 | if project_root_path not in sys.path: 8 | sys.path.append(project_root_path) 9 | if os.path.dirname(project_root_path) not in sys.path: 10 | sys.path.append(os.path.dirname(project_root_path)) 11 | 12 | from agent.base import AbstractAgent, AgentReturnInfo 13 | from agent.pure_neuro_agent.prompts import DIRECT_PROMPT 14 | 15 | 16 | class Notebook: 17 | def __init__(self): 18 | self.note = "" 19 | 20 | def write(self, description: str, content: str): 21 | self.note += description.strip() + "\n" 22 | self.note += content.strip() + "\n" 23 | return "NoteBook updated." 24 | 25 | def read(self): 26 | return self.note 27 | 28 | def reset(self): 29 | self.note = "" 30 | 31 | 32 | class ActAgent(AbstractAgent): 33 | def __init__( 34 | self, 35 | env, 36 | backbone_llm, 37 | prompt, 38 | max_steps=50, 39 | plan_prompt=DIRECT_PROMPT, 40 | debug=True, 41 | ): 42 | super().__init__(env) 43 | 44 | self.backbone_llm = backbone_llm 45 | self.max_steps = max_steps 46 | self.debug = debug 47 | self.prompt = prompt 48 | self.plan_prompt = plan_prompt 49 | 50 | self.json_scratchpad = [] 51 | self.cur_step = 0 52 | self.finished = False 53 | self.notebook = Notebook() 54 | 55 | self.next_page_cnt = 0 56 | self.notedown_cnt = 0 57 | 58 | def reset(self): 59 | self._log = [] 60 | self._ans = "" 61 | self.cur_step = 0 62 | self.finished = False 63 | self.json_scratchpad = [] 64 | self.notebook.reset() 65 | self.next_page_cnt = 0 66 | self.notedown_cnt = 0 67 | self.backbone_llm.input_token_count = 0 68 | self.backbone_llm.output_token_count = 0 69 | self.backbone_llm.input_token_maxx = 0 70 | 71 | def run(self, query): 72 | self.reset() 73 | query = self.prompt + query 74 | self.json_scratchpad.append({"role": "user", "content": query}) 75 | self._log.append({"Query": query}) 76 | 77 | while self.cur_step < self.max_steps and not self.finished: 78 | self.step() 79 | if self.finished: 80 | return AgentReturnInfo( 81 | ans=self._ans, 82 | log=self._log, 83 | ) 84 | else: 85 | return AgentReturnInfo( 86 | ans='{"error": "Exceed maximum steps"}', 87 | log=self._log, 88 | ) 89 | 90 | def step(self): 91 | self.cur_step += 1 92 | self.act() 93 | 94 | def act(self): 95 | # Act 96 | self.json_scratchpad.append( 97 | {"role": "user", "content": f"Action[{self.cur_step}]:"} 98 | ) 99 | action = self.backbone_llm(self.json_scratchpad, one_line=True) 100 | self.json_scratchpad.append({"role": "assistant", "content": action}) 101 | self._log.append({f"Action[{self.cur_step}]": action}) 102 | if self.debug: 103 | print(f"Action {self.cur_step}: {action}") 104 | 105 | # Observe 106 | self.json_scratchpad.append( 107 | {"role": "user", "content": f"Observation[{self.cur_step}]:"} 108 | ) 109 | observation = "" 110 | if action.startswith("Action"): 111 | action = action.split(":", 1)[1].strip() 112 | action_cmd = action.split("(")[0].strip() 113 | if action_cmd == "notedown": 114 | notedown = self.notebook.write 115 | try: 116 | observation = eval(action) 117 | self.notedown_cnt += 1 118 | if self.notedown_cnt >= 3: 119 | self.notedown_cnt = 1 120 | observation = ( 121 | observation 122 | + "\n" 123 | + "Please note down everything in one time. Anyway, it is noted down." 124 | ) 125 | except Exception as e: 126 | observation = "Error to note down: " + str(e) 127 | elif action_cmd == "plan": 128 | plan = self.plan 129 | self.finished = True 130 | try: 131 | observation = eval(action) 132 | self._ans = observation 133 | except Exception as e: 134 | observation = "Error to plan: " + str(e) 135 | else: 136 | observation = str(self.env(action)) 137 | if "next_page" in action: 138 | self.next_page_cnt += 1 139 | if self.next_page_cnt >= 3: 140 | self.next_page_cnt = 1 141 | observation = ( 142 | "Use next_page() too many times. Please ensure your action is reasonable. Only call next_page() when you didn't get the expected results." 143 | + "\n" 144 | + observation 145 | ) 146 | if observation == "No data." and "select" in action: 147 | if "cuisine" in action: 148 | observation = "Maybe you need use restaurants_cuisine(city) to learn the cuisine." 149 | elif "attraction" in action and "type" in action: 150 | observation = "Maybe you need use attractions_types(city) to learn the attraction type." 151 | 152 | self.json_scratchpad.append({"role": "user", "content": observation}) 153 | self._log.append({f"Observation[{self.cur_step}]": observation}) 154 | if self.debug: 155 | print(f"Observation {self.cur_step}: {observation}") 156 | 157 | def plan(self, query): 158 | query = self.plan_prompt + self.notebook.read() + query 159 | query = [{"role": "user", "content": query}] 160 | return self.backbone_llm(query, json_mode=True, one_line=False) 161 | 162 | 163 | class ReActAgent(ActAgent): 164 | def think(self): 165 | self.json_scratchpad.append( 166 | {"role": "user", "content": f"Thought[{self.cur_step}]:"} 167 | ) 168 | thought = self.backbone_llm(self.json_scratchpad, one_line=True) 169 | self.json_scratchpad.append({"role": "assistant", "content": thought}) 170 | self._log.append({f"Thought[{self.cur_step}]": thought}) 171 | if self.debug: 172 | print(f"Thought {self.cur_step}: {thought}") 173 | 174 | def step(self): 175 | self.cur_step += 1 176 | self.think() 177 | self.act() 178 | 179 | 180 | if __name__ == "__main__": 181 | from agent.llms import Deepseek, Qwen 182 | from environment.world_env import WorldEnv 183 | from agent.pure_neuro_agent.prompts import ( 184 | ZEROSHOT_ACT_INSTRUCTION, 185 | ONESHOT_REACT_INSTRUCTION, 186 | ) 187 | 188 | print(os.environ.get("OPENAI_API_KEY")) 189 | 190 | # deepseek = Deepseek() 191 | qwen = Qwen() 192 | env = WorldEnv() 193 | # agent = ActAgent(env, deepseek, ZEROSHOT_ACT_INSTRUCTION, debug=True) 194 | agent = ReActAgent(env, qwen, ONESHOT_REACT_INSTRUCTION, debug=True) 195 | query = "当前位置上海。我一个人想去杭州玩1天,预算3000人民币,请给我一个旅行规划。" 196 | results = agent(query) 197 | -------------------------------------------------------------------------------- /chinatravel/agent/nesy_verifier/prompts/poi_selection.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | HOTEL_SELECTION_INSTRUCTION = """ 4 | You are a travel planning assistant. 5 | Your task is to rank all available hotel options based on the user's needs and the provided hotel information. Consider the following factors: 6 | 1. User preferences (e.g., comfort, cost, location). 7 | 2. Hotel features. 8 | 3. Room price per night. 9 | 4. Number of beds per room (numbed=2 for double beds, numbed=1 for single beds). 10 | 5. Proximity to key attractions or points of interest in the target city. 11 | 6. In order to meet the user's needs, provide diverse options from different perspectives such as geographical location, room characteristics, and price. 12 | 13 | Additionally, keep in mind that the user's budget is allocated across multiple expenses, including intercity transportation and daily meals. Ensure that the hotel recommendations fit within the remaining budget constraints after accounting for these costs. Note that the price provided for each hotel is the cost per night per room. If the user has provided a specific budget requirement, ensure that the total cost of the hotel stay, including intercity transportation and daily meals, does not exceed this budget. Leave sufficient space in the budget for daily meals and other travel expenses. 14 | 15 | Please provide a selected list of {required_options} hotel options based on the user's preferences. For each hotel, include the name. 16 | 17 | ***** Example ***** 18 | The user's requirement are: 当前位置上海。我和女朋友打算去苏州玩两天,预算1300元,希望酒店每晚不超过500元。请给我一个旅行规划。 19 | Selected Hotels: ["苏州东太湖智选假日酒店", "嘉宝酒店(苏州观前街十全街店)", "昆山城东开发区亚朵酒店", "苏州吴江宾馆", "桔子酒店(苏州园区奥体中心店)", "漫际酒店(苏州龙湖东吴天街石湖东路地铁站店)", "苏州泊Hotel", "清能宜尚PLUS酒店(苏州吴江东太湖万宝商业广场店)", "维也纳国际酒店(苏州火车站北广场店)", "锦江之星风尚(苏州园区独墅湖高教区店)"] 20 | 21 | ***** Example Ends ***** 22 | Given information: 23 | {hotel_info} 24 | 25 | The user's requirements are: {user_requirements}. 26 | Selected Hotels (please only output the LIST of HOTEL NAME without explanatory information):""" 27 | 28 | ATTRACTION_SELECTION_INSTRUCTION = """ 29 | You are a travel planning assistant. 30 | Your task is to select and rank attractions based on the user's needs and the provided attraction information. Consider the following factors: 31 | 1. Attraction name 32 | 2. Attraction type 33 | 3. Location 34 | 4. Recommended duration 35 | 36 | Additionally, keep in mind that the user's budget is allocated across multiple expenses, including intercity transportation and hotel accommodations. Ensure that the attraction recommendations fit within the remaining budget constraints after accounting for the past cost. 37 | To ensure a comprehensive list, consider a larger pool of candidates and prioritize diversity in attraction type and location. 38 | Please provide a selected list of {required_options} attraction options based on the user's preferences. For each attraction, include the name. 39 | 40 | ***** Example ***** 41 | The user's requirement are: 当前位置上海。我和女朋友打算去苏州玩两天,预算1300元,希望酒店每晚不超过500元。请给我一个旅行规划。 42 | Selected Attractions: ["拙政园", "寒山寺", "枫桥景区", "狮子林", "虎丘山风景名胜区", "留园", "苏州博物馆", "同里古镇", "山塘街", "平江路历史街区", "金鸡湖", "网师园", "沧浪亭", "木渎古镇", "姑苏水上游", "苏州古运河游船(山塘街白居易码头)", "耦园", "西园寺", "甪直古镇", "天平山", "盘门", "锦溪古镇", "苏州博物馆西馆", "同里国家湿地公园", "观前街", "沙家浜风景区", "苏州古运河", "怡园", "金鸡湖游船", "艺圃", "可园", "苏州太湖国家湿地公园", "诚品书店(诚品生活苏州店)", "苏州湾梦幻水世界", "东太湖生态旅游度假区", "尚湖风景区", "虞山文化旅游度假区", "苏州乐园森林水世界", "上方山森林动物世界", "沙溪古镇", "琵琶语评弹茶馆(苏州平江路店)", "千灯古镇", "双桥", "虞山景区", "太湖", "苏州古运河游船(新市桥码头)", "重元寺", "黎里古镇", "香山景区", "林屋洞"] 43 | 44 | ***** Example Ends ***** 45 | Given information: 46 | {attraction_info} 47 | 48 | The user's requirements are: {user_requirements}. 49 | Selected Attractions (please only output the LIST of ATTRACTION NAME without explanatory information):""" 50 | 51 | RESTAURANT_SELECTION_INSTRUCTION = """ 52 | You are a travel planning assistant. 53 | Your task is to select and rank restaurants based on the user's needs and the provided restaurant information. Consider the following factors: 54 | 1. Restaurant name 55 | 2. Cuisine type 56 | 3. Price range 57 | 4. Recommended food 58 | 59 | Additionally, keep in mind that the user's budget needs to cover various expenses throughout the trip, including intercity transportation, hotel accommodations, attractions, and costs related to visiting restaurants. 60 | As the user has three meals each day, when recommending restaurants, ensure that the total cost of meals, along with other expenses, strictly stays within the budget. Note that the price range provided for each restaurant is the average cost per person per meal. 61 | To ensure a comprehensive list, consider a larger pool of candidates and prioritize diversity in restaurant type and location. 62 | Please provide a selected list of {required_options} restaurant options based on the user's preferences. For each restaurant, include the name. 63 | 64 | ***** Example ***** 65 | The user's requirement are: 当前位置上海。我和女朋友打算去苏州玩两天,预算1300元,希望酒店每晚不超过500元。请给我一个旅行规划。 66 | Selected Restaurants: ["得月楼(观前店)", "哑巴生煎(临顿路店)", "裕面堂·精品苏式面馆(石路店)", "孙盛兴奥灶面馆(山塘街店)", "吴记小园楼(西北街店)", "珍珠饭店", "鹤园苏帮菜(平江路店)", "平江桃花源记(平江1店)", "鑫震源·苏式大虾生煎(山塘街店)", "同得兴(十全街店)", "朱新年点心店", "乐惠馄饨店(吴趋坊店)", "明月楼·糕团店(三元坊店)"] 67 | ***** Example Ends ***** 68 | Given information: 69 | {restaurant_info} 70 | 71 | The user's requirements are: {user_requirements}. 72 | Selected Restaurants (please only output the LIST of RESTAURANT NAME without explanatory information):""" 73 | 74 | 75 | TRANSPORT_GO_SELECTION_INSTRUCTION = """ 76 | You are a travel planning assistant. 77 | Now let's plan the journey from the origin city to the destination city. 78 | Your task is to rank all available intercity transport options based on the user's needs and the provided transport information. Consider the following factors: 79 | 1. User preferences (e.g., type, comfort, cost, speed). 80 | 2. Availability and reliability of the transport options. 81 | 82 | Please provide a selected list of {required_options} transport options based on the user's preferences. 83 | 84 | For each train transport, include the TrainID. 85 | For each flight transport, include the FlightID. 86 | 87 | ***** Example ***** 88 | The user's requirement are: 当前位置北京。我和女朋友打算去上海玩两天,预算5300元,希望酒店每晚不超过500元。请给我一个旅行规划。 89 | Selected Transports from 北京 to 上海 : ["G159", "FL81", "FL083", "FL090", "FL082", "D7", "G101", "G103", "G115", "Z281"] 90 | ***** Example Ends ***** 91 | Given information: 92 | 93 | Train Information 94 | {train_info} 95 | 96 | Flight Information 97 | {flight_info} 98 | 99 | The user's requirement are: {user_requirements} 100 | Selected Transports from {origin} to {destination} : (please only output the LIST of TRANSPORT ID without explanatory information):""" 101 | 102 | 103 | 104 | TRANSPORT_BACK_SELECTION_INSTRUCTION = """ 105 | You are a travel planning assistant. 106 | Now let's plan the journey from the destination city to the origin city. 107 | Your task is to rank all available intercity transport options based on the user's needs and the provided transport information. Consider the following factors: 108 | 1. User preferences (e.g., type, comfort, cost, speed). 109 | 2. Availability and reliability of the transport options. 110 | 3. When users don't specify particular requirements, prioritize transport options with later departure times to allow for more travel time. 111 | 112 | Please provide a selected list of {required_options} transport options based on the user's preferences. 113 | 114 | For each train transport, include the TrainID. 115 | For each flight transport, include the FlightID. 116 | 117 | ***** Example ***** 118 | The user's requirement are: 当前位置北京。我和女朋友打算去上海玩两天,预算5300元,希望酒店每晚不超过500元。请给我一个旅行规划。 119 | Selected Transports from 上海 to 北京 : ["G2", "FL009", "FL007", "FL001", "FL004", "G8", "G18", "G26", "G6", "D6"] 120 | ***** Example Ends ***** 121 | Given information: 122 | 123 | Train Information 124 | {train_info} 125 | 126 | Flight Information 127 | {flight_info} 128 | 129 | The user's requirement are: {user_requirements} 130 | Selected Transports from {origin} to {destination} : (please only output the LIST of TRANSPORT ID without explanatory information):""" 131 | 132 | -------------------------------------------------------------------------------- /chinatravel/evaluation/default_splits/easy.txt: -------------------------------------------------------------------------------- 1 | e20241028160248698752 2 | e20241028160251109186 3 | e20241028160253742452 4 | e20241028160255878579 5 | e20241028160258515151 6 | e20241028160300570125 7 | e20241028160301887545 8 | e20241028160303262666 9 | e20241028160306588034 10 | e20241028160307743716 11 | e20241028160309165566 12 | e20241028160311039073 13 | e20241028160312706595 14 | e20241028160314145526 15 | e20241028160317745984 16 | e20241028160319013029 17 | e20241028160320866976 18 | e20241028160321868761 19 | e20241028160323008998 20 | e20241028160324401047 21 | e20241028160332826857 22 | e20241028160336568120 23 | e20241028160339524446 24 | e20241028160342578788 25 | e20241028160344635426 26 | e20241028160347041172 27 | e20241028160349812056 28 | e20241028160351641015 29 | e20241028160353644992 30 | e20241028160355814749 31 | e20241028160358265981 32 | e20241028160400903418 33 | e20241028160401902010 34 | e20241028160403455246 35 | e20241028160404792482 36 | e20241028160406080180 37 | e20241028160407201457 38 | e20241028160408419219 39 | e20241028160409480928 40 | e20241028160410629558 41 | e20241028160412220794 42 | e20241028160415503132 43 | e20241028160417133031 44 | e20241028160418473172 45 | e20241028160419949225 46 | e20241028160421661493 47 | e20241028160424099199 48 | e20241028160428605482 49 | e20241028160430652306 50 | e20241028160433311881 51 | e20241028160435135251 52 | e20241028160438025375 53 | e20241028160440274536 54 | e20241028160441438118 55 | e20241028160442853706 56 | e20241028160444030851 57 | e20241028160445320665 58 | e20241028160500856634 59 | e20241028160502862487 60 | e20241028160504916405 61 | m20241028164637595271 62 | m20241028164638852401 63 | m20241028164640040111 64 | m20241028164641561951 65 | m20241028164642633824 66 | m20241028164643850084 67 | m20241028164645261270 68 | m20241028164646493412 69 | m20241028164647615645 70 | m20241028164648944630 71 | m20241028164650234586 72 | m20241028164651295246 73 | m20241028164652533992 74 | m20241028164654021442 75 | m20241028164655120918 76 | m20241028164655927567 77 | m20241028164657140437 78 | m20241028164658228656 79 | m20241028164727691681 80 | m20241028164728831122 81 | m20241028164731140489 82 | m20241028164732663326 83 | m20241028164734359829 84 | m20241028164736263410 85 | m20241028164737668149 86 | m20241028164738811222 87 | m20241028164739786600 88 | m20241028164741002971 89 | m20241028164742823033 90 | m20241028164743925079 91 | m20241028164905266860 92 | m20241028164906595512 93 | m20241028164908007773 94 | m20241028164909526065 95 | m20241028164911083678 96 | m20241028164912237931 97 | m20241028164913584976 98 | m20241028164914972867 99 | m20241028164916356794 100 | m20241028164917698287 101 | m20241028164918990517 102 | m20241028164920226613 103 | m20241028164921505540 104 | m20241028164922818260 105 | m20241028164924049935 106 | m20241028164925825092 107 | m20241028164927523772 108 | m20241028164928339245 109 | m20241028164929631601 110 | m20241028164931151225 111 | m20241028164932051734 112 | m20241028164933359426 113 | m20241028164935359794 114 | m20241028164936662445 115 | m20241028164938643608 116 | m20241028164940013894 117 | m20241028164941330266 118 | m20241028164942477698 119 | m20241028164944055132 120 | m20241028164945999336 121 | h20241204223651527057 122 | m20241028164659379868 123 | m20241028164700651701 124 | m20241028164702515882 125 | m20241028164704059677 126 | m20241028164705832439 127 | m20241028164707359057 128 | m20241028164708678530 129 | m20241028164710182329 130 | m20241028164711515367 131 | m20241028164712892867 132 | m20241028164714075557 133 | m20241028164715213722 134 | m20241028164716406990 135 | m20241028164717588104 136 | m20241028164718765036 137 | m20241028164721740054 138 | m20241028164723199132 139 | m20241028164724408368 140 | m20241028164726492363 141 | m20241028164806059008 142 | m20241028164807696509 143 | m20241028164809233884 144 | m20241028164810563611 145 | m20241028164812917723 146 | m20241028164814314214 147 | m20241028164815894420 148 | m20241028164819146853 149 | m20241028164820407247 150 | m20241028164821619599 151 | m20241028164947135130 152 | m20241028164948918926 153 | m20241028164949891147 154 | m20241028164951181054 155 | m20241028164953193047 156 | m20241028164954033603 157 | m20241028164954860610 158 | m20241028164956243805 159 | m20241028164957253403 160 | m20241028164958125795 161 | m20241028164958963982 162 | m20241028164959927007 163 | m20241028165000952381 164 | m20241028165003809657 165 | m20241028165005397280 166 | m20241028165006698484 167 | m20241028165007771022 168 | m20241028165009070086 169 | m20241028165010342686 170 | m20241028165011582732 171 | m20241028165012424316 172 | m20241028165013851142 173 | m20241028165014731116 174 | m20241028165016137718 175 | m20241028165017072372 176 | m20241028165017924681 177 | m20241028165022784704 178 | m20241028165023656799 179 | m20241028165024514344 180 | m20241028165027206780 181 | m20241028164620332052 182 | m20241028164622142080 183 | m20241028164623257189 184 | m20241028164624886634 185 | m20241028164626006100 186 | m20241028164627218521 187 | m20241028164628014228 188 | m20241028164628833631 189 | m20241028164630109683 190 | m20241028164631458440 191 | m20241028164632364032 192 | m20241028164633505983 193 | m20241028164634741822 194 | m20241028164636129853 195 | m20241028164745025235 196 | m20241028164746091821 197 | m20241028164747208761 198 | m20241028164748425604 199 | m20241028164750006178 200 | m20241028164751582756 201 | m20241028164752783728 202 | m20241028164754037833 203 | m20241028164755130582 204 | m20241028164756492413 205 | m20241028164758041148 206 | m20241028164759233559 207 | m20241028164800617250 208 | m20241028164802046283 209 | m20241028164803818799 210 | m20241028164804685984 211 | m20241028164822843097 212 | m20241028164824112842 213 | m20241028164825618732 214 | m20241028164827054321 215 | m20241028164828899127 216 | m20241028164830108978 217 | m20241028164831226675 218 | m20241028164833637716 219 | m20241028164835034738 220 | m20241028164836175101 221 | m20241028164837927103 222 | m20241028164839232455 223 | m20241028164840787934 224 | m20241028164841864509 225 | m20241028164843213370 226 | m20241028164846117340 227 | m20241028164847353449 228 | m20241028164848577483 229 | m20241028164849883248 230 | m20241028164851214825 231 | m20241028164852256268 232 | m20241028164853608114 233 | m20241028164854688667 234 | m20241028164856128486 235 | m20241028164857685556 236 | m20241028164859295493 237 | m20241028164900386638 238 | m20241028164901591817 239 | m20241028164903025025 240 | m20241028164904164475 241 | t20241206164116321945 242 | t20241206171736098796 243 | t20241206171751946181 244 | t20241206171807314229 245 | t20241206171819321990 246 | t20241206171834406054 247 | t20241206171855902412 248 | t20241206171912934521 249 | t20241206171925565020 250 | t20241206171940092767 251 | t20241206171952305386 252 | t20241206172005248006 253 | t20241206172016879934 254 | t20241206172028371509 255 | t20241206172139971351 256 | t20241206172253793024 257 | t20241206172345246461 258 | t20241206172420699802 259 | t20241206172609101147 260 | t20241206172708164570 261 | t20241206211714064387 262 | t20241206211714064535 263 | t20241206211714064628 264 | t20241206211714064711 265 | t20241206211714064780 266 | t20241206211714064851 267 | t20241206211714064921 268 | t20241206211714064986 269 | t20241206211714065049 270 | t20241206211714065114 271 | t20241206211714065174 272 | t20241206211714065238 273 | t20241206211714065305 274 | t20241206211714065368 275 | t20241206211714065430 276 | t20241206211714065493 277 | t20241206211714065556 278 | t20241206211714065615 279 | t20241206211714065676 280 | t20241206211714065799 281 | t20241206211714065864 282 | t20241206211714065925 283 | t20241206211714065984 284 | t20241206211714066044 285 | t20241206211714066103 286 | t20241206211714066164 287 | t20241206211714066235 288 | t20241206211714066294 289 | t20241206211714066351 290 | t20241206211714066408 291 | t20241206211714066466 292 | t20241206211714066523 293 | t20241206211714066581 294 | t20241206211714066637 295 | t20241206211714066694 296 | t20241206211714066755 297 | t20241206211714066811 298 | t20241206211714066866 299 | t20241206211714066934 300 | t20241206211714066996 301 | -------------------------------------------------------------------------------- /chinatravel/symbol_verification/readme.md: -------------------------------------------------------------------------------- 1 | # Evaluation Metrics 2 | 3 | ## Commonsense Constraints 4 | 5 | | | Environment Constraints | Semantics | 6 | | ------------------------- | ----------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 7 | | Cross-city Transportation | Intercity transportation events must occur. | The first event and last event must be cross-city transports. | 8 | | | Available Trains or Airplanes across cities. | The provied TrainID/FlightID, origin and destination shoud be valid in the travel sandbox. | 9 | | | Correct information of price, duration. | The price and duration information should match the travel sandbox. | 10 | | | Detailed cost on inter-city transportation | The travel plan should provide the number of tickets and cost of each inter-city-transportation activity.$cost=price*tickets$ | 11 | | Inner-city Transportation | Available Metro, Taxi or Walking between different positions. | The provided routes should be valid in the travel sandbox. | 12 | | | Correct information of price, distance, and duration. | The price, distance and duration information should match the travel sandbox. | 13 | | | Detailed cost on inner-city transportation | The travel plan should provide the number of tickets/cars and cost of each inter-city-transportation activity. The taxi capacity is 4 people per car.$cost=price*tickets$, $cost=price*cars$ | 14 | | Attractions | Available attractions in the target city | The provided attractions should be valid in the travel sandbox. | 15 | | | Visiting the attractions in their openting time. | Visiting the attractions in their openting time. | 16 | | | Correct information of price. | The price information should match the travel sandbox. | 17 | | | Detailed cost of attraction activity. | The travel plan should provide the number of tickets and cost of each attraction activity.$cost=price*tickets$ | 18 | | | Attraction choices should not be repeated throughout the trip. | Attraction choices should not be repeated throughout the trip. | 19 | | Restaurants | Available restaurants in the target city | The provided restaurants should be valid in the travel sandbox. | 20 | | | Visiting the restaurants in their openting time. | Visiting the restaurants in their openting time. | 21 | | | Correct information of price. | The price information should match the travel sandbox. | 22 | | | Detailed cost of restaurant activity. | The travel plan should provide the number of tickets and cost of each restaurant activity.$cost=price*tickets$ | 23 | | | Restaurant choices should not be repeated throughout the trip. | Restaurant choices should not be repeated throughout the trip. | 24 | | | Breakfast, lunch, and dinner are served at their designated meal times. | Breakfast starts no later than 9:00 and ends no earlier than 06:00. Lunch starts no later than 14:00 and ends no earlier than 11:00. Dinner starts no later than 17:00 and ends no earlier than 20:00. | 25 | | Accommodation | Available accommodation in the target city. | The provided accommodations should be valid in the travel sandbox. | 26 | | | Correct room information of price and type. | The price and room type information should match the travel sandbox. | 27 | | | Detailed cost of accommodation activity. | The travel plan should provide the number of rooms and cost of each accommodation activity.$cost=price*rooms$ | 28 | | | Accomondation is necessary for trips longer than one day | We need a hotel for a trip more than one day. | 29 | | Time | Detailed duration information of each activity. | The travel plan should provide the starting time and ending time of each activity. The ending time must be later than the starting time. | 30 | | | The given activity events occur in chronological order. | Travel Plan lists activities in chronological order. The starting time of each activity should be no earlier than the ending time of the transportation that reaches its position. | 31 | | Space | Events at different positions should provide transport information. | If the current position is different from the previous event position, the transportation route from the previous event position to the current position must be given. | 32 | 33 | ## Personal Constraints 34 | -------------------------------------------------------------------------------- /eval_tpc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | import sys 8 | import os 9 | import json 10 | 11 | project_root_path = os.path.dirname(os.path.abspath(__file__)) 12 | if project_root_path not in sys.path: sys.path.insert(0, project_root_path) 13 | 14 | 15 | from chinatravel.data.load_datasets import load_query 16 | from chinatravel.evaluation.utils import load_json_file, validate_json 17 | 18 | from chinatravel.evaluation.schema_constraint import evaluate_schema_constraints 19 | from chinatravel.evaluation.commonsense_constraint import evaluate_commonsense_constraints 20 | from chinatravel.evaluation.hard_constraint import evaluate_hard_constraints, evaluate_hard_constraints_v2 21 | from chinatravel.evaluation.preference import evaluate_preference, evaluate_preference_v2 22 | 23 | 24 | 25 | 26 | 27 | DEFAULT_ATTRACTION_PR=""" 28 | attraction_count = 0 29 | for activity in allactivities(plan): 30 | if activity_type(activity) == 'attraction': 31 | attraction_count += 1 32 | result=attraction_count/(4*day_count(plan)) 33 | """ 34 | 35 | 36 | DEFAULT_TRANS_PR=""" 37 | time_cost = 0 38 | transport_count = 0 39 | for activity in allactivities(plan): 40 | transports = activity_transports(activity) 41 | if transports!=[]: 42 | transport_count += 1 43 | time_cost += innercity_transport_time(transports) 44 | average_time_cost = time_cost / transport_count if transport_count > 0 else -1 45 | result= (-1/105) * average_time_cost + 8/7 46 | """ 47 | DEFAULT_RES_PR=""" 48 | res_count=0 49 | for activity in allactivities(plan): 50 | if activity_type(activity) in ['breakfast', 'lunch', 'dinner']: 51 | res_count+=1 52 | res_count=res_count/(day_count(plan)) 53 | result=res_count/3 54 | """ 55 | DEFAULT_PR=[ 56 | DEFAULT_ATTRACTION_PR, 57 | DEFAULT_TRANS_PR, 58 | DEFAULT_RES_PR 59 | ] 60 | 61 | METHOD_LIST = [ 62 | ] 63 | 64 | from tqdm import tqdm 65 | from chinatravel.symbol_verification.concept_func import func_dict 66 | from copy import deepcopy 67 | 68 | def cal_default_pr_score(query_index, query_data, result_data,all_pass_id): 69 | all_score=[] 70 | def clamp(value): 71 | return max(0.0, min(1.0, value)) 72 | 73 | for ii, idx in enumerate(tqdm(query_index)): 74 | symbolic_input, plan = query_data[idx], result_data[idx] 75 | results = [] 76 | if idx not in all_pass_id: 77 | results=np.zeros(len(DEFAULT_PR)) 78 | continue 79 | for constraint in DEFAULT_PR: 80 | vars_dict = deepcopy(func_dict) 81 | vars_dict["plan"] = plan 82 | 83 | # exec(constraint, {"__builtins__": {"set": set, "print": print}}, vars_dict) 84 | # results.append(vars_dict.get("result", False)) 85 | try: 86 | # Evaluate the constraint in a safe manner 87 | exec( 88 | constraint, 89 | { 90 | "__builtins__": { 91 | "set": set, 92 | } 93 | }, 94 | vars_dict, 95 | ) 96 | res_i = vars_dict.get("result", False) 97 | # print("result: ", res_i) 98 | # print(type(res_i)) 99 | results.append(clamp(res_i)) 100 | except Exception as e: 101 | results.append(0.) 102 | all_score.append(np.array(results)) 103 | if len(all_score)==0: 104 | return np.zeros(len(DEFAULT_PR)) 105 | print(np.mean(all_score,axis=0)) 106 | return np.mean(all_score,axis=0) 107 | 108 | 109 | 110 | def load_result(args, query_index,path, verbose=False): 111 | 112 | def load_result_for_method(path): 113 | plans = {} 114 | for query_id in query_index: 115 | result_file = os.path.join( 116 | path, "{}.json".format(query_id) 117 | ) 118 | 119 | try: 120 | if os.path.exists(result_file): 121 | result = load_json_file(result_file) 122 | plans[query_id] = result 123 | else: 124 | plans[query_id] = {} 125 | except: 126 | plans[query_id] = {} 127 | return plans 128 | 129 | result = {} 130 | 131 | result['default'] = load_result_for_method(path) 132 | 133 | if verbose: 134 | print(result) 135 | 136 | return ['default'], result 137 | 138 | def write_file(file, content): 139 | """ Write content in file. 140 | """ 141 | with open(file, 'a', encoding="utf-8") as f: 142 | f.write(content) 143 | 144 | 145 | if __name__ == "__main__": 146 | 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--splits", "-s", type=str, default="example") 149 | parser.add_argument( 150 | "--method", "-m", type=str, default="travel_agent" 151 | ) # , choices=METHOD_LIST) 152 | parser.add_argument("--preference", "-p", action="store_true", default=False) 153 | args = parser.parse_args() 154 | 155 | # print(args.splits) 156 | 157 | 158 | query_index, query_data = load_query(args) 159 | 160 | results_dir = os.path.join("results", args.method) 161 | method_list, result_data = load_result(args, query_index, results_dir) 162 | 163 | schema_file_path = 'chinatravel/evaluation/output_schema.json' 164 | schema = load_json_file(schema_file_path) 165 | 166 | 167 | scores = {} 168 | for method in method_list: 169 | 170 | 171 | 172 | 173 | print("Method: {}".format(args.method)) 174 | 175 | if not os.path.exists("eval_res/splits_{}/{}/".format(args.splits, method)): 176 | os.makedirs("eval_res/splits_{}/{}/".format(args.splits, method)) 177 | 178 | schema_rate, schema_result_agg, schema_pass_id = evaluate_schema_constraints( 179 | query_index, result_data[method], schema=schema 180 | ) 181 | # print("Schema Pass Rate:", schema_rate) 182 | 183 | macro_comm, micro_comm, common_result_agg, commonsense_pass_id = evaluate_commonsense_constraints( 184 | query_index, query_data, result_data[method], verbose=False 185 | ) 186 | 187 | # print("Commonsense constraints:") 188 | print("Mic.EPR {}".format(micro_comm)) 189 | scores['MicEPR'] = micro_comm 190 | print("Mac.EPR: {}".format(macro_comm)) 191 | scores['MacEPR'] = macro_comm 192 | 193 | # print("Logical constraints (python version):") 194 | macro_logi, micro_logi, conditional_macro_logi, conditional_micro_logi, logi_result_agg, logi_pass_id = evaluate_hard_constraints_v2( 195 | query_index, query_data, result_data[method], env_pass_id=commonsense_pass_id, verbose=False 196 | ) 197 | 198 | 199 | 200 | 201 | print("C-LPR: {}".format(conditional_micro_logi)) 202 | scores['C-LPR'] = conditional_micro_logi 203 | 204 | # record the index of the queries that pass the logical constraints 205 | logical_pass_info = logi_result_agg.iloc[:, 1:] 206 | id_list = logi_result_agg.iloc[:, 0].tolist() 207 | 208 | all_pass_id = list(set(schema_pass_id) & set(commonsense_pass_id) & set(logi_pass_id)) 209 | 210 | 211 | print("FPR: ", 1. * len(all_pass_id) / len(query_index) * 100) 212 | fpr= 1. * len(all_pass_id) / len(query_index) * 100 213 | scores['FPR'] = fpr 214 | 215 | pre_res=cal_default_pr_score(query_index,query_data,result_data[method],all_pass_id) 216 | scores['DAV']=pre_res[0]*100 217 | scores['ATT']=pre_res[1]*100 218 | scores['DDR']=pre_res[2]*100 219 | 220 | final_score=0.1*micro_comm+0.1*micro_comm+0.25*conditional_micro_logi+0.05*scores['DAV']+0.05*scores['ATT']+0.05*scores['DDR']+0.4*fpr 221 | print('Overall Score: ',final_score) 222 | scores['overall'] = final_score 223 | print(scores) 224 | 225 | score_file = os.path.join('your_tpc_scores.json') 226 | write_file(score_file, json.dumps(scores)) 227 | if args.preference: 228 | print("Preference:") 229 | result_agg = evaluate_preference_v2( 230 | query_index, 231 | query_data, 232 | result_data[method], 233 | list(set(commonsense_pass_id) & set(logi_pass_id)), 234 | ) 235 | 236 | -------------------------------------------------------------------------------- /chinatravel/evaluation/hard_constraint.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import sys 5 | import os 6 | 7 | # from chinatravel.environment.tools.accommodations.apis import Accommodations 8 | # from chinatravel.environment.tools.restaurants.apis import Restaurants 9 | # from chinatravel.environment.tools.attractions.apis import Attractions 10 | # from chinatravel.environment.tools.intercity_transport.apis import IntercityTransport 11 | # from chinatravel.environment.tools.transportation.apis import Transportation 12 | 13 | from chinatravel.evaluation.utils import load_json_file 14 | 15 | from chinatravel.symbol_verification.hard_constraint import get_symbolic_concepts, evaluate_constraints, evaluate_constraints_py 16 | 17 | from tqdm import tqdm 18 | 19 | import pandas as pd 20 | 21 | # accommodation = Accommodations() 22 | # restaurants = Restaurants() 23 | # attractions = Attractions() 24 | 25 | 26 | 27 | 28 | def evaluate_hard_constraints(data_index, symbolic_input_dict, plan_json_dict, verbose=False): 29 | 30 | result_agg = pd.DataFrame(columns=['data_id', 31 | 'Trip_Days', 'Trip_People', 32 | 'Required_InterCity_Transport_Type', 33 | 'Required_Transport_Type', 34 | 'Required_Attraction_Type', 'Required_Attraction', 35 | 'Required_Hotel_Type', 'Required_Hotel', 'Required_Room_Type', 'Required_Room_Count', 36 | 'Required_Restruant_Type', 'Required_Restaurants', 37 | 'Budget']) 38 | result_agg['data_id'] = data_index 39 | for col_i in result_agg.columns[1:]: 40 | result_agg[col_i] = 0 41 | 42 | macro_count, macro_succ_count = 0, 0 43 | micro_count, micro_succ_count = 0, 0 44 | 45 | 46 | results=[] 47 | passed_id = [] 48 | 49 | for ii, idx in enumerate(data_index): 50 | symbolic_input, plan_json = symbolic_input_dict[idx], plan_json_dict[idx] 51 | 52 | extracted_vars=get_symbolic_concepts(symbolic_input, plan_json, need_ood=False) 53 | 54 | 55 | 56 | result_ii = evaluate_constraints(extracted_vars, symbolic_input["hard_logic"]) 57 | 58 | 59 | if verbose: 60 | print("symoblic concepts: ", extracted_vars) 61 | print(symbolic_input["hard_logic"]) 62 | print(result_ii) 63 | results.append(result_ii) 64 | 65 | dict_ii = {} 66 | 67 | for j, logical_i in enumerate(symbolic_input["hard_logic"]): 68 | 69 | if "days" in logical_i: 70 | col_name = "Trip_Days" 71 | elif "people_number" in logical_i: 72 | col_name = "People_Number" 73 | elif "tickets" in logical_i: 74 | col_name = "People_Number" 75 | elif "intercity_transport" in logical_i: 76 | col_name = "Required_InterCity_Transport_Type" 77 | elif "transport_type" in logical_i: 78 | col_name = "Transport_Type" 79 | elif "spot_type" in logical_i: 80 | col_name = "Required_Attraction_Type" 81 | elif "attraction_names" in logical_i: 82 | col_name = "Required_Attraction" 83 | elif "hotel_feature" in logical_i: 84 | col_name = "Required_Hotel_Type" 85 | elif "hotel_names" in logical_i: 86 | col_name = "Required_Hotel" 87 | elif "room_type" in logical_i: 88 | col_name = "Required_Room_Type" 89 | elif "rooms" in logical_i: 90 | col_name = "Required_Room_Count" 91 | elif "food_type" in logical_i: 92 | col_name = "Required_Restruant_Type" 93 | elif "restaurant_names" in logical_i: 94 | col_name = "Required_Restaurants" 95 | 96 | elif "cost" in logical_i: 97 | col_name = "Budget" 98 | elif "price" in logical_i: 99 | col_name = "Budget" 100 | 101 | # result_agg[col_name, ii] = result_ii[j] 102 | if not col_name in dict_ii: 103 | dict_ii[col_name] = int(result_ii[j]) 104 | else: 105 | if result_ii[j] == 0: 106 | dict_ii[col_name] = 0 107 | 108 | result_agg.loc[ii] = pd.Series(dict_ii) 109 | 110 | 111 | succ_c_sum = 0 112 | for col in dict_ii.keys(): 113 | # print(col, dict_ii[col]) 114 | succ_c_sum += dict_ii[col] 115 | 116 | # print(dict_ii) 117 | # print(succ_c_sum, len(dict_ii)) 118 | 119 | macro_count += 1 120 | macro_succ_count += (succ_c_sum == len(dict_ii)) 121 | micro_count += len(dict_ii) 122 | micro_succ_count += succ_c_sum 123 | 124 | if succ_c_sum == len(dict_ii): 125 | passed_id.append(idx) 126 | 127 | # macro, micro, _ = calculate_metrics(results) 128 | 129 | macro = macro_succ_count / macro_count 130 | micro = micro_succ_count / micro_count 131 | 132 | return macro*100, micro*100, result_agg, passed_id 133 | 134 | 135 | def evaluate_hard_constraints_v2(data_index, symbolic_input_dict, plan_json_dict, env_pass_id, verbose=False): 136 | 137 | 138 | max_logic_num = 0 139 | for idx in data_index: 140 | max_logic_num = max(max_logic_num, len(symbolic_input_dict[idx]["hard_logic_py"])) 141 | 142 | columns=['data_id'] 143 | for i in range(max_logic_num): 144 | columns.append(f'logic_py_{i}') 145 | result_agg = pd.DataFrame(columns=columns) 146 | # result_agg['data_id'] = data_index 147 | for col_i in result_agg.columns[1:]: 148 | result_agg[col_i] = 0 149 | 150 | macro_count, macro_succ_count = 0, 0 151 | micro_count, micro_succ_count = 0, 0 152 | 153 | conditional_micro_succ_count, conditional_macro_succ_count = 0, 0 154 | 155 | results=[] 156 | passed_id = [] 157 | 158 | for ii, idx in enumerate(tqdm(data_index)): 159 | symbolic_input, plan_json = symbolic_input_dict[idx], plan_json_dict[idx] 160 | result_ii = evaluate_constraints_py(symbolic_input["hard_logic_py"], plan_json, verbose=verbose) 161 | results.append(result_ii) 162 | 163 | # print(symbolic_input) 164 | # print(plan_json) 165 | 166 | if verbose: 167 | for logic_i, res_i in zip(symbolic_input["hard_logic_py"], result_ii): 168 | print(logic_i, "\n", "[", res_i, "]") 169 | 170 | dict_ii = {} 171 | succ_c_sum = 0 172 | for logic_i in range(len(symbolic_input["hard_logic_py"])): 173 | dict_ii[f'logic_py_{logic_i}'] = int(result_ii[logic_i]) 174 | succ_c_sum += int(result_ii[logic_i]) 175 | 176 | 177 | macro_count += 1 178 | macro_succ_count += (succ_c_sum == len(dict_ii)) 179 | micro_count += len(dict_ii) 180 | micro_succ_count += succ_c_sum 181 | 182 | if idx in env_pass_id: 183 | conditional_micro_succ_count += succ_c_sum 184 | conditional_macro_succ_count += (succ_c_sum == len(dict_ii)) 185 | 186 | if succ_c_sum == len(dict_ii): 187 | passed_id.append(idx) 188 | 189 | dict_ii["data_id"] = idx 190 | result_agg.loc[ii] = pd.Series(dict_ii) 191 | 192 | # else: 193 | # print(symbolic_input["hard_logic_py"]) 194 | # print(result_ii) 195 | # print(plan_json) 196 | 197 | macro = macro_succ_count / macro_count 198 | micro = micro_succ_count / micro_count 199 | 200 | c_marco = conditional_macro_succ_count / macro_count 201 | c_micro = conditional_micro_succ_count / micro_count 202 | 203 | print("conditional_micro_succ_count: ", conditional_micro_succ_count) 204 | print("conditional_macro_succ_count: ", conditional_macro_succ_count) 205 | 206 | return macro*100, micro*100, c_marco*100, c_micro*100, result_agg, passed_id 207 | 208 | if __name__ == "__main__": 209 | 210 | symbolic_input_list=[] 211 | plan_json_list=[] 212 | 213 | for i in range(1): 214 | test_plan_path='./example/plan_{}.json'.format(i+1) 215 | test_example_path='./example/query_{}.json'.format(i+1) 216 | test_example=load_json_file(test_example_path) 217 | test_plan=load_json_file(test_plan_path) 218 | symbolic_input_list.append(test_example) 219 | plan_json_list.append(test_plan) 220 | macro_accuracy, micro_accuracy,_=evaluate_hard_constraints(symbolic_input_list,plan_json_list) 221 | print('macro: {}%, micro: {}%'.format(macro_accuracy,micro_accuracy)) -------------------------------------------------------------------------------- /chinatravel/symbol_verification/concept_func.py: -------------------------------------------------------------------------------- 1 | from chinatravel.environment.tools.accommodations.apis import Accommodations 2 | from chinatravel.environment.tools.restaurants.apis import Restaurants 3 | from chinatravel.environment.tools.attractions.apis import Attractions 4 | 5 | 6 | def day_count(plan): 7 | return len(plan["itinerary"]) 8 | 9 | 10 | def people_count(plan): 11 | return plan["people_number"] 12 | 13 | def start_city(plan): 14 | return plan["start_city"] 15 | 16 | def target_city(plan): 17 | return plan["target_city"] 18 | 19 | 20 | def allactivities(plan): 21 | activity_list = [] 22 | for day_activity in plan["itinerary"]: 23 | for act in day_activity["activities"]: 24 | activity_list.append(act) 25 | return activity_list 26 | 27 | 28 | def allactivities_count(plan): 29 | count = 0 30 | for day_activity in plan["itinerary"]: 31 | count += len(day_activity["activities"]) 32 | return count 33 | 34 | 35 | def dayactivities(plan, day): 36 | activity_list = [] 37 | for act in plan["itinerary"][day - 1]["activities"]: 38 | activity_list.append(act) 39 | return activity_list 40 | 41 | 42 | def activity_position(activity): 43 | return activity.get("position", "") 44 | 45 | 46 | def activity_cost(activity): 47 | return activity.get("cost", 0) 48 | 49 | 50 | def activity_price(activity): 51 | return activity.get("price", 0) 52 | 53 | 54 | def activity_type(activity): 55 | return activity.get("type", "") 56 | 57 | 58 | def activity_tickets(activity): 59 | return activity.get("tickets", 0) 60 | 61 | 62 | def activity_transports(activity): 63 | return activity.get("transports", []) 64 | 65 | 66 | def activity_start_time(activity): 67 | return activity.get("start_time") 68 | 69 | 70 | def activity_end_time(activity): 71 | return activity.get("end_time") 72 | 73 | 74 | def activity_time(activity): 75 | 76 | start_time = activity.get("start_time") 77 | end_time = activity.get("end_time") 78 | 79 | if start_time and end_time: 80 | st_h, st_m = map(int, start_time.split(":")) 81 | ed_h, ed_m = map(int, end_time.split(":")) 82 | return (ed_m - st_m) + (ed_h - st_h) * 60 83 | else: 84 | return -1 85 | 86 | 87 | def poi_recommend_time(city, poi): 88 | select = Attractions().select 89 | attrction_info = select(city, key="name", func=lambda x: x == poi).iloc[0] 90 | recommend_time = (attrction_info["recommendmintime"]) * 60 91 | return recommend_time 92 | 93 | 94 | def poi_distance(city, poi1, poi2, start_time="00:00", transport_type="walk"): 95 | from chinatravel.environment.tools.transportation.apis import Transportation 96 | 97 | goto = Transportation().goto 98 | return goto(city, poi1, poi2, start_time, transport_type)[0]["distance"] 99 | 100 | 101 | def innercity_transport_cost(transports, node=None): 102 | """ 103 | 计算市内交通费用 104 | Args: 105 | transports: 交通信息列表 106 | node: 可选的交通类型筛选 ('walk', 'metro', 'taxi' 或 None) 107 | Returns: 108 | float: 选定类型的交通费用总和,如果 node 为 None 则返回所有类型的总和 109 | """ 110 | cost = 0 111 | for transport in transports: 112 | if node is None or transport.get("type") == node: 113 | cost += transport.get("cost", 0) 114 | return cost 115 | 116 | 117 | def innercity_transport_price(transports): 118 | price = 0 119 | for transport in transports: 120 | price += transport["price"] 121 | return price 122 | 123 | 124 | def innercity_transport_distance(transports, mode=None): 125 | """ 126 | 计算市内交通距离 127 | Args: 128 | transports: 交通信息列表 129 | mode: 可选的交通类型筛选 ('walk', 'metro', 'taxi' 或 None) 130 | Returns: 131 | float: 选定类型的交通距离总和,如果 mode 为 None 则返回所有类型的总和 132 | """ 133 | distance = 0 134 | for transport in transports: 135 | if mode is None or transport.get("type") == mode: 136 | distance += transport.get("distance", 0) 137 | return distance 138 | 139 | 140 | def innercity_transport_time(transports, mode=None): 141 | 142 | def calc_time_delta(end_time, start_time): 143 | hour1, minu1 = int(end_time.split(":")[0]), int(end_time.split(":")[1]) 144 | hour2, minu2 = int(start_time.split(":")[0]), int(start_time.split(":")[1]) 145 | 146 | return (hour1 - hour2) * 60 + (minu1 - minu2) 147 | 148 | time_cost = 0 149 | for transport in transports: 150 | time_cost += calc_time_delta(transport["end_time"], transport["start_time"]) 151 | return time_cost 152 | 153 | def metro_tickets(transports): 154 | return transports[1]["tickets"] 155 | 156 | 157 | def taxi_cars(transports): 158 | if len(transports) > 0 and "cars" in transports[0]: 159 | return transports[0]["cars"] 160 | else: 161 | return "invalid input" 162 | 163 | 164 | def room_count(activity): 165 | return activity.get("rooms", 0) 166 | 167 | 168 | def room_type(activity): 169 | return activity.get("room_type", 0) 170 | 171 | 172 | def restaurant_type(activity, target_city): 173 | from chinatravel.environment.tools.restaurants.apis import Restaurants 174 | 175 | restaurants = Restaurants() 176 | select_food_type = restaurants.select( 177 | target_city, key="name", func=lambda x: x == activity["position"] 178 | )["cuisine"] 179 | if not select_food_type.empty: 180 | return select_food_type.iloc[0] 181 | return "empty" 182 | 183 | 184 | def attraction_type(activity, target_city): 185 | from chinatravel.environment.tools.attractions.apis import Attractions 186 | 187 | attractions = Attractions() 188 | select_attr_type = attractions.select( 189 | target_city, key="name", func=lambda x: x == activity["position"] 190 | )["type"] 191 | if not select_attr_type.empty: 192 | return select_attr_type.iloc[0] 193 | return "" 194 | 195 | 196 | def accommodation_type(activity, target_city): 197 | from chinatravel.environment.tools.accommodations.apis import Accommodations 198 | 199 | accommodations = Accommodations() 200 | select_hotel_type = accommodations.select( 201 | target_city, key="name", func=lambda x: x == activity["position"] 202 | )["featurehoteltype"] 203 | if not select_hotel_type.empty: 204 | return select_hotel_type.iloc[0] 205 | return "" 206 | 207 | 208 | def innercity_transport_type(transports): 209 | if len(transports) == 3: 210 | return transports[1]["mode"] 211 | elif len(transports) == 1: 212 | return transports[0]["mode"] 213 | return "empty" 214 | 215 | 216 | def intercity_transport_type(activity): 217 | return activity.get("type", "empty") 218 | 219 | 220 | def innercity_transport_start_time(transports): 221 | return transports[0]["start_time"] 222 | 223 | 224 | def innercity_transport_end_time(transports): 225 | return transports[len(transports) - 1]["end_time"] 226 | 227 | def intercity_transport_origin(activity): 228 | city_list=["上海", "北京", "深圳", "广州", "重庆", "成都", "杭州", "武汉", "南京","苏州"] 229 | if "start" in activity: 230 | for city in city_list: 231 | if city in activity["start"]: 232 | return city 233 | return "" 234 | 235 | def intercity_transport_destination(activity): 236 | city_list=["上海", "北京", "深圳", "广州", "重庆", "成都", "杭州", "武汉", "南京","苏州"] 237 | if "end" in activity: 238 | for city in city_list: 239 | if city in activity["end"]: 240 | return city 241 | return "" 242 | 243 | func_dict = { 244 | "day_count": day_count, 245 | "people_count": people_count, 246 | "start_city": start_city, 247 | "target_city": target_city, 248 | "allactivities": allactivities, 249 | "allactivities_count": allactivities_count, 250 | "activity_position": activity_position, 251 | "activity_cost": activity_cost, 252 | "activity_type": activity_type, 253 | "activity_tickets": activity_tickets, 254 | "activity_transports": activity_transports, 255 | "activity_price": activity_price, 256 | "activity_time": activity_time, 257 | "poi_recommend_time": poi_recommend_time, 258 | "poi_distance": poi_distance, 259 | "metro_tickets": metro_tickets, 260 | "taxi_cars": taxi_cars, 261 | "room_count": room_count, 262 | "room_type": room_type, 263 | "restaurant_type": restaurant_type, 264 | "attraction_type": attraction_type, 265 | "accommodation_type": accommodation_type, 266 | "innercity_transport_type": innercity_transport_type, 267 | "innercity_transport_cost": innercity_transport_cost, 268 | "innercity_transport_price": innercity_transport_price, 269 | "innercity_transport_start_time": innercity_transport_start_time, 270 | "innercity_transport_end_time": innercity_transport_end_time, 271 | "innercity_transport_distance": innercity_transport_distance, 272 | "intercity_transport_type": intercity_transport_type, 273 | "dayactivities": dayactivities, 274 | "activity_start_time": activity_start_time, 275 | "activity_end_time": activity_end_time, 276 | "intercity_transport_origin": intercity_transport_origin, 277 | "intercity_transport_destination": intercity_transport_destination, 278 | "innercity_transport_time": innercity_transport_time, 279 | } 280 | -------------------------------------------------------------------------------- /chinatravel/evaluation/rank.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | project_root_path = os.path.dirname( 6 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | ) 8 | if project_root_path not in sys.path: 9 | sys.path.insert(0, project_root_path) 10 | if os.path.join(project_root_path, "chinatravel") not in sys.path: 11 | sys.path.insert(0, os.path.join(project_root_path, "chinatravel")) 12 | 13 | import json 14 | import pandas as pd 15 | from chinatravel.evaluation.utils import load_json_file 16 | 17 | 18 | class CompareError(Exception): 19 | pass 20 | 21 | 22 | preference_list = [ 23 | "convenient transportation", 24 | "convenient restaurants", 25 | "close to poi", 26 | "less walk", 27 | "more cost on meals", 28 | "more cost on hotel", 29 | "more cost on attractions", 30 | "less cost on meals", 31 | "less cost on hotel", 32 | "less cost on attractions", 33 | "less total cost", 34 | "easy trip", 35 | "more attractions", 36 | "more indoor attractions", 37 | "more outdoor attractions", 38 | "more popular attractions", 39 | "more unpopular attractions", 40 | ] 41 | min_best_list = [ 42 | "convenient_transport", 43 | "convenient_restaurant", 44 | "close_to_poi", 45 | "less_walk", 46 | "less cost on meals", 47 | "less cost on hotel", 48 | "less cost on attractions", 49 | "less total cost", 50 | ] 51 | 52 | max_best_list = [ 53 | "more cost on meals", 54 | "more cost on hotel", 55 | "more cost on attractions", 56 | "easy trip", 57 | "more attractions", 58 | "more indoor attractions", 59 | "more outdoor attractions", 60 | "more popular attractions", 61 | "more unpopular attractions", 62 | ] 63 | 64 | func_result_name_list = [ 65 | "convenient_transport", 66 | "convenient_restaurant", 67 | "near_poi", 68 | "less_walk", 69 | "meal_cost_ratio", 70 | "accommodation_cost_ratio", 71 | "attraction_cost_ratio", 72 | "total_cost", 73 | "attraction_satisfaction", 74 | "attraction_count", 75 | "indoor_attraction_ratio", 76 | "popular_attraction_ratio", 77 | ] 78 | 79 | 80 | def get_funcname_by_preference(preference: str): 81 | if preference == "convenient transportation": 82 | return "convenient_transport" 83 | elif preference == "convenient restaurants": 84 | return "convenient_restaurant" 85 | elif preference == "close to poi" or preference == "close to": 86 | return "near_poi" 87 | elif preference == "less walk": 88 | return "less_walk" 89 | elif preference == "more cost on meals": 90 | return "meal_cost_ratio" 91 | elif preference == "more cost on hotel": 92 | return "accommodation_cost_ratio" 93 | elif preference == "more cost on attractions": 94 | return "attraction_cost_ratio" 95 | elif preference == "less cost on meals": 96 | return "meal_cost_ratio" 97 | elif preference == "less cost on hotel": 98 | return "accommodation_cost_ratio" 99 | elif preference == "less cost on attractions": 100 | return "attraction_cost_ratio" 101 | elif preference == "less total cost": 102 | return "total_cost" 103 | elif preference == "easy trip": 104 | return "attraction_satisfaction" 105 | elif preference == "more attractions": 106 | return "attraction_count" 107 | elif preference == "more indoor attractions": 108 | return "indoor_attraction_ratio" 109 | elif preference == "more outdoor attractions": 110 | return "indoor_attraction_ratio" 111 | elif preference == "more popular attractions": 112 | return "popular_attraction_ratio" 113 | elif preference == "more unpopular attractions": 114 | return "popular_attraction_ratio" 115 | else: 116 | raise CompareError("No such preference {}".format(preference)) 117 | 118 | 119 | def get_rank_with_value(value_list, best_type: str): 120 | pandas_data = pd.Series(value_list) 121 | if best_type == "max": 122 | worst_value = min(pandas_data) - 1 123 | for meta_data in pandas_data: 124 | if meta_data == -1: 125 | meta_data = worst_value 126 | rank = pandas_data.rank(ascending=False, method="min") 127 | elif best_type == "min": 128 | worst_value = max(pandas_data) + 1 129 | for meta_data in pandas_data: 130 | if meta_data == -1: 131 | meta_data = worst_value 132 | rank = pandas_data.rank(ascending=True, method="min") 133 | else: 134 | raise CompareError("best_type should be 'max' or 'min'") 135 | return rank.tolist() 136 | 137 | 138 | def load_query(query_id_list): 139 | query_data = {} 140 | data_dir = os.path.join(project_root_path, "chinatravel/data") 141 | data_dir_list = os.listdir(data_dir) 142 | for dir in data_dir_list: 143 | split_dir_path = os.path.join(data_dir, dir) 144 | if not os.path.isdir(split_dir_path): 145 | continue 146 | file_list = os.listdir(split_dir_path) 147 | for file in file_list: 148 | query_id = file.split(".")[0] 149 | if query_id in query_id_list: 150 | query_data[query_id] = load_json_file(os.path.join(data_dir, dir, file)) 151 | return query_data 152 | 153 | 154 | def rank(method_list, split_list): 155 | # load query data 156 | for split in split_list: 157 | query_id_list = [] 158 | query_id_list_file_path = os.path.join("./default_splits", f"{split}.txt") 159 | with open(query_id_list_file_path, "r") as f: 160 | for line in f.readlines(): 161 | query_id_list.append(line.strip()) 162 | queries = load_query(query_id_list) 163 | 164 | # load result data 165 | results = {} 166 | for query_id in query_id_list: 167 | results[query_id] = {} 168 | for split in split_list: 169 | for method in method_list: 170 | preference_results_path = os.path.join( 171 | project_root_path, 172 | "chinatravel", 173 | "eval/eval_res", 174 | f"splits_{split}", 175 | method, 176 | "preference.csv", 177 | ) 178 | preference_results = pd.read_csv(preference_results_path).to_dict( 179 | orient="records" 180 | ) 181 | for preference_result in preference_results: 182 | query_id = preference_result["data_id"] 183 | results[query_id][method] = preference_result 184 | 185 | # rank 186 | rank_result = [] 187 | for method in method_list: 188 | rank_result.append({"method": method}) 189 | for preference in preference_list: 190 | rank_result[-1][preference] = [] 191 | rank_result[-1]["avg"] = -1 192 | 193 | for query_id in query_id_list: 194 | for preference in preference_list: 195 | if preference == "close to poi": 196 | flag = False 197 | for preference_en in queries[query_id]["preference_en"]: 198 | if "close to" in preference_en: 199 | flag = True 200 | break 201 | if not flag: 202 | continue 203 | elif not preference in queries[query_id]["preference_en"]: 204 | continue 205 | value_list = [] 206 | for method in method_list: 207 | if preference == "convenient transportation": 208 | print(query_id) 209 | print( 210 | f"method: {method}, value: {results[query_id][method][get_funcname_by_preference(preference)]}" 211 | ) 212 | value_list.append( 213 | results[query_id][method][get_funcname_by_preference(preference)] 214 | ) 215 | rank_list = get_rank_with_value( 216 | value_list, "max" if preference in max_best_list else "min" 217 | ) 218 | for i, method in enumerate(method_list): 219 | rank_result[i][ 220 | preference if "close to" not in preference else "close to poi" 221 | ].append(rank_list[i]) 222 | 223 | for i, method in enumerate(method_list): 224 | avg = 0 225 | num = 0 226 | for preference in preference_list: 227 | if len(rank_result[i][preference]) == 0: 228 | print(f"no data for {method} {preference}") 229 | rank_result[i][preference] = -1 230 | continue 231 | avg += sum(rank_result[i][preference]) 232 | num += len(rank_result[i][preference]) 233 | rank_result[i][preference] = sum(rank_result[i][preference]) / len( 234 | rank_result[i][preference] 235 | ) 236 | rank_result[i]["avg"] = avg / num 237 | rank_result = pd.DataFrame(rank_result) 238 | return rank_result 239 | 240 | 241 | if __name__ == "__main__": 242 | method_list = [ 243 | "naive_ns_Deepseek", 244 | "naive_ns_GLM4Plus", 245 | "naive_ns_GPT4o", 246 | # "react_GPT4o_oneshot", 247 | # "react_Deepseek_oneshot", 248 | ] 249 | split_list = ["preference"] 250 | rank_res = rank(method_list, split_list) 251 | print(rank_res) 252 | file_name = f"rank_{time.strftime('%Y%m%d%H%M%S')}.csv" 253 | file_path = os.path.join( 254 | project_root_path, "chinatravel", "eval", "rank_preference", file_name 255 | ) 256 | rank_res.to_csv(file_path, index=False) 257 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |