├── chatarena ├── __init__.py ├── environments │ ├── __init__.py │ ├── base.py │ └── conversation.py ├── backends │ ├── human.py │ ├── __init__.py │ ├── base.py │ ├── hf_transformers.py │ ├── anthropic.py │ ├── cohere.py │ └── openai.py ├── utils.py ├── message.py ├── database.py ├── config.py ├── agent.py ├── ui │ └── cli.py └── arena.py ├── imgs ├── demo.gif └── framework.png ├── dataset └── Readme.txt ├── seed_dataset └── Readme.txt ├── requirements.txt ├── README.md ├── .gitignore ├── instruction.py ├── data_utils.py ├── eval └── eval_generation.py ├── dialog_simulation.py ├── data_preprocess.py └── LICENSE /chatarena/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iwangjian/TopDial/HEAD/imgs/demo.gif -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iwangjian/TopDial/HEAD/imgs/framework.png -------------------------------------------------------------------------------- /dataset/Readme.txt: -------------------------------------------------------------------------------- 1 | Please follow the README.md to download the seed dataset and unzip it to this folder. -------------------------------------------------------------------------------- /seed_dataset/Readme.txt: -------------------------------------------------------------------------------- 1 | Please follow the README.md to download the seed dataset and unzip it to this folder. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==0.27.2 2 | anthropic==0.2.8 3 | cohere==4.3.1 4 | transformers>=4.27.4 5 | tenacity==8.2.2 6 | rich==13.3.3 7 | prompt_toolkit 8 | py2neo 9 | tqdm 10 | -------------------------------------------------------------------------------- /chatarena/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Environment, TimeStep 2 | from .conversation import Conversation, ModeratedConversation 3 | from ..config import EnvironmentConfig 4 | 5 | ALL_ENVIRONMENTS = [ 6 | Conversation, 7 | ModeratedConversation, 8 | ] 9 | 10 | ENV_REGISTRY = {env.type_name: env for env in ALL_ENVIRONMENTS} 11 | 12 | 13 | # Load an environment from a config dictionary 14 | def load_environment(config: EnvironmentConfig): 15 | try: 16 | env_cls = ENV_REGISTRY[config["env_type"]] 17 | except KeyError: 18 | raise ValueError(f"Unknown environment type: {config['env_type']}") 19 | 20 | env = env_cls.from_config(config) 21 | return env 22 | -------------------------------------------------------------------------------- /chatarena/backends/human.py: -------------------------------------------------------------------------------- 1 | from .base import IntelligenceBackend 2 | from ..config import BackendConfig 3 | 4 | 5 | # An Error class for the human backend 6 | class HumanBackendError(Exception): 7 | def __init__(self, agent_name: str): 8 | self.agent_name = agent_name 9 | super().__init__(f"Human backend requires a UI to get input from {agent_name}.") 10 | 11 | 12 | class Human(IntelligenceBackend): 13 | stateful = False 14 | type_name = "human" 15 | 16 | def __init__(self, **kwargs): 17 | super().__init__(**kwargs) 18 | 19 | def to_config(self) -> BackendConfig: 20 | return BackendConfig(backend_type=self.type_name) 21 | 22 | def query(self, agent_name: str, **kwargs) -> str: 23 | raise HumanBackendError(agent_name) 24 | -------------------------------------------------------------------------------- /chatarena/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from ..config import BackendConfig 2 | 3 | from .base import IntelligenceBackend 4 | from .openai import OpenAIChat 5 | from .cohere import CohereAIChat 6 | from .human import Human 7 | from .hf_transformers import TransformersConversational 8 | from .anthropic import Claude 9 | 10 | ALL_BACKENDS = [ 11 | Human, 12 | OpenAIChat, 13 | CohereAIChat, 14 | TransformersConversational, 15 | Claude, 16 | ] 17 | 18 | BACKEND_REGISTRY = {backend.type_name: backend for backend in ALL_BACKENDS} 19 | 20 | 21 | # Load a backend from a config dictionary 22 | def load_backend(config: BackendConfig): 23 | try: 24 | backend_cls = BACKEND_REGISTRY[config.backend_type] 25 | except KeyError: 26 | raise ValueError(f"Unknown backend type: {config.backend_type}") 27 | 28 | backend = backend_cls.from_config(config) 29 | return backend 30 | -------------------------------------------------------------------------------- /chatarena/utils.py: -------------------------------------------------------------------------------- 1 | class AttributedDict(dict): 2 | """ 3 | A dict class whose keys are automatically set as attributes of the class. 4 | Serializable to JSON. 5 | """ 6 | 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | def __setattr__(self, key, value): 11 | self[key] = value 12 | 13 | def __getattr__(self, key): 14 | if key in self: 15 | return self[key] 16 | raise AttributeError 17 | 18 | def __delattr__(self, key): 19 | del self[key] 20 | 21 | # check whether the key is string when adding the key 22 | def __setitem__(self, key, value): 23 | if not isinstance(key, str): 24 | raise ValueError("The key must be a string") 25 | super().__setitem__(key, value) 26 | 27 | def update(self, *args, **kwargs): 28 | for key, value in dict(*args, **kwargs).items(): 29 | self[key] = value 30 | -------------------------------------------------------------------------------- /chatarena/backends/base.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from abc import abstractmethod 3 | 4 | from ..config import BackendConfig, Configurable 5 | from ..message import Message 6 | 7 | 8 | class IntelligenceBackend(Configurable): 9 | """An abstraction of the intelligence source of the agents.""" 10 | stateful = None 11 | type_name = None 12 | 13 | @abstractmethod 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) # registers the arguments with Configurable 16 | 17 | def __init_subclass__(cls, **kwargs): 18 | # check if the subclass has the required attributes 19 | for required in ('stateful', 'type_name',): 20 | if getattr(cls, required) is None: 21 | raise TypeError(f"Can't instantiate abstract class {cls.__name__} without {required} attribute defined") 22 | return super().__init_subclass__(**kwargs) 23 | 24 | def to_config(self) -> BackendConfig: 25 | self._config_dict["backend_type"] = self.type_name 26 | return BackendConfig(**self._config_dict) 27 | 28 | @abstractmethod 29 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None, 30 | request_msg: Message = None, *args, **kwargs) -> str: 31 | raise NotImplementedError 32 | 33 | @abstractmethod 34 | async def async_query(self, agent_name: str, role_desc: str, history_messages: List[Message], 35 | global_prompt: str = None, request_msg: Message = None, *args, **kwargs) -> str: 36 | """Async querying""" 37 | raise NotImplementedError 38 | 39 | # reset the state of the backend 40 | def reset(self): 41 | if self.stateful: 42 | raise NotImplementedError 43 | else: 44 | pass 45 | -------------------------------------------------------------------------------- /chatarena/environments/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Dict 3 | from abc import abstractmethod 4 | 5 | from ..message import Message 6 | from ..utils import AttributedDict 7 | from ..config import Configurable, EnvironmentConfig 8 | 9 | 10 | @dataclass 11 | class TimeStep(AttributedDict): 12 | observation: List[Message] 13 | reward: Dict[str, float] 14 | terminal: bool 15 | 16 | 17 | class Environment(Configurable): 18 | """ 19 | The environment that the agents interacts with. 20 | """ 21 | type_name = None 22 | 23 | @abstractmethod 24 | def __init__(self, player_names: List[str], **kwargs): 25 | super().__init__(player_names=player_names, **kwargs) # registers the arguments with Configurable 26 | self.player_names = player_names 27 | 28 | def __init_subclass__(cls, **kwargs): 29 | # check if the subclass has the required attributes 30 | for required in ('type_name',): 31 | if getattr(cls, required) is None: 32 | cls.type_name = cls.__name__.lower() 33 | 34 | return super().__init_subclass__(**kwargs) 35 | 36 | @abstractmethod 37 | def reset(self): 38 | """ 39 | reset the environment 40 | """ 41 | pass 42 | 43 | def to_config(self) -> EnvironmentConfig: 44 | self._config_dict["env_type"] = self.type_name 45 | return EnvironmentConfig(**self._config_dict) 46 | 47 | @property 48 | def num_players(self) -> int: 49 | """ 50 | get the number of players 51 | """ 52 | return len(self.player_names) 53 | 54 | @abstractmethod 55 | def get_next_player(self) -> str: 56 | """ 57 | get name of the next player 58 | """ 59 | pass 60 | 61 | @abstractmethod 62 | def get_observation(self, player_name=None) -> List[Message]: 63 | """ 64 | get observation for the player 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def print(self): 70 | """ 71 | print the environment state 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def step(self, player_name: str, action: str) -> TimeStep: 77 | """ 78 | step function that is called by the arena 79 | Args: 80 | player_name: the name of the player 81 | action: the action that the agents wants to take 82 | Returns: 83 | timestep: the timestep that contains the observation, reward and done 84 | """ 85 | pass 86 | 87 | @abstractmethod 88 | def check_action(self, action: str, player_name: str) -> bool: 89 | """ 90 | check whether the action is valid 91 | """ 92 | return True 93 | 94 | @abstractmethod 95 | def is_terminal(self) -> bool: 96 | """ 97 | check whether the environment is in terminal state 98 | """ 99 | pass 100 | 101 | def get_zero_rewards(self) -> Dict[str, float]: 102 | return {player_name: 0. for player_name in self.player_names} 103 | -------------------------------------------------------------------------------- /chatarena/message.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from dataclasses import dataclass 3 | import time 4 | from uuid import uuid1 5 | import hashlib 6 | 7 | # Preserved roles 8 | SYSTEM_NAME = "System" 9 | MODERATOR_NAME = "Moderator" 10 | 11 | 12 | def _hash(input: str): 13 | hex_dig = hashlib.sha256(input.encode()).hexdigest() 14 | return hex_dig 15 | 16 | 17 | @dataclass 18 | class Message: 19 | agent_name: str 20 | content: str # it can be an image or a text 21 | turn: int 22 | timestamp: int = time.time_ns() 23 | visible_to: Union[str, List[str]] = 'all' 24 | msg_type: str = "text" 25 | logged: bool = False # Whether the message is logged in the database 26 | 27 | @property 28 | def msg_hash(self): 29 | # Generate a unique message id given the content, timestamp and role 30 | return _hash( 31 | f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}") 32 | 33 | 34 | class MessagePool(): 35 | """ 36 | A message pool to manage the messages. This allows a unified treatment of the visibility of the messages. 37 | Draft design: 38 | The message pool is a list of (named) tuples, where each tuple has (turn, role, content). 39 | 40 | There should be two potential configurations for step definition: multiple players can act in the same turn (rock-paper-scissors). 41 | The agents can only see the messages that 42 | 1) before the current turn, and 43 | 2) visible to the current role 44 | """ 45 | 46 | def __init__(self): 47 | self.conversation_id = str(uuid1()) 48 | self._messages: List[Message] = [] # TODO: for the sake of thread safety, use a queue instead 49 | self._last_message_idx = 0 50 | 51 | def reset(self): 52 | self._messages = [] 53 | 54 | def append_message(self, message: Message): 55 | self._messages.append(message) 56 | 57 | def print(self): 58 | for message in self._messages: 59 | print(f"[{message.agent_name}->{message.visible_to}]: {message.content}") 60 | 61 | @property 62 | def last_turn(self): 63 | if len(self._messages) == 0: 64 | return 0 65 | else: 66 | return self._messages[-1].turn 67 | 68 | @property 69 | def last_message(self): 70 | if len(self._messages) == 0: 71 | return None 72 | else: 73 | return self._messages[-1] 74 | 75 | def get_all_messages(self) -> List[Message]: 76 | return self._messages 77 | 78 | def get_visible_messages(self, agent_name, turn: int) -> List[Message]: 79 | """ 80 | get the messages that are visible to the agents before the specified turn 81 | """ 82 | 83 | # Get the messages before the current turn 84 | prev_messages = [message for message in self._messages if message.turn < turn] 85 | 86 | visible_messages = [] 87 | for message in prev_messages: 88 | if message.visible_to == "all" or agent_name in message.visible_to or agent_name == "Moderator": 89 | visible_messages.append(message) 90 | return visible_messages 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TopDial 2 | This repository contains code and data for the paper [Target-oriented Proactive Dialogue Systems with Personalization: Problem Formulation and Dataset Curation](http://arxiv.org/abs/2310.07397) accepted by EMNLP 2023. 3 | 4 | ## Overview 5 | 6 |

7 | 8 | Target-oriented dialogue systems, designed to proactively steer conversations toward predefined targets or accomplish specific system-side goals, are an exciting area in conversational AI. In this work, by formulating a pair as the conversation target, we explore a novel problem of personalized target-oriented dialogue by considering personalization during the target accomplishment process. However, there remains an emergent need for high-quality datasets, and building one from scratch requires tremendous human effort. To address this, we propose an automatic dataset curation framework using a role-playing approach. Based on this framework, we construct a large-scale personalized target-oriented dialogue dataset, **TopDial**, which comprises about 18K multi-turn dialogues. 9 | 10 | 11 | ## Dataset 12 | We upload the curated **TopDial** dataset to the [Google Drive](https://drive.google.com/file/d/1AWyjmUxYlppNCKkdK46riMF_2y0XeSI8/view?usp=sharing). 13 | 14 | 15 | ## Dataset Curation 16 | 17 | 18 | ### Requirements 19 | We use [Neo4j](https://neo4j.com/) as the graph database tool to process domain knowledge graph in the seed dataset. Please install it by following the [official guide](https://neo4j.com/docs/operations-manual/current/installation/). The required Python packages are listed in `requirements.txt`. Please install them by running: 20 | ```bash 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### Seed Dataset 25 | We use the [re-purposed version](https://github.com/iwangjian/Color4Dial) of the DuRecDial 2.0 dataset as the seed dataset. 26 | 27 | ### Step 1: Preprocessing the seed dataset 28 | ```python 29 | python data_preprocess.py --seed_dataset_dir ${seed_dataset_dir} --cache_dir ${cache_dir} 30 | ``` 31 | Running this script will generate the following files in the specified cache dir: 32 | `cache_dialogue_{train|dev|test_seen|test_unseen}.jsonl` 33 | 34 | 35 | ### Step 2: Dataset curation 36 | ```python 37 | # set your OpenAI API key 38 | export OPENAI_API_KEY="" 39 | 40 | python -u dialog_simulation.py --cached_seed_path ${cached_seed_path} \ 41 | --output_dir ${output_dir} \ 42 | --max_interaction_step ${max_interaction_step} 43 | ``` 44 | Running the above script will be like: 45 |

