├── chatarena
├── __init__.py
├── environments
│ ├── __init__.py
│ ├── base.py
│ └── conversation.py
├── backends
│ ├── human.py
│ ├── __init__.py
│ ├── base.py
│ ├── hf_transformers.py
│ ├── anthropic.py
│ ├── cohere.py
│ └── openai.py
├── utils.py
├── message.py
├── database.py
├── config.py
├── agent.py
├── ui
│ └── cli.py
└── arena.py
├── imgs
├── demo.gif
└── framework.png
├── dataset
└── Readme.txt
├── seed_dataset
└── Readme.txt
├── requirements.txt
├── README.md
├── .gitignore
├── instruction.py
├── data_utils.py
├── eval
└── eval_generation.py
├── dialog_simulation.py
├── data_preprocess.py
└── LICENSE
/chatarena/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iwangjian/TopDial/HEAD/imgs/demo.gif
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iwangjian/TopDial/HEAD/imgs/framework.png
--------------------------------------------------------------------------------
/dataset/Readme.txt:
--------------------------------------------------------------------------------
1 | Please follow the README.md to download the seed dataset and unzip it to this folder.
--------------------------------------------------------------------------------
/seed_dataset/Readme.txt:
--------------------------------------------------------------------------------
1 | Please follow the README.md to download the seed dataset and unzip it to this folder.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai==0.27.2
2 | anthropic==0.2.8
3 | cohere==4.3.1
4 | transformers>=4.27.4
5 | tenacity==8.2.2
6 | rich==13.3.3
7 | prompt_toolkit
8 | py2neo
9 | tqdm
10 |
--------------------------------------------------------------------------------
/chatarena/environments/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Environment, TimeStep
2 | from .conversation import Conversation, ModeratedConversation
3 | from ..config import EnvironmentConfig
4 |
5 | ALL_ENVIRONMENTS = [
6 | Conversation,
7 | ModeratedConversation,
8 | ]
9 |
10 | ENV_REGISTRY = {env.type_name: env for env in ALL_ENVIRONMENTS}
11 |
12 |
13 | # Load an environment from a config dictionary
14 | def load_environment(config: EnvironmentConfig):
15 | try:
16 | env_cls = ENV_REGISTRY[config["env_type"]]
17 | except KeyError:
18 | raise ValueError(f"Unknown environment type: {config['env_type']}")
19 |
20 | env = env_cls.from_config(config)
21 | return env
22 |
--------------------------------------------------------------------------------
/chatarena/backends/human.py:
--------------------------------------------------------------------------------
1 | from .base import IntelligenceBackend
2 | from ..config import BackendConfig
3 |
4 |
5 | # An Error class for the human backend
6 | class HumanBackendError(Exception):
7 | def __init__(self, agent_name: str):
8 | self.agent_name = agent_name
9 | super().__init__(f"Human backend requires a UI to get input from {agent_name}.")
10 |
11 |
12 | class Human(IntelligenceBackend):
13 | stateful = False
14 | type_name = "human"
15 |
16 | def __init__(self, **kwargs):
17 | super().__init__(**kwargs)
18 |
19 | def to_config(self) -> BackendConfig:
20 | return BackendConfig(backend_type=self.type_name)
21 |
22 | def query(self, agent_name: str, **kwargs) -> str:
23 | raise HumanBackendError(agent_name)
24 |
--------------------------------------------------------------------------------
/chatarena/backends/__init__.py:
--------------------------------------------------------------------------------
1 | from ..config import BackendConfig
2 |
3 | from .base import IntelligenceBackend
4 | from .openai import OpenAIChat
5 | from .cohere import CohereAIChat
6 | from .human import Human
7 | from .hf_transformers import TransformersConversational
8 | from .anthropic import Claude
9 |
10 | ALL_BACKENDS = [
11 | Human,
12 | OpenAIChat,
13 | CohereAIChat,
14 | TransformersConversational,
15 | Claude,
16 | ]
17 |
18 | BACKEND_REGISTRY = {backend.type_name: backend for backend in ALL_BACKENDS}
19 |
20 |
21 | # Load a backend from a config dictionary
22 | def load_backend(config: BackendConfig):
23 | try:
24 | backend_cls = BACKEND_REGISTRY[config.backend_type]
25 | except KeyError:
26 | raise ValueError(f"Unknown backend type: {config.backend_type}")
27 |
28 | backend = backend_cls.from_config(config)
29 | return backend
30 |
--------------------------------------------------------------------------------
/chatarena/utils.py:
--------------------------------------------------------------------------------
1 | class AttributedDict(dict):
2 | """
3 | A dict class whose keys are automatically set as attributes of the class.
4 | Serializable to JSON.
5 | """
6 |
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 |
10 | def __setattr__(self, key, value):
11 | self[key] = value
12 |
13 | def __getattr__(self, key):
14 | if key in self:
15 | return self[key]
16 | raise AttributeError
17 |
18 | def __delattr__(self, key):
19 | del self[key]
20 |
21 | # check whether the key is string when adding the key
22 | def __setitem__(self, key, value):
23 | if not isinstance(key, str):
24 | raise ValueError("The key must be a string")
25 | super().__setitem__(key, value)
26 |
27 | def update(self, *args, **kwargs):
28 | for key, value in dict(*args, **kwargs).items():
29 | self[key] = value
30 |
--------------------------------------------------------------------------------
/chatarena/backends/base.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from abc import abstractmethod
3 |
4 | from ..config import BackendConfig, Configurable
5 | from ..message import Message
6 |
7 |
8 | class IntelligenceBackend(Configurable):
9 | """An abstraction of the intelligence source of the agents."""
10 | stateful = None
11 | type_name = None
12 |
13 | @abstractmethod
14 | def __init__(self, **kwargs):
15 | super().__init__(**kwargs) # registers the arguments with Configurable
16 |
17 | def __init_subclass__(cls, **kwargs):
18 | # check if the subclass has the required attributes
19 | for required in ('stateful', 'type_name',):
20 | if getattr(cls, required) is None:
21 | raise TypeError(f"Can't instantiate abstract class {cls.__name__} without {required} attribute defined")
22 | return super().__init_subclass__(**kwargs)
23 |
24 | def to_config(self) -> BackendConfig:
25 | self._config_dict["backend_type"] = self.type_name
26 | return BackendConfig(**self._config_dict)
27 |
28 | @abstractmethod
29 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None,
30 | request_msg: Message = None, *args, **kwargs) -> str:
31 | raise NotImplementedError
32 |
33 | @abstractmethod
34 | async def async_query(self, agent_name: str, role_desc: str, history_messages: List[Message],
35 | global_prompt: str = None, request_msg: Message = None, *args, **kwargs) -> str:
36 | """Async querying"""
37 | raise NotImplementedError
38 |
39 | # reset the state of the backend
40 | def reset(self):
41 | if self.stateful:
42 | raise NotImplementedError
43 | else:
44 | pass
45 |
--------------------------------------------------------------------------------
/chatarena/environments/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Dict
3 | from abc import abstractmethod
4 |
5 | from ..message import Message
6 | from ..utils import AttributedDict
7 | from ..config import Configurable, EnvironmentConfig
8 |
9 |
10 | @dataclass
11 | class TimeStep(AttributedDict):
12 | observation: List[Message]
13 | reward: Dict[str, float]
14 | terminal: bool
15 |
16 |
17 | class Environment(Configurable):
18 | """
19 | The environment that the agents interacts with.
20 | """
21 | type_name = None
22 |
23 | @abstractmethod
24 | def __init__(self, player_names: List[str], **kwargs):
25 | super().__init__(player_names=player_names, **kwargs) # registers the arguments with Configurable
26 | self.player_names = player_names
27 |
28 | def __init_subclass__(cls, **kwargs):
29 | # check if the subclass has the required attributes
30 | for required in ('type_name',):
31 | if getattr(cls, required) is None:
32 | cls.type_name = cls.__name__.lower()
33 |
34 | return super().__init_subclass__(**kwargs)
35 |
36 | @abstractmethod
37 | def reset(self):
38 | """
39 | reset the environment
40 | """
41 | pass
42 |
43 | def to_config(self) -> EnvironmentConfig:
44 | self._config_dict["env_type"] = self.type_name
45 | return EnvironmentConfig(**self._config_dict)
46 |
47 | @property
48 | def num_players(self) -> int:
49 | """
50 | get the number of players
51 | """
52 | return len(self.player_names)
53 |
54 | @abstractmethod
55 | def get_next_player(self) -> str:
56 | """
57 | get name of the next player
58 | """
59 | pass
60 |
61 | @abstractmethod
62 | def get_observation(self, player_name=None) -> List[Message]:
63 | """
64 | get observation for the player
65 | """
66 | pass
67 |
68 | @abstractmethod
69 | def print(self):
70 | """
71 | print the environment state
72 | """
73 | pass
74 |
75 | @abstractmethod
76 | def step(self, player_name: str, action: str) -> TimeStep:
77 | """
78 | step function that is called by the arena
79 | Args:
80 | player_name: the name of the player
81 | action: the action that the agents wants to take
82 | Returns:
83 | timestep: the timestep that contains the observation, reward and done
84 | """
85 | pass
86 |
87 | @abstractmethod
88 | def check_action(self, action: str, player_name: str) -> bool:
89 | """
90 | check whether the action is valid
91 | """
92 | return True
93 |
94 | @abstractmethod
95 | def is_terminal(self) -> bool:
96 | """
97 | check whether the environment is in terminal state
98 | """
99 | pass
100 |
101 | def get_zero_rewards(self) -> Dict[str, float]:
102 | return {player_name: 0. for player_name in self.player_names}
103 |
--------------------------------------------------------------------------------
/chatarena/message.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | from dataclasses import dataclass
3 | import time
4 | from uuid import uuid1
5 | import hashlib
6 |
7 | # Preserved roles
8 | SYSTEM_NAME = "System"
9 | MODERATOR_NAME = "Moderator"
10 |
11 |
12 | def _hash(input: str):
13 | hex_dig = hashlib.sha256(input.encode()).hexdigest()
14 | return hex_dig
15 |
16 |
17 | @dataclass
18 | class Message:
19 | agent_name: str
20 | content: str # it can be an image or a text
21 | turn: int
22 | timestamp: int = time.time_ns()
23 | visible_to: Union[str, List[str]] = 'all'
24 | msg_type: str = "text"
25 | logged: bool = False # Whether the message is logged in the database
26 |
27 | @property
28 | def msg_hash(self):
29 | # Generate a unique message id given the content, timestamp and role
30 | return _hash(
31 | f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}")
32 |
33 |
34 | class MessagePool():
35 | """
36 | A message pool to manage the messages. This allows a unified treatment of the visibility of the messages.
37 | Draft design:
38 | The message pool is a list of (named) tuples, where each tuple has (turn, role, content).
39 |
40 | There should be two potential configurations for step definition: multiple players can act in the same turn (rock-paper-scissors).
41 | The agents can only see the messages that
42 | 1) before the current turn, and
43 | 2) visible to the current role
44 | """
45 |
46 | def __init__(self):
47 | self.conversation_id = str(uuid1())
48 | self._messages: List[Message] = [] # TODO: for the sake of thread safety, use a queue instead
49 | self._last_message_idx = 0
50 |
51 | def reset(self):
52 | self._messages = []
53 |
54 | def append_message(self, message: Message):
55 | self._messages.append(message)
56 |
57 | def print(self):
58 | for message in self._messages:
59 | print(f"[{message.agent_name}->{message.visible_to}]: {message.content}")
60 |
61 | @property
62 | def last_turn(self):
63 | if len(self._messages) == 0:
64 | return 0
65 | else:
66 | return self._messages[-1].turn
67 |
68 | @property
69 | def last_message(self):
70 | if len(self._messages) == 0:
71 | return None
72 | else:
73 | return self._messages[-1]
74 |
75 | def get_all_messages(self) -> List[Message]:
76 | return self._messages
77 |
78 | def get_visible_messages(self, agent_name, turn: int) -> List[Message]:
79 | """
80 | get the messages that are visible to the agents before the specified turn
81 | """
82 |
83 | # Get the messages before the current turn
84 | prev_messages = [message for message in self._messages if message.turn < turn]
85 |
86 | visible_messages = []
87 | for message in prev_messages:
88 | if message.visible_to == "all" or agent_name in message.visible_to or agent_name == "Moderator":
89 | visible_messages.append(message)
90 | return visible_messages
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TopDial
2 | This repository contains code and data for the paper [Target-oriented Proactive Dialogue Systems with Personalization: Problem Formulation and Dataset Curation](http://arxiv.org/abs/2310.07397) accepted by EMNLP 2023.
3 |
4 | ## Overview
5 |
6 |

7 |
8 | Target-oriented dialogue systems, designed to proactively steer conversations toward predefined targets or accomplish specific system-side goals, are an exciting area in conversational AI. In this work, by formulating a pair as the conversation target, we explore a novel problem of personalized target-oriented dialogue by considering personalization during the target accomplishment process. However, there remains an emergent need for high-quality datasets, and building one from scratch requires tremendous human effort. To address this, we propose an automatic dataset curation framework using a role-playing approach. Based on this framework, we construct a large-scale personalized target-oriented dialogue dataset, **TopDial**, which comprises about 18K multi-turn dialogues.
9 |
10 |
11 | ## Dataset
12 | We upload the curated **TopDial** dataset to the [Google Drive](https://drive.google.com/file/d/1AWyjmUxYlppNCKkdK46riMF_2y0XeSI8/view?usp=sharing).
13 |
14 |
15 | ## Dataset Curation
16 |
17 |
18 | ### Requirements
19 | We use [Neo4j](https://neo4j.com/) as the graph database tool to process domain knowledge graph in the seed dataset. Please install it by following the [official guide](https://neo4j.com/docs/operations-manual/current/installation/). The required Python packages are listed in `requirements.txt`. Please install them by running:
20 | ```bash
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ### Seed Dataset
25 | We use the [re-purposed version](https://github.com/iwangjian/Color4Dial) of the DuRecDial 2.0 dataset as the seed dataset.
26 |
27 | ### Step 1: Preprocessing the seed dataset
28 | ```python
29 | python data_preprocess.py --seed_dataset_dir ${seed_dataset_dir} --cache_dir ${cache_dir}
30 | ```
31 | Running this script will generate the following files in the specified cache dir:
32 | `cache_dialogue_{train|dev|test_seen|test_unseen}.jsonl`
33 |
34 |
35 | ### Step 2: Dataset curation
36 | ```python
37 | # set your OpenAI API key
38 | export OPENAI_API_KEY=""
39 |
40 | python -u dialog_simulation.py --cached_seed_path ${cached_seed_path} \
41 | --output_dir ${output_dir} \
42 | --max_interaction_step ${max_interaction_step}
43 | ```
44 | Running the above script will be like:
45 | 
46 |
47 | If you hope NOT to show the instructions and the synthesized conversations in the console, please set `--show_description` and `--show_message` to `false`.
48 |
49 |
50 | ## Acknowledgement
51 | Our code is partially based on the implementation of [ChatArena](https://github.com/Farama-Foundation/chatarena). We thank the authors for their excellent work.
52 |
53 |
54 | ## Citation
55 | If you use our data or code in your work, please kindly cite our work as:
56 | ```bibtex
57 | @inproceedings{wang-etal-2023-target,
58 | title = "Target-oriented Proactive Dialogue Systems with Personalization: Problem Formulation and Dataset Curation",
59 | author = "Wang, Jian and
60 | Cheng, Yi and
61 | Lin, Dongding and
62 | Leong, Chak Tou and
63 | Li, Wenjie",
64 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
65 | month = dec,
66 | year = "2023",
67 | address = "Singapore",
68 | publisher = "Association for Computational Linguistics",
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | dataset/
6 | seed_dataset/
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/chatarena/backends/hf_transformers.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from tenacity import retry, stop_after_attempt, wait_random_exponential
3 |
4 | from .base import IntelligenceBackend
5 | from ..message import Message, SYSTEM_NAME as SYSTEM
6 |
7 | # Try to import the transformers package
8 | try:
9 | import transformers
10 | from transformers import pipeline
11 | from transformers.pipelines.conversational import Conversation, ConversationalPipeline
12 | except ImportError:
13 | is_transformers_available = False
14 | else:
15 | is_transformers_available = True
16 |
17 |
18 | class TransformersConversational(IntelligenceBackend):
19 | """
20 | Interface to the Transformers ConversationalPipeline
21 | """
22 | stateful = False
23 | type_name = "transformers:conversational"
24 |
25 | def __init__(self, model: str, device: int = -1, **kwargs):
26 | super().__init__(model=model, device=device, **kwargs)
27 | self.model = model
28 | self.device = device
29 |
30 | assert is_transformers_available, "Transformers package is not installed"
31 | self.chatbot = pipeline(task="conversational", model=self.model, device=self.device)
32 |
33 | @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
34 | def _get_response(self, conversation: Conversation):
35 | conversation = self.chatbot(conversation)
36 | response = conversation.generated_responses[-1]
37 | return response
38 |
39 | @staticmethod
40 | def _msg_template(agent_name, content):
41 | return f"[{agent_name}]: {content}"
42 |
43 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None,
44 | request_msg: Message = None, *args, **kwargs) -> str:
45 | user_inputs, generated_responses = [], []
46 | all_messages = [(SYSTEM, global_prompt), (SYSTEM, role_desc)] if global_prompt else [(SYSTEM, role_desc)]
47 |
48 | for msg in history_messages:
49 | all_messages.append((msg.agent_name, msg.content))
50 | if request_msg:
51 | all_messages.append((SYSTEM, request_msg.content))
52 |
53 | prev_is_user = False # Whether the previous message is from the user
54 | for i, message in enumerate(all_messages):
55 | if i == 0:
56 | assert message[0] == SYSTEM # The first message should be from the system
57 |
58 | if message[0] != agent_name:
59 | if not prev_is_user:
60 | user_inputs.append(self._msg_template(message[0], message[1]))
61 | else:
62 | user_inputs[-1] += "\n" + self._msg_template(message[0], message[1])
63 | prev_is_user = True
64 | else:
65 | if prev_is_user:
66 | generated_responses.append(message[1])
67 | else:
68 | generated_responses[-1] += "\n" + message[1]
69 | prev_is_user = False
70 |
71 | assert len(user_inputs) == len(generated_responses) + 1
72 | past_user_inputs = user_inputs[:-1]
73 | new_user_input = user_inputs[-1]
74 |
75 | # Recreate a conversation object from the history messages
76 | conversation = Conversation(text=new_user_input, past_user_inputs=past_user_inputs,
77 | generated_responses=generated_responses)
78 |
79 | # Get the response
80 | response = self._get_response(conversation)
81 | return response
82 |
83 | # conversation = Conversation("Going to the movies tonight - any suggestions?")
84 | #
85 | # # Steps usually performed by the model when generating a response:
86 | # # 1. Mark the user input as processed (moved to the history)
87 | # conversation.mark_processed()
88 | # # 2. Append a mode response
89 | # conversation.append_response("The Big lebowski.")
90 | #
91 | # conversation.add_user_input("Is it good?")
92 |
--------------------------------------------------------------------------------
/chatarena/backends/anthropic.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import os
3 | import re
4 | import logging
5 | from tenacity import retry, stop_after_attempt, wait_random_exponential
6 |
7 | from .base import IntelligenceBackend
8 | from ..message import Message, SYSTEM_NAME as SYSTEM
9 |
10 | try:
11 | import anthropic
12 | except ImportError:
13 | is_anthropic_available = False
14 | logging.warning("anthropic package is not installed")
15 | else:
16 | anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY')
17 | if anthropic_api_key is None:
18 | logging.warning("Anthropic API key is not set. Please set the environment variable ANTHROPIC_API_KEY")
19 | is_anthropic_available = False
20 | else:
21 | is_anthropic_available = True
22 |
23 | DEFAULT_MAX_TOKENS = 256
24 | DEFAULT_MODEL = "claude-v1"
25 |
26 |
27 | class Claude(IntelligenceBackend):
28 | """
29 | Interface to the Claude offered by Anthropic.
30 | """
31 | stateful = False
32 | type_name = "claude"
33 |
34 | def __init__(self, max_tokens: int = DEFAULT_MAX_TOKENS, model: str = DEFAULT_MODEL, **kwargs):
35 | assert is_anthropic_available, "anthropic package is not installed or the API key is not set"
36 | super().__init__(max_tokens=max_tokens, model=model, **kwargs)
37 |
38 | self.max_tokens = max_tokens
39 | self.model = model
40 |
41 | self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY'])
42 |
43 | @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
44 | def _get_response(self, prompt: str):
45 | response = self.client.completion(
46 | prompt=prompt,
47 | stop_sequences=[anthropic.HUMAN_PROMPT],
48 | model=self.model,
49 | max_tokens_to_sample=self.max_tokens,
50 | )
51 |
52 | response = response['completion'].strip()
53 | return response
54 |
55 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None,
56 | request_msg: Message = None, *args, **kwargs) -> str:
57 | """
58 | format the input and call the Claude API
59 | args:
60 | agent_name: the name of the agent
61 | role_desc: the description of the role of the agent
62 | env_desc: the description of the environment
63 | history_messages: the history of the conversation, or the observation for the agent
64 | request_msg: the request from the system to guide the agent's next response
65 | """
66 | all_messages = [(SYSTEM, global_prompt), (SYSTEM, role_desc)] if global_prompt else [(SYSTEM, role_desc)]
67 |
68 | for message in history_messages:
69 | all_messages.append((message.agent_name, message.content))
70 | if request_msg:
71 | all_messages.append((SYSTEM, request_msg.content))
72 |
73 | prompt = ""
74 | prev_is_human = False # Whether the previous message is from human (in anthropic, the human is the user)
75 | for i, message in enumerate(all_messages):
76 | if i == 0:
77 | assert message[0] == SYSTEM # The first message should be from the system
78 |
79 | if message[0] == agent_name:
80 | if prev_is_human:
81 | prompt = f"{prompt}{anthropic.AI_PROMPT} {message[1]}"
82 | else:
83 | prompt = f"{prompt}\n\n{message[1]}"
84 | prev_is_human = False
85 | else:
86 | if prev_is_human:
87 | prompt = f"{prompt}\n\n[{message[0]}]: {message[1]}"
88 | else:
89 | prompt = f"{prompt}{anthropic.HUMAN_PROMPT}\n[{message[0]}]: {message[1]}"
90 | prev_is_human = True
91 | assert prev_is_human # The last message should be from the human
92 | # Add the AI prompt for Claude to generate the response
93 | prompt = f"{prompt}{anthropic.AI_PROMPT}"
94 |
95 | response = self._get_response(prompt, *args, **kwargs)
96 |
97 | # Remove the agent name if the response starts with it
98 | response = re.sub(rf"^\s*\[{agent_name}]:?", "", response).strip()
99 |
100 | return response
101 |
--------------------------------------------------------------------------------
/chatarena/backends/cohere.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import os
3 | from tenacity import retry, stop_after_attempt, wait_random_exponential
4 |
5 | from .base import IntelligenceBackend
6 | from ..message import Message
7 |
8 | # Try to import the cohere package and check whether the API key is set
9 | try:
10 | import cohere
11 | except ImportError:
12 | is_cohere_available = False
13 | else:
14 | if os.environ.get('COHEREAI_API_KEY') is None:
15 | is_cohere_available = False
16 | else:
17 | is_cohere_available = True
18 |
19 | # Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
20 | DEFAULT_TEMPERATURE = 0.8
21 | DEFAULT_MAX_TOKENS = 200
22 | DEFAULT_MODEL = "command-xlarge"
23 |
24 |
25 | class CohereAIChat(IntelligenceBackend):
26 | """
27 | Interface to the Cohere API
28 | """
29 | stateful = True
30 | type_name = "cohere-chat"
31 |
32 | def __init__(self, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS,
33 | model: str = DEFAULT_MODEL, **kwargs):
34 | super().__init__(temperature=temperature, max_tokens=max_tokens, model=model, **kwargs)
35 |
36 | self.temperature = temperature
37 | self.max_tokens = max_tokens
38 | self.model = model
39 |
40 | assert is_cohere_available, "Cohere package is not installed or the API key is not set"
41 | self.client = cohere.Client(os.environ.get('COHEREAI_API_KEY'))
42 |
43 | # Stateful variables
44 | self.session_id = None # The session id for the last conversation
45 | self.last_msg_hash = None # The hash of the last message of the last conversation
46 |
47 | def reset(self):
48 | self.session_id = None
49 | self.last_msg_hash = None
50 |
51 | @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
52 | def _get_response(self, new_message: str, persona_prompt: str):
53 | response = self.client.chat(
54 | new_message,
55 | persona_prompt=persona_prompt,
56 | temperature=self.temperature,
57 | max_tokens=self.max_tokens,
58 | session_id=self.session_id
59 | )
60 |
61 | self.session_id = response.session_id # Update the session id
62 | return response.reply
63 |
64 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None,
65 | request_msg: Message = None, *args, **kwargs) -> str:
66 | """
67 | format the input and call the Cohere API
68 | args:
69 | agent_name: the name of the agent
70 | role_desc: the description of the role of the agent
71 | env_desc: the description of the environment
72 | history_messages: the history of the conversation, or the observation for the agent
73 | request_msg: the request for the CohereAI
74 | """
75 | # Find the index of the last message of the last conversation
76 | new_message_start_idx = 0
77 | if self.last_msg_hash is not None:
78 | for i, message in enumerate(history_messages):
79 | if message.msg_hash == self.last_msg_hash:
80 | new_message_start_idx = i + 1
81 | break
82 |
83 | new_messages = history_messages[new_message_start_idx:]
84 | assert len(new_messages) > 0, "No new messages found (this should not happen)"
85 |
86 | new_conversations = []
87 | for message in new_messages:
88 | if message.agent_name != agent_name:
89 | # Since there are more than one player, we need to distinguish between the players
90 | new_conversations.append(f"[{message.agent_name}]: {message.content}")
91 |
92 | if request_msg:
93 | new_conversations.append(f"[{request_msg.agent_name}]: {request_msg.content}")
94 |
95 | # Concatenate all new messages into one message because the Cohere API only accepts one message
96 | new_message = "\n".join(new_conversations)
97 | persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"
98 |
99 | response = self._get_response(new_message, persona_prompt)
100 |
101 | # Only update the last message hash if the API call is successful
102 | self.last_msg_hash = new_messages[-1].msg_hash
103 |
104 | return response
105 |
--------------------------------------------------------------------------------
/chatarena/database.py:
--------------------------------------------------------------------------------
1 | """
2 | Datastore module for chat_arena.
3 | This module provides utilities for storing the messages and the game results into database.
4 | Currently, it supports Supabase.
5 | """
6 | import json
7 | import os
8 | from typing import List
9 | import uuid
10 |
11 | from .arena import Arena
12 | from .message import Message
13 |
14 | # Attempt importing Supabase
15 | try:
16 | import supabase
17 |
18 | # Get the Supabase URL and secret key from environment variables
19 | SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
20 | SUPABASE_SECRET_KEY = os.environ.get("SUPABASE_SECRET_KEY", "")
21 | assert SUPABASE_URL and SUPABASE_SECRET_KEY
22 | except:
23 | supabase_available = False
24 | else:
25 | supabase_available = True
26 |
27 |
28 | # Store the messages into the Supabase database
29 | class SupabaseDB:
30 | def __init__(self):
31 | assert supabase_available and SUPABASE_URL and SUPABASE_SECRET_KEY
32 | supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_SECRET_KEY)
33 | self.client = supabase_client
34 |
35 | # Save Arena state to Supabase
36 | def save_arena(self, arena: Arena):
37 | # Save the environment config
38 | self._save_environment(arena)
39 |
40 | # Save the player configs
41 | self._save_player_configs(arena)
42 |
43 | # Save the messages
44 | self.save_messages(arena)
45 |
46 | # Save the environment config of the arena
47 | def _save_environment(self, arena: Arena):
48 | env = arena.environment
49 | env_config = env.to_config()
50 | moderator_config = env_config.pop("moderator", None)
51 |
52 | arena_row = {
53 | "arena_id": str(arena.uuid),
54 | "global_prompt": arena.global_prompt,
55 | "env_type": env_config["env_type"],
56 | "env_config": json.dumps(env_config),
57 | }
58 | self.client.table("Arena").insert(arena_row).execute()
59 |
60 | # Get the moderator config
61 | if moderator_config:
62 | moderator_row = {
63 | "moderator_id": str(uuid.uuid5(arena.uuid, json.dumps(moderator_config))),
64 | "arena_id": str(arena.uuid),
65 | "role_desc": moderator_config["role_desc"],
66 | "terminal_condition": moderator_config["terminal_condition"],
67 | "backend_type": moderator_config["backend"]["backend_type"],
68 | "temperature": moderator_config["backend"]["temperature"],
69 | "max_tokens": moderator_config["backend"]["max_tokens"],
70 | }
71 | self.client.table("Moderator").insert(moderator_row).execute()
72 |
73 | # Save the player configs of the arena
74 | def _save_player_configs(self, arena: Arena):
75 | player_rows = []
76 | for player in arena.players:
77 | player_config = player.to_config()
78 | player_row = {
79 | "player_id": str(uuid.uuid5(arena.uuid, json.dumps(player_config))),
80 | "arena_id": str(arena.uuid),
81 | "name": player.name,
82 | "role_desc": player_config["role_desc"],
83 | "backend_type": player_config["backend"]["backend_type"],
84 | "temperature": player_config["backend"].get("temperature", None),
85 | "max_tokens": player_config["backend"].get("max_tokens", None),
86 | }
87 | player_rows.append(player_row)
88 |
89 | self.client.table("Player").insert(player_rows).execute()
90 |
91 | # Save the messages
92 | def save_messages(self, arena: Arena, messages: List[Message] = None):
93 | if messages is None:
94 | messages = arena.environment.get_observation()
95 |
96 | # Filter messages that are already logged
97 | messages = [msg for msg in messages if not msg.logged]
98 |
99 | message_rows = []
100 | for message in messages:
101 | message_row = {
102 | "message_id": str(uuid.uuid5(arena.uuid, message.msg_hash)),
103 | "arena_id": str(arena.uuid),
104 | "agent_name": message.agent_name,
105 | "content": message.content,
106 | "turn": message.turn,
107 | "timestamp": str(message.timestamp),
108 | "msg_type": message.msg_type,
109 | "visible_to": json.dumps(message.visible_to),
110 | }
111 | message_rows.append(message_row)
112 |
113 | self.client.table("Message").insert(message_rows).execute()
114 |
115 | # Mark the messages as logged
116 | for message in messages:
117 | message.logged = True
118 |
119 |
120 | # Log the arena results into the Supabase database
121 | def log_arena(arena: Arena, database=None):
122 | if database is None:
123 | pass
124 | else:
125 | database.save_arena(arena)
126 |
127 |
128 | # Log the messages into the Supabase database
129 | def log_messages(arena: Arena, messages: List[Message], database=None):
130 | if database is None:
131 | pass
132 | else:
133 | database.save_messages(arena, messages)
134 |
--------------------------------------------------------------------------------
/chatarena/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import copy
3 | from abc import abstractmethod
4 |
5 | from .utils import AttributedDict
6 |
7 |
8 | class Config(AttributedDict):
9 | """
10 | Config class to manage the configuration of the games.
11 | The class has a few useful methods to load and save the config.
12 | """
13 |
14 | # convert dict to Config recursively
15 | def __init__(self, *args, **kwargs):
16 | super().__init__(*args, **kwargs)
17 | for key, value in self.items():
18 | if isinstance(value, dict):
19 | self[key] = init_config(value) # convert dict to Config recursively
20 | # convert list of dict to list of Config recursively
21 | elif isinstance(value, list) and len(value) > 0:
22 | self[key] = [init_config(item) if isinstance(item, dict) else item for item in value]
23 |
24 | def save(self, path: str):
25 | # save config to file
26 | with open(path, "w") as f:
27 | json.dump(self, f, indent=4)
28 |
29 | @classmethod
30 | def load(cls, path: str):
31 | # load config from file
32 | with open(path, "r") as f:
33 | config = json.load(f)
34 | return cls(config)
35 |
36 | def deepcopy(self):
37 | # get the config class so that subclasses can be copied in the correct class
38 | config_class = self.__class__
39 | # make a deep copy of the config
40 | return config_class(copy.deepcopy(self))
41 |
42 |
43 | class Configurable:
44 | """
45 | Configurable is an interface for classes that can be initialized with a config.
46 | """
47 |
48 | def __init__(self, **kwargs):
49 | self._config_dict = kwargs
50 |
51 | @classmethod
52 | def from_config(cls, config: Config):
53 | return cls(**config)
54 |
55 | def to_config(self) -> Config:
56 | # Convert the _config_dict to Config
57 | return Config(**self._config_dict)
58 |
59 | def save_config(self, path: str):
60 | self.to_config().save(path)
61 |
62 |
63 | class EnvironmentConfig(Config):
64 | """
65 | EnvironmentConfig contains a env_type field to indicate the name of the environment.
66 | """
67 |
68 | def __init__(self, *args, **kwargs):
69 | super().__init__(*args, **kwargs)
70 | # check if the env_type field is specified
71 | if "env_type" not in self:
72 | raise ValueError("The env_type field is not specified")
73 |
74 |
75 | class BackendConfig(Config):
76 | """
77 | BackendConfig contains a backend_type field to indicate the name of the backend.
78 | """
79 |
80 | def __init__(self, *args, **kwargs):
81 | super().__init__(*args, **kwargs)
82 | # check if the backend_type field is specified
83 | if "backend_type" not in self:
84 | raise ValueError("The backend_type field is not specified")
85 |
86 |
87 | class AgentConfig(Config):
88 | """
89 | AgentConfig contains role_desc and backend fields.
90 | """
91 |
92 | def __init__(self, *args, **kwargs):
93 | super().__init__(*args, **kwargs)
94 | # check if the role_desc field is specified
95 | if "role_desc" not in self:
96 | raise ValueError("The role_desc field is not specified")
97 | # check if the backend field is specified
98 | if "backend" not in self:
99 | raise ValueError("The backend field is not specified")
100 | # Make sure the backend field is a BackendConfig
101 | if not isinstance(self["backend"], BackendConfig):
102 | raise ValueError("The backend field must be a BackendConfig")
103 |
104 |
105 | class ArenaConfig(Config):
106 | """
107 | ArenaConfig contains a list of AgentConfig.
108 | """
109 |
110 | def __init__(self, *args, **kwargs):
111 | super().__init__(*args, **kwargs)
112 | # check if the players field is specified and it is List[AgentConfig]
113 | if "players" not in self:
114 | raise ValueError("The players field is not specified")
115 | if not isinstance(self["players"], list):
116 | raise ValueError("The players field must be a list")
117 | for player in self["players"]:
118 | if not isinstance(player, AgentConfig):
119 | raise ValueError("The players field must be a list of AgentConfig")
120 |
121 | # check if environment field is specified and it is EnvironmentConfig
122 | if "environment" not in self:
123 | raise ValueError("The environment field is not specified")
124 | if not isinstance(self["environment"], EnvironmentConfig):
125 | raise ValueError("The environment field must be an EnvironmentConfig")
126 |
127 |
128 | # Initialize with different config class depending on whether the config is for environment or backend
129 | def init_config(config: dict):
130 | if not isinstance(config, dict):
131 | raise ValueError("The config must be a dict")
132 |
133 | # check if the config is for environment or backend
134 | if "env_type" in config:
135 | return EnvironmentConfig(config)
136 | elif "backend_type" in config:
137 | return BackendConfig(config)
138 | elif "role_desc" in config:
139 | return AgentConfig(config)
140 | elif "players" in config:
141 | return ArenaConfig(config)
142 | else:
143 | return Config(config)
144 |
--------------------------------------------------------------------------------
/chatarena/agent.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | import re
3 | from tenacity import RetryError
4 | import logging
5 | import uuid
6 | from abc import abstractmethod
7 | import asyncio
8 |
9 | from .backends import IntelligenceBackend, load_backend
10 | from .message import Message, SYSTEM_NAME
11 | from .config import AgentConfig, Configurable, BackendConfig
12 |
13 | # A special signal sent by the player to indicate that it is not possible to continue the conversation, and it requests to end the conversation.
14 | # It contains a random UUID string to avoid being exploited by any of the players.
15 | SIGNAL_END_OF_CONVERSATION = f"<<<<<>>>>>{uuid.uuid4()}"
16 |
17 |
18 | class Agent(Configurable):
19 |
20 | @abstractmethod
21 | def __init__(self, name: str, role_desc: str, global_prompt: str = None, *args, **kwargs):
22 | super().__init__(name=name, role_desc=role_desc, global_prompt=global_prompt, **kwargs)
23 | self.name = name
24 | self.role_desc = role_desc
25 | self.global_prompt = global_prompt
26 |
27 |
28 | class Player(Agent):
29 | """
30 | Player of the game. It can takes the observation from the environment and return an action
31 | """
32 |
33 | def __init__(self, name: str, role_desc: str, backend: Union[BackendConfig, IntelligenceBackend],
34 | global_prompt: str = None, **kwargs):
35 |
36 | if isinstance(backend, BackendConfig):
37 | backend_config = backend
38 | backend = load_backend(backend_config)
39 | elif isinstance(backend, IntelligenceBackend):
40 | backend_config = backend.to_config()
41 | else:
42 | raise ValueError(f"backend must be a BackendConfig or an IntelligenceBackend, but got {type(backend)}")
43 |
44 | assert name != SYSTEM_NAME, f"Player name cannot be {SYSTEM_NAME}, which is reserved for the system."
45 |
46 | # Register the fields in the _config
47 | super().__init__(name=name, role_desc=role_desc, backend=backend_config,
48 | global_prompt=global_prompt, **kwargs)
49 |
50 | self.backend = backend
51 |
52 | def to_config(self) -> AgentConfig:
53 | return AgentConfig(
54 | name=self.name,
55 | role_desc=self.role_desc,
56 | backend=self.backend.to_config(),
57 | global_prompt=self.global_prompt,
58 | )
59 |
60 | def act(self, observation: List[Message]) -> str:
61 | """
62 | Call the agents to generate a response (equivalent to taking an action).
63 | """
64 | try:
65 | response = self.backend.query(agent_name=self.name, role_desc=self.role_desc,
66 | history_messages=observation, global_prompt=self.global_prompt,
67 | request_msg=None)
68 | except RetryError as e:
69 | logging.warning(f"Agent {self.name} failed to generate a response. "
70 | f"Error: {e.last_attempt.exception()}. "
71 | f"Sending signal to end the conversation.")
72 | response = SIGNAL_END_OF_CONVERSATION
73 |
74 | return response
75 |
76 | def __call__(self, observation: List[Message]) -> str:
77 | return self.act(observation)
78 |
79 | async def async_act(self, observation: List[Message]) -> str:
80 | """
81 | Async call the agents to generate a response (equivalent to taking an action).
82 | """
83 | try:
84 | response = self.backend.async_query(agent_name=self.name, role_desc=self.role_desc,
85 | history_messages=observation, global_prompt=self.global_prompt,
86 | request_msg=None)
87 | except RetryError as e:
88 | logging.warning(f"Agent {self.name} failed to generate a response. "
89 | f"Error: {e.last_attempt.exception()}. "
90 | f"Sending signal to end the conversation.")
91 | response = SIGNAL_END_OF_CONVERSATION
92 |
93 | return response
94 |
95 | def reset(self):
96 | self.backend.reset()
97 |
98 |
99 | class Moderator(Player):
100 | """
101 | A special type of player that moderates the conversation (usually used as a component of environment).
102 | """
103 |
104 | def __init__(self, role_desc: str, backend: Union[BackendConfig, IntelligenceBackend],
105 | terminal_condition: str, global_prompt: str = None, **kwargs):
106 | name = "Moderator"
107 | super().__init__(name=name, role_desc=role_desc, backend=backend, global_prompt=global_prompt, **kwargs)
108 |
109 | self.terminal_condition = terminal_condition
110 |
111 | def to_config(self) -> AgentConfig:
112 | return AgentConfig(
113 | name=self.name,
114 | role_desc=self.role_desc,
115 | backend=self.backend.to_config(),
116 | terminal_condition=self.terminal_condition,
117 | global_prompt=self.global_prompt,
118 | )
119 |
120 | def is_terminal(self, history: List[Message], *args, **kwargs) -> bool:
121 | """
122 | check whether the conversation is over
123 | """
124 | # If the last message is the signal, then the conversation is over
125 | if history[-1].content == SIGNAL_END_OF_CONVERSATION:
126 | return True
127 |
128 | try:
129 | request_msg = Message(agent_name=self.name, content=self.terminal_condition, turn=-1)
130 | response = self.backend.query(agent_name=self.name, role_desc=self.role_desc, history_messages=history,
131 | global_prompt=self.global_prompt, request_msg=request_msg, *args, **kwargs)
132 | except RetryError as e:
133 | logging.warning(f"Agent {self.name} failed to generate a response. "
134 | f"Error: {e.last_attempt.exception()}.")
135 | return True
136 |
137 | if re.match(r"yes|y|yea|yeah|yep|yup|sure|ok|okay|alright", response, re.IGNORECASE):
138 | # print(f"Decision: {response}. Conversation is ended by moderator.")
139 | return True
140 | else:
141 | return False
142 |
--------------------------------------------------------------------------------
/chatarena/backends/openai.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import os
3 | import re
4 | import logging
5 | from tenacity import retry, stop_after_attempt, wait_random, wait_random_exponential
6 |
7 | from .base import IntelligenceBackend
8 | from ..message import Message, SYSTEM_NAME, MODERATOR_NAME
9 |
10 | try:
11 | import openai
12 | except ImportError:
13 | is_openai_available = False
14 | logging.warning("openai package is not installed")
15 | else:
16 | openai.api_key = os.environ.get("OPENAI_API_KEY")
17 | if openai.api_key is None:
18 | logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY")
19 | is_openai_available = False
20 | else:
21 | is_openai_available = True
22 |
23 | # Default config follows the OpenAI playground
24 | DEFAULT_TEMPERATURE = 0.7
25 | DEFAULT_MAX_TOKENS = 256
26 | DEFAULT_MODEL = "gpt-3.5-turbo"
27 |
28 | END_OF_MESSAGE = "" # End of message token specified by us not OpenAI
29 | STOP = ("<|endoftext|>", END_OF_MESSAGE) # End of sentence token
30 | BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}."
31 |
32 |
33 | class OpenAIChat(IntelligenceBackend):
34 | """
35 | Interface to the ChatGPT style model with system, user, assistant roles separation
36 | """
37 | stateful = False
38 | type_name = "openai-chat"
39 |
40 | def __init__(self, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS,
41 | model: str = DEFAULT_MODEL, merge_other_agents_as_one_user: bool = True, **kwargs):
42 | """
43 | instantiate the OpenAIChat backend
44 | args:
45 | temperature: the temperature of the sampling
46 | max_tokens: the maximum number of tokens to sample
47 | model: the model to use
48 | merge_other_agents_as_one_user: whether to merge messages from other agents as one user message
49 | """
50 | assert is_openai_available, "openai package is not installed or the API key is not set"
51 | super().__init__(temperature=temperature, max_tokens=max_tokens, model=model,
52 | merge_other_agents_as_one_user=merge_other_agents_as_one_user, **kwargs)
53 |
54 | self.temperature = temperature
55 | self.max_tokens = max_tokens
56 | self.model = model
57 | self.merge_other_agent_as_user = merge_other_agents_as_one_user
58 |
59 | @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(min=1, max=60)) # Modified retry strategy
60 | def _get_response(self, messages):
61 | completion = openai.ChatCompletion.create(
62 | model=self.model,
63 | messages=messages,
64 | temperature=self.temperature,
65 | max_tokens=self.max_tokens,
66 | stop=STOP
67 | )
68 |
69 | response = completion.choices[0]['message']['content']
70 | response = response.strip()
71 | return response
72 |
73 | def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None,
74 | request_msg: Message = None, *args, **kwargs) -> str:
75 | """
76 | format the input and call the ChatGPT/GPT-4 API
77 | args:
78 | agent_name: the name of the agent
79 | role_desc: the description of the role of the agent
80 | env_desc: the description of the environment
81 | history_messages: the history of the conversation, or the observation for the agent
82 | request_msg: the request from the system to guide the agent's next response
83 | """
84 |
85 | # Merge the role description and the global prompt as the system prompt for the agent
86 | if global_prompt: # Prepend the global prompt if it exists
87 | system_prompt = f"{global_prompt.strip()}\n\nYour name: {agent_name}\n\nYour role: {role_desc}"
88 | else:
89 | system_prompt = f"You are {agent_name}.\n\nYour role: {role_desc}"
90 |
91 | all_messages = [(SYSTEM_NAME, system_prompt)]
92 | for msg in history_messages:
93 | if msg.agent_name == SYSTEM_NAME:
94 | all_messages.append((SYSTEM_NAME, msg.content))
95 | else: # non-system messages are suffixed with the end of message token
96 | all_messages.append((msg.agent_name, f"{msg.content}{END_OF_MESSAGE}"))
97 |
98 | if request_msg is not None:
99 | all_messages.append((SYSTEM_NAME, request_msg.content))
100 | else: # The default request message that reminds the agent its role and instruct it to speak
101 | all_messages.append((SYSTEM_NAME, f"Now you speak, {agent_name}.{END_OF_MESSAGE}"))
102 |
103 | messages = []
104 | for i, msg in enumerate(all_messages):
105 | if i == 0:
106 | assert msg[0] == SYSTEM_NAME # The first message should be from the system
107 | messages.append({"role": "system", "content": msg[1]})
108 | else:
109 | if msg[0] == agent_name:
110 | messages.append({"role": "assistant", "content": msg[1]})
111 | else:
112 | if messages[-1]["role"] == "user": # last message is from user
113 | if self.merge_other_agent_as_user:
114 | messages[-1]["content"] = f"{messages[-1]['content']}\n\n[{msg[0]}]: {msg[1]}"
115 | else:
116 | messages.append({"role": "user", "content": f"[{msg[0]}]: {msg[1]}"})
117 | elif messages[-1]["role"] == "assistant": # consecutive assistant messages
118 | # Merge the assistant messages
119 | messages[-1]["content"] = f"{messages[-1]['content']}\n{msg[1]}"
120 | elif messages[-1]["role"] == "system":
121 | messages.append({"role": "user", "content": f"[{msg[0]}]: {msg[1]}"})
122 | else:
123 | raise ValueError(f"Invalid role: {messages[-1]['role']}")
124 |
125 | response = self._get_response(messages, *args, **kwargs)
126 |
127 | # Remove the agent name if the response starts with it
128 | response = re.sub(rf"^\s*\[.*]:", "", response).strip()
129 | response = re.sub(rf"^\s*{re.escape(agent_name)}\s*:", "", response).strip()
130 | # Remove the tailing end of message token
131 | response = re.sub(rf"{END_OF_MESSAGE}$", "", response).strip()
132 |
133 | return response
134 |
--------------------------------------------------------------------------------
/chatarena/environments/conversation.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from .base import TimeStep, Environment
4 | from ..message import Message, MessagePool
5 | from ..agent import Moderator, SIGNAL_END_OF_CONVERSATION
6 | from ..config import EnvironmentConfig, AgentConfig
7 |
8 |
9 | class Conversation(Environment):
10 | """
11 | Turn-based fully observable conversation environment.
12 | Next speaker order is either parallel or round-robin.
13 | """
14 | type_name = "conversation"
15 |
16 | def __init__(self, player_names: List[str], parallel: bool = False, **kwargs):
17 | super().__init__(player_names=player_names, parallel=parallel, **kwargs)
18 |
19 | self.parallel = parallel
20 |
21 | # The "state" of the environment is maintained by the message pool
22 | self.message_pool = MessagePool()
23 |
24 | self._current_turn = 0
25 | self._next_player_idx = 0
26 |
27 | def reset(self):
28 | self._current_turn = 0
29 | self._next_player_idx = 0
30 | self.message_pool.reset()
31 |
32 | init_timestep = TimeStep(observation=[],
33 | reward=self.get_zero_rewards(),
34 | terminal=False)
35 | return init_timestep
36 |
37 | def to_config(self) -> EnvironmentConfig:
38 | return EnvironmentConfig(env_type=self.type_name, player_names=self.player_names, parallel=self.parallel)
39 |
40 | def print(self):
41 | self.message_pool.print()
42 |
43 | def get_next_player(self) -> str:
44 | """
45 | get the next player
46 | """
47 | return self.player_names[self._next_player_idx]
48 |
49 | def get_observation(self, player_name=None) -> List[Message]:
50 | """
51 | get observation for the player
52 | """
53 | if player_name is None:
54 | return self.message_pool.get_all_messages()
55 | else:
56 | return self.message_pool.get_visible_messages(player_name, turn=self._current_turn)
57 |
58 | def is_terminal(self) -> bool:
59 | """
60 | check if the conversation is over
61 | """
62 | # If the last message is the signal, then the conversation is over
63 | if self.message_pool.last_message.content == SIGNAL_END_OF_CONVERSATION:
64 | return True
65 |
66 | def step(self, player_name: str, action: str) -> TimeStep:
67 | """
68 | step function that is called by the arena
69 | Args:
70 | player_name: the name of the player that takes the action
71 | action: the action that the agents wants to take
72 | """
73 | message = Message(agent_name=player_name, content=action, turn=self._current_turn)
74 | self.message_pool.append_message(message)
75 |
76 | # Update the counters
77 | if not self.parallel or self._next_player_idx == 0:
78 | self._current_turn += 1
79 | self._next_player_idx = (self._next_player_idx + 1) % self.num_players
80 |
81 | timestep = TimeStep(observation=self.get_observation(),
82 | reward=self.get_zero_rewards(),
83 | terminal=self.is_terminal()) # Return all the messages
84 | return timestep
85 |
86 |
87 | class ModeratedConversation(Conversation):
88 | """
89 | Turn-based fully observable conversation environment.
90 | Next speaker order is either parallel or round-robin.
91 | Moderator is a special agent that can see all messages and can decide whether the conversation is over.
92 | """
93 |
94 | type_name = "moderated_conversation"
95 |
96 | def __init__(self, player_names: List[str], moderator: Union[Moderator, AgentConfig],
97 | parallel: bool = False, moderator_visibility="all", moderator_period="turn", **kwargs):
98 |
99 | super().__init__(player_names=player_names, parallel=parallel, **kwargs)
100 |
101 | if isinstance(moderator, AgentConfig):
102 | moderator_config = moderator
103 | moderator = Moderator.from_config(moderator_config)
104 | elif not isinstance(moderator, Moderator):
105 | raise ValueError("moderator must be either an AgentConfig or a Moderator instance.")
106 |
107 | self.moderator = moderator
108 | self.moderator_visibility = moderator_visibility
109 | self.moderator_period = moderator_period
110 |
111 | def to_config(self) -> EnvironmentConfig:
112 | # This environment contains some speical config arguments that needs to be handle specially
113 | return EnvironmentConfig(env_type=self.type_name, player_names=self.player_names, parallel=self.parallel,
114 | moderator=self.moderator.to_config(), moderator_visibility=self.moderator_visibility,
115 | moderator_period=self.moderator_period)
116 |
117 | def step(self, player_name: str, action: str) -> TimeStep:
118 | """
119 | step function that is called by the arena
120 | Args:
121 | player_name: the name of the player that takes the action
122 | action: the action that the agents wants to take
123 | """
124 | message = Message(agent_name=player_name, content=action, turn=self._current_turn)
125 | self.message_pool.append_message(message)
126 |
127 | # Round-robin order for the next player
128 | self._next_player_idx = (self._next_player_idx + 1) % self.num_players
129 |
130 | if self.moderator_period == "turn" or \
131 | (self.moderator_period == "round" and self._next_player_idx == 0):
132 | # Moderator's turn
133 | moderator_history = self.message_pool.get_all_messages()
134 |
135 | # Moderator's response is not used
136 | #moderator_response = self.moderator(moderator_history)
137 | #moderator_message = Message(agent_name=self.moderator.name,
138 | # content=moderator_response,
139 | # turn=self._current_turn,
140 | # visible_to=self.moderator_visibility)
141 | #self.message_pool.append_message(moderator_message)
142 |
143 | # We only use Moderator to determine whether the conversation should be ended
144 | terminal = self.moderator.is_terminal(moderator_history) or self.is_terminal()
145 | else:
146 | terminal = self.is_terminal()
147 |
148 | # Update the counters
149 | if not self.parallel or self._next_player_idx == 0:
150 | self._current_turn += 1
151 |
152 | timestep = TimeStep(observation=self.get_observation(),
153 | reward=self.get_zero_rewards(),
154 | terminal=terminal) # Return all the messages
155 | return timestep
156 |
--------------------------------------------------------------------------------
/chatarena/ui/cli.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit import prompt
2 | from prompt_toolkit.completion import WordCompleter
3 | from prompt_toolkit.styles import Style
4 | from rich.console import Console
5 | from rich.text import Text
6 | from rich.color import ANSI_COLOR_NAMES
7 | import random
8 |
9 | from ..arena import Arena, TooManyInvalidActions
10 | from ..backends.human import HumanBackendError
11 | from ..agent import SIGNAL_END_OF_CONVERSATION
12 |
13 | ASCII_ART = r"""
14 | _________ .__ __ _____
15 | \_ ___ \ | |__ _____ _/ |_ / _ \ _______ ____ ____ _____
16 | / \ \/ | | \ \__ \ \ __\ / /_\ \ \_ __ \W/ __ \ / \ \__ \
17 | \ \____| Y \ / __ \_ | | / | \ | | \/\ ___/ | | \ / __ \_
18 | \______ /|___| /(____ / |__| \____|__ / |__| \___ >|___| /(____ /
19 | \/ \/ \/ \/ \/ \/ \/
20 | """
21 |
22 | visible_colors = [color for color in ANSI_COLOR_NAMES.keys() if
23 | color not in ["black", "white", "red", "green"] and "grey" not in color]
24 |
25 | MAX_STEPS = 5
26 |
27 | import logging
28 |
29 | # Set logging level to ERROR
30 | logging.getLogger().setLevel(logging.ERROR)
31 |
32 |
33 | class ArenaCLI:
34 | """
35 | The CLI user interface for ChatArena.
36 | """
37 |
38 | def __init__(self, arena: Arena):
39 | self.arena = arena
40 |
41 | def launch(self, max_steps: int = None, interactive: bool = True, show_description: bool = True, show_message: bool = True):
42 | """
43 | Run the CLI
44 | """
45 | if not interactive and max_steps is None:
46 | max_steps = MAX_STEPS
47 |
48 | console = Console()
49 | # Print ascii art
50 | #console.print(ASCII_ART, style="bold dark_orange3")
51 | timestep = self.arena.reset()
52 | #console.print("🏟 Chat Arena Initialized!", style="bold green")
53 |
54 | env = self.arena.environment
55 | players = self.arena.players
56 |
57 | env_desc = self.arena.global_prompt
58 | num_players = env.num_players
59 | player_colors = random.sample(visible_colors, num_players) # sample different colors for players
60 | name_to_color = dict(zip(env.player_names, player_colors))
61 | # System and Moderator messages are printed in red
62 | name_to_color["System"] = "red"
63 | name_to_color["Moderator"] = "red"
64 |
65 | # Print the player name, role_desc and backend_type
66 | if show_description:
67 | console.print(f"[bold green underline]Environment ({env.type_name}) description:[/]\n{env_desc}")
68 |
69 | for i, player in enumerate(players):
70 | player_name = Text(f"[{player.name} ({player.backend.type_name})] Role Description:")
71 | player_name.stylize(f"bold {name_to_color[player.name]} underline")
72 | console.print(player_name)
73 | console.print(player.role_desc)
74 | if show_message:
75 | console.print("\n========= Arena Start! ==========\n", style="bold green")
76 |
77 | step = 0
78 | while not timestep.terminal:
79 | if interactive:
80 | command = prompt([('class:command', "command (n/r/q/s/h) > ")],
81 | style=Style.from_dict({'command': 'blue'}),
82 | completer=WordCompleter(
83 | ['next', 'n', 'reset', 'r', 'exit', 'quit', 'q', 'help', 'h', 'save', 's']))
84 | command = command.strip()
85 |
86 | if command == "help" or command == "h":
87 | console.print("Available commands:")
88 | console.print(" [bold]next or n or [/]: next step")
89 | console.print(" [bold]exit or quit or q[/]: exit the game")
90 | console.print(" [bold]help or h[/]: print this message")
91 | console.print(" [bold]reset or r[/]: reset the game")
92 | console.print(" [bold]save or s[/]: save the history to file")
93 | continue
94 | elif command == "exit" or command == "quit" or command == "q":
95 | break
96 | elif command == "reset" or command == "r":
97 | timestep = self.arena.reset()
98 | console.print("\n========= Arena Reset! ==========\n", style="bold green")
99 | continue
100 | elif command == "next" or command == "n" or command == "":
101 | pass
102 | elif command == "save" or command == "s":
103 | # Prompt to get the file path
104 | file_path = prompt([('class:command', "save file path > ")],
105 | style=Style.from_dict({'command': 'blue'}))
106 | file_path = file_path.strip()
107 | # Save the history to file
108 | self.arena.save_history(file_path)
109 | # Print the save success message
110 | console.print(f"History saved to {file_path}", style="bold green")
111 | else:
112 | console.print(f"Invalid command: {command}", style="bold red")
113 | continue
114 |
115 | try:
116 | timestep = self.arena.step()
117 | except HumanBackendError as e:
118 | # Handle human input and recover with the game update
119 | human_player_name = env.get_next_player()
120 | if interactive:
121 | human_input = prompt(
122 | [('class:user_prompt', f"Type your input for {human_player_name}: ")],
123 | style=Style.from_dict({'user_prompt': 'ansicyan underline'})
124 | )
125 | # If not, the conversation does not stop
126 | timestep = env.step(human_player_name, human_input)
127 | else:
128 | raise e # cannot recover from this error in non-interactive mode
129 | except TooManyInvalidActions as e:
130 | # Print the error message
131 | console.print(f"Too many invalid actions: {e}", style="bold red")
132 | break
133 |
134 | # The messages that are not yet logged
135 | messages = [msg for msg in env.get_observation() if not msg.logged]
136 | # Print the new messages
137 | for msg in messages:
138 | message_text = Text(f"[{msg.agent_name}->{msg.visible_to}]: {msg.content}")
139 | message_text.stylize(f"bold {name_to_color[msg.agent_name]}", 0,
140 | len(f"[{msg.agent_name}->{msg.visible_to}]:"))
141 | if show_message:
142 | console.print(message_text)
143 | msg.logged = True
144 |
145 | step += 1
146 | if max_steps is not None and step >= max_steps:
147 | break
148 |
149 | if show_message:
150 | console.print("\n========= Arena Ended! ==========\n", style="bold red")
151 |
--------------------------------------------------------------------------------
/chatarena/arena.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Union
2 | import uuid
3 | import json
4 | import csv
5 | import logging
6 |
7 | from .agent import Player
8 | from .environments import Environment, TimeStep, load_environment
9 | from .backends import Human
10 | from .config import ArenaConfig
11 |
12 |
13 | class TooManyInvalidActions(Exception):
14 | pass
15 |
16 |
17 | class Arena:
18 | """
19 | Utility class that manages the game environment and players
20 | """
21 |
22 | def __init__(self, players: List[Player], environment: Environment, global_prompt: str = None):
23 | # Create a container for the players and environment and reset the game
24 | self.players = players
25 | self.environment = environment
26 | self.global_prompt = global_prompt
27 |
28 | self.current_timestep = environment.reset()
29 | self.uuid = uuid.uuid4() # Generate a unique id for the game
30 | self.invalid_actions_retry = 5
31 |
32 | @property
33 | def num_players(self):
34 | return self.environment.num_players
35 |
36 | @property
37 | def name_to_player(self) -> Dict[str, Player]:
38 | return {player.name: player for player in self.players}
39 |
40 | def reset(self) -> TimeStep:
41 | # Reset the environment
42 | self.current_timestep = self.environment.reset()
43 | # Reset the players
44 | for player in self.players:
45 | player.reset()
46 | # Reset the uuid
47 | self.uuid = uuid.uuid4()
48 | return self.current_timestep
49 |
50 | def step(self) -> TimeStep:
51 | """
52 | Take a step in the game: one player takes an action and the environment updates
53 | """
54 | player_name = self.environment.get_next_player()
55 | player = self.name_to_player[player_name] # get the player object
56 | observation = self.environment.get_observation(player_name) # get the observation for the player
57 |
58 | timestep = None
59 | for i in range(self.invalid_actions_retry): # try to take an action for a few times
60 | action = player(observation) # take an action
61 | if self.environment.check_action(action, player_name): # action is valid
62 | timestep = self.environment.step(player_name, action) # update the environment
63 | break
64 | else: # action is invalid
65 | logging.warning(f"{player_name} made an invalid action {action}")
66 | continue
67 |
68 | if timestep is None: # if the player made invalid actions for too many times, terminate the game
69 | warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game."
70 | logging.warning(warning_msg)
71 | raise TooManyInvalidActions(warning_msg)
72 |
73 | return timestep
74 |
75 | def next_is_human(self):
76 | """
77 | check if the next player is human
78 | """
79 | player_name = self.environment.get_next_player()
80 | player = self.name_to_player[player_name]
81 | return isinstance(player.backend, Human)
82 |
83 | def run(self, num_steps: int = 1):
84 | """
85 | run the game for num_turns
86 | """
87 | for i in range(num_steps):
88 | timestep = self.step()
89 | if timestep.terminal:
90 | break
91 |
92 | @classmethod
93 | def from_config(cls, config: Union[str, ArenaConfig]):
94 | """
95 | create an arena from a config
96 | """
97 | # If config is a path, load the config
98 | if isinstance(config, str):
99 | config = ArenaConfig.load(config)
100 |
101 | global_prompt = config.get("global_prompt", None)
102 |
103 | # Create the players
104 | players = []
105 | for player_config in config.players:
106 | # Add public_prompt to the player config
107 | if global_prompt is not None:
108 | player_config["global_prompt"] = global_prompt
109 |
110 | player = Player.from_config(player_config)
111 | players.append(player)
112 |
113 | # Check that the player names are unique
114 | player_names = [player.name for player in players]
115 | assert len(player_names) == len(set(player_names)), "Player names must be unique"
116 |
117 | # Create the environment
118 | config.environment["player_names"] = player_names # add the player names to the environment config
119 | env = load_environment(config.environment)
120 |
121 | return cls(players, env, global_prompt=global_prompt)
122 |
123 | def to_config(self) -> ArenaConfig:
124 | """
125 | convert the arena to a config
126 | """
127 | # return {
128 | # "players": [player.to_config() for player in self.players],
129 | # "environment": self.environment.to_config(),
130 | # "global_prompt": self.global_prompt
131 | # }
132 | return ArenaConfig(
133 | players=[player.to_config() for player in self.players],
134 | environment=self.environment.to_config(),
135 | global_prompt=self.global_prompt
136 | )
137 |
138 | def launch_cli(self, max_steps: int = None, interactive: bool = True, show_description: bool = True, show_message: bool = True):
139 | """
140 | launch the command line interface
141 | """
142 | from chatarena.ui.cli import ArenaCLI
143 | cli = ArenaCLI(self)
144 | cli.launch(max_steps=max_steps, interactive=interactive, show_description=show_description, show_message=show_message)
145 |
146 | def save_config(self, path: str):
147 | """
148 | save the config to a file
149 | """
150 | config = self.to_config()
151 | config.save(path)
152 |
153 | def save_history(self, path: str):
154 | """
155 | save the history of the game to a file
156 | Supports csv and json formats.
157 | """
158 | messages = self.environment.get_observation()
159 | message_rows = []
160 |
161 | if path.endswith(".csv"):
162 | header = ["agent_name", "content", "turn", "timestamp", "visible_to", "msg_type"]
163 | for message in messages:
164 | message_row = [
165 | message.agent_name,
166 | message.content,
167 | message.turn,
168 | str(message.timestamp),
169 | message.visible_to,
170 | message.msg_type,
171 | ]
172 | message_rows.append(message_row)
173 |
174 | with open(path, "w") as f:
175 | writer = csv.writer(f)
176 | writer.writerow(header)
177 | writer.writerows(message_rows)
178 | elif path.endswith(".json"):
179 | for message in messages:
180 | message_row = {
181 | "agent_name": message.agent_name,
182 | "content": message.content,
183 | "turn": message.turn,
184 | "timestamp": str(message.timestamp),
185 | "visible_to": message.visible_to,
186 | "msg_type": message.msg_type,
187 | }
188 | message_rows.append(message_row)
189 |
190 | with open(path, "w") as f:
191 | json.dump(message_rows, f, indent=4)
192 | else:
193 | raise ValueError("Invalid file format")
194 |
--------------------------------------------------------------------------------
/instruction.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import List
3 |
4 |
5 | def create_instruct(
6 | target: List[str],
7 | simulated_profile: dict,
8 | simulated_personality: dict,
9 | assistant_name: str,
10 | domain_knowledge: List[List],
11 | seed_conversation: dict
12 | ):
13 | """Create instructions about the conversation environment and roles."""
14 | domain = ""
15 | target_action = target[0].lower()
16 | if "movie" in target_action:
17 | domain = "movie"
18 | elif "music" in target_action:
19 | domain = "music"
20 | elif "food" in target_action:
21 | domain = "food"
22 | elif "poi" in target_action:
23 | domain = "poi"
24 | else:
25 | raise ValueError("Invalid target action: {}".format(target_action))
26 |
27 | # Describe the environment (shared by all roles)
28 | if domain == "movie" or domain == "music":
29 | env_desc = "You are participating in a conversation about music or movies."
30 | else:
31 | env_desc = "You are participating in a conversation about delicious food or point-of-interest (POI)."
32 |
33 | # Describe the user
34 | user_desc = "You are {}, ".format(simulated_profile["Name"])
35 | profile_desc = ""
36 |
37 | if simulated_profile["Occupation"] == "Student":
38 | if simulated_profile["Gender"] == "Male":
39 | profile_desc = "a male student in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"])
40 | else:
41 | profile_desc = "a female student in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"])
42 | elif simulated_profile["Occupation"] == "Employed":
43 | if simulated_profile["Gender"] == "Male":
44 | profile_desc = "a man in the age range of {}, working in a company and living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"])
45 | else:
46 | profile_desc = "a woman in the age range of {}, working in a company and living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"])
47 | else:
48 | if simulated_profile["Gender"] == "Male":
49 | profile_desc = "a retired man in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"])
50 | else:
51 | profile_desc = "a retired woman in the age range of {}, living in {}".format(simulated_profile["Age Range"].lower(), simulated_profile["Residence"])
52 | user_desc += profile_desc + ".\n\n"
53 |
54 | user_desc += "Based on your past experiences, you have the following preferences:\n"
55 | if domain == "movie" or domain == "music":
56 | for k in ["Accepted movies", "Accepted music", "Accepted celebrities", "Rejected movies", "Rejected music"]:
57 | kk = k.replace("Accepted", "liked").replace("Rejected", "disliked")
58 | user_desc += "Your {}: {}.\n".format(kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else ""
59 | else:
60 | for k in ["Accepted food", "Accepted POI"]:
61 | kk = k.replace("Accepted", "liked")
62 | user_desc += "Your {}: {}.\n".format(kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else ""
63 | user_desc += "\n"
64 |
65 |
66 | user_desc += "Based on the Big-5 personality traits, your personality is measured as:\n"
67 | for k, v in simulated_personality.items():
68 | user_desc += "For {}, you are {}.\n".format(k, v)
69 | user_desc += "\n"
70 |
71 | user_desc += "Your response should match your profile and personality, and be concise (no longer than 30 words).\n"
72 | user_desc += "You don't need to recommend anything, but feel free to express your personal interests."
73 |
74 | gender_desc = "his" if simulated_profile["Gender"] == "Male" else "her"
75 | if domain == "movie" or domain == "music":
76 | for k in ["Accepted movies", "Accepted music", "Accepted celebrities", "Rejected movies", "Rejected music"]:
77 | kk = k.replace("Accepted", "liked").replace("Rejected", "disliked")
78 | profile_desc += "; {} {}: {}".format(gender_desc, kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else ""
79 | else:
80 | for k in ["Accepted food", "Accepted POI"]:
81 | kk = k.replace("Accepted", "liked")
82 | profile_desc += "; {} {}: {}".format(gender_desc, kk, simulated_profile[k]) if simulated_profile.get(k, "") != "" else ""
83 | profile_desc += "."
84 |
85 | user_dict = {
86 | "name": simulated_profile["Name"],
87 | "role_desc": user_desc,
88 | }
89 |
90 | # Describe the assistant
91 | if domain == "movie":
92 | assistant_desc = "You are {}, a movie enthusiast who enjoys a variety of films.\n".format(assistant_name)
93 | elif domain == "music":
94 | assistant_desc = "You are {}, a music enthusiast who enjoys a variety of music.\n".format(assistant_name)
95 | elif domain == "food":
96 | assistant_desc = "You are {}, a foodie who enjoys delicious food.\n".format(assistant_name)
97 | elif domain == "poi":
98 | assistant_desc = "You are {}, a food enthusiast who is interested in exploring different restaurants.\n".format(assistant_name)
99 | else:
100 | raise ValueError("Invalid domain: {}".format(domain))
101 |
102 | assistant_desc += "You are conversing with {}, whose profile is below: \n## {}\n\n".format(simulated_profile["Name"], profile_desc)
103 | assistant_desc += "Your goal is to proactively lead the conversation with {} towards the target {} \"{}\".\n".format(simulated_profile["Name"], domain, target[1])
104 | assistant_desc += "To start the conversation, please begin with a greeting and avoid mentioning the target {}.\n".format(domain)
105 | assistant_desc += "As the conversation progresses, use your domain knowledge to steer the discussed topic towards the target {} step by step.\n".format(domain)
106 | assistant_desc += "Be informative and engaging while providing insights to arouse {}'s interest.\n".format(simulated_profile["Name"])
107 | assistant_desc += "Remember to ultimately recommend \"{}\" as the focus of the conversation.\n".format(target[1])
108 | assistant_desc += "Your words at each turn should be concise (no longer than 30 words).\n\n"
109 | assistant_desc += "You may access the following domain knowledge for conversation: \n## {}.".format(domain_knowledge)
110 |
111 | assistant_dict = {
112 | "name": assistant_name,
113 | "role_desc": assistant_desc,
114 | }
115 |
116 | # Describe the moderator
117 | moderator_desc = "You are the moderator of a conversation. You need to determine whether the discussion between Role-S and Role-U should come to an immediate end.\n"
118 | moderator_desc += "The conversation should conclude under the following two conditions:\n"
119 | moderator_desc += "(1) If Role-S completes {} recommendation on \"{}\" and Role-U accepts it, and Role-S no longer takes the initiative for two rounds.\n".format(domain, target[1])
120 | moderator_desc += "(2) If Role-U explicitly rejects Role-S's recommendation on \"{}\" when Role-S has tried to recommend it for the second time.\n".format(target[1])
121 | moderator_desc += "In either of these cases, the conversation should be brought to an immediate end.\n\n"
122 |
123 | moderator_desc += "For example, here is a conversation:\n## {}".format(seed_conversation["seed_continue"])
124 | moderator_desc += "Should the conversation end? The answer is no.\n\n"
125 | moderator_desc += "Here is another conversation:\n## {}".format(seed_conversation["seed_end"])
126 | moderator_desc += "Should the conversation end? The answer is yes."
127 |
128 |
129 | terminal_condition = "Now, for the above discussion between {} (Role-S) and {} (Role-U), should the conversation end? Answer yes or no.".format(assistant_name, simulated_profile["Name"])
130 |
131 | moderator_dict = {
132 | "role_desc": moderator_desc,
133 | "terminal_condition": terminal_condition
134 | }
135 |
136 | return (env_desc, user_dict, assistant_dict, moderator_dict)
137 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import re
3 | import random
4 |
5 |
6 | def find_word_in_string(w, s):
7 | return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s)
8 |
9 | def normalize_profile(profile: dict, domain: str):
10 | """Nomalize profile based on specific domain"""
11 | norm_profile = {}
12 | for slot_k, slot_value in profile.items():
13 | if slot_k == "Age Range":
14 | norm_profile[slot_k] = slot_value.replace("years old", "").strip()
15 | elif slot_k == "Accepted Music": # mismatched slot key in raw data
16 | if "Accepted music" in norm_profile.keys():
17 | norm_profile["Accepted music"] += "; " + slot_value
18 | else:
19 | norm_profile["Accepted music"] = slot_value
20 | elif slot_k == "Accepted movie": # mismatched slot key in raw data
21 | if "Accepted movies" in norm_profile.keys():
22 | norm_profile["Accepted movies"] += "; " + slot_value
23 | else:
24 | norm_profile["Accepted movies"] = slot_value
25 | else:
26 | if slot_k in norm_profile.keys():
27 | norm_profile[slot_k] += "; " + slot_value
28 | else:
29 | norm_profile[slot_k] = slot_value
30 |
31 | for slot_k, slot_v in norm_profile.items():
32 | if "Accepted" in slot_k or "Rejected" in slot_k:
33 | if len(slot_v.split("; ")) > 2:
34 | norm_profile[slot_k] = "; ".join(slot_v.split("; ")[:2])
35 |
36 | # remove unnecessary slots for a specific domain
37 | assert domain in ["movie", "music", "food", "poi"]
38 | if "Accepted news" in norm_profile.keys():
39 | norm_profile.pop("Accepted news")
40 | if "Favorite news" in norm_profile.keys():
41 | norm_profile.pop("Favorite news")
42 | if "Reject" in norm_profile.keys():
43 | norm_profile.pop("Reject")
44 |
45 | if domain == "food" or domain == "poi":
46 | if "Accepted movies" in norm_profile.keys():
47 | norm_profile.pop("Accepted movies")
48 | if "Accepted music" in norm_profile.keys():
49 | norm_profile.pop("Accepted music")
50 | if "Accepted celebrities" in norm_profile.keys():
51 | norm_profile.pop("Accepted celebrities")
52 | if "Rejected movies" in norm_profile.keys():
53 | norm_profile.pop("Rejected movies")
54 | if "Rejected music" in norm_profile.keys():
55 | norm_profile.pop("Rejected music")
56 | else:
57 | if "Accepted food" in norm_profile.keys():
58 | norm_profile.pop("Accepted food")
59 | if "Accepted POI" in norm_profile.keys():
60 | norm_profile.pop("Accepted POI")
61 |
62 | return norm_profile
63 |
64 |
65 | def sample_profile(profile_slots, target_topic, domain):
66 | """Sample a profile different from raw_profile."""
67 | sampled_profile = {}
68 | for slot_key, slot_values in profile_slots.items():
69 | sampled_value = random.choice(slot_values)
70 | while sampled_value in target_topic or target_topic in sampled_value:
71 | sampled_value = random.choice(slot_values)
72 | sampled_profile[slot_key] = sampled_value
73 | # check age range and occupation
74 | if sampled_profile["Age Range"] == "Under 18":
75 | sampled_profile["Occupation"] = "Student"
76 | elif sampled_profile["Age Range"] == "18-25" or sampled_profile["Age Range"] == "26-35":
77 | sampled_profile["Occupation"] = random.choice(["Student", "Employed"])
78 | elif sampled_profile["Age Range"] == "36-50":
79 | sampled_profile["Occupation"] = "Employed"
80 | else:
81 | sampled_profile["Occupation"] = random.choice(["Employed", "Retired"])
82 |
83 | normed_profile = normalize_profile(sampled_profile, domain)
84 | return normed_profile
85 |
86 |
87 | def check_kg_exceed(kg_list, max_len):
88 | limit_len = max_len - len(kg_list)
89 | kg_str = " ".join([" ".join(kg) for kg in kg_list])
90 | kg_tokens = kg_str.split(" ")
91 | if len(kg_tokens) > limit_len:
92 | return True
93 | else:
94 | return False
95 |
96 | def check_topic_covered(sampled_kg, topic_path):
97 | sampled_objs = set()
98 | for triple in sampled_kg:
99 | s, p, o = triple
100 | sampled_objs.add(s)
101 | sampled_objs.add(o)
102 | is_covered = True
103 | for t in topic_path:
104 | if t != "NULL" and t not in sampled_objs:
105 | is_covered = False
106 | break
107 | return is_covered
108 |
109 | def get_outer_kg(kg_list, sampled_kg, topic_path):
110 | topic_list = []
111 | for t in topic_path:
112 | if t != "NULL":
113 | topic_list.append(t)
114 |
115 | sampled_objs = set()
116 | for triple in sampled_kg:
117 | s, p, o = triple
118 | sampled_objs.add(s)
119 | sampled_objs.add(o)
120 |
121 | tmp_kg = {}
122 | for t in topic_list:
123 | if t not in sampled_objs:
124 | for triple in kg_list:
125 | s, p, o = triple
126 | if s == t:
127 | if s in tmp_kg:
128 | tmp_kg[s].append(triple)
129 | else:
130 | tmp_kg[s] = [triple]
131 | elif o == t:
132 | if o in tmp_kg:
133 | tmp_kg[o].append(triple)
134 | else:
135 | tmp_kg[o] = [triple]
136 | outer_kg = []
137 | for k, v_list in tmp_kg.items():
138 | spo = random.sample(v_list, 1)
139 | outer_kg.append(spo[0])
140 |
141 | return outer_kg
142 |
143 | def sample_knowledge(raw_kg_list, target, topic_path, user_utt="", bot_utt="", max_len=300):
144 | kg_list = []
145 | for kg in raw_kg_list:
146 | s, p, o = kg
147 | if p == "Stars":
148 | if len(o.split()) <= 40:
149 | kg_list.append(kg)
150 | else:
151 | kg_list.append(kg)
152 |
153 | topic_trans = []
154 | kg_topic_path = []
155 | for t in topic_path:
156 | if t != "NULL":
157 | kg_topic_path.append(t)
158 | if len(kg_topic_path) > 1:
159 | for j in range(1, len(kg_topic_path)-1):
160 | topic_trans.append([kg_topic_path[j], kg_topic_path[j-1]])
161 |
162 | sampled_kg = []
163 | for kg in kg_list:
164 | s, p, o = kg
165 |
166 | if target[0] == "Food recommendation" and target[1] == "Marinated Fish" and p == "Specials" and o == "Marinated Fish":
167 | pass
168 | elif s == target[1] or o == target[1]:
169 | if not kg in sampled_kg:
170 | sampled_kg.append(kg)
171 | if "℃" in o and "℃" in bot_utt:
172 | if not kg in sampled_kg:
173 | sampled_kg.append(kg)
174 | if p == "Perfect for having" and (o.lower() in user_utt.lower() or o.lower() in bot_utt.lower()):
175 | if not kg in sampled_kg:
176 | sampled_kg.append(kg)
177 | if p.lower() in user_utt.lower() or p.lower() in bot_utt.lower():
178 | if p == "Sings" and s == topic_path[0]:
179 | pass
180 | elif p == "Achievement" and s == topic_path[0]:
181 | if o.lower() in bot_utt.lower():
182 | if not kg in sampled_kg:
183 | sampled_kg.append(kg)
184 | elif p == "Awards" and s == topic_path[0]:
185 | if o.lower() in bot_utt.lower():
186 | if not kg in sampled_kg:
187 | sampled_kg.append(kg)
188 | else:
189 | if s == topic_path[0] or o == topic_path[0]:
190 | if not kg in sampled_kg:
191 | sampled_kg.append(kg)
192 | if s == topic_path[0]:
193 | if o.lower() in bot_utt.lower():
194 | if not kg in sampled_kg:
195 | sampled_kg.append(kg)
196 | for tp in topic_trans:
197 | src, tgt = tp
198 | if (src == s and tgt in o) or (src in o and tgt == s):
199 | if not kg in sampled_kg:
200 | sampled_kg.append(kg)
201 | # check which topic not in sampled knowledge
202 | outer_kg = get_outer_kg(kg_list, sampled_kg, topic_path)
203 | if len(outer_kg) > 0:
204 | sampled_kg += outer_kg
205 |
206 | noised_kg = []
207 | for kg in kg_list:
208 | if not kg in sampled_kg:
209 | noised_kg.append(kg)
210 | random.shuffle(noised_kg)
211 |
212 | num_spling = 1
213 | tmp_kg = []
214 | while True:
215 | if num_spling > len(noised_kg):
216 | break
217 | tmp_kg = random.sample(noised_kg, num_spling)
218 | check_kg = sampled_kg + tmp_kg
219 | if check_kg_exceed(check_kg, max_len=max_len):
220 | break
221 | num_spling += 1
222 | sampled_kg += tmp_kg[:-1]
223 | random.shuffle(sampled_kg)
224 |
225 | return sampled_kg
--------------------------------------------------------------------------------
/eval/eval_generation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import argparse
4 | import json
5 | import numpy as np
6 | from collections import Counter
7 | import nltk
8 | from nltk.translate import bleu_score
9 | from nltk.translate.bleu_score import SmoothingFunction
10 |
11 |
12 | def calc_bleu(hyps, refs):
13 | """ Calculate bleu score """
14 | bleu_1 = []
15 | bleu_2 = []
16 | for hyp, ref in zip(hyps, refs):
17 | try:
18 | score = bleu_score.sentence_bleu(
19 | [ref], hyp,
20 | smoothing_function=SmoothingFunction().method1,
21 | weights=[1, 0, 0, 0])
22 | except:
23 | score = 0
24 | bleu_1.append(score)
25 | try:
26 | score = bleu_score.sentence_bleu(
27 | [ref], hyp,
28 | smoothing_function=SmoothingFunction().method1,
29 | weights=[0, 1, 0, 0])
30 | except:
31 | score = 0
32 | bleu_2.append(score)
33 | bleu_1 = np.average(bleu_1)
34 | bleu_2 = np.average(bleu_2)
35 | avg_bleu = (bleu_1 + bleu_2) / 2
36 | return bleu_1, bleu_2, avg_bleu
37 |
38 |
39 | def calc_knowledge_f1(hyps, knowledge_refs, knowledge_alls):
40 | """" Calculate knowledge f1 score """
41 | golden_total = 0.0
42 | pred_total = 0.0
43 | hit_total = 0.0
44 | for response, golden_kd, all_kd in zip(hyps, knowledge_refs, knowledge_alls):
45 | golden_total += len(golden_kd)
46 | for kd in golden_kd:
47 | if is_obj_hit(response, kd):
48 | hit_total += 1
49 | for kd in all_kd:
50 | if is_obj_hit(response, kd):
51 | pred_total += 1
52 | p = hit_total / pred_total if pred_total > 0 else 0
53 | r = hit_total / golden_total if golden_total > 0 else 0
54 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0
55 | return f1
56 |
57 |
58 | def calc_persona_f1(hyps, persona_refs, persona_alls):
59 | """Calculate persona f1 score"""
60 | golden_total = 0.0
61 | pred_total = 0.0
62 | hit_total = 0.0
63 | for response, golden_persona, all_persona in zip(hyps, persona_refs, persona_alls):
64 | golden_total += len(golden_persona)
65 | for persona in golden_persona:
66 | if is_obj_hit(response, persona, threshold=0.8):
67 | hit_total += 1
68 | for persona in all_persona:
69 | if is_obj_hit(response, persona, threshold=0.8):
70 | pred_total += 1
71 | p = hit_total / pred_total if pred_total > 0 else 0
72 | r = hit_total / golden_total if golden_total > 0 else 0
73 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0
74 | return f1
75 |
76 | def calc_succ(eval_fp, gold_fp):
77 | all_eval, all_gold = [], []
78 | with open(eval_fp, 'r', encoding='utf-8') as fr:
79 | for line in fr:
80 | sample = json.loads(line)
81 | all_eval.append(sample)
82 | with open(gold_fp, 'r', encoding='utf-8') as fr:
83 | for line in fr:
84 | raw_sample = json.loads(line)
85 | sample = {
86 | "id": raw_sample["id"],
87 | "target": raw_sample["target"],
88 | "response": raw_sample["response"]
89 | }
90 | all_gold.append(sample)
91 | assert len(all_eval) == len(all_gold)
92 |
93 | topic_hit, topic_total = 0, 0
94 | movie_hit, music_hit, poi_hit, food_hit = 0, 0, 0, 0
95 | movie_total, music_total, poi_total, food_total = 0, 0, 0, 0
96 |
97 | for idx, gold_sample in enumerate(all_gold):
98 | if gold_sample["target"][1].lower() in gold_sample["response"].lower():
99 | topic_total += 1
100 | eval_action = gold_sample["target"][0]
101 | eval_topic = gold_sample["target"][1]
102 |
103 | # eval target turn and neighboring turns
104 | eval_list = get_eval_response(idx, all_eval, all_gold)
105 |
106 | eval_topic = " ".join(nltk.word_tokenize(eval_topic))
107 | eval_list = [" ".join(nltk.word_tokenize(eval_response)) for eval_response in eval_list]
108 |
109 | if is_topic_hit(eval_topic, eval_list):
110 | topic_hit += 1
111 |
112 | if eval_action == "Movie recommendation":
113 | movie_total += 1
114 | if is_topic_hit(eval_topic, eval_list):
115 | movie_hit += 1
116 | elif eval_action == "Music recommendation" or eval_action == "Play music":
117 | music_total += 1
118 | if is_topic_hit(eval_topic, eval_list):
119 | music_hit += 1
120 | elif eval_action == "POI recommendation":
121 | poi_total += 1
122 | if is_topic_hit(eval_topic, eval_list):
123 | poi_hit += 1
124 | elif eval_action == "Food recommendation":
125 | food_total += 1
126 | if is_topic_hit(eval_topic, eval_list):
127 | food_hit += 1
128 | succ_rate = float(topic_hit) / topic_total
129 | movie_rec_sr = float(movie_hit) / movie_total
130 | music_rec_sr = float(music_hit) / music_total
131 | poi_rec_sr = float(poi_hit) / poi_total
132 | food_rec_sr = float(food_hit) / food_total
133 | print("Succ.: {:.2f}%".format(succ_rate*100))
134 | print("Succ.-Movie: {}/{} = {:.2f}%".format(movie_hit, movie_total, movie_rec_sr*100))
135 | print("Succ.-Music: {}/{} = {:.2f}%".format(music_hit, music_total, music_rec_sr*100))
136 | print("Succ.-POI: {}/{} = {:.2f}%".format(poi_hit, poi_total, poi_rec_sr*100))
137 | print("Succ.-Food: {}/{} = {:.2f}%".format(food_hit, food_total, food_rec_sr*100))
138 |
139 |
140 | def get_eval_response(idx, eval_samples, gold_samples):
141 | eval_list = [eval_samples[idx]["response"]]
142 | dialog_id = gold_samples[idx]["id"]
143 | if idx - 1 >= 0 and gold_samples[idx-1]["id"] == dialog_id:
144 | eval_list.append(eval_samples[idx-1]["response"])
145 | if idx + 1 < len(gold_samples) and gold_samples[idx+1]["id"] == dialog_id:
146 | eval_list.append(eval_samples[idx+1]["response"])
147 | return eval_list
148 |
149 | def is_topic_hit(topic, candidates):
150 | for cand in candidates:
151 | if topic.lower() in cand.lower():
152 | return True
153 | return False
154 |
155 | def is_obj_hit(utterance_toks, obj_str, threshold=0.55):
156 | utterance = " ".join(utterance_toks)
157 | flag = False
158 | if obj_str in utterance:
159 | flag = True
160 | else:
161 | # English word-level
162 | common = Counter(utterance.split()) & Counter(obj_str.split())
163 | # knowledge recall
164 | hit_char_total = sum(common.values())
165 | golden_char_total = len(obj_str)
166 | recall = hit_char_total / golden_char_total if golden_char_total > 0 else 0
167 | if recall >= threshold:
168 | flag = True
169 | return flag
170 |
171 | def label_knowledge(utterance_toks, kg_list, lower_case=True):
172 | gold_knowledge = []
173 | all_objs = set()
174 | for triple in kg_list:
175 | assert len(triple) == 3
176 | all_objs.add(triple[0].lower() if lower_case else triple[0])
177 | all_objs.add(triple[2].lower() if lower_case else triple[2])
178 | for obj in all_objs:
179 | if is_obj_hit(utterance_toks, obj):
180 | gold_knowledge.append(obj)
181 | all_objs = list(all_objs)
182 | return all_objs, gold_knowledge
183 |
184 | def label_persona(utterance_toks, persona_dict, lower_case=True):
185 | all_personas = []
186 | gold_persona = []
187 | for k, v in persona_dict.items():
188 | if v != '' and v != ' ':
189 | all_personas.append(v.lower() if lower_case else v)
190 | for persona in all_personas:
191 | if is_obj_hit(utterance_toks, persona, threshold=0.8):
192 | gold_persona.append(persona)
193 | return all_personas, gold_persona
194 |
195 |
196 | def load_data(fp, is_gold=False, lower_case=True):
197 | samples = []
198 | all_knowledges, gold_knowledges = [], []
199 | all_personas, gold_personas = [], []
200 | with open(fp, 'r', encoding='utf-8') as fr:
201 | for idx, line in enumerate(fr):
202 | sample = json.loads(line)
203 | response = sample["response"].lower() if lower_case else sample["response"]
204 | # English word-level
205 | sentence_toks = nltk.word_tokenize(response)
206 | samples.append(sentence_toks)
207 | if is_gold:
208 | knowledge = sample["knowledge"]
209 | all_k, all_k= label_knowledge(sentence_toks, knowledge, lower_case=lower_case)
210 | all_knowledges.append(all_k)
211 | gold_knowledges.append(all_k)
212 | persona = sample["user_profile"]
213 | all_p, gold_p = label_persona(sentence_toks, persona, lower_case=lower_case)
214 | all_personas.append(all_p)
215 | gold_personas.append(gold_p)
216 | if is_gold:
217 | assert len(samples) == len(all_knowledges) and \
218 | len(samples) == len(gold_knowledges) and \
219 | len(samples) == len(all_personas)
220 | return (samples, all_knowledges, gold_knowledges, all_personas, gold_personas)
221 | else:
222 | return samples
223 |
224 |
225 | if __name__ == "__main__":
226 | parser = argparse.ArgumentParser()
227 | parser.add_argument("--eval_file", type=str)
228 | parser.add_argument("--gold_file", type=str)
229 | args = parser.parse_args()
230 |
231 | preds = load_data(args.eval_file)
232 | refs, all_knowledges, ref_knowlwedges, all_peronas, ref_personas = load_data(args.gold_file, is_gold=True)
233 | assert len(preds) == len(refs)
234 |
235 | # calculate bleu
236 | bleu1, bleu2, avg_bleu = calc_bleu(preds, refs)
237 |
238 | # calculate knowledge-F1
239 | kg_f1 = calc_knowledge_f1(preds, ref_knowlwedges, all_knowledges)
240 |
241 | # calculate persona-F1
242 | persona_f1 = calc_persona_f1(preds, ref_personas, all_peronas)
243 |
244 | output_str = "Avg. BLEU: %.3f\n" % avg_bleu
245 | output_str += "Knowledge F1: %.2f%%\n" % (kg_f1 * 100)
246 | output_str += "Persona F1: %.2f%%" % (persona_f1 * 100)
247 |
248 | print(output_str)
249 |
250 | # calculate target success
251 | calc_succ(args.eval_file, args.gold_file)
252 |
--------------------------------------------------------------------------------
/dialog_simulation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | import json
4 | import os
5 | import random
6 | import argparse
7 | from tqdm import tqdm
8 | from chatarena.agent import Player, Moderator
9 | from chatarena.backends import OpenAIChat
10 | from chatarena.environments.conversation import ModeratedConversation
11 | from chatarena.arena import Arena
12 | from data_utils import find_word_in_string
13 | from instruction import create_instruct
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument("--cached_seed_path", type=str, required=True,
20 | help="The cached seed dialog file.")
21 | parser.add_argument("--profile_path", type=str, default="seed_dataset/caches/db_slot/slot_profiles.json",
22 | help="The user profile slot-values file.")
23 | parser.add_argument("--output_dir", type=str, default="data/TopDial",
24 | help="The output directory to save the simulated dialog data.")
25 | parser.add_argument("--max_interaction_step", type=int,default=12,
26 | help="The max number of interaction steps, i.e., 2 * max rounds.")
27 | parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo",
28 | help="The chat model to use.")
29 | parser.add_argument("--temperature", type=float, default=0.75,
30 | help="The temperature to use in sampling.")
31 | parser.add_argument("--max_system_tokens", type=int, default=100,
32 | help="The max number of tokens to generate for the system.")
33 | parser.add_argument("--max_user_tokens", type=int, default=80,
34 | help="The max number of tokens to generate for the user.")
35 | parser.add_argument("--max_moderator_tokens", type=int, default=10,
36 | help="The max number of tokens to generate for the moderator.")
37 | parser.add_argument("--show_description", type=str2bool, default="true",
38 | help="Whether to show the role description.")
39 | parser.add_argument("--show_message", type=str2bool, default="true",
40 | help="Whether to show the conversation messages.")
41 | parser.add_argument("--random_seed", type=int, default=42)
42 | return parser.parse_args()
43 |
44 | def str2bool(v):
45 | if v.lower() in ('true', 'yes', 't', 'y', '1'):
46 | return True
47 | elif v.lower() in ('false',' no', 'f', 'n', '0'):
48 | return False
49 | else:
50 | raise argparse.ArgumentTypeError("Unsupported value encountered.")
51 |
52 | def clean_utterance(s):
53 | s = s.strip()
54 | for start_str in ['[1]', '[2]', '[3]', '[4]', '[5]', '[6]', '[7]', '[8]', '[9]']:
55 | if s.startswith(start_str):
56 | s = s[len(start_str):].strip()
57 | return s
58 |
59 | def prompt_conversation(raw_goal, conversation):
60 | """Prompt the conversation context."""
61 | conversation_ctx = ""
62 | for idx, utt in enumerate(conversation):
63 | utt = clean_utterance(utt)
64 | if "User Initiative" in raw_goal:
65 | if idx % 2 == 0:
66 | conversation_ctx += f"[Role-U]: {utt}\n\n"
67 | else:
68 | conversation_ctx += f"[Role-S]: {utt}\n\n"
69 | else:
70 | if idx % 2 == 0:
71 | conversation_ctx += f"[Role-S]: {utt}\n\n"
72 | else:
73 | conversation_ctx += f"[Role-U]: {utt}\n\n"
74 | return conversation_ctx
75 |
76 | def sample_seed_conversation(raw_goal, conversation):
77 | """Sample seed conversations (continue | end)."""
78 | conv_lens = len(conversation)
79 | continue_len = random.choice(range(1, int(conv_lens * 0.6)))
80 | conv_continue = prompt_conversation(raw_goal, conversation[:continue_len])
81 | conv_end = prompt_conversation(raw_goal, conversation)
82 | seed_conv = {
83 | "seed_continue": conv_continue,
84 | "seed_end": conv_end
85 | }
86 | return seed_conv
87 |
88 | def sample_assistant_role(profile_slots, user_profile):
89 | """Sample an assistant role."""
90 | all_names = profile_slots["Name"]
91 | user_name = user_profile["Name"]
92 | sampled_name = random.choice(all_names)
93 | while find_word_in_string(sampled_name, user_name):
94 | sampled_name = random.choice(all_names)
95 | return sampled_name
96 |
97 | def sample_personality():
98 | """Sample a personality based on Big Five personality traits."""
99 | personalities = {
100 | "agreeableness": ["trustworthy, straightforward, and generous", "unreliable, complicated, meager, and boastful"],
101 | "conscientiousness": ["efficient, organized, and careful", "inefficient, careless, and sloppy"],
102 | "extraversion": ["outgoing, energetic, and talkative", "shy, reserved, and quiet"],
103 | "neuroticism": ["sensitive, nervous, and insecure", "secure, confident, and calm"],
104 | "openness": ["intellectual, imaginative, and curious", "unimaginative, uncreative, and conventional"]
105 | }
106 | sampled_personality = {}
107 | for trait, values in personalities.items():
108 | sampled_personality[trait] = random.choice(values)
109 | return sampled_personality
110 |
111 |
112 | def generate_dialog_data(
113 | profile_path,
114 | seed_path,
115 | output_dir,
116 | max_interaction_step=10,
117 | model_name="gpt-3.5-turbo",
118 | temperature=0.75,
119 | max_system_tokens=100,
120 | max_user_tokens=80,
121 | max_moderator_tokens=10,
122 | show_description=True,
123 | show_message=True,
124 | ):
125 | """Generate dialog data from a seed dialog file."""
126 | profile_slots = json.load(open(profile_path, "r", encoding='utf-8'))
127 | print(f"Loaded user profiles with {len(profile_slots)} slot keys.")
128 |
129 | seed_dialogs = []
130 | with open(seed_path, "r", encoding='utf-8') as f:
131 | for line in f:
132 | seed_dialogs.append(json.loads(line))
133 | print(f"Loaded {len(seed_dialogs)} cached dialogs.")
134 |
135 | if not os.path.exists(output_dir):
136 | os.makedirs(output_dir)
137 | if "test_seen" in seed_path:
138 | output_path = os.path.join(output_dir, "dialogue_test_seen.jsonl")
139 | elif "test_unseen" in seed_path:
140 | output_path = os.path.join(output_dir, "dialogue_test_unseen.jsonl")
141 | elif "dev" in seed_path:
142 | output_path = os.path.join(output_dir, "dialogue_dev.jsonl")
143 | else:
144 | output_path = os.path.join(output_dir, "dialogue_train.jsonl")
145 |
146 | with open(output_path, "w", encoding='utf-8') as fw:
147 | for seed_dialog in tqdm(seed_dialogs):
148 | simulated_profile = seed_dialog["user_profile"]
149 | sampled_knowledge = seed_dialog["knowledge"]
150 | target = seed_dialog["target"]
151 |
152 | conversation = seed_dialog["seed_conversation"]
153 | seed_conv = sample_seed_conversation(seed_dialog["original_goal"], conversation)
154 |
155 | # randomly sample a personality
156 | simulated_personality = sample_personality()
157 | assistant_name = sample_assistant_role(profile_slots, simulated_profile)
158 |
159 | env_desc, user_dict, assistant_dict, moderator_dict = create_instruct(
160 | target=target,
161 | simulated_profile=simulated_profile,
162 | simulated_personality=simulated_personality,
163 | assistant_name=assistant_name,
164 | domain_knowledge=sampled_knowledge,
165 | seed_conversation=seed_conv
166 | )
167 | assistant = Player(
168 | name=assistant_dict["name"], backend=OpenAIChat(model=model_name, temperature=temperature, max_tokens=max_system_tokens),
169 | role_desc=assistant_dict["role_desc"], global_prompt=env_desc
170 | )
171 | user = Player(
172 | name=user_dict["name"], backend=OpenAIChat(model=model_name, temperature=temperature, max_tokens=max_user_tokens),
173 | role_desc=user_dict["role_desc"], global_prompt=env_desc
174 | )
175 | moderator = Moderator(
176 | backend=OpenAIChat(model=model_name, temperature=temperature, max_tokens=max_moderator_tokens),
177 | role_desc=moderator_dict["role_desc"], terminal_condition=moderator_dict["terminal_condition"]
178 | )
179 | # let assistant start the conversation
180 | env = ModeratedConversation(player_names=[p.name for p in [assistant, user]], moderator=moderator, moderator_period="round")
181 | arena = Arena(players=[assistant, user], environment=env, global_prompt=env_desc)
182 |
183 | arena.launch_cli(max_steps=max_interaction_step, show_description=show_description, show_message=show_message, interactive=False)
184 |
185 | #print("Save? (y/n)")
186 | #if input() == "n":
187 | # continue
188 |
189 | # save the simulated dialog to file
190 | messages = env.get_observation()
191 | simulated_convs = []
192 | for msg in messages:
193 | if msg.agent_name == assistant.name:
194 | utt = {"system": msg.content}
195 | else:
196 | utt = {"user": msg.content}
197 | simulated_convs.append(utt)
198 |
199 | write_line = {
200 | "id": "s_" + str(seed_dialog["id"]),
201 | "user_profile": simulated_profile,
202 | "user_personality": simulated_personality,
203 | "knowledge": sampled_knowledge,
204 | "target": target,
205 | "conversation": simulated_convs
206 | }
207 | fw.write(json.dumps(write_line, ensure_ascii=False) + "\n")
208 | fw.flush()
209 |
210 | print("Sleeping for 5 seconds...")
211 | time.sleep(5)
212 |
213 | #print("Continue? (y/n)")
214 | #if input() == "n":
215 | # break
216 |
217 |
218 | if __name__ == '__main__':
219 | args = parse_args()
220 | random.seed(args.random_seed)
221 |
222 | generate_dialog_data(args.profile_path, args.cached_seed_path, args.output_dir,
223 | max_interaction_step=args.max_interaction_step,
224 | model_name=args.model_name,
225 | temperature=args.temperature,
226 | max_system_tokens=args.max_system_tokens,
227 | max_user_tokens=args.max_user_tokens,
228 | max_moderator_tokens=args.max_moderator_tokens,
229 | show_description=args.show_description,
230 | show_message=args.show_message)
231 |
--------------------------------------------------------------------------------
/data_preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import os
4 | import random
5 | import argparse
6 | from tqdm import tqdm
7 | from py2neo import Graph
8 | from data_utils import normalize_profile, sample_profile, sample_knowledge
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument(
15 | "--seed_dataset_dir",
16 | type=str,
17 | default="seed_dataset/DuRecDial2",
18 | help="The seed dataset directory."
19 | )
20 | parser.add_argument(
21 | "--cache_dir",
22 | type=str,
23 | default="seed_dataset/caches",
24 | help="The cached data directory."
25 | )
26 | parser.add_argument(
27 | "--num_instance_per_seed",
28 | type=int,
29 | default=3,
30 | help="The number of instances to curate for each seed dialog.",
31 | )
32 | parser.add_argument(
33 | "--random_seed",
34 | type=int,
35 | default=42,
36 | )
37 | return parser.parse_args()
38 |
39 |
40 | def extract_profile(data_fp_list, save_fp=None):
41 | """Extract all user profile slots from the given data file."""
42 |
43 | SLOT_KEYS = [
44 | "Age Range", "Name", "Gender", "Residence", "Occupation", "POI",
45 | "Accepted movies", "Accepted music", "Accepted celebrities", "Accepted food", "Accepted POI",
46 | "Reject", "Rejected movies", "Rejected music"
47 | ]
48 | ALL_SLOTS = dict()
49 | for k in SLOT_KEYS:
50 | ALL_SLOTS[k] = set()
51 |
52 | for data_fp in data_fp_list:
53 | with open(data_fp, 'r', encoding='utf-8') as fp:
54 | for line in fp:
55 | sample = json.loads(line.strip())
56 | for slot in sample['user_profile']:
57 | if slot in SLOT_KEYS:
58 | slot_value = list(sample['user_profile'][slot].split("; "))
59 | for v in slot_value:
60 | if slot == "Age Range":
61 | v = v.replace("years old", "").strip()
62 | ALL_SLOTS[slot].add(v)
63 | elif slot == "Accepted Music":
64 | slot_value = list(sample['user_profile'][slot].split("; "))
65 | for v in slot_value:
66 | ALL_SLOTS["Accepted music"].add(v)
67 | elif slot == "Accepted movie":
68 | slot_value = list(sample['user_profile'][slot].split("; "))
69 | for v in slot_value:
70 | ALL_SLOTS["Accepted movies"].add(v)
71 | else:
72 | print("Out of slot keys: ", slot)
73 | for k in ALL_SLOTS:
74 | ALL_SLOTS[k] = list(ALL_SLOTS[k])
75 | print(k, len(ALL_SLOTS[k]))
76 | if save_fp is not None:
77 | with open(save_fp, 'w', encoding='utf-8') as fp:
78 | json.dump(ALL_SLOTS, fp, indent=4, ensure_ascii=False)
79 | print("Saved to {}".format(save_fp))
80 |
81 |
82 | def exe_query(graph: Graph, query: str):
83 | triple_dict = {}
84 | results = graph.run(query).data()
85 | for res in results:
86 | s = "{}".format(res['s.value'])
87 | r = "{}".format(res['type(r)'])
88 | o = "{}".format(res['o.value'])
89 | kk = "{}__REL__{}".format(s, r)
90 | if kk in triple_dict.keys():
91 | triple_dict[kk].append(o)
92 | else:
93 | triple_dict[kk] = [o]
94 | triples = []
95 | for kk, vv in triple_dict.items():
96 | s, r = kk.split("__REL__")
97 | o = random.choice(vv)
98 | triples.append([s, r, o])
99 |
100 | return triples
101 |
102 | def ground_knowledge(graph, data_fp_list, profile_fp, save_dir, num_instance_per_seed=3):
103 | """Ground seed dialogs with domain knowledge and comments."""
104 |
105 | profile_slots = json.load(open(profile_fp, "r", encoding='utf-8'))
106 | print(f"Loaded user profiles with {len(profile_slots)} slot keys.")
107 |
108 | for data_fp in data_fp_list:
109 | seed_dialogs = []
110 | with open(data_fp, "r", encoding='utf-8') as f:
111 | for line in f:
112 | seed_dialogs.append(json.loads(line))
113 | print(f"Loaded {len(seed_dialogs)} seed dialogs from {data_fp}.")
114 |
115 | save_fp = os.path.join(save_dir, "cache_{}".format(data_fp.split("/")[-1]))
116 | with open(save_fp, "w", encoding='utf-8') as fw:
117 | for seed_dialog in tqdm(seed_dialogs):
118 | user_profile = seed_dialog["user_profile"]
119 | knowledge = seed_dialog["knowledge_graph"]
120 | target = seed_dialog["target"]
121 |
122 | domain = ""
123 | target_action = target[0].lower()
124 | if "movie" in target_action:
125 | domain = "movie"
126 | elif "music" in target_action:
127 | domain = "music"
128 | elif "food" in target_action:
129 | domain = "food"
130 | elif "poi" in target_action:
131 | domain = "poi"
132 | else:
133 | raise ValueError("Invalid target action: {}".format(target_action))
134 |
135 | for idx in range(num_instance_per_seed):
136 | if idx == 0:
137 | # adopt raw user profile
138 | simulated_profile = normalize_profile(user_profile, domain)
139 | else:
140 | # sample a profile different from raw user profile
141 | simulated_profile = sample_profile(profile_slots, target_topic=target[1], domain=domain)
142 |
143 | sampled_knowledge = sample_knowledge(knowledge, target, topic_path=seed_dialog["topic_path"], max_len=300)
144 |
145 | # sample comment about the target topic
146 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND type(r)="{}" RETURN s.value, type(r), o.value'.format(target[1], "Comments")
147 | target_comments = exe_query(graph, query_t)
148 | if len(target_comments) > 0:
149 | target_comment = random.choice(target_comments)
150 | sampled_knowledge.append(target_comment)
151 |
152 | profile_knowledge = []
153 | for slot_key, slot_value in simulated_profile.items():
154 | if "movies" in slot_key or "music" in slot_key:
155 | # sample domain knowledge about movies/music
156 | entities = slot_value.split("; ")
157 | for ent in entities:
158 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND (type(r)="{}" OR type(r)="{}" OR type(r)="{}" OR type(r)="{}") RETURN s.value, type(r), o.value'.format(
159 | ent, "Stars", "Sings", "Type", "Comments")
160 | triples = exe_query(graph, query_t)
161 | if len(triples) > 0:
162 | ss_triples = random.choice(triples)
163 | profile_knowledge.append(ss_triples)
164 | elif "celebrities" in slot_key:
165 | # sample domain knowledge about celebrities
166 | entities = slot_value.split("; ")
167 | for ent in entities:
168 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND (type(r)="{}" OR type(r)="{}" OR type(r)="{}") RETURN s.value, type(r), o.value'.format(
169 | ent, "Intro", "Achievement", "Comments")
170 | triples = exe_query(graph, query_t)
171 | if len(triples) > 0:
172 | ss_triples = random.choice(triples)
173 | profile_knowledge.append(ss_triples)
174 | elif "food" in slot_key or "Accepted POI" in slot_key:
175 | # sample domain knowledge about food/POI
176 | entities = slot_value.split("; ")
177 | for ent in entities:
178 | query_t = 'MATCH (s)-[r]->(o) WHERE s.value="{}" AND (type(r)="{}" OR type(r)="{}" OR type(r)="{}" OR type(r)="{}") RETURN s.value, type(r), o.value'.format(
179 | ent, "Price per person", "Rating", "Address", "Comments")
180 | triples = exe_query(graph, query_t)
181 | if len(triples) > 0:
182 | ss_triples = random.choice(triples)
183 | profile_knowledge.append(ss_triples)
184 | knowledge_str_list = ["__SEP__".join(triple) for triple in sampled_knowledge]
185 | for triple in profile_knowledge:
186 | triple_str = "__SEP__".join(triple)
187 | if triple_str not in knowledge_str_list:
188 | sampled_knowledge.append(triple)
189 |
190 | new_dialog = {
191 | "id": str(seed_dialog["id"]) + "_{}".format(idx),
192 | "original_goal": seed_dialog["original_goal"],
193 | "user_profile": simulated_profile,
194 | "knowledge": sampled_knowledge,
195 | "target": target,
196 | "seed_conversation": seed_dialog["conversation"],
197 | "seed_action_path": seed_dialog["action_path"],
198 | "seed_topic_path": seed_dialog["topic_path"],
199 | }
200 | line = json.dumps(new_dialog, ensure_ascii=False)
201 | fw.write(line + "\n")
202 | fw.flush()
203 | print("Saved {} simulated dialogs to {}.".format(num_instance_per_seed * len(seed_dialogs), save_fp))
204 |
205 |
206 | if __name__ == "__main__":
207 | args = parse_args()
208 | random.seed(args.random_seed)
209 |
210 | train_fp = os.path.join(args.data_dir, "seed_dialogue_train.jsonl")
211 | dev_fp = os.path.join(args.data_dir, "seed_dialogue_dev.jsonl")
212 | test_seen_fp = os.path.join(args.data_dir,"seed_dialogue_test_seen.jsonl")
213 | test_unseen_fp = os.path.join(args.data_dir, "seed_dialogue_test_unseen.jsonl")
214 |
215 | if not os.path.exists(args.cache_dir):
216 | os.makedirs(args.cache_dir)
217 |
218 | # prepare user profile slots
219 | saved_dir = os.path.join(args.cache_dir, "db_slot")
220 | if not os.path.exists(saved_dir):
221 | os.makedirs(saved_dir)
222 |
223 | saved_profile_fp = os.path.join(saved_dir, "slot_profiles.json")
224 | if not os.path.exists(saved_profile_fp):
225 | print("Extracting user profile slot-values...")
226 | extract_profile(data_fp_list=[train_fp, dev_fp, test_seen_fp, test_unseen_fp], save_fp=saved_profile_fp)
227 | else:
228 | print("File exists: {}".format(saved_profile_fp))
229 |
230 | # prepare domain knowledge and topic-related comments
231 | # set neo4j database connection (username: neo4j, password: neo4j)
232 | graph = Graph("http://localhost:7474", auth=("neo4j", "neo4j"))
233 | ground_knowledge(graph, data_fp_list=[train_fp, dev_fp, test_seen_fp, test_unseen_fp],
234 | profile_fp=saved_profile_fp, save_dir=args.cache_dir,
235 | num_instance_per_seed=args.num_instance_per_seed)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------