46 | 47 | If you hope NOT to show the instructions and the synthesized conversations in the console, please set `--show_description` and `--show_message` to `false`. 48 | 49 | 50 | ## Acknowledgement 51 | Our code is partially based on the implementation of [ChatArena](https://github.com/Farama-Foundation/chatarena). We thank the authors for their excellent work. 52 | 53 | 54 | ## Citation 55 | If you use our data or code in your work, please kindly cite our work as: 56 | ```bibtex 57 | @inproceedings{wang-etal-2023-target, 58 | title = "Target-oriented Proactive Dialogue Systems with Personalization: Problem Formulation and Dataset Curation", 59 | author = "Wang, Jian and 60 | Cheng, Yi and 61 | Lin, Dongding and 62 | Leong, Chak Tou and 63 | Li, Wenjie", 64 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 65 | month = dec, 66 | year = "2023", 67 | address = "Singapore", 68 | publisher = "Association for Computational Linguistics", 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | dataset/ 6 | seed_dataset/ 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /chatarena/backends/hf_transformers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from tenacity import retry, stop_after_attempt, wait_random_exponential 3 | 4 | from .base import IntelligenceBackend 5 | from ..message import Message, SYSTEM_NAME as SYSTEM 6 | 7 | # Try to import the transformers package 8 | try: 9 | import transformers 10 | from transformers import pipeline 11 | from transformers.pipelines.conversational import Conversation, ConversationalPipeline 12 | except ImportError: 13 | is_transformers_available = False 14 | else: 15 | is_transformers_available = True 16 | 17 | 18 | class TransformersConversational(IntelligenceBackend): 19 | """ 20 | Interface to the Transformers ConversationalPipeline 21 | """ 22 | stateful = False 23 | type_name = "transformers:conversational" 24 | 25 | def __init__(self, model: str, device: int = -1, **kwargs): 26 | super().__init__(model=model, device=device, **kwargs) 27 | self.model = model 28 | self.device = device 29 | 30 | assert is_transformers_available, "Transformers package is not installed" 31 | self.chatbot = pipeline(task="conversational", model=self.model, device=self.device) 32 | 33 | @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60)) 34 | def _get_response(self, conversation: Conversation): 35 | conversation = self.chatbot(conversation) 36 | response = conversation.generated_responses[-1] 37 | return response 38 | 39 | @staticmethod 40 | def _msg_template(agent_name, content): 41 | return f"[{agent_name}]: {content}" 42 | 43 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None, 44 | request_msg: Message = None, *args, **kwargs) -> str: 45 | user_inputs, generated_responses = [], [] 46 | all_messages = [(SYSTEM, global_prompt), (SYSTEM, role_desc)] if global_prompt else [(SYSTEM, role_desc)] 47 | 48 | for msg in history_messages: 49 | all_messages.append((msg.agent_name, msg.content)) 50 | if request_msg: 51 | all_messages.append((SYSTEM, request_msg.content)) 52 | 53 | prev_is_user = False # Whether the previous message is from the user 54 | for i, message in enumerate(all_messages): 55 | if i == 0: 56 | assert message[0] == SYSTEM # The first message should be from the system 57 | 58 | if message[0] != agent_name: 59 | if not prev_is_user: 60 | user_inputs.append(self._msg_template(message[0], message[1])) 61 | else: 62 | user_inputs[-1] += "\n" + self._msg_template(message[0], message[1]) 63 | prev_is_user = True 64 | else: 65 | if prev_is_user: 66 | generated_responses.append(message[1]) 67 | else: 68 | generated_responses[-1] += "\n" + message[1] 69 | prev_is_user = False 70 | 71 | assert len(user_inputs) == len(generated_responses) + 1 72 | past_user_inputs = user_inputs[:-1] 73 | new_user_input = user_inputs[-1] 74 | 75 | # Recreate a conversation object from the history messages 76 | conversation = Conversation(text=new_user_input, past_user_inputs=past_user_inputs, 77 | generated_responses=generated_responses) 78 | 79 | # Get the response 80 | response = self._get_response(conversation) 81 | return response 82 | 83 | # conversation = Conversation("Going to the movies tonight - any suggestions?") 84 | # 85 | # # Steps usually performed by the model when generating a response: 86 | # # 1. Mark the user input as processed (moved to the history) 87 | # conversation.mark_processed() 88 | # # 2. Append a mode response 89 | # conversation.append_response("The Big lebowski.") 90 | # 91 | # conversation.add_user_input("Is it good?") 92 | -------------------------------------------------------------------------------- /chatarena/backends/anthropic.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | import re 4 | import logging 5 | from tenacity import retry, stop_after_attempt, wait_random_exponential 6 | 7 | from .base import IntelligenceBackend 8 | from ..message import Message, SYSTEM_NAME as SYSTEM 9 | 10 | try: 11 | import anthropic 12 | except ImportError: 13 | is_anthropic_available = False 14 | logging.warning("anthropic package is not installed") 15 | else: 16 | anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY') 17 | if anthropic_api_key is None: 18 | logging.warning("Anthropic API key is not set. Please set the environment variable ANTHROPIC_API_KEY") 19 | is_anthropic_available = False 20 | else: 21 | is_anthropic_available = True 22 | 23 | DEFAULT_MAX_TOKENS = 256 24 | DEFAULT_MODEL = "claude-v1" 25 | 26 | 27 | class Claude(IntelligenceBackend): 28 | """ 29 | Interface to the Claude offered by Anthropic. 30 | """ 31 | stateful = False 32 | type_name = "claude" 33 | 34 | def __init__(self, max_tokens: int = DEFAULT_MAX_TOKENS, model: str = DEFAULT_MODEL, **kwargs): 35 | assert is_anthropic_available, "anthropic package is not installed or the API key is not set" 36 | super().__init__(max_tokens=max_tokens, model=model, **kwargs) 37 | 38 | self.max_tokens = max_tokens 39 | self.model = model 40 | 41 | self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY']) 42 | 43 | @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60)) 44 | def _get_response(self, prompt: str): 45 | response = self.client.completion( 46 | prompt=prompt, 47 | stop_sequences=[anthropic.HUMAN_PROMPT], 48 | model=self.model, 49 | max_tokens_to_sample=self.max_tokens, 50 | ) 51 | 52 | response = response['completion'].strip() 53 | return response 54 | 55 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None, 56 | request_msg: Message = None, *args, **kwargs) -> str: 57 | """ 58 | format the input and call the Claude API 59 | args: 60 | agent_name: the name of the agent 61 | role_desc: the description of the role of the agent 62 | env_desc: the description of the environment 63 | history_messages: the history of the conversation, or the observation for the agent 64 | request_msg: the request from the system to guide the agent's next response 65 | """ 66 | all_messages = [(SYSTEM, global_prompt), (SYSTEM, role_desc)] if global_prompt else [(SYSTEM, role_desc)] 67 | 68 | for message in history_messages: 69 | all_messages.append((message.agent_name, message.content)) 70 | if request_msg: 71 | all_messages.append((SYSTEM, request_msg.content)) 72 | 73 | prompt = "" 74 | prev_is_human = False # Whether the previous message is from human (in anthropic, the human is the user) 75 | for i, message in enumerate(all_messages): 76 | if i == 0: 77 | assert message[0] == SYSTEM # The first message should be from the system 78 | 79 | if message[0] == agent_name: 80 | if prev_is_human: 81 | prompt = f"{prompt}{anthropic.AI_PROMPT} {message[1]}" 82 | else: 83 | prompt = f"{prompt}\n\n{message[1]}" 84 | prev_is_human = False 85 | else: 86 | if prev_is_human: 87 | prompt = f"{prompt}\n\n[{message[0]}]: {message[1]}" 88 | else: 89 | prompt = f"{prompt}{anthropic.HUMAN_PROMPT}\n[{message[0]}]: {message[1]}" 90 | prev_is_human = True 91 | assert prev_is_human # The last message should be from the human 92 | # Add the AI prompt for Claude to generate the response 93 | prompt = f"{prompt}{anthropic.AI_PROMPT}" 94 | 95 | response = self._get_response(prompt, *args, **kwargs) 96 | 97 | # Remove the agent name if the response starts with it 98 | response = re.sub(rf"^\s*\[{agent_name}]:?", "", response).strip() 99 | 100 | return response 101 | -------------------------------------------------------------------------------- /chatarena/backends/cohere.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | from tenacity import retry, stop_after_attempt, wait_random_exponential 4 | 5 | from .base import IntelligenceBackend 6 | from ..message import Message 7 | 8 | # Try to import the cohere package and check whether the API key is set 9 | try: 10 | import cohere 11 | except ImportError: 12 | is_cohere_available = False 13 | else: 14 | if os.environ.get('COHEREAI_API_KEY') is None: 15 | is_cohere_available = False 16 | else: 17 | is_cohere_available = True 18 | 19 | # Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat) 20 | DEFAULT_TEMPERATURE = 0.8 21 | DEFAULT_MAX_TOKENS = 200 22 | DEFAULT_MODEL = "command-xlarge" 23 | 24 | 25 | class CohereAIChat(IntelligenceBackend): 26 | """ 27 | Interface to the Cohere API 28 | """ 29 | stateful = True 30 | type_name = "cohere-chat" 31 | 32 | def __init__(self, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS, 33 | model: str = DEFAULT_MODEL, **kwargs): 34 | super().__init__(temperature=temperature, max_tokens=max_tokens, model=model, **kwargs) 35 | 36 | self.temperature = temperature 37 | self.max_tokens = max_tokens 38 | self.model = model 39 | 40 | assert is_cohere_available, "Cohere package is not installed or the API key is not set" 41 | self.client = cohere.Client(os.environ.get('COHEREAI_API_KEY')) 42 | 43 | # Stateful variables 44 | self.session_id = None # The session id for the last conversation 45 | self.last_msg_hash = None # The hash of the last message of the last conversation 46 | 47 | def reset(self): 48 | self.session_id = None 49 | self.last_msg_hash = None 50 | 51 | @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60)) 52 | def _get_response(self, new_message: str, persona_prompt: str): 53 | response = self.client.chat( 54 | new_message, 55 | persona_prompt=persona_prompt, 56 | temperature=self.temperature, 57 | max_tokens=self.max_tokens, 58 | session_id=self.session_id 59 | ) 60 | 61 | self.session_id = response.session_id # Update the session id 62 | return response.reply 63 | 64 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None, 65 | request_msg: Message = None, *args, **kwargs) -> str: 66 | """ 67 | format the input and call the Cohere API 68 | args: 69 | agent_name: the name of the agent 70 | role_desc: the description of the role of the agent 71 | env_desc: the description of the environment 72 | history_messages: the history of the conversation, or the observation for the agent 73 | request_msg: the request for the CohereAI 74 | """ 75 | # Find the index of the last message of the last conversation 76 | new_message_start_idx = 0 77 | if self.last_msg_hash is not None: 78 | for i, message in enumerate(history_messages): 79 | if message.msg_hash == self.last_msg_hash: 80 | new_message_start_idx = i + 1 81 | break 82 | 83 | new_messages = history_messages[new_message_start_idx:] 84 | assert len(new_messages) > 0, "No new messages found (this should not happen)" 85 | 86 | new_conversations = [] 87 | for message in new_messages: 88 | if message.agent_name != agent_name: 89 | # Since there are more than one player, we need to distinguish between the players 90 | new_conversations.append(f"[{message.agent_name}]: {message.content}") 91 | 92 | if request_msg: 93 | new_conversations.append(f"[{request_msg.agent_name}]: {request_msg.content}") 94 | 95 | # Concatenate all new messages into one message because the Cohere API only accepts one message 96 | new_message = "\n".join(new_conversations) 97 | persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}" 98 | 99 | response = self._get_response(new_message, persona_prompt) 100 | 101 | # Only update the last message hash if the API call is successful 102 | self.last_msg_hash = new_messages[-1].msg_hash 103 | 104 | return response 105 | -------------------------------------------------------------------------------- /chatarena/database.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datastore module for chat_arena. 3 | This module provides utilities for storing the messages and the game results into database. 4 | Currently, it supports Supabase. 5 | """ 6 | import json 7 | import os 8 | from typing import List 9 | import uuid 10 | 11 | from .arena import Arena 12 | from .message import Message 13 | 14 | # Attempt importing Supabase 15 | try: 16 | import supabase 17 | 18 | # Get the Supabase URL and secret key from environment variables 19 | SUPABASE_URL = os.environ.get("SUPABASE_URL", "") 20 | SUPABASE_SECRET_KEY = os.environ.get("SUPABASE_SECRET_KEY", "") 21 | assert SUPABASE_URL and SUPABASE_SECRET_KEY 22 | except: 23 | supabase_available = False 24 | else: 25 | supabase_available = True 26 | 27 | 28 | # Store the messages into the Supabase database 29 | class SupabaseDB: 30 | def __init__(self): 31 | assert supabase_available and SUPABASE_URL and SUPABASE_SECRET_KEY 32 | supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_SECRET_KEY) 33 | self.client = supabase_client 34 | 35 | # Save Arena state to Supabase 36 | def save_arena(self, arena: Arena): 37 | # Save the environment config 38 | self._save_environment(arena) 39 | 40 | # Save the player configs 41 | self._save_player_configs(arena) 42 | 43 | # Save the messages 44 | self.save_messages(arena) 45 | 46 | # Save the environment config of the arena 47 | def _save_environment(self, arena: Arena): 48 | env = arena.environment 49 | env_config = env.to_config() 50 | moderator_config = env_config.pop("moderator", None) 51 | 52 | arena_row = { 53 | "arena_id": str(arena.uuid), 54 | "global_prompt": arena.global_prompt, 55 | "env_type": env_config["env_type"], 56 | "env_config": json.dumps(env_config), 57 | } 58 | self.client.table("Arena").insert(arena_row).execute() 59 | 60 | # Get the moderator config 61 | if moderator_config: 62 | moderator_row = { 63 | "moderator_id": str(uuid.uuid5(arena.uuid, json.dumps(moderator_config))), 64 | "arena_id": str(arena.uuid), 65 | "role_desc": moderator_config["role_desc"], 66 | "terminal_condition": moderator_config["terminal_condition"], 67 | "backend_type": moderator_config["backend"]["backend_type"], 68 | "temperature": moderator_config["backend"]["temperature"], 69 | "max_tokens": moderator_config["backend"]["max_tokens"], 70 | } 71 | self.client.table("Moderator").insert(moderator_row).execute() 72 | 73 | # Save the player configs of the arena 74 | def _save_player_configs(self, arena: Arena): 75 | player_rows = [] 76 | for player in arena.players: 77 | player_config = player.to_config() 78 | player_row = { 79 | "player_id": str(uuid.uuid5(arena.uuid, json.dumps(player_config))), 80 | "arena_id": str(arena.uuid), 81 | "name": player.name, 82 | "role_desc": player_config["role_desc"], 83 | "backend_type": player_config["backend"]["backend_type"], 84 | "temperature": player_config["backend"].get("temperature", None), 85 | "max_tokens": player_config["backend"].get("max_tokens", None), 86 | } 87 | player_rows.append(player_row) 88 | 89 | self.client.table("Player").insert(player_rows).execute() 90 | 91 | # Save the messages 92 | def save_messages(self, arena: Arena, messages: List[Message] = None): 93 | if messages is None: 94 | messages = arena.environment.get_observation() 95 | 96 | # Filter messages that are already logged 97 | messages = [msg for msg in messages if not msg.logged] 98 | 99 | message_rows = [] 100 | for message in messages: 101 | message_row = { 102 | "message_id": str(uuid.uuid5(arena.uuid, message.msg_hash)), 103 | "arena_id": str(arena.uuid), 104 | "agent_name": message.agent_name, 105 | "content": message.content, 106 | "turn": message.turn, 107 | "timestamp": str(message.timestamp), 108 | "msg_type": message.msg_type, 109 | "visible_to": json.dumps(message.visible_to), 110 | } 111 | message_rows.append(message_row) 112 | 113 | self.client.table("Message").insert(message_rows).execute() 114 | 115 | # Mark the messages as logged 116 | for message in messages: 117 | message.logged = True 118 | 119 | 120 | # Log the arena results into the Supabase database 121 | def log_arena(arena: Arena, database=None): 122 | if database is None: 123 | pass 124 | else: 125 | database.save_arena(arena) 126 | 127 | 128 | # Log the messages into the Supabase database 129 | def log_messages(arena: Arena, messages: List[Message], database=None): 130 | if database is None: 131 | pass 132 | else: 133 | database.save_messages(arena, messages) 134 | -------------------------------------------------------------------------------- /chatarena/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | from abc import abstractmethod 4 | 5 | from .utils import AttributedDict 6 | 7 | 8 | class Config(AttributedDict): 9 | """ 10 | Config class to manage the configuration of the games. 11 | The class has a few useful methods to load and save the config. 12 | """ 13 | 14 | # convert dict to Config recursively 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | for key, value in self.items(): 18 | if isinstance(value, dict): 19 | self[key] = init_config(value) # convert dict to Config recursively 20 | # convert list of dict to list of Config recursively 21 | elif isinstance(value, list) and len(value) > 0: 22 | self[key] = [init_config(item) if isinstance(item, dict) else item for item in value] 23 | 24 | def save(self, path: str): 25 | # save config to file 26 | with open(path, "w") as f: 27 | json.dump(self, f, indent=4) 28 | 29 | @classmethod 30 | def load(cls, path: str): 31 | # load config from file 32 | with open(path, "r") as f: 33 | config = json.load(f) 34 | return cls(config) 35 | 36 | def deepcopy(self): 37 | # get the config class so that subclasses can be copied in the correct class 38 | config_class = self.__class__ 39 | # make a deep copy of the config 40 | return config_class(copy.deepcopy(self)) 41 | 42 | 43 | class Configurable: 44 | """ 45 | Configurable is an interface for classes that can be initialized with a config. 46 | """ 47 | 48 | def __init__(self, **kwargs): 49 | self._config_dict = kwargs 50 | 51 | @classmethod 52 | def from_config(cls, config: Config): 53 | return cls(**config) 54 | 55 | def to_config(self) -> Config: 56 | # Convert the _config_dict to Config 57 | return Config(**self._config_dict) 58 | 59 | def save_config(self, path: str): 60 | self.to_config().save(path) 61 | 62 | 63 | class EnvironmentConfig(Config): 64 | """ 65 | EnvironmentConfig contains a env_type field to indicate the name of the environment. 66 | """ 67 | 68 | def __init__(self, *args, **kwargs): 69 | super().__init__(*args, **kwargs) 70 | # check if the env_type field is specified 71 | if "env_type" not in self: 72 | raise ValueError("The env_type field is not specified") 73 | 74 | 75 | class BackendConfig(Config): 76 | """ 77 | BackendConfig contains a backend_type field to indicate the name of the backend. 78 | """ 79 | 80 | def __init__(self, *args, **kwargs): 81 | super().__init__(*args, **kwargs) 82 | # check if the backend_type field is specified 83 | if "backend_type" not in self: 84 | raise ValueError("The backend_type field is not specified") 85 | 86 | 87 | class AgentConfig(Config): 88 | """ 89 | AgentConfig contains role_desc and backend fields. 90 | """ 91 | 92 | def __init__(self, *args, **kwargs): 93 | super().__init__(*args, **kwargs) 94 | # check if the role_desc field is specified 95 | if "role_desc" not in self: 96 | raise ValueError("The role_desc field is not specified") 97 | # check if the backend field is specified 98 | if "backend" not in self: 99 | raise ValueError("The backend field is not specified") 100 | # Make sure the backend field is a BackendConfig 101 | if not isinstance(self["backend"], BackendConfig): 102 | raise ValueError("The backend field must be a BackendConfig") 103 | 104 | 105 | class ArenaConfig(Config): 106 | """ 107 | ArenaConfig contains a list of AgentConfig. 108 | """ 109 | 110 | def __init__(self, *args, **kwargs): 111 | super().__init__(*args, **kwargs) 112 | # check if the players field is specified and it is List[AgentConfig] 113 | if "players" not in self: 114 | raise ValueError("The players field is not specified") 115 | if not isinstance(self["players"], list): 116 | raise ValueError("The players field must be a list") 117 | for player in self["players"]: 118 | if not isinstance(player, AgentConfig): 119 | raise ValueError("The players field must be a list of AgentConfig") 120 | 121 | # check if environment field is specified and it is EnvironmentConfig 122 | if "environment" not in self: 123 | raise ValueError("The environment field is not specified") 124 | if not isinstance(self["environment"], EnvironmentConfig): 125 | raise ValueError("The environment field must be an EnvironmentConfig") 126 | 127 | 128 | # Initialize with different config class depending on whether the config is for environment or backend 129 | def init_config(config: dict): 130 | if not isinstance(config, dict): 131 | raise ValueError("The config must be a dict") 132 | 133 | # check if the config is for environment or backend 134 | if "env_type" in config: 135 | return EnvironmentConfig(config) 136 | elif "backend_type" in config: 137 | return BackendConfig(config) 138 | elif "role_desc" in config: 139 | return AgentConfig(config) 140 | elif "players" in config: 141 | return ArenaConfig(config) 142 | else: 143 | return Config(config) 144 | -------------------------------------------------------------------------------- /chatarena/agent.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import re 3 | from tenacity import RetryError 4 | import logging 5 | import uuid 6 | from abc import abstractmethod 7 | import asyncio 8 | 9 | from .backends import IntelligenceBackend, load_backend 10 | from .message import Message, SYSTEM_NAME 11 | from .config import AgentConfig, Configurable, BackendConfig 12 | 13 | # A special signal sent by the player to indicate that it is not possible to continue the conversation, and it requests to end the conversation. 14 | # It contains a random UUID string to avoid being exploited by any of the players. 15 | SIGNAL_END_OF_CONVERSATION = f"<<<<<>>>>>{uuid.uuid4()}" 16 | 17 | 18 | class Agent(Configurable): 19 | 20 | @abstractmethod 21 | def __init__(self, name: str, role_desc: str, global_prompt: str = None, *args, **kwargs): 22 | super().__init__(name=name, role_desc=role_desc, global_prompt=global_prompt, **kwargs) 23 | self.name = name 24 | self.role_desc = role_desc 25 | self.global_prompt = global_prompt 26 | 27 | 28 | class Player(Agent): 29 | """ 30 | Player of the game. It can takes the observation from the environment and return an action 31 | """ 32 | 33 | def __init__(self, name: str, role_desc: str, backend: Union[BackendConfig, IntelligenceBackend], 34 | global_prompt: str = None, **kwargs): 35 | 36 | if isinstance(backend, BackendConfig): 37 | backend_config = backend 38 | backend = load_backend(backend_config) 39 | elif isinstance(backend, IntelligenceBackend): 40 | backend_config = backend.to_config() 41 | else: 42 | raise ValueError(f"backend must be a BackendConfig or an IntelligenceBackend, but got {type(backend)}") 43 | 44 | assert name != SYSTEM_NAME, f"Player name cannot be {SYSTEM_NAME}, which is reserved for the system." 45 | 46 | # Register the fields in the _config 47 | super().__init__(name=name, role_desc=role_desc, backend=backend_config, 48 | global_prompt=global_prompt, **kwargs) 49 | 50 | self.backend = backend 51 | 52 | def to_config(self) -> AgentConfig: 53 | return AgentConfig( 54 | name=self.name, 55 | role_desc=self.role_desc, 56 | backend=self.backend.to_config(), 57 | global_prompt=self.global_prompt, 58 | ) 59 | 60 | def act(self, observation: List[Message]) -> str: 61 | """ 62 | Call the agents to generate a response (equivalent to taking an action). 63 | """ 64 | try: 65 | response = self.backend.query(agent_name=self.name, role_desc=self.role_desc, 66 | history_messages=observation, global_prompt=self.global_prompt, 67 | request_msg=None) 68 | except RetryError as e: 69 | logging.warning(f"Agent {self.name} failed to generate a response. " 70 | f"Error: {e.last_attempt.exception()}. " 71 | f"Sending signal to end the conversation.") 72 | response = SIGNAL_END_OF_CONVERSATION 73 | 74 | return response 75 | 76 | def __call__(self, observation: List[Message]) -> str: 77 | return self.act(observation) 78 | 79 | async def async_act(self, observation: List[Message]) -> str: 80 | """ 81 | Async call the agents to generate a response (equivalent to taking an action). 82 | """ 83 | try: 84 | response = self.backend.async_query(agent_name=self.name, role_desc=self.role_desc, 85 | history_messages=observation, global_prompt=self.global_prompt, 86 | request_msg=None) 87 | except RetryError as e: 88 | logging.warning(f"Agent {self.name} failed to generate a response. " 89 | f"Error: {e.last_attempt.exception()}. " 90 | f"Sending signal to end the conversation.") 91 | response = SIGNAL_END_OF_CONVERSATION 92 | 93 | return response 94 | 95 | def reset(self): 96 | self.backend.reset() 97 | 98 | 99 | class Moderator(Player): 100 | """ 101 | A special type of player that moderates the conversation (usually used as a component of environment). 102 | """ 103 | 104 | def __init__(self, role_desc: str, backend: Union[BackendConfig, IntelligenceBackend], 105 | terminal_condition: str, global_prompt: str = None, **kwargs): 106 | name = "Moderator" 107 | super().__init__(name=name, role_desc=role_desc, backend=backend, global_prompt=global_prompt, **kwargs) 108 | 109 | self.terminal_condition = terminal_condition 110 | 111 | def to_config(self) -> AgentConfig: 112 | return AgentConfig( 113 | name=self.name, 114 | role_desc=self.role_desc, 115 | backend=self.backend.to_config(), 116 | terminal_condition=self.terminal_condition, 117 | global_prompt=self.global_prompt, 118 | ) 119 | 120 | def is_terminal(self, history: List[Message], *args, **kwargs) -> bool: 121 | """ 122 | check whether the conversation is over 123 | """ 124 | # If the last message is the signal, then the conversation is over 125 | if history[-1].content == SIGNAL_END_OF_CONVERSATION: 126 | return True 127 | 128 | try: 129 | request_msg = Message(agent_name=self.name, content=self.terminal_condition, turn=-1) 130 | response = self.backend.query(agent_name=self.name, role_desc=self.role_desc, history_messages=history, 131 | global_prompt=self.global_prompt, request_msg=request_msg, *args, **kwargs) 132 | except RetryError as e: 133 | logging.warning(f"Agent {self.name} failed to generate a response. " 134 | f"Error: {e.last_attempt.exception()}.") 135 | return True 136 | 137 | if re.match(r"yes|y|yea|yeah|yep|yup|sure|ok|okay|alright", response, re.IGNORECASE): 138 | # print(f"Decision: {response}. Conversation is ended by moderator.") 139 | return True 140 | else: 141 | return False 142 | -------------------------------------------------------------------------------- /chatarena/backends/openai.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | import re 4 | import logging 5 | from tenacity import retry, stop_after_attempt, wait_random, wait_random_exponential 6 | 7 | from .base import IntelligenceBackend 8 | from ..message import Message, SYSTEM_NAME, MODERATOR_NAME 9 | 10 | try: 11 | import openai 12 | except ImportError: 13 | is_openai_available = False 14 | logging.warning("openai package is not installed") 15 | else: 16 | openai.api_key = os.environ.get("OPENAI_API_KEY") 17 | if openai.api_key is None: 18 | logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY") 19 | is_openai_available = False 20 | else: 21 | is_openai_available = True 22 | 23 | # Default config follows the OpenAI playground 24 | DEFAULT_TEMPERATURE = 0.7 25 | DEFAULT_MAX_TOKENS = 256 26 | DEFAULT_MODEL = "gpt-3.5-turbo" 27 | 28 | END_OF_MESSAGE = "" # End of message token specified by us not OpenAI 29 | STOP = ("<|endoftext|>", END_OF_MESSAGE) # End of sentence token 30 | BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}." 31 | 32 | 33 | class OpenAIChat(IntelligenceBackend): 34 | """ 35 | Interface to the ChatGPT style model with system, user, assistant roles separation 36 | """ 37 | stateful = False 38 | type_name = "openai-chat" 39 | 40 | def __init__(self, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS, 41 | model: str = DEFAULT_MODEL, merge_other_agents_as_one_user: bool = True, **kwargs): 42 | """ 43 | instantiate the OpenAIChat backend 44 | args: 45 | temperature: the temperature of the sampling 46 | max_tokens: the maximum number of tokens to sample 47 | model: the model to use 48 | merge_other_agents_as_one_user: whether to merge messages from other agents as one user message 49 | """ 50 | assert is_openai_available, "openai package is not installed or the API key is not set" 51 | super().__init__(temperature=temperature, max_tokens=max_tokens, model=model, 52 | merge_other_agents_as_one_user=merge_other_agents_as_one_user, **kwargs) 53 | 54 | self.temperature = temperature 55 | self.max_tokens = max_tokens 56 | self.model = model 57 | self.merge_other_agent_as_user = merge_other_agents_as_one_user 58 | 59 | @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(min=1, max=60)) # Modified retry strategy 60 | def _get_response(self, messages): 61 | completion = openai.ChatCompletion.create( 62 | model=self.model, 63 | messages=messages, 64 | temperature=self.temperature, 65 | max_tokens=self.max_tokens, 66 | stop=STOP 67 | ) 68 | 69 | response = completion.choices[0]['message']['content'] 70 | response = response.strip() 71 | return response 72 | 73 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None, 74 | request_msg: Message = None, *args, **kwargs) -> str: 75 | """ 76 | format the input and call the ChatGPT/GPT-4 API 77 | args: 78 | agent_name: the name of the agent 79 | role_desc: the description of the role of the agent 80 | env_desc: the description of the environment 81 | history_messages: the history of the conversation, or the observation for the agent 82 | request_msg: the request from the system to guide the agent's next response 83 | """ 84 | 85 | # Merge the role description and the global prompt as the system prompt for the agent 86 | if global_prompt: # Prepend the global prompt if it exists 87 | system_prompt = f"{global_prompt.strip()}\n\nYour name: {agent_name}\n\nYour role: {role_desc}" 88 | else: 89 | system_prompt = f"You are {agent_name}.\n\nYour role: {role_desc}" 90 | 91 | all_messages = [(SYSTEM_NAME, system_prompt)] 92 | for msg in history_messages: 93 | if msg.agent_name == SYSTEM_NAME: 94 | all_messages.append((SYSTEM_NAME, msg.content)) 95 | else: # non-system messages are suffixed with the end of message token 96 | all_messages.append((msg.agent_name, f"{msg.content}{END_OF_MESSAGE}")) 97 | 98 | if request_msg is not None: 99 | all_messages.append((SYSTEM_NAME, request_msg.content)) 100 | else: # The default request message that reminds the agent its role and instruct it to speak 101 | all_messages.append((SYSTEM_NAME, f"Now you speak, {agent_name}.{END_OF_MESSAGE}")) 102 | 103 | messages = [] 104 | for i, msg in enumerate(all_messages): 105 | if i == 0: 106 | assert msg[0] == SYSTEM_NAME # The first message should be from the system 107 | messages.append({"role": "system", "content": msg[1]}) 108 | else: 109 | if msg[0] == agent_name: 110 | messages.append({"role": "assistant", "content": msg[1]}) 111 | else: 112 | if messages[-1]["role"] == "user": # last message is from user 113 | if self.merge_other_agent_as_user: 114 | messages[-1]["content"] = f"{messages[-1]['content']}\n\n[{msg[0]}]: {msg[1]}" 115 | else: 116 | messages.append({"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}) 117 | elif messages[-1]["role"] == "assistant": # consecutive assistant messages 118 | # Merge the assistant messages 119 | messages[-1]["content"] = f"{messages[-1]['content']}\n{msg[1]}" 120 | elif messages[-1]["role"] == "system": 121 | messages.append({"role": "user", "content": f"[{msg[0]}]: {msg[1]}"}) 122 | else: 123 | raise ValueError(f"Invalid role: {messages[-1]['role']}") 124 | 125 | response = self._get_response(messages, *args, **kwargs) 126 | 127 | # Remove the agent name if the response starts with it 128 | response = re.sub(rf"^\s*\[.*]:", "", response).strip() 129 | response = re.sub(rf"^\s*{re.escape(agent_name)}\s*:", "", response).strip() 130 | # Remove the tailing end of message token 131 | response = re.sub(rf"{END_OF_MESSAGE}$", "", response).strip() 132 | 133 | return response 134 | -------------------------------------------------------------------------------- /chatarena/environments/conversation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from .base import TimeStep, Environment 4 | from ..message import Message, MessagePool 5 | from ..agent import Moderator, SIGNAL_END_OF_CONVERSATION 6 | from ..config import EnvironmentConfig, AgentConfig 7 | 8 | 9 | class Conversation(Environment): 10 | """ 11 | Turn-based fully observable conversation environment. 12 | Next speaker order is either parallel or round-robin. 13 | """ 14 | type_name = "conversation" 15 | 16 | def __init__(self, player_names: List[str], parallel: bool = False, **kwargs): 17 | super().__init__(player_names=player_names, parallel=parallel, **kwargs) 18 | 19 | self.parallel = parallel 20 | 21 | # The "state" of the environment is maintained by the message pool 22 | self.message_pool = MessagePool() 23 | 24 | self._current_turn = 0 25 | self._next_player_idx = 0 26 | 27 | def reset(self): 28 | self._current_turn = 0 29 | self._next_player_idx = 0 30 | self.message_pool.reset() 31 | 32 | init_timestep = TimeStep(observation=[], 33 | reward=self.get_zero_rewards(), 34 | terminal=False) 35 | return init_timestep 36 | 37 | def to_config(self) -> EnvironmentConfig: 38 | return EnvironmentConfig(env_type=self.type_name, player_names=self.player_names, parallel=self.parallel) 39 | 40 | def print(self): 41 | self.message_pool.print() 42 | 43 | def get_next_player(self) -> str: 44 | """ 45 | get the next player 46 | """ 47 | return self.player_names[self._next_player_idx] 48 | 49 | def get_observation(self, player_name=None) -> List[Message]: 50 | """ 51 | get observation for the player 52 | """ 53 | if player_name is None: 54 | return self.message_pool.get_all_messages() 55 | else: 56 | return self.message_pool.get_visible_messages(player_name, turn=self._current_turn) 57 | 58 | def is_terminal(self) -> bool: 59 | """ 60 | check if the conversation is over 61 | """ 62 | # If the last message is the signal, then the conversation is over 63 | if self.message_pool.last_message.content == SIGNAL_END_OF_CONVERSATION: 64 | return True 65 | 66 | def step(self, player_name: str, action: str) -> TimeStep: 67 | """ 68 | step function that is called by the arena 69 | Args: 70 | player_name: the name of the player that takes the action 71 | action: the action that the agents wants to take 72 | """ 73 | message = Message(agent_name=player_name, content=action, turn=self._current_turn) 74 | self.message_pool.append_message(message) 75 | 76 | # Update the counters 77 | if not self.parallel or self._next_player_idx == 0: 78 | self._current_turn += 1 79 | self._next_player_idx = (self._next_player_idx + 1) % self.num_players 80 | 81 | timestep = TimeStep(observation=self.get_observation(), 82 | reward=self.get_zero_rewards(), 83 | terminal=self.is_terminal()) # Return all the messages 84 | return timestep 85 | 86 | 87 | class ModeratedConversation(Conversation): 88 | """ 89 | Turn-based fully observable conversation environment. 90 | Next speaker order is either parallel or round-robin. 91 | Moderator is a special agent that can see all messages and can decide whether the conversation is over. 92 | """ 93 | 94 | type_name = "moderated_conversation" 95 | 96 | def __init__(self, player_names: List[str], moderator: Union[Moderator, AgentConfig], 97 | parallel: bool = False, moderator_visibility="all", moderator_period="turn", **kwargs): 98 | 99 | super().__init__(player_names=player_names, parallel=parallel, **kwargs) 100 | 101 | if isinstance(moderator, AgentConfig): 102 | moderator_config = moderator 103 | moderator = Moderator.from_config(moderator_config) 104 | elif not isinstance(moderator, Moderator): 105 | raise ValueError("moderator must be either an AgentConfig or a Moderator instance.") 106 | 107 | self.moderator = moderator 108 | self.moderator_visibility = moderator_visibility 109 | self.moderator_period = moderator_period 110 | 111 | def to_config(self) -> EnvironmentConfig: 112 | # This environment contains some speical config arguments that needs to be handle specially 113 | return EnvironmentConfig(env_type=self.type_name, player_names=self.player_names, parallel=self.parallel, 114 | moderator=self.moderator.to_config(), moderator_visibility=self.moderator_visibility, 115 | moderator_period=self.moderator_period) 116 | 117 | def step(self, player_name: str, action: str) -> TimeStep: 118 | """ 119 | step function that is called by the arena 120 | Args: 121 | player_name: the name of the player that takes the action 122 | action: the action that the agents wants to take 123 | """ 124 | message = Message(agent_name=player_name, content=action, turn=self._current_turn) 125 | self.message_pool.append_message(message) 126 | 127 | # Round-robin order for the next player 128 | self._next_player_idx = (self._next_player_idx + 1) % self.num_players 129 | 130 | if self.moderator_period == "turn" or \ 131 | (self.moderator_period == "round" and self._next_player_idx == 0): 132 | # Moderator's turn 133 | moderator_history = self.message_pool.get_all_messages() 134 | 135 | # Moderator's response is not used 136 | #moderator_response = self.moderator(moderator_history) 137 | #moderator_message = Message(agent_name=self.moderator.name, 138 | # content=moderator_response, 139 | # turn=self._current_turn, 140 | # visible_to=self.moderator_visibility) 141 | #self.message_pool.append_message(moderator_message) 142 | 143 | # We only use Moderator to determine whether the conversation should be ended 144 | terminal = self.moderator.is_terminal(moderator_history) or self.is_terminal() 145 | else: 146 | terminal = self.is_terminal() 147 | 148 | # Update the counters 149 | if not self.parallel or self._next_player_idx == 0: 150 | self._current_turn += 1 151 | 152 | timestep = TimeStep(observation=self.get_observation(), 153 | reward=self.get_zero_rewards(), 154 | terminal=terminal) # Return all the messages 155 | return timestep 156 | -------------------------------------------------------------------------------- /chatarena/ui/cli.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit import prompt 2 | from prompt_toolkit.completion import WordCompleter 3 | from prompt_toolkit.styles import Style 4 | from rich.console import Console 5 | from rich.text import Text 6 | from rich.color import ANSI_COLOR_NAMES 7 | import random 8 | 9 | from ..arena import Arena, TooManyInvalidActions 10 | from ..backends.human import HumanBackendError 11 | from ..agent import SIGNAL_END_OF_CONVERSATION 12 | 13 | ASCII_ART = r""" 14 | _________ .__ __ _____ 15 | \_ ___ \ | |__ _____ _/ |_ / _ \ _______ ____ ____ _____ 16 | / \ \/ | | \ \__ \ \ __\ / /_\ \ \_ __ \W/ __ \ / \ \__ \ 17 | \ \____| Y \ / __ \_ | | / | \ | | \/\ ___/ | | \ / __ \_ 18 | \______ /|___| /(____ / |__| \____|__ / |__| \___ >|___| /(____ / 19 | \/ \/ \/ \/ \/ \/ \/ 20 | """ 21 | 22 | visible_colors = [color for color in ANSI_COLOR_NAMES.keys() if 23 | color not in ["black", "white", "red", "green"] and "grey" not in color] 24 | 25 | MAX_STEPS = 5 26 | 27 | import logging 28 | 29 | # Set logging level to ERROR 30 | logging.getLogger().setLevel(logging.ERROR) 31 | 32 | 33 | class ArenaCLI: 34 | """ 35 | The CLI user interface for ChatArena. 36 | """ 37 | 38 | def __init__(self, arena: Arena): 39 | self.arena = arena 40 | 41 | def launch(self, max_steps: int = None, interactive: bool = True, show_description: bool = True, show_message: bool = True): 42 | """ 43 | Run the CLI 44 | """ 45 | if not interactive and max_steps is None: 46 | max_steps = MAX_STEPS 47 | 48 | console = Console() 49 | # Print ascii art 50 | #console.print(ASCII_ART, style="bold dark_orange3") 51 | timestep = self.arena.reset() 52 | #console.print("🏟 Chat Arena Initialized!", style="bold green") 53 | 54 | env = self.arena.environment 55 | players = self.arena.players 56 | 57 | env_desc = self.arena.global_prompt 58 | num_players = env.num_players 59 | player_colors = random.sample(visible_colors, num_players) # sample different colors for players 60 | name_to_color = dict(zip(env.player_names, player_colors)) 61 | # System and Moderator messages are printed in red 62 | name_to_color["System"] = "red" 63 | name_to_color["Moderator"] = "red" 64 | 65 | # Print the player name, role_desc and backend_type 66 | if show_description: 67 | console.print(f"[bold green underline]Environment ({env.type_name}) description:[/]\n{env_desc}") 68 | 69 | for i, player in enumerate(players): 70 | player_name = Text(f"[{player.name} ({player.backend.type_name})] Role Description:") 71 | player_name.stylize(f"bold {name_to_color[player.name]} underline") 72 | console.print(player_name) 73 | console.print(player.role_desc) 74 | if show_message: 75 | console.print("\n========= Arena Start! ==========\n", style="bold green") 76 | 77 | step = 0 78 | while not timestep.terminal: 79 | if interactive: 80 | command = prompt([('class:command', "command (n/r/q/s/h) > ")], 81 | style=Style.from_dict({'command': 'blue'}), 82 | completer=WordCompleter( 83 | ['next', 'n', 'reset', 'r', 'exit', 'quit', 'q', 'help', 'h', 'save', 's'])) 84 | command = command.strip() 85 | 86 | if command == "help" or command == "h": 87 | console.print("Available commands:") 88 | console.print(" [bold]next or n or [/]: next step") 89 | console.print(" [bold]exit or quit or q[/]: exit the game") 90 | console.print(" [bold]help or h[/]: print this message") 91 | console.print(" [bold]reset or r[/]: reset the game") 92 | console.print(" [bold]save or s[/]: save the history to file") 93 | continue 94 | elif command == "exit" or command == "quit" or command == "q": 95 | break 96 | elif command == "reset" or command == "r": 97 | timestep = self.arena.reset() 98 | console.print("\n========= Arena Reset! ==========\n", style="bold green") 99 | continue 100 | elif command == "next" or command == "n" or command == "": 101 | pass 102 | elif command == "save" or command == "s": 103 | # Prompt to get the file path 104 | file_path = prompt([('class:command', "save file path > ")], 105 | style=Style.from_dict({'command': 'blue'})) 106 | file_path = file_path.strip() 107 | # Save the history to file 108 | self.arena.save_history(file_path) 109 | # Print the save success message 110 | console.print(f"History saved to {file_path}", style="bold green") 111 | else: 112 | console.print(f"Invalid command: {command}", style="bold red") 113 | continue 114 | 115 | try: 116 | timestep = self.arena.step() 117 | except HumanBackendError as e: 118 | # Handle human input and recover with the game update 119 | human_player_name = env.get_next_player() 120 | if interactive: 121 | human_input = prompt( 122 | [('class:user_prompt', f"Type your input for {human_player_name}: ")], 123 | style=Style.from_dict({'user_prompt': 'ansicyan underline'}) 124 | ) 125 | # If not, the conversation does not stop 126 | timestep = env.step(human_player_name, human_input) 127 | else: 128 | raise e # cannot recover from this error in non-interactive mode 129 | except TooManyInvalidActions as e: 130 | # Print the error message 131 | console.print(f"Too many invalid actions: {e}", style="bold red") 132 | break 133 | 134 | # The messages that are not yet logged 135 | messages = [msg for msg in env.get_observation() if not msg.logged] 136 | # Print the new messages 137 | for msg in messages: 138 | message_text = Text(f"[{msg.agent_name}->{msg.visible_to}]: {msg.content}") 139 | message_text.stylize(f"bold {name_to_color[msg.agent_name]}", 0, 140 | len(f"[{msg.agent_name}->{msg.visible_to}]:")) 141 | if show_message: 142 | console.print(message_text) 143 | msg.logged = True 144 | 145 | step += 1 146 | if max_steps is not None and step >= max_steps: 147 | break 148 | 149 | if show_message: 150 | console.print("\n========= Arena Ended! ==========\n", style="bold red") 151 | -------------------------------------------------------------------------------- /chatarena/arena.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union 2 | import uuid 3 | import json 4 | import csv 5 | import logging 6 | 7 | from .agent import Player 8 | from .environments import Environment, TimeStep, load_environment 9 | from .backends import Human 10 | from .config import ArenaConfig 11 | 12 | 13 | class TooManyInvalidActions(Exception): 14 | pass 15 | 16 | 17 | class Arena: 18 | """ 19 | Utility class that manages the game environment and players 20 | """ 21 | 22 | def __init__(self, players: List[Player], environment: Environment, global_prompt: str = None): 23 | # Create a container for the players and environment and reset the game 24 | self.players = players 25 | self.environment = environment 26 | self.global_prompt = global_prompt 27 | 28 | self.current_timestep = environment.reset() 29 | self.uuid = uuid.uuid4() # Generate a unique id for the game 30 | self.invalid_actions_retry = 5 31 | 32 | @property 33 | def num_players(self): 34 | return self.environment.num_players 35 | 36 | @property 37 | def name_to_player(self) -> Dict[str, Player]: 38 | return {player.name: player for player in self.players} 39 | 40 | def reset(self) -> TimeStep: 41 | # Reset the environment 42 | self.current_timestep = self.environment.reset() 43 | # Reset the players 44 | for player in self.players: 45 | player.reset() 46 | # Reset the uuid 47 | self.uuid = uuid.uuid4() 48 | return self.current_timestep 49 | 50 | def step(self) -> TimeStep: 51 | """ 52 | Take a step in the game: one player takes an action and the environment updates 53 | """ 54 | player_name = self.environment.get_next_player() 55 | player = self.name_to_player[player_name] # get the player object 56 | observation = self.environment.get_observation(player_name) # get the observation for the player 57 | 58 | timestep = None 59 | for i in range(self.invalid_actions_retry): # try to take an action for a few times 60 | action = player(observation) # take an action 61 | if self.environment.check_action(action, player_name): # action is valid 62 | timestep = self.environment.step(player_name, action) # update the environment 63 | break 64 | else: # action is invalid 65 | logging.warning(f"{player_name} made an invalid action {action}") 66 | continue 67 | 68 | if timestep is None: # if the player made invalid actions for too many times, terminate the game 69 | warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game." 70 | logging.warning(warning_msg) 71 | raise TooManyInvalidActions(warning_msg) 72 | 73 | return timestep 74 | 75 | def next_is_human(self): 76 | """ 77 | check if the next player is human 78 | """ 79 | player_name = self.environment.get_next_player() 80 | player = self.name_to_player[player_name] 81 | return isinstance(player.backend, Human) 82 | 83 | def run(self, num_steps: int = 1): 84 | """ 85 | run the game for num_turns 86 | """ 87 | for i in range(num_steps): 88 | timestep = self.step() 89 | if timestep.terminal: 90 | break 91 | 92 | @classmethod 93 | def from_config(cls, config: Union[str, ArenaConfig]): 94 | """ 95 | create an arena from a config 96 | """ 97 | # If config is a path, load the config 98 | if isinstance(config, str): 99 | config = ArenaConfig.load(config) 100 | 101 | global_prompt = config.get("global_prompt", None) 102 | 103 | # Create the players 104 | players = [] 105 | for player_config in config.players: 106 | # Add public_prompt to the player config 107 | if global_prompt is not None: 108 | player_config["global_prompt"] = global_prompt 109 | 110 | player = Player.from_config(player_config) 111 | players.append(player) 112 | 113 | # Check that the player names are unique 114 | player_names = [player.name for player in players] 115 | assert len(player_names) == len(set(player_names)), "Player names must be unique" 116 | 117 | # Create the environment 118 | config.environment["player_names"] = player_names # add the player names to the environment config 119 | env = load_environment(config.environment) 120 | 121 | return cls(players, env, global_prompt=global_prompt) 122 | 123 | def to_config(self) -> ArenaConfig: 124 | """ 125 | convert the arena to a config 126 | """ 127 | # return { 128 | # "players": [player.to_config() for player in self.players], 129 | # "environment": self.environment.to_config(), 130 | # "global_prompt": self.global_prompt 131 | # } 132 | return ArenaConfig( 133 | players=[player.to_config() for player in self.players], 134 | environment=self.environment.to_config(), 135 | global_prompt=self.global_prompt 136 | ) 137 | 138 | def launch_cli(self, max_steps: int = None, interactive: bool = True, show_description: bool = True, show_message: bool = True): 139 | """ 140 | launch the command line interface 141 | """ 142 | from chatarena.ui.cli import ArenaCLI 143 | cli = ArenaCLI(self) 144 | cli.launch(max_steps=max_steps, interactive=interactive, show_description=show_description, show_message=show_message) 145 | 146 | def save_config(self, path: str): 147 | """ 148 | save the config to a file 149 | """ 150 | config = self.to_config() 151 | config.save(path) 152 | 153 | def save_history(self, path: str): 154 | """ 155 | save the history of the game to a file 156 | Supports csv and json formats. 157 | """ 158 | messages = self.environment.get_observation() 159 | message_rows = [] 160 | 161 | if path.endswith(".csv"): 162 | header = ["agent_name", "content", "turn", "timestamp", "visible_to", "msg_type"] 163 | for message in messages: 164 | message_row = [ 165 | message.agent_name, 166 | message.content, 167 | message.turn, 168 | str(message.timestamp), 169 | message.visible_to, 170 | message.msg_type, 171 | ] 172 | message_rows.append(message_row) 173 | 174 | with open(path, "w") as f: 175 | writer = csv.writer(f) 176 | writer.writerow(header) 177 | writer.writerows(message_rows) 178 | elif path.endswith(".json"): 179 | for message in messages: 180 | message_row = { 181 | "agent_name": message.agent_name, 182 | "content": message.content, 183 | "turn": message.turn, 184 | "timestamp": str(message.timestamp), 185 | "visible_to": message.visible_to, 186 | "msg_type": message.msg_type, 187 | } 188 | message_rows.append(message_row) 189 | 190 | with open(path, "w") as f: 191 | json.dump(message_rows, f, indent=4) 192 | else: 193 | raise ValueError("Invalid file format") 194 | -------------------------------------------------------------------------------- /instruction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | 4 | 5 | def create_instruct( 6 | target: List[str], 7 | simulated_profile: dict, 8 | simulated_personality: dict, 9 | assistant_name: str, 10 | domain_knowledge: List[List], 11 | seed_conversation: dict 12 | ): 13 | """Create instructions about the conversation environment and roles.""" 14 | domain = "" 15 | target_action = target[0].lower() 16 | if "movie" in target_action: 17 | domain = "movie" 18 | elif "music" in target_action: 19 | domain = "music" 20 | elif "food" in target_action: 21 | domain = "food" 22 | elif "poi" in target_action: 23 | domain = "poi" 24 | else: 25 | raise ValueError("Invalid target action: {}".format(target_action)) 26 | 27 | # Describe the environment (shared by all roles) 28 | if domain == "movie" or domain == "music": 29 | env_desc = "You are participating in a conversation about music or movies." 30 | else: 31 | env_desc = "You are participating in a conversation about delicious food or point-of-interest (POI)." 32 | 33 | # Describe the user 34 | user_desc = "You are {}, ".format(simulated_profile["Name"]) 35 | profile_desc = "" 36 | 37 | if simulated_profile["Occupation"] == "Student": 38 | if simulated_profile["Gender"] == "Male": 39 | profile_desc = "a male student in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"]) 40 | else: 41 | profile_desc = "a female student in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"]) 42 | elif simulated_profile["Occupation"] == "Employed": 43 | if simulated_profile["Gender"] == "Male": 44 | profile_desc = "a man in the age range of {}, working in a company and living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"]) 45 | else: 46 | profile_desc = "a woman in the age range of {}, working in a company and living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"]) 47 | else: 48 | if simulated_profile["Gender"] == "Male": 49 | profile_desc = "a retired man in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"]) 50 | else: 51 | profile_desc = "a retired woman in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"]) 52 | user_desc += profile_desc + ".\n\n" 53 | 54 | user_desc += "Based on your past experiences, you have the following preferences:\n" 55 | if domain == "movie" or domain == "music": 56 | for k in ["Accepted movies", "Accepted music", "Accepted celebrities", "Rejected movies", "Rejected music"]: 57 | kk = k.replace("Accepted", "liked").replace("Rejected", "disliked") 58 | user_desc += "Your {}: {}.\n".format(kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else "" 59 | else: 60 | for k in ["Accepted food", "Accepted POI"]: 61 | kk = k.replace("Accepted", "liked") 62 | user_desc += "Your {}: {}.\n".format(kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else "" 63 | user_desc += "\n" 64 | 65 | 66 | user_desc += "Based on the Big-5 personality traits, your personality is measured as:\n" 67 | for k, v in simulated_personality.items(): 68 | user_desc += "For {}, you are {}.\n".format(k, v) 69 | user_desc += "\n" 70 | 71 | user_desc += "Your response should match your profile and personality, and be concise (no longer than 30 words).\n" 72 | user_desc += "You don't need to recommend anything, but feel free to express your personal interests." 73 | 74 | gender_desc = "his" if simulated_profile["Gender"] == "Male" else "her" 75 | if domain == "movie" or domain == "music": 76 | for k in ["Accepted movies", "Accepted music", "Accepted celebrities", "Rejected movies", "Rejected music"]: 77 | kk = k.replace("Accepted", "liked").replace("Rejected", "disliked") 78 | profile_desc += "; {} {}: {}".format(gender_desc, kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else "" 79 | else: 80 | for k in ["Accepted food", "Accepted POI"]: 81 | kk = k.replace("Accepted", "liked") 82 | profile_desc += "; {} {}: {}".format(gender_desc, kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else "" 83 | profile_desc += "." 84 | 85 | user_dict = { 86 | "name": simulated_profile["Name"], 87 | "role_desc": user_desc, 88 | } 89 | 90 | # Describe the assistant 91 | if domain == "movie": 92 | assistant_desc = "You are {}, a movie enthusiast who enjoys a variety of films.\n".format(assistant_name) 93 | elif domain == "music": 94 | assistant_desc = "You are {}, a music enthusiast who enjoys a variety of music.\n".format(assistant_name) 95 | elif domain == "food": 96 | assistant_desc = "You are {}, a foodie who enjoys delicious food.\n".format(assistant_name) 97 | elif domain == "poi": 98 | assistant_desc = "You are {}, a food enthusiast who is interested in exploring different restaurants.\n".format(assistant_name) 99 | else: 100 | raise ValueError("Invalid domain: {}".format(domain)) 101 | 102 | assistant_desc += "You are conversing with {}, whose profile is below: \n## {}\n\n".format(simulated_profile["Name"], profile_desc) 103 | assistant_desc += "Your goal is to proactively lead the conversation with {} towards the target {} \"{}\".\n".format(simulated_profile["Name"], domain, target[1]) 104 | assistant_desc += "To start the conversation, please begin with a greeting and avoid mentioning the target {}.\n".format(domain) 105 | assistant_desc += "As the conversation progresses, use your domain knowledge to steer the discussed topic towards the target {} step by step.\n".format(domain) 106 | assistant_desc += "Be informative and engaging while providing insights to arouse {}'s interest.\n".format(simulated_profile["Name"]) 107 | assistant_desc += "Remember to ultimately recommend \"{}\" as the focus of the conversation.\n".format(target[1]) 108 | assistant_desc += "Your words at each turn should be concise (no longer than 30 words).\n\n" 109 | assistant_desc += "You may access the following domain knowledge for conversation: \n## {}.".format(domain_knowledge) 110 | 111 | assistant_dict = { 112 | "name": assistant_name, 113 | "role_desc": assistant_desc, 114 | } 115 | 116 | # Describe the moderator 117 | moderator_desc = "You are the moderator of a conversation. You need to determine whether the discussion between Role-S and Role-U should come to an immediate end.\n" 118 | moderator_desc += "The conversation should conclude under the following two conditions:\n" 119 | moderator_desc += "(1) If Role-S completes {} recommendation on \"{}\" and Role-U accepts it, and Role-S no longer takes the initiative for two rounds.\n".format(domain, target[1]) 120 | moderator_desc += "(2) If Role-U explicitly rejects Role-S's recommendation on \"{}\" when Role-S has tried to recommend it for the second time.\n".format(target[1]) 121 | moderator_desc += "In either of these cases, the conversation should be brought to an immediate end.\n\n" 122 | 123 | moderator_desc += "For example, here is a conversation:\n## {}".format(seed_conversation["seed_continue"]) 124 | moderator_desc += "Should the conversation end? The answer is no.\n\n" 125 | moderator_desc += "Here is another conversation:\n## {}".format(seed_conversation["seed_end"]) 126 | moderator_desc += "Should the conversation end? The answer is yes." 127 | 128 | 129 | terminal_condition = "Now, for the above discussion between {} (Role-S) and {} (Role-U), should the conversation end? Answer yes or no.".format(assistant_name, simulated_profile["Name"]) 130 | 131 | moderator_dict = { 132 | "role_desc": moderator_desc, 133 | "terminal_condition": terminal_condition 134 | } 135 | 136 | return (env_desc, user_dict, assistant_dict, moderator_dict) 137 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import random 4 | 5 | 6 | def find_word_in_string(w, s): 7 | return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s) 8 | 9 | def normalize_profile(profile: dict, domain: str): 10 | """Nomalize profile based on specific domain""" 11 | norm_profile = {} 12 | for slot_k, slot_value in profile.items(): 13 | if slot_k == "Age Range": 14 | norm_profile[slot_k] = slot_value.replace("years old", "").strip() 15 | elif slot_k == "Accepted Music": # mismatched slot key in raw data 16 | if "Accepted music" in norm_profile.keys(): 17 | norm_profile["Accepted music"] += "; " + slot_value 18 | else: 19 | norm_profile["Accepted music"] = slot_value 20 | elif slot_k == "Accepted movie": # mismatched slot key in raw data 21 | if "Accepted movies" in norm_profile.keys(): 22 | norm_profile["Accepted movies"] += "; " + slot_value 23 | else: 24 | norm_profile["Accepted movies"] = slot_value 25 | else: 26 | if slot_k in norm_profile.keys(): 27 | norm_profile[slot_k] += "; " + slot_value 28 | else: 29 | norm_profile[slot_k] = slot_value 30 | 31 | for slot_k, slot_v in norm_profile.items(): 32 | if "Accepted" in slot_k or "Rejected" in slot_k: 33 | if len(slot_v.split("; ")) > 2: 34 | norm_profile[slot_k] = "; ".join(slot_v.split("; ")[:2]) 35 | 36 | # remove unnecessary slots for a specific domain 37 | assert domain in ["movie", "music", "food", "poi"] 38 | if "Accepted news" in norm_profile.keys(): 39 | norm_profile.pop("Accepted news") 40 | if "Favorite news" in norm_profile.keys(): 41 | norm_profile.pop("Favorite news") 42 | if "Reject" in norm_profile.keys(): 43 | norm_profile.pop("Reject") 44 | 45 | if domain == "food" or domain == "poi": 46 | if "Accepted movies" in norm_profile.keys(): 47 | norm_profile.pop("Accepted movies") 48 | if "Accepted music" in norm_profile.keys(): 49 | norm_profile.pop("Accepted music") 50 | if "Accepted celebrities" in norm_profile.keys(): 51 | norm_profile.pop("Accepted celebrities") 52 | if "Rejected movies" in norm_profile.keys(): 53 | norm_profile.pop("Rejected movies") 54 | if "Rejected music" in norm_profile.keys(): 55 | norm_profile.pop("Rejected music") 56 | else: 57 | if "Accepted food" in norm_profile.keys(): 58 | norm_profile.pop("Accepted food") 59 | if "Accepted POI" in norm_profile.keys(): 60 | norm_profile.pop("Accepted POI") 61 | 62 | return norm_profile 63 | 64 | 65 | def sample_profile(profile_slots, target_topic, domain): 66 | """Sample a profile different from raw_profile.""" 67 | sampled_profile = {} 68 | for slot_key, slot_values in profile_slots.items(): 69 | sampled_value = random.choice(slot_values) 70 | while sampled_value in target_topic or target_topic in sampled_value: 71 | sampled_value = random.choice(slot_values) 72 | sampled_profile[slot_key] = sampled_value 73 | # check age range and occupation 74 | if sampled_profile["Age Range"] == "Under 18": 75 | sampled_profile["Occupation"] = "Student" 76 | elif sampled_profile["Age Range"] == "18-25" or sampled_profile["Age Range"] == "26-35": 77 | sampled_profile["Occupation"] = random.choice(["Student", "Employed"]) 78 | elif sampled_profile["Age Range"] == "36-50": 79 | sampled_profile["Occupation"] = "Employed" 80 | else: 81 | sampled_profile["Occupation"] = random.choice(["Employed", "Retired"]) 82 | 83 | normed_profile = normalize_profile(sampled_profile, domain) 84 | return normed_profile 85 | 86 | 87 | def check_kg_exceed(kg_list, max_len): 88 | limit_len = max_len - len(kg_list) 89 | kg_str = " ".join([" ".join(kg) for kg in kg_list]) 90 | kg_tokens = kg_str.split(" ") 91 | if len(kg_tokens) > limit_len: 92 | return True 93 | else: 94 | return False 95 | 96 | def check_topic_covered(sampled_kg, topic_path): 97 | sampled_objs = set() 98 | for triple in sampled_kg: 99 | s, p, o = triple 100 | sampled_objs.add(s) 101 | sampled_objs.add(o) 102 | is_covered = True 103 | for t in topic_path: 104 | if t != "NULL" and t not in sampled_objs: 105 | is_covered = False 106 | break 107 | return is_covered 108 | 109 | def get_outer_kg(kg_list, sampled_kg, topic_path): 110 | topic_list = [] 111 | for t in topic_path: 112 | if t != "NULL": 113 | topic_list.append(t) 114 | 115 | sampled_objs = set() 116 | for triple in sampled_kg: 117 | s, p, o = triple 118 | sampled_objs.add(s) 119 | sampled_objs.add(o) 120 | 121 | tmp_kg = {} 122 | for t in topic_list: 123 | if t not in sampled_objs: 124 | for triple in kg_list: 125 | s, p, o = triple 126 | if s == t: 127 | if s in tmp_kg: 128 | tmp_kg[s].append(triple) 129 | else: 130 | tmp_kg[s] = [triple] 131 | elif o == t: 132 | if o in tmp_kg: 133 | tmp_kg[o].append(triple) 134 | else: 135 | tmp_kg[o] = [triple] 136 | outer_kg = [] 137 | for k, v_list in tmp_kg.items(): 138 | spo = random.sample(v_list, 1) 139 | outer_kg.append(spo[0]) 140 | 141 | return outer_kg 142 | 143 | def sample_knowledge(raw_kg_list, target, topic_path, user_utt="", bot_utt="", max_len=300): 144 | kg_list = [] 145 | for kg in raw_kg_list: 146 | s, p, o = kg 147 | if p == "Stars": 148 | if len(o.split()) <= 40: 149 | kg_list.append(kg) 150 | else: 151 | kg_list.append(kg) 152 | 153 | topic_trans = [] 154 | kg_topic_path = [] 155 | for t in topic_path: 156 | if t != "NULL": 157 | kg_topic_path.append(t) 158 | if len(kg_topic_path) > 1: 159 | for j in range(1, len(kg_topic_path)-1): 160 | topic_trans.append([kg_topic_path[j], kg_topic_path[j-1]]) 161 | 162 | sampled_kg = [] 163 | for kg in kg_list: 164 | s, p, o = kg 165 | 166 | if target[0] == "Food recommendation" and target[1] == "Marinated Fish" and p == "Specials" and o == "Marinated Fish": 167 | pass 168 | elif s == target[1] or o == target[1]: 169 | if not kg in sampled_kg: 170 | sampled_kg.append(kg) 171 | if "℃" in o and "℃" in bot_utt: 172 | if not kg in sampled_kg: 173 | sampled_kg.append(kg) 174 | if p == "Perfect for having" and (o.lower() in user_utt.lower() or o.lower() in bot_utt.lower()): 175 | if not kg in sampled_kg: 176 | sampled_kg.append(kg) 177 | if p.lower() in user_utt.lower() or p.lower() in bot_utt.lower(): 178 | if p == "Sings" and s == topic_path[0]: 179 | pass 180 | elif p == "Achievement" and s == topic_path[0]: 181 | if o.lower() in bot_utt.lower(): 182 | if not kg in sampled_kg: 183 | sampled_kg.append(kg) 184 | elif p == "Awards" and s == topic_path[0]: 185 | if o.lower() in bot_utt.lower(): 186 | if not kg in sampled_kg: 187 | sampled_kg.append(kg) 188 | else: 189 | if s == topic_path[0] or o == topic_path[0]: 190 | if not kg in sampled_kg: 191 | sampled_kg.append(kg) 192 | if s == topic_path[0]: 193 | if o.lower() in bot_utt.lower(): 194 | if not kg in sampled_kg: 195 | sampled_kg.append(kg) 196 | for tp in topic_trans: 197 | src, tgt = tp 198 | if (src == s and tgt in o) or (src in o and tgt == s): 199 | if not kg in sampled_kg: 200 | sampled_kg.append(kg) 201 | # check which topic not in sampled knowledge 202 | outer_kg = get_outer_kg(kg_list, sampled_kg, topic_path) 203 | if len(outer_kg) > 0: 204 | sampled_kg += outer_kg 205 | 206 | noised_kg = [] 207 | for kg in kg_list: 208 | if not kg in sampled_kg: 209 | noised_kg.append(kg) 210 | random.shuffle(noised_kg) 211 | 212 | num_spling = 1 213 | tmp_kg = [] 214 | while True: 215 | if num_spling > len(noised_kg): 216 | break 217 | tmp_kg = random.sample(noised_kg, num_spling) 218 | check_kg = sampled_kg + tmp_kg 219 | if check_kg_exceed(check_kg, max_len=max_len): 220 | break 221 | num_spling += 1 222 | sampled_kg += tmp_kg[:-1] 223 | random.shuffle(sampled_kg) 224 | 225 | return sampled_kg -------------------------------------------------------------------------------- /eval/eval_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import argparse 4 | import json 5 | import numpy as np 6 | from collections import Counter 7 | import nltk 8 | from nltk.translate import bleu_score 9 | from nltk.translate.bleu_score import SmoothingFunction 10 | 11 | 12 | def calc_bleu(hyps, refs): 13 | """ Calculate bleu score """ 14 | bleu_1 = [] 15 | bleu_2 = [] 16 | for hyp, ref in zip(hyps, refs): 17 | try: 18 | score = bleu_score.sentence_bleu( 19 | [ref], hyp, 20 | smoothing_function=SmoothingFunction().method1, 21 | weights=[1, 0, 0, 0]) 22 | except: 23 | score = 0 24 | bleu_1.append(score) 25 | try: 26 | score = bleu_score.sentence_bleu( 27 | [ref], hyp, 28 | smoothing_function=SmoothingFunction().method1, 29 | weights=[0, 1, 0, 0]) 30 | except: 31 | score = 0 32 | bleu_2.append(score) 33 | bleu_1 = np.average(bleu_1) 34 | bleu_2 = np.average(bleu_2) 35 | avg_bleu = (bleu_1 + bleu_2) / 2 36 | return bleu_1, bleu_2, avg_bleu 37 | 38 | 39 | def calc_knowledge_f1(hyps, knowledge_refs, knowledge_alls): 40 | """" Calculate knowledge f1 score """ 41 | golden_total = 0.0 42 | pred_total = 0.0 43 | hit_total = 0.0 44 | for response, golden_kd, all_kd in zip(hyps, knowledge_refs, knowledge_alls): 45 | golden_total += len(golden_kd) 46 | for kd in golden_kd: 47 | if is_obj_hit(response, kd): 48 | hit_total += 1 49 | for kd in all_kd: 50 | if is_obj_hit(response, kd): 51 | pred_total += 1 52 | p = hit_total / pred_total if pred_total > 0 else 0 53 | r = hit_total / golden_total if golden_total > 0 else 0 54 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0 55 | return f1 56 | 57 | 58 | def calc_persona_f1(hyps, persona_refs, persona_alls): 59 | """Calculate persona f1 score""" 60 | golden_total = 0.0 61 | pred_total = 0.0 62 | hit_total = 0.0 63 | for response, golden_persona, all_persona in zip(hyps, persona_refs, persona_alls): 64 | golden_total += len(golden_persona) 65 | for persona in golden_persona: 66 | if is_obj_hit(response, persona, threshold=0.8): 67 | hit_total += 1 68 | for persona in all_persona: 69 | if is_obj_hit(response, persona, threshold=0.8): 70 | pred_total += 1 71 | p = hit_total / pred_total if pred_total > 0 else 0 72 | r = hit_total / golden_total if golden_total > 0 else 0 73 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0 74 | return f1 75 | 76 | def calc_succ(eval_fp, gold_fp): 77 | all_eval, all_gold = [], [] 78 | with open(eval_fp, 'r', encoding='utf-8') as fr: 79 | for line in fr: 80 | sample = json.loads(line) 81 | all_eval.append(sample) 82 | with open(gold_fp, 'r', encoding='utf-8') as fr: 83 | for line in fr: 84 | raw_sample = json.loads(line) 85 | sample = { 86 | "id": raw_sample["id"], 87 | "target": raw_sample["target"], 88 | "response": raw_sample["response"] 89 | } 90 | all_gold.append(sample) 91 | assert len(all_eval) == len(all_gold) 92 | 93 | topic_hit, topic_total = 0, 0 94 | movie_hit, music_hit, poi_hit, food_hit = 0, 0, 0, 0 95 | movie_total, music_total, poi_total, food_total = 0, 0, 0, 0 96 | 97 | for idx, gold_sample in enumerate(all_gold): 98 | if gold_sample["target"][1].lower() in gold_sample["response"].lower(): 99 | topic_total += 1 100 | eval_action = gold_sample["target"][0] 101 | eval_topic = gold_sample["target"][1] 102 | 103 | # eval target turn and neighboring turns 104 | eval_list = get_eval_response(idx, all_eval, all_gold) 105 | 106 | eval_topic = " ".join(nltk.word_tokenize(eval_topic)) 107 | eval_list = [" ".join(nltk.word_tokenize(eval_response)) for eval_response in eval_list] 108 | 109 | if is_topic_hit(eval_topic, eval_list): 110 | topic_hit += 1 111 | 112 | if eval_action == "Movie recommendation": 113 | movie_total += 1 114 | if is_topic_hit(eval_topic, eval_list): 115 | movie_hit += 1 116 | elif eval_action == "Music recommendation" or eval_action == "Play music": 117 | music_total += 1 118 | if is_topic_hit(eval_topic, eval_list): 119 | music_hit += 1 120 | elif eval_action == "POI recommendation": 121 | poi_total += 1 122 | if is_topic_hit(eval_topic, eval_list): 123 | poi_hit += 1 124 | elif eval_action == "Food recommendation": 125 | food_total += 1 126 | if is_topic_hit(eval_topic, eval_list): 127 | food_hit += 1 128 | succ_rate = float(topic_hit) / topic_total 129 | movie_rec_sr = float(movie_hit) / movie_total 130 | music_rec_sr = float(music_hit) / music_total 131 | poi_rec_sr = float(poi_hit) / poi_total 132 | food_rec_sr = float(food_hit) / food_total 133 | print("Succ.: {:.2f}%".format(succ_rate*100)) 134 | print("Succ.-Movie: {}/{} = {:.2f}%".format(movie_hit, movie_total, movie_rec_sr*100)) 135 | print("Succ.-Music: {}/{} = {:.2f}%".format(music_hit, music_total, music_rec_sr*100)) 136 | print("Succ.-POI: {}/{} = {:.2f}%".format(poi_hit, poi_total, poi_rec_sr*100)) 137 | print("Succ.-Food: {}/{} = {:.2f}%".format(food_hit, food_total, food_rec_sr*100)) 138 | 139 | 140 | def get_eval_response(idx, eval_samples, gold_samples): 141 | eval_list = [eval_samples[idx]["response"]] 142 | dialog_id = gold_samples[idx]["id"] 143 | if idx - 1 >= 0 and gold_samples[idx-1]["id"] == dialog_id: 144 | eval_list.append(eval_samples[idx-1]["response"]) 145 | if idx + 1 < len(gold_samples) and gold_samples[idx+1]["id"] == dialog_id: 146 | eval_list.append(eval_samples[idx+1]["response"]) 147 | return eval_list 148 | 149 | def is_topic_hit(topic, candidates): 150 | for cand in candidates: 151 | if topic.lower() in cand.lower(): 152 | return True 153 | return False 154 | 155 | def is_obj_hit(utterance_toks, obj_str, threshold=0.55): 156 | utterance = " ".join(utterance_toks) 157 | flag = False 158 | if obj_str in utterance: 159 | flag = True 160 | else: 161 | # English word-level 162 | common = Counter(utterance.split()) & Counter(obj_str.split()) 163 | # knowledge recall 164 | hit_char_total = sum(common.values()) 165 | golden_char_total = len(obj_str) 166 | recall = hit_char_total / golden_char_total if golden_char_total > 0 else 0 167 | if recall >= threshold: 168 | flag = True 169 | return flag 170 | 171 | def label_knowledge(utterance_toks, kg_list, lower_case=True): 172 | gold_knowledge = [] 173 | all_objs = set() 174 | for triple in kg_list: 175 | assert len(triple) == 3 176 | all_objs.add(triple[0].lower() if lower_case else triple[0]) 177 | all_objs.add(triple[2].lower() if lower_case else triple[2]) 178 | for obj in all_objs: 179 | if is_obj_hit(utterance_toks, obj): 180 | gold_knowledge.append(obj) 181 | all_objs = list(all_objs) 182 | return all_objs, gold_knowledge 183 | 184 | def label_persona(utterance_toks, persona_dict, lower_case=True): 185 | all_personas = [] 186 | gold_persona = [] 187 | for k, v in persona_dict.items(): 188 | if v != '' and v != ' ': 189 | all_personas.append(v.lower() if lower_case else v) 190 | for persona in all_personas: 191 | if is_obj_hit(utterance_toks, persona, threshold=0.8): 192 | gold_persona.append(persona) 193 | return all_personas, gold_persona 194 | 195 | 196 | def load_data(fp, is_gold=False, lower_case=True): 197 | samples = [] 198 | all_knowledges, gold_knowledges = [], [] 199 | all_personas, gold_personas = [], [] 200 | with open(fp, 'r', encoding='utf-8') as fr: 201 | for idx, line in enumerate(fr): 202 | sample = json.loads(line) 203 | response = sample["response"].lower() if lower_case else sample["response"] 204 | # English word-level 205 | sentence_toks = nltk.word_tokenize(response) 206 | samples.append(sentence_toks) 207 | if is_gold: 208 | knowledge = sample["knowledge"] 209 | all_k, all_k= label_knowledge(sentence_toks, knowledge, lower_case=lower_case) 210 | all_knowledges.append(all_k) 211 | gold_knowledges.append(all_k) 212 | persona = sample["user_profile"] 213 | all_p, gold_p = label_persona(sentence_toks, persona, lower_case=lower_case) 214 | all_personas.append(all_p) 215 | gold_personas.append(gold_p) 216 | if is_gold: 217 | assert len(samples) == len(all_knowledges) and \ 218 | len(samples) == len(gold_knowledges) and \ 219 | len(samples) == len(all_personas) 220 | return (samples, all_knowledges, gold_knowledges, all_personas, gold_personas) 221 | else: 222 | return samples 223 | 224 | 225 | if __name__ == "__main__": 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument("--eval_file", type=str) 228 | parser.add_argument("--gold_file", type=str) 229 | args = parser.parse_args() 230 | 231 | preds = load_data(args.eval_file) 232 | refs, all_knowledges, ref_knowlwedges, all_peronas, ref_personas = load_data(args.gold_file, is_gold=True) 233 | assert len(preds) == len(refs) 234 | 235 | # calculate bleu 236 | bleu1, bleu2, avg_bleu = calc_bleu(preds, refs) 237 | 238 | # calculate knowledge-F1 239 | kg_f1 = calc_knowledge_f1(preds, ref_knowlwedges, all_knowledges) 240 | 241 | # calculate persona-F1 242 | persona_f1 = calc_persona_f1(preds, ref_personas, all_peronas) 243 | 244 | output_str = "Avg. BLEU: %.3f\n" % avg_bleu 245 | output_str += "Knowledge F1: %.2f%%\n" % (kg_f1 * 100) 246 | output_str += "Persona F1: %.2f%%" % (persona_f1 * 100) 247 | 248 | print(output_str) 249 | 250 | # calculate target success 251 | calc_succ(args.eval_file, args.gold_file) 252 | -------------------------------------------------------------------------------- /dialog_simulation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import json 4 | import os 5 | import random 6 | import argparse 7 | from tqdm import tqdm 8 | from chatarena.agent import Player, Moderator 9 | from chatarena.backends import OpenAIChat 10 | from chatarena.environments.conversation import ModeratedConversation 11 | from chatarena.arena import Arena 12 | from data_utils import find_word_in_string 13 | from instruction import create_instruct 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--cached_seed_path", type=str, required=True, 20 | help="The cached seed dialog file.") 21 | parser.add_argument("--profile_path", type=str, default="seed_dataset/caches/db_slot/slot_profiles.json", 22 | help="The user profile slot-values file.") 23 | parser.add_argument("--output_dir", type=str, default="data/TopDial", 24 | help="The output directory to save the simulated dialog data.") 25 | parser.add_argument("--max_interaction_step", type=int,default=12, 26 | help="The max number of interaction steps, i.e., 2 * max rounds.") 27 | parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo", 28 | help="The chat model to use.") 29 | parser.add_argument("--temperature", type=float, default=0.75, 30 | help="The temperature to use in sampling.") 31 | parser.add_argument("--max_system_tokens", type=int, default=100, 32 | help="The max number of tokens to generate for the system.") 33 | parser.add_argument("--max_user_tokens", type=int, default=80, 34 | help="The max number of tokens to generate for the user.") 35 | parser.add_argument("--max_moderator_tokens", type=int, default=10, 36 | help="The max number of tokens to generate for the moderator.") 37 | parser.add_argument("--show_description", type=str2bool, default="true", 38 | help="Whether to show the role description.") 39 | parser.add_argument("--show_message", type=str2bool, default="true", 40 | help="Whether to show the conversation messages.") 41 | parser.add_argument("--random_seed", type=int, default=42) 42 | return parser.parse_args() 43 | 44 | def str2bool(v): 45 | if v.lower() in ('true', 'yes', 't', 'y', '1'): 46 | return True 47 | elif v.lower() in ('false',' no', 'f', 'n', '0'): 48 | return False 49 | else: 50 | raise argparse.ArgumentTypeError("Unsupported value encountered.") 51 | 52 | def clean_utterance(s): 53 | s = s.strip() 54 | for start_str in ['[1]', '[2]', '[3]', '[4]', '[5]', '[6]', '[7]', '[8]', '[9]']: 55 | if s.startswith(start_str): 56 | s = s[len(start_str):].strip() 57 | return s 58 | 59 | def prompt_conversation(raw_goal, conversation): 60 | """Prompt the conversation context.""" 61 | conversation_ctx = "" 62 | for idx, utt in enumerate(conversation): 63 | utt = clean_utterance(utt) 64 | if "User Initiative" in raw_goal: 65 | if idx % 2 == 0: 66 | conversation_ctx += f"[Role-U]: {utt}\n\n" 67 | else: 68 | conversation_ctx += f"[Role-S]: {utt}\n\n" 69 | else: 70 | if idx % 2 == 0: 71 | conversation_ctx += f"[Role-S]: {utt}\n\n" 72 | else: 73 | conversation_ctx += f"[Role-U]: {utt}\n\n" 74 | return conversation_ctx 75 | 76 | def sample_seed_conversation(raw_goal, conversation): 77 | """Sample seed conversations (continue | end).""" 78 | conv_lens = len(conversation) 79 | continue_len = random.choice(range(1, int(conv_lens * 0.6))) 80 | conv_continue = prompt_conversation(raw_goal, conversation[:continue_len]) 81 | conv_end = prompt_conversation(raw_goal, conversation) 82 | seed_conv = { 83 | "seed_continue": conv_continue, 84 | "seed_end": conv_end 85 | } 86 | return seed_conv 87 | 88 | def sample_assistant_role(profile_slots, user_profile): 89 | """Sample an assistant role.""" 90 | all_names = profile_slots["Name"] 91 | user_name = user_profile["Name"] 92 | sampled_name = random.choice(all_names) 93 | while find_word_in_string(sampled_name, user_name): 94 | sampled_name = random.choice(all_names) 95 | return sampled_name 96 | 97 | def sample_personality(): 98 | """Sample a personality based on Big Five personality traits.""" 99 | personalities = { 100 | "agreeableness": ["trustworthy, straightforward, and generous", "unreliable, complicated, meager, and boastful"], 101 | "conscientiousness": ["efficient, organized, and careful", "inefficient, careless, and sloppy"], 102 | "extraversion": ["outgoing, energetic, and talkative", "shy, reserved, and quiet"], 103 | "neuroticism": ["sensitive, nervous, and insecure", "secure, confident, and calm"], 104 | "openness": ["intellectual, imaginative, and curious", "unimaginative, uncreative, and conventional"] 105 | } 106 | sampled_personality = {} 107 | for trait, values in personalities.items(): 108 | sampled_personality[trait] = random.choice(values) 109 | return sampled_personality 110 | 111 | 112 | def generate_dialog_data( 113 | profile_path, 114 | seed_path, 115 | output_dir, 116 | max_interaction_step=10, 117 | model_name="gpt-3.5-turbo", 118 | temperature=0.75, 119 | max_system_tokens=100, 120 | max_user_tokens=80, 121 | max_moderator_tokens=10, 122 | show_description=True, 123 | show_message=True, 124 | ): 125 | """Generate dialog data from a seed dialog file.""" 126 | profile_slots = json.load(open(profile_path, "r", encoding='utf-8')) 127 | print(f"Loaded user profiles with {len(profile_slots)} slot keys.") 128 | 129 | seed_dialogs = [] 130 | with open(seed_path, "r", encoding='utf-8') as f: 131 | for line in f: 132 | seed_dialogs.append(json.loads(line)) 133 | print(f"Loaded {len(seed_dialogs)} cached dialogs.") 134 | 135 | if not os.path.exists(output_dir): 136 | os.makedirs(output_dir) 137 | if "test_seen" in seed_path: 138 | output_path = os.path.join(output_dir, "dialogue_test_seen.jsonl") 139 | elif "test_unseen" in seed_path: 140 | output_path = os.path.join(output_dir, "dialogue_test_unseen.jsonl") 141 | elif "dev" in seed_path: 142 | output_path = os.path.join(output_dir, "dialogue_dev.jsonl") 143 | else: 144 | output_path = os.path.join(output_dir, "dialogue_train.jsonl") 145 | 146 | with open(output_path, "w", encoding='utf-8') as fw: 147 | for seed_dialog in tqdm(seed_dialogs): 148 | simulated_profile = seed_dialog["user_profile"] 149 | sampled_knowledge = seed_dialog["knowledge"] 150 | target = seed_dialog["target"] 151 | 152 | conversation = seed_dialog["seed_conversation"] 153 | seed_conv = sample_seed_conversation(seed_dialog["original_goal"], conversation) 154 | 155 | # randomly sample a personality 156 | simulated_personality = sample_personality() 157 | assistant_name = sample_assistant_role(profile_slots, simulated_profile) 158 | 159 | env_desc, user_dict, assistant_dict, moderator_dict = create_instruct( 160 | target=target, 161 | simulated_profile=simulated_profile, 162 | simulated_personality=simulated_personality, 163 | assistant_name=assistant_name, 164 | domain_knowledge=sampled_knowledge, 165 | seed_conversation=seed_conv 166 | ) 167 | assistant = Player( 168 | name=assistant_dict["name"], backend=OpenAIChat(model=model_name, temperature=temperature, max_tokens=max_system_tokens), 169 | role_desc=assistant_dict["role_desc"], global_prompt=env_desc 170 | ) 171 | user = Player( 172 | name=user_dict["name"], backend=OpenAIChat(model=model_name, temperature=temperature, max_tokens=max_user_tokens), 173 | role_desc=user_dict["role_desc"], global_prompt=env_desc 174 | ) 175 | moderator = Moderator( 176 | backend=OpenAIChat(model=model_name, temperature=temperature, max_tokens=max_moderator_tokens), 177 | role_desc=moderator_dict["role_desc"], terminal_condition=moderator_dict["terminal_condition"] 178 | ) 179 | # let assistant start the conversation 180 | env = ModeratedConversation(player_names=[p.name for p in [assistant, user]], moderator=moderator, moderator_period="round") 181 | arena = Arena(players=[assistant, user], environment=env, global_prompt=env_desc) 182 | 183 | arena.launch_cli(max_steps=max_interaction_step, show_description=show_description, show_message=show_message, interactive=False) 184 | 185 | #print("Save? (y/n)") 186 | #if input() == "n": 187 | # continue 188 | 189 | # save the simulated dialog to file 190 | messages = env.get_observation() 191 | simulated_convs = [] 192 | for msg in messages: 193 | if msg.agent_name == assistant.name: 194 | utt = {"system": msg.content} 195 | else: 196 | utt = {"user": msg.content} 197 | simulated_convs.append(utt) 198 | 199 | write_line = { 200 | "id": "s_" + str(seed_dialog["id"]), 201 | "user_profile": simulated_profile, 202 | "user_personality": simulated_personality, 203 | "knowledge": sampled_knowledge, 204 | "target": target, 205 | "conversation": simulated_convs 206 | } 207 | fw.write(json.dumps(write_line, ensure_ascii=False) + "\n") 208 | fw.flush() 209 | 210 | print("Sleeping for 5 seconds...") 211 | time.sleep(5) 212 | 213 | #print("Continue? (y/n)") 214 | #if input() == "n": 215 | # break 216 | 217 | 218 | if __name__ == '__main__': 219 | args = parse_args() 220 | random.seed(args.random_seed) 221 | 222 | generate_dialog_data(args.profile_path, args.cached_seed_path, args.output_dir, 223 | max_interaction_step=args.max_interaction_step, 224 | model_name=args.model_name, 225 | temperature=args.temperature, 226 | max_system_tokens=args.max_system_tokens, 227 | max_user_tokens=args.max_user_tokens, 228 | max_moderator_tokens=args.max_moderator_tokens, 229 | show_description=args.show_description, 230 | show_message=args.show_message) 231 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | import random 5 | import argparse 6 | from tqdm import tqdm 7 | from py2neo import Graph 8 | from data_utils import normalize_profile, sample_profile, sample_knowledge 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument( 15 | "--seed_dataset_dir", 16 | type=str, 17 | default="seed_dataset/DuRecDial2", 18 | help="The seed dataset directory." 19 | ) 20 | parser.add_argument( 21 | "--cache_dir", 22 | type=str, 23 | default="seed_dataset/caches", 24 | help="The cached data directory." 25 | ) 26 | parser.add_argument( 27 | "--num_instance_per_seed", 28 | type=int, 29 | default=3, 30 | help="The number of instances to curate for each seed dialog.", 31 | ) 32 | parser.add_argument( 33 | "--random_seed", 34 | type=int, 35 | default=42, 36 | ) 37 | return parser.parse_args() 38 | 39 | 40 | def extract_profile(data_fp_list, save_fp=None): 41 | """Extract all user profile slots from the given data file.""" 42 | 43 | SLOT_KEYS = [ 44 | "Age Range", "Name", "Gender", "Residence", "Occupation", "POI", 45 | "Accepted movies", "Accepted music", "Accepted celebrities", "Accepted food", "Accepted POI", 46 | "Reject", "Rejected movies", "Rejected music" 47 | ] 48 | ALL_SLOTS = dict() 49 | for k in SLOT_KEYS: 50 | ALL_SLOTS[k] = set() 51 | 52 | for data_fp in data_fp_list: 53 | with open(data_fp, 'r', encoding='utf-8') as fp: 54 | for line in fp: 55 | sample = json.loads(line.strip()) 56 | for slot in sample['user_profile']: 57 | if slot in SLOT_KEYS: 58 | slot_value = list(sample['user_profile'][slot].split("; ")) 59 | for v in slot_value: 60 | if slot == "Age Range": 61 | v = v.replace("years old", "").strip() 62 | ALL_SLOTS[slot].add(v) 63 | elif slot == "Accepted Music": 64 | slot_value = list(sample['user_profile'][slot].split("; ")) 65 | for v in slot_value: 66 | ALL_SLOTS["Accepted music"].add(v) 67 | elif slot == "Accepted movie": 68 | slot_value = list(sample['user_profile'][slot].split("; ")) 69 | for v in slot_value: 70 | ALL_SLOTS["Accepted movies"].add(v) 71 | else: 72 | print("Out of slot keys: ", slot) 73 | for k in ALL_SLOTS: 74 | ALL_SLOTS[k] = list(ALL_SLOTS[k]) 75 | print(k, len(ALL_SLOTS[k])) 76 | if save_fp is not None: 77 | with open(save_fp, 'w', encoding='utf-8') as fp: 78 | json.dump(ALL_SLOTS, fp, indent=4, ensure_ascii=False) 79 | print("Saved to {}".format(save_fp)) 80 | 81 | 82 | def exe_query(graph: Graph, query: str): 83 | triple_dict = {} 84 | results = graph.run(query).data() 85 | for res in results: 86 | s = "{}".format(res['s.value']) 87 | r = "{}".format(res['type(r)']) 88 | o = "{}".format(res['o.value']) 89 | kk = "{}__REL__{}".format(s, r) 90 | if kk in triple_dict.keys(): 91 | triple_dict[kk].append(o) 92 | else: 93 | triple_dict[kk] = [o] 94 | triples = [] 95 | for kk, vv in triple_dict.items(): 96 | s, r = kk.split("__REL__") 97 | o = random.choice(vv) 98 | triples.append([s, r, o]) 99 | 100 | return triples 101 | 102 | def ground_knowledge(graph, data_fp_list, profile_fp, save_dir, num_instance_per_seed=3): 103 | """Ground seed dialogs with domain knowledge and comments.""" 104 | 105 | profile_slots = json.load(open(profile_fp, "r", encoding='utf-8')) 106 | print(f"Loaded user profiles with {len(profile_slots)} slot keys.") 107 | 108 | for data_fp in data_fp_list: 109 | seed_dialogs = [] 110 | with open(data_fp, "r", encoding='utf-8') as f: 111 | for line in f: 112 | seed_dialogs.append(json.loads(line)) 113 | print(f"Loaded {len(seed_dialogs)} seed dialogs from {data_fp}.") 114 | 115 | save_fp = os.path.join(save_dir, "cache_{}".format(data_fp.split("/")[-1])) 116 | with open(save_fp, "w", encoding='utf-8') as fw: 117 | for seed_dialog in tqdm(seed_dialogs): 118 | user_profile = seed_dialog["user_profile"] 119 | knowledge = seed_dialog["knowledge_graph"] 120 | target = seed_dialog["target"] 121 | 122 | domain = "" 123 | target_action = target[0].lower() 124 | if "movie" in target_action: 125 | domain = "movie" 126 | elif "music" in target_action: 127 | domain = "music" 128 | elif "food" in target_action: 129 | domain = "food" 130 | elif "poi" in target_action: 131 | domain = "poi" 132 | else: 133 | raise ValueError("Invalid target action: {}".format(target_action)) 134 | 135 | for idx in range(num_instance_per_seed): 136 | if idx == 0: 137 | # adopt raw user profile 138 | simulated_profile = normalize_profile(user_profile, domain) 139 | else: 140 | # sample a profile different from raw user profile 141 | simulated_profile = sample_profile(profile_slots, target_topic=target[1], domain=domain) 142 | 143 | sampled_knowledge = sample_knowledge(knowledge, target, topic_path=seed_dialog["topic_path"], max_len=300) 144 | 145 | # sample comment about the target topic 146 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND type(r)="{}" RETURN s.value, type(r), o.value'.format(target[1], "Comments") 147 | target_comments = exe_query(graph, query_t) 148 | if len(target_comments) > 0: 149 | target_comment = random.choice(target_comments) 150 | sampled_knowledge.append(target_comment) 151 | 152 | profile_knowledge = [] 153 | for slot_key, slot_value in simulated_profile.items(): 154 | if "movies" in slot_key or "music" in slot_key: 155 | # sample domain knowledge about movies/music 156 | entities = slot_value.split("; ") 157 | for ent in entities: 158 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND (type(r)="{}" OR type(r)="{}" OR type(r)="{}" OR type(r)="{}") RETURN s.value, type(r), o.value'.format( 159 | ent, "Stars", "Sings", "Type", "Comments") 160 | triples = exe_query(graph, query_t) 161 | if len(triples) > 0: 162 | ss_triples = random.choice(triples) 163 | profile_knowledge.append(ss_triples) 164 | elif "celebrities" in slot_key: 165 | # sample domain knowledge about celebrities 166 | entities = slot_value.split("; ") 167 | for ent in entities: 168 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND (type(r)="{}" OR type(r)="{}" OR type(r)="{}") RETURN s.value, type(r), o.value'.format( 169 | ent, "Intro", "Achievement", "Comments") 170 | triples = exe_query(graph, query_t) 171 | if len(triples) > 0: 172 | ss_triples = random.choice(triples) 173 | profile_knowledge.append(ss_triples) 174 | elif "food" in slot_key or "Accepted POI" in slot_key: 175 | # sample domain knowledge about food/POI 176 | entities = slot_value.split("; ") 177 | for ent in entities: 178 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND (type(r)="{}" OR type(r)="{}" OR type(r)="{}" OR type(r)="{}") RETURN s.value, type(r), o.value'.format( 179 | ent, "Price per person", "Rating", "Address", "Comments") 180 | triples = exe_query(graph, query_t) 181 | if len(triples) > 0: 182 | ss_triples = random.choice(triples) 183 | profile_knowledge.append(ss_triples) 184 | knowledge_str_list = ["__SEP__".join(triple) for triple in sampled_knowledge] 185 | for triple in profile_knowledge: 186 | triple_str = "__SEP__".join(triple) 187 | if triple_str not in knowledge_str_list: 188 | sampled_knowledge.append(triple) 189 | 190 | new_dialog = { 191 | "id": str(seed_dialog["id"]) + "_{}".format(idx), 192 | "original_goal": seed_dialog["original_goal"], 193 | "user_profile": simulated_profile, 194 | "knowledge": sampled_knowledge, 195 | "target": target, 196 | "seed_conversation": seed_dialog["conversation"], 197 | "seed_action_path": seed_dialog["action_path"], 198 | "seed_topic_path": seed_dialog["topic_path"], 199 | } 200 | line = json.dumps(new_dialog, ensure_ascii=False) 201 | fw.write(line + "\n") 202 | fw.flush() 203 | print("Saved {} simulated dialogs to {}.".format(num_instance_per_seed * len(seed_dialogs), save_fp)) 204 | 205 | 206 | if __name__ == "__main__": 207 | args = parse_args() 208 | random.seed(args.random_seed) 209 | 210 | train_fp = os.path.join(args.data_dir, "seed_dialogue_train.jsonl") 211 | dev_fp = os.path.join(args.data_dir, "seed_dialogue_dev.jsonl") 212 | test_seen_fp = os.path.join(args.data_dir,"seed_dialogue_test_seen.jsonl") 213 | test_unseen_fp = os.path.join(args.data_dir, "seed_dialogue_test_unseen.jsonl") 214 | 215 | if not os.path.exists(args.cache_dir): 216 | os.makedirs(args.cache_dir) 217 | 218 | # prepare user profile slots 219 | saved_dir = os.path.join(args.cache_dir, "db_slot") 220 | if not os.path.exists(saved_dir): 221 | os.makedirs(saved_dir) 222 | 223 | saved_profile_fp = os.path.join(saved_dir, "slot_profiles.json") 224 | if not os.path.exists(saved_profile_fp): 225 | print("Extracting user profile slot-values...") 226 | extract_profile(data_fp_list=[train_fp, dev_fp, test_seen_fp, test_unseen_fp], save_fp=saved_profile_fp) 227 | else: 228 | print("File exists: {}".format(saved_profile_fp)) 229 | 230 | # prepare domain knowledge and topic-related comments 231 | # set neo4j database connection (username: neo4j, password: neo4j) 232 | graph = Graph("http://localhost:7474", auth=("neo4j", "neo4j")) 233 | ground_knowledge(graph, data_fp_list=[train_fp, dev_fp, test_seen_fp, test_unseen_fp], 234 | profile_fp=saved_profile_fp, save_dir=args.cache_dir, 235 | num_instance_per_seed=args.num_instance_per_seed) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------