├── .gitignore ├── LICENSE ├── README.md ├── agents ├── __init__.py ├── history.py ├── modules.py └── selfconscious_blender.py ├── assets └── figure1.png ├── environment.yml ├── eval_dnli.py ├── eval_personachat.py ├── modules ├── __init__.py └── dnli_bert.py └── tasks ├── __init__.py ├── build.py ├── teachers.py ├── test_persona_map.pkl ├── train_persona_map.pkl └── valid_persona_map.pkl /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .cenv 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hyunwoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pragmatic Self-Consciousness
for Improving Persona Consistency in Dialogues 2 | 3 | ![figure](assets/figure1.png) 4 | 5 | **Official PyTorch implementation of our EMNLP paper:**
6 | [Hyunwoo Kim](https://hyunw.kim), [Byeongchang Kim](https://bckim92.github.io), and [Gunhee Kim](https://vision.snu.ac.kr/gunhee). Will I Sound Like Me? Improving Persona Consistency in Dialogues through Pragmatic Self-Consciousness. _EMNLP_, 2020 [[Paper]](https://arxiv.org/abs/2004.05816) 7 | 8 | * **TL;DR**: Inspired by social cognition and pragmatics, we model _public self-consciousness_ in existing dialogue agents with an imaginary listener to improve consistency. Compared to previous works, our method does not require additional consistency-related labels or training. 9 | 10 | Earlier version of this work was also accepted at ICLR 2020 [Bridging AI and Cognitive Science (BAICS) workshop](https://baicsworkshop.github.io/) as an oral presentation. 11 | 12 | 13 | ## Reference 14 | 15 | If you use the materials in this repository as part of any published research, we ask you to cite the following [paper](https://arxiv.org/abs/2004.05816): 16 | 17 | ```bibtex 18 | @inproceedings{Kim:2020:selfc, 19 | title={Will I Sound Like Me? Improving Persona Consistency in Dialogues through Pragmatic Self-Consciousness}, 20 | author={Kim, Hyunwoo and Kim, Byeongchang and Kim, Gunhee}, 21 | booktitle={EMNLP}, 22 | year=2020 23 | } 24 | ``` 25 | 26 | ### Have any question? 27 | Please contact [Hyunwoo Kim](https://hyunw.kim) at hyunw.kim@vl.snu.ac.kr. 28 | 29 | ## Implementation 30 | 31 | ### System Requirements 32 | 33 | * Python 3.6.8 34 | * Pytorch 1.6.0 35 | * CUDA 10.1 supported GPU with at least 12GB memory 36 | * See [environment.yml](https://github.com/skywalker023/pragmatic-consistency/blob/master/environment.yml) for details 37 | 38 | ### Environment setup 39 | 40 | Our code is built on the [ParlAI](https://parl.ai/) framework.
41 | We recommend you create a conda environment as follows 42 | 43 | ```bash 44 | conda env create -f environment.yml 45 | ``` 46 | 47 | and activate it with 48 | 49 | ```bash 50 | conda activate pragmatic-consistency 51 | ``` 52 | 53 | ## Running Experiments 54 | 55 | ### Self-conscious Blender for its persona 56 | 57 | #### Dialogue NLI 58 | 59 | ```bash 60 | python eval_dnli.py --conscious-target self -t tasks.teachers:SelfConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --fp16 false 61 | ``` 62 | 63 | #### PersonaChat 64 | 65 | ```bash 66 | python eval_personachat.py --conscious-target self -t tasks.teachers:SelfConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --batchsize 48 --fp16 false 67 | ``` 68 | 69 | ### Self-conscious Blender for its context 70 | 71 | #### Dialogue NLI 72 | 73 | ```bash 74 | python eval_dnli.py --conscious-target context -t tasks.teachers:ContextConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --fp16 false 75 | ``` 76 | 77 | #### PersonaChat 78 | 79 | ```bash 80 | python eval_personachat.py --conscious-target context -t tasks.teachers:ContextConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --batchsize 48 --fp16 false 81 | ``` 82 | 83 | 💡 In case you want to run the evaluation with vanilla Blender as is, set the `--conscious-target` to `none`. 84 | 85 | 86 | ## Acknowledgements 87 | 88 | We would like to thank [Reuben Cohn-Gordon](https://reubencohngordon.com/), [Sean Welleck](https://cs.nyu.edu/~welleck/), [Junhyug Noh](https://junhyug.github.io/) and [Jiwan Chung](https://vl.snu.ac.kr/people/jiwanchung.html) for their valuable comments. We also thank the anonymous reviewers for their thoughtful suggestions on this work. 89 | 90 | This research was supported by Brain Research Program by National Research Foundation of Korea (NRF) (2017M3C7A1047860), Institute of Information \& communications Technology Planning \& Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2017-0-01772, Video Turing Test, No. 2019-0-01082, SW StarLab), and Creative Pioneering Researchers Program through Seoul National University. 91 | 92 | 93 | ## License 94 | 95 | This repository is MIT licensed. See the [LICENSE](https://github.com/skywalker023/pragmatic-consistency/blob/master/LICENSE) file for details. 96 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/agents/__init__.py -------------------------------------------------------------------------------- /agents/history.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This file is derived from parlai/core/seq2seq/seq2seq.py. 9 | In particular, it's derived from an older version that inherits from TorchAgent rather 10 | than TorchGeneratorAgent. 11 | It should be possible to refactor this file to be comparable to the current 12 | parlai/core/seq2seq/seq2seq.py, i.e. inherit from TorchGeneratorAgent - this would 13 | probably reduce the amount of boilerplate in this file. 14 | However, for simplicity and to keep things as similar as possible to the version used 15 | for the paper, we have kept this file mostly the same. 16 | """ 17 | 18 | from parlai.core.torch_agent import Batch, History, TorchAgent 19 | from parlai.core.torch_generator_agent import TorchGeneratorAgent 20 | from parlai.utils.torch import padded_tensor, argsort 21 | # from .base_controllable_seq2seq import BaseControllableSeq2seqAgent 22 | # from .util import ConvAI2History 23 | # from .controls import get_ctrl_vec 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from collections import defaultdict, namedtuple, Counter, deque 30 | from operator import attrgetter 31 | 32 | import os 33 | import math 34 | import json 35 | import tempfile 36 | import copy 37 | from itertools import chain 38 | 39 | 40 | def list_to_matrix(l, n): 41 | return [l[i:i+n] for i in range(0, len(l), n)] 42 | 43 | 44 | class SelfConsciousHistory(History): 45 | def __init__(self, *args, **kwargs): 46 | super().__init__(*args, **kwargs) 47 | opt = args[0] 48 | if opt['eval_type'] == 'convai2': 49 | self.add_person_tokens = True 50 | elif opt['eval_type'] == 'dnli': 51 | self.add_person_tokens = False 52 | else: 53 | raise ValueError 54 | 55 | self.world_cardinality = opt.get('world_cardinality', 5) 56 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] 57 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] 58 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] 59 | # Will be used for TransferTransfo 60 | self.history_token_type_ids = [] 61 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] 62 | 63 | def reset(self): 64 | """Clear the history""" 65 | super().reset() 66 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] 67 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] 68 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] 69 | self.history_token_type_ids = [] 70 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] 71 | 72 | def _update_distractor_strings(self, text, idx): 73 | history_strings = self.history_distractor_strings[idx] 74 | if self.size > 0: 75 | while len(history_strings) >= self.size: 76 | history_strings.pop(0) 77 | history_strings.append(text) 78 | 79 | def _update_distractor_raw_strings(self, text, idx): 80 | history_raw_strings = self.history_distractor_raw_strings[idx] 81 | if self.size > 0: 82 | while len(history_raw_strings) >= self.size: 83 | history_raw_strings.pop(0) 84 | history_raw_strings.append(text) 85 | 86 | def _update_distractor_vecs(self, text, idx): 87 | history_vecs = self.history_distractor_vecs[idx] 88 | if self.size > 0: 89 | while len(history_vecs) >= self.size: 90 | history_vecs.pop(0) 91 | history_vecs.append(self.parse(text)) 92 | 93 | def _update_token_type_ids(self, text, idx): 94 | pass 95 | 96 | def add_reply_to_distractors(self, model_reply): 97 | 98 | # Update model's response to the history 99 | if model_reply is not None: 100 | for idx in range(self.world_cardinality): 101 | self._update_distractor_raw_strings(model_reply, idx) 102 | # this is causing the repetition of p2 token. 103 | # need to do this only once. not every loop 104 | if self.add_person_tokens and idx == 0: 105 | model_reply = self._add_person_tokens(model_reply, self.p2_token) 106 | self._update_distractor_strings(model_reply, idx) 107 | self._update_distractor_vecs(model_reply, idx) 108 | 109 | # def update_history(self, obs, add_next=None): 110 | def update_history(self, obs, temp_history=None): 111 | """ 112 | Update the history with the given observation. 113 | :param add_next: 114 | string to append to history prior to updating it with the 115 | observation 116 | """ 117 | # super().update_history(obs, add_next) 118 | super().update_history(obs, temp_history=temp_history) 119 | 120 | # Update previous turn's my response 121 | # if add_next is not None: 122 | # for idx in range(self.world_cardinality): 123 | # self._update_distractor_raw_strings(add_next, idx) 124 | # # this is causing the repetition of p2 token. 125 | # # need to do this only once. not every loop 126 | # if self.add_person_tokens and idx == 0: 127 | # add_next = self._add_person_tokens(add_next, self.p2_token) 128 | # self._update_distractor_strings(add_next, idx) 129 | # self._update_distractor_vecs(add_next, idx) 130 | 131 | # Update current turn's opponent's response 132 | if 'distractor_text' in obs: 133 | assert len(obs['distractor_text']) == self.world_cardinality, \ 134 | f"Numer of distractor_text must be eqaul to world_cardinality. ({len(obs['distractor_text'])} vs {self.world_cardinality})" 135 | for idx, distractor_text in enumerate(obs['distractor_text']): 136 | if self.split_on_newln: 137 | next_texts = distractor_text.split('\n') 138 | else: 139 | next_texts = [distractor_text] 140 | for text in next_texts: 141 | self._update_distractor_raw_strings(text, idx) 142 | if self.add_person_tokens: 143 | text = self._add_person_tokens( 144 | distractor_text, self.p1_token, self.add_p1_after_newln 145 | ) 146 | self._update_distractor_strings(text, idx) 147 | self._update_distractor_vecs(text, idx) 148 | 149 | def get_history_distractor_str(self): 150 | """Return the list of string version of the distractor histories.""" 151 | if len(self.history_distractor_strings[0]) > 0: 152 | return [ 153 | self.delimiter.join(history_strings) 154 | for history_strings in self.history_distractor_strings 155 | ] 156 | return None 157 | 158 | def get_history_distractor_vec(self): 159 | """Return a vectorized version of the distractor histories.""" 160 | if len(self.history_distractor_vecs[0]) == 0: 161 | return None 162 | 163 | histories = [] 164 | for idx in range(self.world_cardinality): 165 | history_vecs = self.history_distractor_vecs[idx] 166 | 167 | # if self.vec_type == 'deque': 168 | # history = deque(maxlen=self.max_len) 169 | # for vec in history_vecs[:-1]: 170 | # history.extend(vec) 171 | # history.extend(self.delimiter_tok) 172 | # history.extend(history_vecs[-1]) 173 | # else: 174 | # vec type is a list 175 | history = [] 176 | for vec in history_vecs[:-1]: 177 | history += vec 178 | history += self.delimiter_tok 179 | history += history_vecs[-1] 180 | 181 | histories.append(history) 182 | return histories 183 | 184 | def get_token_type_ids(self): 185 | """ 186 | Return a vectorized version of the token_type_ids and 187 | distractor_token_type_ids 188 | """ 189 | pass 190 | 191 | 192 | class ContextConsciousHistory(History): 193 | def __init__(self, *args, **kwargs): 194 | super().__init__(*args, **kwargs) 195 | opt = args[0] 196 | if opt['eval_type'] == 'convai2': 197 | self.add_person_tokens = True 198 | elif opt['eval_type'] == 'dnli': 199 | self.add_person_tokens = False 200 | else: 201 | raise ValueError 202 | 203 | self.world_cardinality = opt.get('world_cardinality', 5) 204 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] 205 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] 206 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] 207 | # Will be used for TransferTransfo 208 | self.history_token_type_ids = [] 209 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] 210 | self.eval_type = opt.get('eval_type') 211 | 212 | def reset(self): 213 | """Clear the history""" 214 | super().reset() 215 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)] 216 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)] 217 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)] 218 | self.history_token_type_ids = [] 219 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)] 220 | 221 | def _update_distractor_strings(self, text, idx): 222 | history_strings = self.history_distractor_strings[idx] 223 | if self.size > 0: 224 | while len(history_strings) >= self.size: 225 | history_strings.pop(0) 226 | history_strings.append(text) 227 | 228 | def _update_distractor_raw_strings(self, text, idx): 229 | history_raw_strings = self.history_distractor_raw_strings[idx] 230 | if self.size > 0: 231 | while len(history_raw_strings) >= self.size: 232 | history_raw_strings.pop(0) 233 | history_raw_strings.append(text) 234 | 235 | def _update_distractor_vecs(self, text, idx): 236 | history_vecs = self.history_distractor_vecs[idx] 237 | if self.size > 0: 238 | while len(history_vecs) >= self.size: 239 | history_vecs.pop(0) 240 | history_vecs.append(self.parse(text)) 241 | 242 | def _update_token_type_ids(self, text, idx): 243 | pass 244 | 245 | def add_reply_to_distractors(self, model_reply, obs=None): 246 | 247 | # Update model's response along with distractor responses to the history 248 | if model_reply is not None and 'distractor_text' in obs: 249 | distractor_responses = obs['distractor_text'] 250 | assert len(obs['distractor_text']) == self.world_cardinality 251 | 252 | for idx in range(self.world_cardinality): 253 | self._update_distractor_raw_strings(distractor_responses[idx], idx) 254 | if self.add_person_tokens: 255 | distractor_responses[idx] = self._add_person_tokens(distractor_responses[idx], self.p2_token) 256 | self._update_distractor_strings(distractor_responses[idx], idx) 257 | self._update_distractor_vecs(distractor_responses[idx], idx) 258 | 259 | # def update_history(self, obs, add_next=None): 260 | def update_history(self, obs, temp_history=None): 261 | """ 262 | Update the history with the given observation. 263 | :param add_next: 264 | string to append to history prior to updating it with the 265 | observation 266 | """ 267 | super().update_history(obs, temp_history=temp_history) 268 | 269 | # Update current turn's opponent's response 270 | if self.eval_type == 'convai2': 271 | if 'text' in obs: 272 | for idx in range(self.world_cardinality): 273 | if self.split_on_newln: 274 | next_texts = obs['text'].split('\n') 275 | else: 276 | next_texts = [obs['text']] 277 | for text in next_texts: 278 | self._update_distractor_raw_strings(text, idx) 279 | if self.add_person_tokens: 280 | text = self._add_person_tokens( 281 | obs['text'], self.p1_token, self.add_p1_after_newln 282 | ) 283 | self._update_distractor_strings(text, idx) 284 | self._update_distractor_vecs(text, idx) 285 | else: 286 | if 'distractor_text' in obs: 287 | distractor_texts = obs['distractor_text'] 288 | for idx, distractor in enumerate(distractor_texts): 289 | self._update_distractor_raw_strings(distractor, idx) 290 | self._update_distractor_strings(distractor, idx) 291 | self._update_distractor_vecs(distractor, idx) 292 | 293 | def get_history_distractor_str(self): 294 | """Return the list of string version of the distractor histories.""" 295 | if len(self.history_distractor_strings[0]) > 0: 296 | return [ 297 | self.delimiter.join(history_strings) 298 | for history_strings in self.history_distractor_strings 299 | ] 300 | return None 301 | 302 | def get_history_distractor_vec(self): 303 | """Return a vectorized version of the distractor histories.""" 304 | if len(self.history_distractor_vecs[0]) == 0: 305 | return None 306 | 307 | histories = [] 308 | for idx in range(self.world_cardinality): 309 | history_vecs = self.history_distractor_vecs[idx] 310 | 311 | # if self.vec_type == 'deque': 312 | # history = deque(maxlen=self.max_len) 313 | # for vec in history_vecs[:-1]: 314 | # history.extend(vec) 315 | # history.extend(self.delimiter_tok) 316 | # history.extend(history_vecs[-1]) 317 | # else: 318 | # vec type is a list 319 | history = [] 320 | for vec in history_vecs[:-1]: 321 | history += vec 322 | history += self.delimiter_tok 323 | history += history_vecs[-1] 324 | 325 | histories.append(history) 326 | return histories 327 | 328 | def get_token_type_ids(self): 329 | """ 330 | Return a vectorized version of the token_type_ids and 331 | distractor_token_type_ids 332 | """ 333 | pass 334 | -------------------------------------------------------------------------------- /agents/modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Implements NN code for transformers. 8 | 9 | Original paper: https://arxiv.org/abs/1706.03762. (Vaswani, 2017). The 10 | `Annotated Transformer` (Rush, 2018) is an excellent reading guide which explains 11 | much of the mechanics of the Transformer model 12 | (http://nlp.seas.harvard.edu/2018/04/03/attention.html). 13 | 14 | This module also supports special segments (ala BERT; 15 | https://arxiv.org/abs/1810.04805), and a few different variations seen in the 16 | literature (BERT and XLM; https://arxiv.org/abs/1901.07291). 17 | """ 18 | 19 | import math 20 | from typing import Dict, Tuple, Optional 21 | 22 | import numpy as np 23 | import torch 24 | import torch.cuda 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | from parlai.core.torch_generator_agent import TorchGeneratorModel 29 | from parlai.agents.transformer.modules import TransformerGeneratorModel 30 | 31 | 32 | class SelfConsciousTransformerModel(TransformerGeneratorModel): 33 | """ 34 | Implements a full transformer generator model, with pragmatic self-consciousness. 35 | """ 36 | 37 | def __init__(self, opt, dictionary): 38 | super().__init__(opt, dictionary) 39 | 40 | self.alpha = 0.0 if opt['conscious_target'] == 'none' else opt['alpha'] 41 | self.beta = opt['beta'] 42 | self.world_cardinality = opt['world_cardinality'] 43 | self.worldprior = opt['worldprior'] 44 | self.target_persona = 0 45 | self.fp16 = opt['fp16'] 46 | 47 | def _initialize_worldpriors(self, bsz, seqlen): 48 | """ 49 | initialize the world prior with a uniform distribution 50 | """ 51 | cardinality = self.world_cardinality 52 | torch_dtype=torch.half if self.fp16 else torch.float 53 | ones = torch.ones(1, seqlen, cardinality, dtype=torch_dtype, requires_grad=False).cuda() 54 | uniform_world_prior = torch.log(ones / cardinality) 55 | world_priors = uniform_world_prior.repeat(bsz, 1, 1).detach() 56 | 57 | return world_priors 58 | 59 | def _pragmatic_reasoning(self, s0_t, worldprior): 60 | """ 61 | run pragmatic reasoning with the base speaker and its imaginary listener 62 | """ 63 | 64 | vocab_size = self.embeddings.num_embeddings 65 | 66 | # log-scale 67 | log_score = nn.functional.log_softmax(s0_t, dim=2) 68 | log_score = log_score.squeeze() # (bpsz, vocab) 69 | 70 | # (bsz, world_cardinality, vocab) 71 | log_score = log_score.view(-1, self.world_cardinality, vocab_size) 72 | 73 | # S_0 for L_1 74 | _literal_speaker = log_score.clone() 75 | _literal_speaker, _literal_s_next_token_idxs = torch.max(_literal_speaker, dim=-1, keepdim=True) 76 | 77 | # S_0 for the actual given persona (bsz, vocab) 78 | speaker_prior = log_score.select(1, self.target_persona) # target persona is always index 0 79 | 80 | # S_0 for L_0 81 | # (bsz, vocab, world_cardinality) 82 | log_score = log_score.transpose(dim0=1, dim1=2).contiguous() 83 | log_score = log_score * self.beta 84 | 85 | # L_0 \propto S_0 * p(i) 86 | # worldprior should be broadcasted to all the tokens 87 | # (bsz, vocab, world_cardinality) 88 | listener_posterior = (log_score + worldprior) - torch.logsumexp(log_score + worldprior, 2, keepdim=True) 89 | 90 | # (bsz, vocab) 91 | listener_score = listener_posterior.select(2, self.target_persona) # target persona is always index 0 92 | listener_score = listener_score * self.alpha 93 | 94 | speaker_posterior = (listener_score + speaker_prior) - torch.logsumexp(listener_score + speaker_prior, 1, keepdim=True) 95 | 96 | # need to unsqueeze in the dimension 1 97 | speaker_posterior = speaker_posterior.unsqueeze(1) # (bsz, 1, vocab) 98 | 99 | # L_0 for L_1 100 | _literal_listener = listener_posterior.transpose(dim0=1, dim1=2).contiguous() 101 | _literal_listener = torch.gather(_literal_listener, -1, _literal_s_next_token_idxs) 102 | 103 | pragmatic_listener = (_literal_speaker + _literal_listener) - torch.logsumexp(_literal_speaker + _literal_listener, 1, keepdim=True) 104 | pragmatic_listener = pragmatic_listener.squeeze() 105 | 106 | return speaker_posterior, listener_posterior, pragmatic_listener 107 | 108 | def selfconscious_decode(self, encoder_states, maxlen): 109 | """ 110 | greedy decoding with pragmatic self-consciousness 111 | """ 112 | bpsz = encoder_states[0].size(0) 113 | bsz = bpsz // self.world_cardinality 114 | 115 | inputs_t = self.START.detach().expand(bpsz, 1) 116 | worldpriors = self._initialize_worldpriors(bsz, maxlen).detach() 117 | 118 | s1_scores = [] 119 | incr_state = None 120 | 121 | for t in range(maxlen): 122 | worldprior_t = worldpriors.select(1, t).unsqueeze(1) 123 | 124 | latent, incr_state = self.decoder(inputs_t, encoder_states, incr_state) 125 | _logits = self.output(latent) 126 | # only get the last timestep's logit 127 | s0_t = _logits.select(dim=1, index=-1).unsqueeze(1) # logits shape: (bpsz, 1, vocab) 128 | 129 | # s1_t: (bsz, 1, vocab) 130 | # listener_posterior: (bsz, vocab, world_cardinality) 131 | s1_t, l0_t, l1_t = self._pragmatic_reasoning(s0_t, worldprior_t) 132 | s1_scores.append(s1_t) 133 | 134 | next_token = s1_t.max(2)[1].clone().detach() # next input is current predicted output idx 135 | 136 | idx_for_tile = torch.arange(bsz).repeat(self.world_cardinality, 1).transpose(0, 1).reshape(-1).cuda() 137 | inputs_next_t = torch.index_select(next_token, 0, idx_for_tile) 138 | next_token = next_token.unsqueeze(2) 139 | tiled_next_token = next_token.repeat(1, 1, self.world_cardinality) 140 | 141 | if self.worldprior != 'uniform': 142 | # (bsz, vocab, world_cardinality) -> (bsz, 1, world_cardinality) 143 | updated_world_prior = torch.gather(l0_t, 1, tiled_next_token).clone().detach() 144 | if t + 1 < maxlen: 145 | if self.worldprior == 'L0': 146 | worldpriors[:, t + 1, :] = updated_world_prior.squeeze() 147 | elif self.worldprior == 'L1': 148 | worldpriors[:, t + 1, :] = l1_t 149 | else: 150 | raise NotImplementedError 151 | 152 | # update inputs for next timestep 153 | inputs_t = torch.cat((inputs_t, inputs_next_t), dim=1) 154 | 155 | s1_scores = torch.cat(s1_scores, dim=1) # (bsz, seqlen, vocab) 156 | _, preds = s1_scores.max(dim=2) 157 | 158 | return preds, s1_scores 159 | 160 | def selfconscious_decode_forced(self, encoder_states, ys): 161 | """ 162 | faster teacher-forced decoding with pragmatic self-consciousness 163 | """ 164 | 165 | bsz = ys.size(0) 166 | seqlen = ys.size(1) 167 | self.longest_label = max(self.longest_label, seqlen) 168 | emb_size = self.encoder.embedding_size 169 | enc_outputs = encoder_states[0].view(bsz * self.world_cardinality, -1, emb_size).contiguous() 170 | enc_outputs_mask = encoder_states[1].view(bsz * self.world_cardinality, -1).contiguous() 171 | enc_states = (enc_outputs, enc_outputs_mask) 172 | bpsz = enc_outputs.size(0) 173 | 174 | # tile ys as much as the world_cardinality 175 | idx_for_tile = torch.arange(bsz).repeat(self.world_cardinality, 1).transpose(0, 1).reshape(-1).cuda() 176 | tiled_ys = torch.index_select(ys, 0, idx_for_tile) 177 | 178 | inputs = tiled_ys.narrow(1, 0, seqlen - 1) 179 | inputs = torch.cat([self.START.detach().expand(bpsz, 1), inputs], 1) 180 | worldpriors = self._initialize_worldpriors(bsz, seqlen).detach() 181 | s1_scores = [] 182 | 183 | latent, _ = self.decoder(inputs, enc_states) 184 | base_speaker = self.output(latent) 185 | 186 | for t in range(seqlen): 187 | 188 | s0_t = base_speaker.select(dim=1, index=t).unsqueeze(1) # s0_t: (bpsz, 1, vocab) 189 | worldprior_t = worldpriors.select(dim=1, index=t).unsqueeze(1) 190 | 191 | # s1_t: (bsz, 1, vocab) 192 | # l0_t: (bsz, vocab, world_cardinality) 193 | s1_t, l0_t, l1_t = self._pragmatic_reasoning(s0_t, worldprior_t) 194 | s1_scores.append(s1_t) 195 | 196 | # Update world_prior with listener posterior 197 | if t + 1 < seqlen: 198 | next_tokens = inputs.select(1, t + 1).view(-1, 1) # (bpsz, 1): the next tokens for each bpsz instance 199 | next_tokens = next_tokens.unsqueeze(2) 200 | # [0, 1*world_cardinality, 2*wc, 3*wc, ..., bpsz - 1wc] -> to get the ground-truth personas 201 | target_persona_idxs = torch.arange(bsz).cuda() * (self.world_cardinality) 202 | 203 | # we only need the next token of the ground-truth persona 204 | next_token = torch.index_select(next_tokens, 0, target_persona_idxs) # (bsz, 1, 1) 205 | tiled_next_token = next_token.repeat(1, 1, self.world_cardinality) # (bsz, 1, world_cardinality) 206 | 207 | if self.worldprior != 'uniform': 208 | # (bsz, vocab, world_cardinality) -> (bsz, 1, world_cardinality) 209 | updated_world_prior = torch.gather(l0_t, 1, tiled_next_token).clone().detach() 210 | if self.worldprior == 'L0': 211 | worldpriors[:, t + 1, :] = updated_world_prior.squeeze() 212 | elif self.worldprior == 'L1': 213 | worldpriors[:, t + 1, :] = l1_t 214 | else: 215 | raise NotImplementedError 216 | 217 | s1_scores = torch.cat(s1_scores, 1) # (bsz, seqlen, vocab) 218 | _, preds = s1_scores.max(dim=2) 219 | 220 | return s1_scores, preds 221 | -------------------------------------------------------------------------------- /agents/selfconscious_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from itertools import chain 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from parlai.core.opt import Opt 10 | from parlai.core.message import Message 11 | from parlai.core.torch_agent import Batch, Output 12 | from parlai.core.torch_generator_agent import PPLMetric 13 | from parlai.core.metrics import SumMetric, AverageMetric 14 | from parlai.utils.torch import padded_tensor 15 | from parlai.utils.misc import warn_once 16 | from parlai.agents.transformer.transformer import ( 17 | TransformerGeneratorAgent, 18 | add_common_cmdline_args 19 | ) 20 | 21 | from agents.modules import SelfConsciousTransformerModel 22 | from modules.dnli_bert import DnliBert 23 | from agents.history import SelfConsciousHistory, ContextConsciousHistory 24 | 25 | 26 | def list_to_matrix(l, n): 27 | return [l[i:i+n] for i in range(0, len(l), n)] 28 | 29 | 30 | class SelfConsciousBlenderAgent(TransformerGeneratorAgent): 31 | """ 32 | Implementation of the Self-Conscious Blender Agent. 33 | """ 34 | 35 | @classmethod 36 | def add_cmdline_args(cls, argparser): 37 | """ 38 | Add command-line arguments specifically for this agent. 39 | """ 40 | agent = argparser.add_argument_group('Self-conscious Blender Arguments') 41 | agent.add_argument( 42 | '--conscious-target', 43 | type=str, 44 | choices=['none', 'self', 'context'], 45 | default='self', 46 | help='The target which the agent will be concerned about.', 47 | ) 48 | agent.add_argument( 49 | '-a', 50 | '--alpha', 51 | type=float, 52 | default=0, 53 | help='Rationality parameter for S_1(speaker_1)', 54 | ) 55 | agent.add_argument( 56 | '-b', 57 | '--beta', 58 | type=float, 59 | default=1, 60 | help='Rationality parameter for Listener', 61 | ) 62 | agent.add_argument( 63 | '--world_cardinality', 64 | type=int, 65 | default=3, 66 | help='Cardinality of world I:= Number of persona to use RSA model (including GT)', 67 | ) 68 | agent.add_argument( 69 | '--worldprior', 70 | type=str, 71 | choices=['uniform', 'L0', 'L1'], 72 | default='L0', 73 | help='Update world prior with a `uniform` distribution or `L0` or `L1`.', 74 | ) 75 | agent.add_argument( 76 | '--use_dnli', 77 | type=bool, 78 | default=True, 79 | help='Whether to use dnli model to measure consistency-score in Convai2 or rerank candidates in DNLI' 80 | ) 81 | add_common_cmdline_args(agent) 82 | cls.dictionary_class().add_cmdline_args(argparser) 83 | 84 | super(SelfConsciousBlenderAgent, cls).add_cmdline_args(argparser) 85 | return agent 86 | 87 | def __init__(self, opt: Opt, shared=None): 88 | 89 | self.task = str.lower(opt['task'].split(':')[-1]) 90 | 91 | if opt['conscious_target'] != 'none': 92 | assert opt['conscious_target'] in self.task, \ 93 | "conscious_target (`" + opt['conscious_target'] + "`) must match task type (`" + self.task + "`)" 94 | 95 | SEED = 46 96 | random.seed(SEED) 97 | np.random.seed(SEED) 98 | os.environ['PYTHONHASHSEED'] = str(SEED) 99 | torch.random.manual_seed(SEED) 100 | torch.cuda.manual_seed(SEED) 101 | torch.manual_seed(SEED) 102 | torch.cuda.manual_seed_all(SEED) 103 | torch.backends.cudnn.deterministic = True 104 | torch.backends.cudnn.benchmark = False 105 | 106 | # For public self-consciousness 107 | self.target_persona = opt.get('target_persona', 0) 108 | self.conscious_target = opt.get('conscious_target', 'self') 109 | self.world_cardinality = opt.get('world_cardinality', 3) 110 | self.alpha = 0.0 if self.conscious_target == 'none' else opt.get('alpha', 2.0) 111 | self.beta = opt.get('beta', 1.0) 112 | self.worldprior = opt.get('worldprior', 'L0') 113 | 114 | self.eval_type = opt.get('eval_type') 115 | # self.rank_candidates = opt.get('rank_candidates', True) 116 | self.multigpu = ( 117 | opt.get('multigpu', False) and self.use_cuda and (opt.get('batchsize') > 1) 118 | ) 119 | 120 | init_model, is_finetune = self._get_init_model(opt, shared) 121 | super().__init__(opt, shared) 122 | 123 | # Implementation is based on beam_size 1 124 | self.beam_size = 1 125 | warn_once(f'This implementation is assumed to have beam-size 1.') 126 | 127 | # Always rank candidates for the ranking metrics 128 | self.rank_candidates = True 129 | warn_once(f'rank-candidates is always True for ranking metrics.') 130 | 131 | if opt['use_dnli']: 132 | if not shared: 133 | self.dnli_model = DnliBert(opt, use_cuda=self.use_cuda) 134 | else: 135 | self.dnli_model = shared['dnli_model'] 136 | else: 137 | self.dnli_model = None 138 | 139 | self.id = 'SelfConsciousBlender' 140 | 141 | self.reset() 142 | 143 | def build_model(self, states=None): 144 | """ 145 | Build and return model. 146 | """ 147 | model = SelfConsciousTransformerModel(self.opt, self.dict) 148 | if self.opt['embedding_type'] != 'random': 149 | self._copy_embeddings( 150 | model.encoder.embeddings.weight, self.opt['embedding_type'] 151 | ) 152 | return model 153 | 154 | def history_class(self): 155 | return ContextConsciousHistory if 'context' in self.task else SelfConsciousHistory 156 | 157 | def _model_input(self, batch): 158 | """ 159 | Override from TorchGeneratorAgent 160 | passes (batch.text_vec,) to TorchGeneratorAgent._encoder_input() 161 | TGA._encoder_input() directly passes the result of TGA._model_input() 162 | change batch.text_vec to batch.distractor_text_vec for pragmatic decoding 163 | """ 164 | bsz = batch.text_vec.size(0) 165 | distractor_text_vec = batch.distractor_text_vec.view(bsz * self.world_cardinality, -1).contiguous() 166 | return (distractor_text_vec,) 167 | 168 | def selfconscious_greedy_generate(self, batch, maxlen): 169 | """ 170 | Greedy decoding with Public Self-Consciousness 171 | """ 172 | 173 | bsz = batch.text_vec.size(0) 174 | world_cardinality = self.world_cardinality 175 | embedding_size = self.opt.get('embedding_size') 176 | encoder_states = self.model.encoder(*self._encoder_input(batch)) 177 | 178 | preds, scores = self.model.selfconscious_decode(encoder_states, maxlen) 179 | 180 | return preds, scores 181 | 182 | def rank(self, batch): 183 | """ 184 | Rank candidates by PPL score 185 | """ 186 | bsz = batch.text_vec.size(0) 187 | world_cardinality = self.world_cardinality 188 | embedding_size = self.opt.get('embedding_size') 189 | ranked_candidates = [] 190 | cand_ordering = [] 191 | encoder_states = self.model.encoder(*self._encoder_input(batch)) 192 | batch_dim = encoder_states[0].size(0) # two possibilities: batchsize or batchsize * world_cardinality 193 | 194 | if bsz != batch_dim: 195 | enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous() 196 | enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous() 197 | encoder_states = (enc_output, enc_output_mask) 198 | 199 | for i in range(bsz): 200 | num_cands = len(batch.candidate_vecs[i]) 201 | cands, _ = self._pad_tensor(batch.candidate_vecs[i]) 202 | # get [i]th state from encoder_states #num_cands time. 203 | # because we need same encoder_states for each candidate 204 | enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands) 205 | 206 | # enc: (num_cands, world_cardinality, seqlen, emb_size) 207 | # scores: (num_cands, max_len, vocab_size) 208 | scores, _ = self.model.selfconscious_decode_forced(enc, cands) 209 | 210 | cand_losses = F.cross_entropy( 211 | scores.view(num_cands * cands.size(1), -1), 212 | cands.view(-1), 213 | reduction='none', 214 | ).view(num_cands, cands.size(1)) 215 | # now cand_losses is cands x seqlen size, but we still need to 216 | # check padding and such 217 | mask = (cands != self.NULL_IDX) 218 | mask = mask.half() if self.fp16 else mask.float() 219 | cand_scores = (-cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9) 220 | 221 | if self.dnli_model is not None and self.eval_type == 'dnli': 222 | cand_scores = torch.unsqueeze(cand_scores, 0) 223 | cand_scores = self.dnli_model.rerank_candidates([batch.observations[i]], cand_scores) 224 | cand_scores = torch.squeeze(cand_scores) 225 | 226 | _, ordering = cand_scores.sort(descending=True) 227 | ranked_candidates.append([batch.candidates[i][o] for o in ordering]) 228 | cand_ordering.append(ordering) 229 | 230 | return ranked_candidates, cand_ordering 231 | 232 | def compute_loss(self, batch, return_output=False): 233 | """ 234 | Override from TorchGeneratorAgent 235 | Compute and return the loss for the given batch. 236 | 237 | Easily overridable for customized loss functions. 238 | 239 | If return_output is True, the full output from the call to self.model() 240 | is also returned, via a (loss, model_output) pair. 241 | """ 242 | if batch.label_vec is None: 243 | raise ValueError('Cannot compute loss without a label.') 244 | 245 | bsz = batch.text_vec.size(0) 246 | world_cardinality = self.world_cardinality 247 | embedding_size = self.opt.get('embedding_size') 248 | encoder_states = self.model.encoder(*self._encoder_input(batch)) 249 | 250 | enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous() 251 | enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous() 252 | encoder_states = (enc_output, enc_output_mask) 253 | 254 | scores, preds = self.model.selfconscious_decode_forced(encoder_states, batch.label_vec) 255 | model_output = (scores, preds, encoder_states) 256 | 257 | score_view = scores.view(-1, scores.size(-1)) 258 | loss = self.criterion(score_view, batch.label_vec.view(-1)) 259 | loss = loss.view(scores.shape[:-1]).sum(dim=1) 260 | # save loss to metrics 261 | notnull = batch.label_vec.ne(self.NULL_IDX) 262 | target_tokens = notnull.long().sum(dim=-1) 263 | correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) 264 | 265 | self.record_local_metric('loss', AverageMetric.many(loss, target_tokens)) 266 | self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens)) 267 | self.record_local_metric( 268 | 'token_acc', AverageMetric.many(correct, target_tokens) 269 | ) 270 | 271 | # actually do backwards loss 272 | loss = loss.sum() 273 | loss /= target_tokens.sum() # average loss per token 274 | 275 | if return_output: 276 | return (loss, model_output) 277 | else: 278 | return loss 279 | 280 | def _eval_convai2_step(self, batch): 281 | """Evaluate a single batch of examples.""" 282 | 283 | assert self.alpha >= 0 284 | if batch.distractor_text_vec is None: 285 | return None 286 | 287 | self.model.eval() 288 | 289 | # 1. Generation 290 | assert self.beam_size is 1 291 | maxlen = self.label_truncate or 256 292 | if not self.skip_generation: 293 | preds, scores = self.selfconscious_greedy_generate(batch, maxlen) 294 | else: 295 | preds = None 296 | 297 | # 2. Compute PPL with teacher-forced generation 298 | # calculate loss on targets with teacher forcing 299 | loss, model_output = self.compute_loss(batch, return_output=True) 300 | token_losses = self._construct_token_losses( 301 | batch.label_vec, model_output 302 | ) 303 | 304 | # 3. Rank candidates by computing PPL for each candidates 305 | if self.rank_candidates: 306 | ranked_cands, ordering = self.rank(batch) 307 | else: 308 | ranked_cands = None 309 | 310 | # 4. Compute consistency score 311 | additional_metrics = [{'c_score': 0.0} for _ in range(len(batch.observations))] 312 | output_texts = [self._v2t(p) for p in preds] if preds is not None else None 313 | if not self.skip_generation: 314 | if self.opt['use_dnli']: 315 | c_scores = [] 316 | for text, obs in zip(output_texts, batch.observations): 317 | if 'context' in self.task: 318 | c_score = self.dnli_model.compute_consistency_scores(text, obs['my_context']) 319 | else: 320 | persona_strings = obs['my_persona'].split('\n') 321 | c_score = self.dnli_model.compute_consistency_scores(text, persona_strings) 322 | 323 | c_scores.append(c_score) 324 | 325 | for idx, c_score in enumerate(c_scores): 326 | additional_metrics[idx]['c_score'] = c_score 327 | 328 | return Output(output_texts, ranked_cands, token_losses=token_losses, metrics=additional_metrics) 329 | 330 | def _eval_dnli_step(self, batch): 331 | """Evaluate a single batch of examples.""" 332 | 333 | assert self.alpha >= 0 334 | 335 | self.model.eval() 336 | ranked_cands, ordering = self.rank(batch) 337 | 338 | bsz = len(ranked_cands) 339 | dnli_metrics = [] 340 | for batch_idx in range(bsz): 341 | dnli_score = {'contradict@1': 0, 'entail@1': 0, 'neutral@1': 0} 342 | top1_idx = ordering[batch_idx][0].item() 343 | if top1_idx == 0: 344 | pass 345 | # dnli_metrics['dnli_hit@1'] += 1 346 | elif top1_idx > 0 and top1_idx < 11: 347 | dnli_score['contradict@1'] += 1 348 | elif top1_idx >= 11 and top1_idx < 21: 349 | dnli_score['entail@1'] += 1 350 | else: 351 | dnli_score['neutral@1'] += 1 352 | dnli_metrics.append(dnli_score) 353 | 354 | return Output(text_candidates=ranked_cands, metrics=dnli_metrics) 355 | 356 | def eval_step(self, batch): 357 | 358 | if self.opt['eval_type'] == 'convai2': 359 | return self._eval_convai2_step(batch) 360 | elif self.opt['eval_type'] == 'dnli': 361 | return self._eval_dnli_step(batch) 362 | else: 363 | raise NotImplementedError 364 | 365 | def self_observe(self, self_message: Message): 366 | """ 367 | Override from TorchAgent 368 | Update the model's reply or label to the history of distractor-fields in History class 369 | """ 370 | episode_done = self.observation['episode_done'] 371 | use_reply = self.opt.get('use_reply', 'label') 372 | 373 | # actually ingest the label 374 | if use_reply == 'none': 375 | # we're not including our own responses anyway. 376 | reply = None 377 | elif use_reply == 'label': 378 | # first look for the true label 379 | label_key = ( 380 | 'labels' 381 | if 'labels' in self.observation 382 | else 'eval_labels' 383 | if 'eval_labels' in self.observation 384 | else None 385 | ) 386 | if label_key is not None: 387 | lbls = self.observation[label_key] 388 | reply = lbls[0] if len(lbls) == 1 else self.random.choice(lbls) 389 | else: 390 | # otherwise, we use the last output the model generated 391 | if self_message is not None: 392 | reply = self_message['text'] 393 | else: 394 | reply = None 395 | 396 | super().self_observe(self_message) 397 | 398 | if episode_done: 399 | return None 400 | 401 | if reply is not None: 402 | if 'context' in self.task: 403 | self.history.add_reply_to_distractors(reply, self.observation) 404 | else: 405 | self.history.add_reply_to_distractors(reply) 406 | 407 | return reply 408 | 409 | def _ordered_cand_scores_to_cand_text(self, ordered_cand_preds, cand_inds, candidates): 410 | cand_replies = [None] * len(candidates) 411 | 412 | for idx, order in enumerate(ordered_cand_preds): # batch_idx, sorted cand_idx 413 | batch_idx = cand_inds[idx] 414 | # get the original sentences from candidates by order 415 | cand_replies[batch_idx] = [candidates[batch_idx][i] for i in order] 416 | 417 | return cand_replies 418 | 419 | def _build_candidates_tensor(self, batch): 420 | if not batch.candidates: 421 | return None, None 422 | 423 | cand_inds = [i for i in range(len(batch.candidates)) if batch.candidates[i]] 424 | cands = [batch.candidate_vecs[i] for i in cand_inds] 425 | 426 | # get the length of the longest candidate in the batch 427 | max_cand_len = max( 428 | [max([cand.size(0) for cand in cands_i]) for cands_i in cands] 429 | ) 430 | 431 | for i, c in enumerate(cands): # make each instance in batch.cands to a padded tensor 432 | cands[i] = padded_tensor(c, use_cuda=self.use_cuda, 433 | max_len=max_cand_len, 434 | fp16friendly=self.fp16)[0].unsqueeze(0) 435 | 436 | # (batchsize, num_cands, max_len + a) +a due to fp16 437 | cands = torch.cat(cands, 0) 438 | 439 | return cands, cand_inds 440 | 441 | def vectorize(self, obs, history, **kwargs): 442 | """ 443 | Override from TorchAgent 444 | Vectorize the texts in observation 445 | """ 446 | super().vectorize(obs, history, **kwargs) # candidate vecs are vectorized here 447 | if not self.is_training: 448 | self._set_distractor_text_vec(obs, history, kwargs['text_truncate']) 449 | return obs 450 | 451 | def _set_text_vec(self, obs, history, truncate): 452 | """ 453 | Override from TorchAgent for DNLI evaluation 454 | This will be called in super().vectorize() 455 | """ 456 | # WARNING: self.is_training is always False in here 457 | is_training = False if 'eval_labels' in obs else True 458 | 459 | if is_training or self.opt['eval_type'] == 'convai2': 460 | return super()._set_text_vec(obs, history, truncate) 461 | elif self.opt['eval_type'] == 'dnli': 462 | if 'text' not in obs: 463 | return obs 464 | 465 | # Vectorize the text 466 | if 'text_vec' not in obs: 467 | obs['full_text'] = obs['text'] 468 | vec = self.dict.txt2vec(obs['full_text']) 469 | obs['text_vec'] = vec 470 | 471 | # check truncation 472 | if obs.get('text_vec') is not None: 473 | truncated_vec = self._check_truncate(obs['text_vec'], truncate, True) 474 | obs.force_set('text_vec', torch.LongTensor(truncated_vec)) 475 | return obs 476 | else: 477 | raise NotImplementedError 478 | 479 | def _set_distractor_text_vec(self, obs, history, truncate): 480 | """ 481 | Set 'distractor_text' and 'distractor_text_vec' field in the observation 482 | """ 483 | if 'distractor_text' not in obs: 484 | return obs 485 | 486 | if 'distractor_text_vec' not in obs: 487 | # distractor_text is in the SelfConsciousHistory class 488 | distractor_string = history.get_history_distractor_str() 489 | 490 | if distractor_string is None: 491 | return obs 492 | 493 | # Set 'full_distractor_text' 494 | obs['full_distractor_text'] = distractor_string 495 | # distractor_text_vec is also in the SelfConsciousHistory class 496 | # they are already vectorized at SelfConsciousHistory.update_history() 497 | if distractor_string: 498 | obs['distractor_text_vec'] = history.get_history_distractor_vec() 499 | 500 | # Check truncation 501 | if obs.get('distractor_text_vec') is not None: 502 | truncated_vec = [ 503 | torch.LongTensor(self._check_truncate(text_vec, truncate, True)) 504 | for text_vec in obs['distractor_text_vec'] 505 | ] 506 | obs.force_set('distractor_text_vec', truncated_vec) 507 | return obs 508 | 509 | def batchify(self, *args, **kwargs): 510 | """ 511 | Override from TorchAgent 512 | Additionally batchify the distractor_text_vec and add it to batch 513 | """ 514 | kwargs['sort'] = True # need sort for pack_padded() 515 | batch = super().batchify(*args, **kwargs) 516 | sort = False # we must not sort after super().batchify() 517 | 518 | exs = batch.observations 519 | d_text_vec, d_lens = None, None 520 | if any('distractor_text_vec' in ex for ex in exs): 521 | # Pad distractor vectors 522 | _d_text_vec = [ex.get('distractor_text_vec', self.EMPTY) for ex in exs] 523 | _d_text_vec_flattened = list(chain(*_d_text_vec)) 524 | d_text_vec, d_lens = self._pad_tensor(_d_text_vec_flattened) 525 | 526 | # Reshape to (batch_size, world_cardinality, max_length) 527 | bsz = len(exs) 528 | d_text_vec = d_text_vec.view(bsz, self.world_cardinality, -1) 529 | d_lens = list_to_matrix(d_lens, self.world_cardinality) 530 | 531 | batch = Batch( 532 | distractor_text_vec=d_text_vec, 533 | distractor_text_lengths=d_lens, 534 | **dict(batch) 535 | ) 536 | 537 | return batch 538 | 539 | def share(self): 540 | shared = super().share() 541 | if self.opt['use_dnli']: 542 | shared['dnli_model'] = self.dnli_model 543 | return shared 544 | -------------------------------------------------------------------------------- /assets/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/assets/figure1.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pragmatic-consistency 2 | channels: 3 | - defaults 4 | dependencies: 5 | - cudatoolkit=10.1 6 | - pytorch::pytorch=1.6.0 7 | - python=3.6.8 8 | - pip=20.1.1 9 | - pip: 10 | # ParlAI (use commit of 1st of October, 2020) 11 | - git+https://github.com/facebookresearch/ParlAI.git@9cd6b6c0e70c72a24e959e4a328cb4093eb7f3de 12 | - torchtext==0.7.0 13 | - spacy==2.3.2 14 | - pytorch-transformers==1.2.0 15 | # For logging 16 | - tqdm 17 | - better-exceptions 18 | # For linting 19 | - pylint 20 | - pycodestyle 21 | - mypy 22 | # For markdown preview 23 | - grip 24 | # etc 25 | - ruamel.yaml 26 | - more_itertools 27 | - isort 28 | - pudb 29 | - jupyter 30 | - orderedset 31 | -------------------------------------------------------------------------------- /eval_dnli.py: -------------------------------------------------------------------------------- 1 | from parlai.scripts.eval_model import eval_model 2 | from parlai.scripts.eval_model import setup_args as parlai_setupargs 3 | 4 | 5 | def setup_args(): 6 | parser = parlai_setupargs() 7 | parser.set_defaults( 8 | model_file='zoo:blender/blender_90M/model', 9 | eval_type='dnli', 10 | metrics='contradict@1,entail@1,neutral@1', 11 | alpha=8, 12 | beta=1, 13 | use_dnli=False 14 | ) 15 | return parser 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = setup_args() 20 | opt = parser.parse_args() 21 | eval_model(opt) 22 | -------------------------------------------------------------------------------- /eval_personachat.py: -------------------------------------------------------------------------------- 1 | from parlai.scripts.eval_model import eval_model 2 | from parlai.scripts.eval_model import setup_args as parlai_setupargs 3 | 4 | 5 | def setup_args(): 6 | parser = parlai_setupargs() 7 | parser.set_defaults( 8 | model_file='zoo:blender/blender_90M/model', 9 | eval_type='convai2', 10 | metrics='token_acc,ppl,loss,c_scores,f1', 11 | alpha=2, 12 | beta=0.5 13 | ) 14 | return parser 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = setup_args() 19 | opt = parser.parse_args() 20 | eval_model(opt) 21 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/modules/__init__.py -------------------------------------------------------------------------------- /modules/dnli_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from pytorch_transformers import ( 8 | BertForSequenceClassification, 9 | BertTokenizer 10 | ) 11 | from parlai.core.build_data import download_from_google_drive 12 | 13 | 14 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 15 | """Truncates a sequence pair in place to the maximum length.""" 16 | # This is a simple heuristic which will always truncate the longer sequence 17 | # one token at a time. This makes more sense than truncating an equal percent 18 | # of tokens from each, since if one sequence is very short then each token 19 | # that's truncated likely contains more information than a longer sequence. 20 | while True: 21 | total_length = len(tokens_a) + len(tokens_b) 22 | if total_length <= max_length: 23 | break 24 | if len(tokens_a) > len(tokens_b): 25 | tokens_a.pop() 26 | else: 27 | tokens_b.pop() 28 | 29 | 30 | class DnliBert(object): 31 | def __init__(self, 32 | opt, 33 | dnli_lambda=1.0, 34 | dnli_k=10, 35 | max_seq_length=128, 36 | use_cuda=True): 37 | self.opt = opt 38 | self.dnli_lambda = dnli_lambda 39 | self.dnli_k = dnli_k 40 | self.max_seq_length = max_seq_length 41 | self.use_cuda = use_cuda 42 | self.mapper = {0: "contradiction", 43 | 1: "entailment", 44 | 2: "neutral"} 45 | 46 | dnli_model, dnli_tokenizer = self._load_dnli_model() 47 | self.dnli_model = dnli_model 48 | self.dnli_tokenizer = dnli_tokenizer 49 | 50 | def _load_dnli_model(self): 51 | # Download pretrained weight 52 | dnli_model_fname = os.path.join(self.opt['datapath'], 'dnli_model.bin') 53 | if not os.path.exists(dnli_model_fname): 54 | print(f"[ Download pretrained dnli model params to {dnli_model_fname}]") 55 | download_from_google_drive( 56 | '1Qawz1pMcV0aGLVYzOgpHPgG5vLSKPOJ1', 57 | dnli_model_fname 58 | ) 59 | 60 | # Load pretrained weight 61 | print(f"[ Load pretrained dnli model from {dnli_model_fname}]") 62 | model_state_dict = torch.load(dnli_model_fname) 63 | dnli_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict, num_labels=3) 64 | if self.use_cuda: 65 | dnli_model.cuda() 66 | dnli_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 67 | 68 | return dnli_model, dnli_tokenizer 69 | 70 | def rerank_candidates(self, observations, cand_scores): 71 | sorted_cand_values, sorted_cand_indices = cand_scores.sort(1, descending=True) 72 | 73 | for batch_idx, obs in enumerate(observations): 74 | full_text = obs['full_text'] 75 | personas = [] 76 | for text in full_text.split('\n'): 77 | if 'your persona:' in text: 78 | personas.append(text.replace('your persona:', '')) 79 | else: 80 | break 81 | candidates = obs['label_candidates'] 82 | 83 | tok_candidates = [self.dnli_tokenizer.tokenize(sent) for sent in candidates] 84 | tok_personas = [self.dnli_tokenizer.tokenize(sent) for sent in personas] 85 | 86 | dnli_scores = self._compute_dnli_scores(tok_candidates, tok_personas) 87 | s_1 = sorted_cand_values[batch_idx, 0] 88 | s_k = sorted_cand_values[batch_idx, self.dnli_k - 1] 89 | 90 | _lambda = self.dnli_lambda 91 | cand_scores[batch_idx] = cand_scores[batch_idx] - _lambda * (s_1 - s_k) * dnli_scores 92 | 93 | return cand_scores 94 | 95 | def compute_consistency_scores(self, pred, personas): 96 | """ 97 | preds, and personas must be list of string 98 | """ 99 | max_seq_length = self.max_seq_length 100 | 101 | pred_tokenized = self.dnli_tokenizer.tokenize(pred) 102 | personas_tokenized = [self.dnli_tokenizer.tokenize(sent.replace('your persona:', '')) for sent in personas] 103 | 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_segment_ids = [] 107 | for idx, persona_tokenized in enumerate(personas_tokenized): 108 | _pred_tokenized = deepcopy(pred_tokenized) 109 | _persona_tokenized = deepcopy(persona_tokenized) 110 | _truncate_seq_pair(_pred_tokenized, _persona_tokenized, max_seq_length - 3) 111 | 112 | tokens = ["[CLS]"] + _pred_tokenized + ["[SEP]"] 113 | segment_ids = [0] * len(tokens) 114 | tokens += _persona_tokenized + ["[SEP]"] 115 | segment_ids += [1] * (len(_persona_tokenized) + 1) 116 | 117 | input_ids = self.dnli_tokenizer.convert_tokens_to_ids(tokens) 118 | input_mask = [1] * len(input_ids) 119 | padding = [0] * (max_seq_length - len(input_ids)) 120 | input_ids += padding 121 | input_mask += padding 122 | segment_ids += padding 123 | 124 | all_input_ids.append(input_ids) 125 | all_input_mask.append(input_mask) 126 | all_segment_ids.append(segment_ids) 127 | 128 | # Convert inputs to tensors 129 | all_input_ids = torch.tensor(all_input_ids, dtype=torch.long) 130 | all_input_mask = torch.tensor(all_input_mask, dtype=torch.long) 131 | all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long) 132 | if self.use_cuda: 133 | all_input_ids = all_input_ids.cuda() 134 | all_input_mask = all_input_mask.cuda() 135 | all_segment_ids = all_segment_ids.cuda() 136 | 137 | # Inference 138 | self.dnli_model.eval() 139 | with torch.no_grad(): 140 | logits = self.dnli_model(all_input_ids, all_segment_ids, all_input_mask) 141 | probs = F.softmax(logits[0], dim=1) 142 | 143 | probs = probs.detach().cpu().numpy() 144 | idx_max = np.argmax(probs, axis=1) 145 | val_max = np.max(probs, axis=1) 146 | 147 | consistency_score = 0.0 148 | for pred_idx in idx_max: 149 | if pred_idx == 0: # contradict 150 | consistency_score -= 1.0 151 | elif pred_idx == 1: # entailment 152 | consistency_score += 1.0 153 | elif pred_idx == 2: # neutral 154 | consistency_score += 0.0 155 | 156 | return consistency_score 157 | 158 | def _compute_dnli_scores(self, tok_candidates, tok_personas): 159 | max_seq_length = self.max_seq_length 160 | 161 | dnli_scores = [] 162 | for cand_idx, tok_candidate in enumerate(tok_candidates): 163 | all_input_ids = [] 164 | all_input_mask = [] 165 | all_segment_ids = [] 166 | for tok_persona in tok_personas: 167 | # Prepare inputs 168 | # [CLS] candidates [SEP] persona [SEP] 169 | _tok_candidate = deepcopy(tok_candidate) 170 | _tok_persona = deepcopy(tok_persona) 171 | # Account for [CLS], [SEP], [SEP] with "- 3" 172 | _truncate_seq_pair(_tok_candidate, _tok_persona, max_seq_length - 3) 173 | 174 | # Make inputs 175 | tokens = ["[CLS]"] + _tok_candidate + ["[SEP]"] 176 | segment_ids = [0] * len(tokens) 177 | tokens += _tok_persona + ["[SEP]"] 178 | segment_ids += [1] * (len(_tok_persona) + 1) 179 | 180 | input_ids = self.dnli_tokenizer.convert_tokens_to_ids(tokens) 181 | input_mask = [1] * len(input_ids) 182 | padding = [0] * (max_seq_length - len(input_ids)) 183 | input_ids += padding 184 | input_mask += padding 185 | segment_ids += padding 186 | 187 | all_input_ids.append(input_ids) 188 | all_input_mask.append(input_mask) 189 | all_segment_ids.append(segment_ids) 190 | 191 | # Convert inputs to tensors 192 | all_input_ids = torch.tensor(all_input_ids, dtype=torch.long) 193 | all_input_mask = torch.tensor(all_input_mask, dtype=torch.long) 194 | all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long) 195 | if self.use_cuda: 196 | all_input_ids = all_input_ids.cuda() 197 | all_input_mask = all_input_mask.cuda() 198 | all_segment_ids = all_segment_ids.cuda() 199 | 200 | # Inference 201 | self.dnli_model.eval() 202 | with torch.no_grad(): 203 | logits = self.dnli_model(all_input_ids, all_segment_ids, all_input_mask) 204 | probs = F.softmax(logits[0], dim=1) 205 | 206 | probs = probs.detach().cpu().numpy() 207 | idx_max = np.argmax(probs, axis=1) 208 | val_max = np.max(probs, axis=1) 209 | dnli_score = np.max((idx_max == 0) * val_max) 210 | dnli_scores.append(dnli_score) 211 | dnli_scores = torch.tensor(dnli_scores, dtype=torch.float) 212 | if self.use_cuda: 213 | dnli_scores = dnli_scores.cuda() 214 | return dnli_scores 215 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/__init__.py -------------------------------------------------------------------------------- /tasks/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import parlai.core.params as params 9 | import parlai.core.build_data as build_data 10 | 11 | 12 | FOLDER_NAME = 'self_conscious_dialogue' 13 | 14 | 15 | def build(opt): 16 | dpath = os.path.join(opt['datapath'], FOLDER_NAME) 17 | # version 1.0: initial release 18 | version = '1.0' 19 | 20 | # check whether data had been previously built 21 | if not build_data.built(dpath, version_string=version): 22 | print('[building data: ' + dpath + ']') 23 | 24 | # make a clean directory if needed 25 | if build_data.built(dpath): 26 | # if an older version exists, remove those outdated files. 27 | build_data.remove_dir(dpath) 28 | build_data.make_dir(dpath) 29 | 30 | ######################### 31 | # ConvAI2 (PersonaChat) 32 | ######################### 33 | fname = 'data_v1.tar.gz' 34 | url = 'https://parl.ai/downloads/controllable_dialogue/' + fname 35 | build_data.download(url, dpath, fname) 36 | build_data.untar(dpath, fname) 37 | 38 | fname = 'convai2_fix_723.tgz' 39 | url = 'http://parl.ai/downloads/convai2/' + fname 40 | build_data.download(url, dpath, fname) 41 | build_data.untar(dpath, fname) 42 | 43 | ######################### 44 | # Dialogue NLI 45 | ######################### 46 | fname = 'dialogue_nli.zip' 47 | gd_id = '1WtbXCv3vPB5ql6w0FVDmAEMmWadbrCuG' 48 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) 49 | build_data.untar(dpath, fname) 50 | 51 | fname = 'dialogue_nli_evaluation.zip' 52 | gd_id = '1sllq30KMJzEVQ4C0-a9ShSLSPIZc3iMi' 53 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) 54 | build_data.untar(dpath, fname) 55 | 56 | ######################### 57 | # Distractor personas 58 | ######################### 59 | fname = 'train_sorted_50_personas.json' 60 | gd_id = '1SGFdJqyNYeepKFqwMLv4Ym717QQTtpi8' 61 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) 62 | fname = 'valid_sorted_50_personas.json' 63 | gd_id = '1A7oVKmjJ1EZTh6-3Gio4XQo81QgnTGGi' 64 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) 65 | fname = 'dnli_sorted_50_personas.json' 66 | gd_id = '1wlIkVcBZoGQd3rbI7XWNhuq4rvw9FyoP' 67 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname)) 68 | 69 | print("Data has been placed in " + dpath) 70 | 71 | build_data.mark_done(dpath, version) 72 | 73 | 74 | def make_path(opt, fname): 75 | return os.path.join(opt['datapath'], FOLDER_NAME, fname) 76 | 77 | 78 | if __name__ == '__main__': 79 | opt = params.ParlaiParser().parse_args(print_args=False) 80 | build(opt) 81 | -------------------------------------------------------------------------------- /tasks/teachers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import math 4 | import json 5 | import random 6 | from operator import itemgetter 7 | from orderedset import OrderedSet 8 | from collections import defaultdict 9 | import pickle 10 | 11 | from parlai.utils.misc import warn_once, str_to_msg 12 | from parlai.core.message import Message 13 | from parlai.core.torch_agent import TorchAgent 14 | from parlai.core.teachers import ( 15 | ParlAIDialogTeacher, 16 | FixedDialogTeacher, 17 | FbDeprecatedDialogTeacher, 18 | DialogData 19 | ) 20 | 21 | from .build import build, make_path 22 | 23 | __PATH__ = os.path.abspath(os.path.dirname(__file__)) 24 | 25 | 26 | def _path(opt): 27 | build(opt) 28 | datatype = opt['datatype'].split(':')[0] 29 | if datatype == 'test': 30 | warn_once("WARNING: Test set not included. Setting datatype to valid.") 31 | datatype = 'valid' 32 | return make_path(opt, datatype + '.txt'), datatype 33 | 34 | 35 | def _split_persona_and_context(text, eval_type='convai2'): 36 | if 'your persona:' not in text: 37 | return None, text 38 | else: 39 | if eval_type == 'convai2': 40 | texts = text.split('\n') 41 | return '\n'.join(texts[:-1]), texts[-1] 42 | elif eval_type =='dnli': 43 | texts = text.split('\n') 44 | last_idx = 0 45 | for idx, text in enumerate(texts): 46 | if 'your persona:' in text: 47 | last_idx = idx 48 | persona_texts = texts[:last_idx+1] 49 | context_texts = texts[last_idx+1:] 50 | return '\n'.join(persona_texts), '\n'.join(context_texts) 51 | 52 | 53 | def _split_personas_and_context(text): 54 | if 'your persona:' not in text: 55 | return text, text, text 56 | else: 57 | your_personas = [] 58 | partner_personas = [] 59 | context = [] 60 | texts = text.split('\n') 61 | for text in texts: 62 | if text.startswith('your persona:'): 63 | your_personas.append(text) 64 | elif text.startswith("partner's persona:"): 65 | partner_personas.append(text) 66 | else: 67 | context.append(text) 68 | 69 | return '\n'.join(your_personas), '\n'.join(partner_personas), context 70 | 71 | 72 | class SelfConsciousDialogueTeacher(FixedDialogTeacher): 73 | """ 74 | Teacher (i.e. input data supplier) for the Self-conscious Agent. 75 | SelfConsciousDialogueTeacher (SCDT) supplies data input 76 | along with the distractors to the Self-conscious Agent. 77 | """ 78 | def __init__(self, opt, shared=None): 79 | super().__init__(opt, shared) 80 | self.opt = opt 81 | 82 | datapath, datatype = _path(opt) 83 | 84 | if not shared: 85 | self.episodes = [] 86 | self.num_exs = 0 87 | self._setup_data(datapath, datatype) 88 | else: 89 | self.episodes = shared['episodes'] 90 | self.num_exs = sum(len(e) for e in self.episodes) 91 | self.id = 'self_conscious_dialogue' 92 | self.reset() 93 | 94 | @staticmethod 95 | def add_cmdline_args(argparser): 96 | agent = argparser.add_argument_group('Self Conscious Dialogue Teacher arguments') 97 | agent.add_argument( 98 | '--eval-type', 99 | type=str, 100 | choices=['convai2', 'dnli'], 101 | default='dnli', 102 | help='Which validation data to use', 103 | ) 104 | 105 | def _setup_data(self, path, datatype): 106 | 107 | random.seed(46) 108 | 109 | # Data loading with script of ParlAIDialogTeacher 110 | print(f"[Loading ParlAI text data: {path}]") 111 | 112 | # Read data from ConvAI2 113 | convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt') 114 | convai2_episodes = self._load_convai2_data(convai2_datapath) 115 | 116 | # Get persona pool 117 | all_personas, persona_to_idx = self._get_persona_pool(self.opt) 118 | sorted_personas = self._get_sorted_persona_pool(datatype) 119 | 120 | 121 | if self.opt['eval_type'] == 'convai2': 122 | self.episodes = [] 123 | self.num_exs = 0 124 | eps = [] 125 | with open(path) as read: 126 | for line in read: 127 | msg = str_to_msg(line.rstrip('\n')) 128 | if msg: 129 | self.num_exs += 1 130 | eps.append(msg) 131 | if msg.get('episode_done', False): 132 | self.episodes.append(eps) 133 | eps = [] 134 | if len(eps) > 0: 135 | # add last episode 136 | eps[-1].force_set('episode_done', True) 137 | self.episodes.append(eps) 138 | # Add label candidates and partner's persona 139 | for episode_idx, episode in enumerate(self.episodes): 140 | for turn_idx, turn in enumerate(episode): 141 | convai2_turn = convai2_episodes[episode_idx][turn_idx] 142 | convai2_text = convai2_turn[0] 143 | label_candidates = convai2_turn[3] 144 | 145 | turn['label_candidates'] = label_candidates 146 | if turn_idx == 0: 147 | my_persona, partner_persona, _ = _split_personas_and_context(convai2_text) 148 | turn['partner_persona'] = partner_persona 149 | turn['my_persona'] = my_persona 150 | else: 151 | turn['partner_persona'] = episode[0]['partner_persona'] 152 | turn['my_persona'] = episode[0]['my_persona'] 153 | elif self.opt['eval_type'] == 'dnli': 154 | self.episodes = [] 155 | self.num_exs = 0 156 | for eval_set in ['attributes', 'havenot', 'likedislike']: 157 | datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl') 158 | with open(datapath, 'r') as fp: 159 | for line in fp: 160 | msg = json.loads(line) 161 | msg['eval_set'] = eval_set 162 | msg['episode_done'] = True 163 | 164 | # Make 'text' 165 | persona_lines = [f'your persona: {x[:-2]}.' for x in msg['persona']] 166 | utts = msg['prefix'] 167 | 168 | p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN 169 | lines = persona_lines 170 | # Identify the dialogue lines. It's assumed that p1 goes first. 171 | for i, utt in enumerate(utts): 172 | if i % 2 == 0: 173 | lines.append(f'{p1_token} {utt}') 174 | else: 175 | lines.append(f'{p2_token} {utt}') 176 | text = '\n'.join(lines) 177 | 178 | msg['text'] = text 179 | 180 | # Make 'label_candidates' 181 | cands = msg['candidates'] 182 | msg['label_candidates'] = cands['label'] + cands['neg'][:10] \ 183 | + cands['similar'][:10] + cands['rand'][:10] 184 | 185 | # Remove unused attributes 186 | del msg['persona'] 187 | del msg['prefix'] 188 | del msg['triple'] 189 | del msg['relevant_persona_sentence'] 190 | del msg['candidates'] 191 | 192 | self.episodes.append([msg]) 193 | self.num_exs += 1 194 | 195 | # Add distractor personas 196 | if self.opt['world_cardinality'] > 0: 197 | num_all_personas = len(all_personas) 198 | persona_indices = list(range(num_all_personas)) 199 | world_cardinality = self.opt['world_cardinality'] 200 | for episode in self.episodes: 201 | gt_persona, first_context = _split_persona_and_context(episode[0]['text'], self.opt['eval_type']) 202 | gt_persona_idx = persona_to_idx.get(gt_persona, -1) 203 | 204 | # Choose random distractor personas 205 | distractor_indices = random.sample(persona_indices, world_cardinality - 1) 206 | while gt_persona_idx in distractor_indices: 207 | # Resample if gt_persona is sampled 208 | distractor_indices = random.sample(persona_indices, world_cardinality - 1) 209 | distractor_personas = itemgetter(*distractor_indices)(all_personas) 210 | distractor_personas = list(distractor_personas) 211 | 212 | # Make it to 'distractor_text' 213 | for turn_idx, turn in enumerate(episode): 214 | if turn_idx == 0: 215 | turn['distractor_text'] = [ 216 | '\n'.join([persona, first_context]) 217 | for persona in [gt_persona] + distractor_personas 218 | ] 219 | else: 220 | turn['distractor_text'] = [turn['text']] * world_cardinality 221 | 222 | def _get_persona_pool(self, opt, remove_duplicate=True): 223 | print("[loading persona pool from convai2 training data]") 224 | # Get episodes from training dataset 225 | datapath = make_path(opt, 'train.txt') 226 | episodes = [] 227 | eps = [] 228 | with open(datapath) as read: 229 | for line in read: 230 | msg = str_to_msg(line.rstrip('\n')) 231 | if msg: 232 | # self.num_exs += 1 233 | eps.append(msg) 234 | if msg.get('episode_done', False): 235 | episodes.append(eps) 236 | eps = [] 237 | if len(eps) > 0: 238 | # add last episode 239 | eps[-1].force_set('episode_done', True) 240 | episodes.append(eps) 241 | 242 | # Extract personas from episodes 243 | persona_set = OrderedSet() 244 | for episode in episodes: 245 | first_turn = episode[0] 246 | text = first_turn['text'] 247 | persona, _ = _split_persona_and_context(text) 248 | persona_set.add(persona) 249 | 250 | # Remove duplicate 251 | if remove_duplicate: 252 | train_persona_fname = os.path.join(__PATH__, 'train_persona_map.pkl') 253 | with open(train_persona_fname, 'rb') as fp: 254 | _train_personas = pickle.load(fp) 255 | train_personas = [] 256 | for personas in _train_personas.values(): 257 | longest_idx = 0 258 | longest_length = -1 259 | for idx, persona in enumerate(personas): 260 | if len(persona) > longest_length: 261 | longest_idx = idx 262 | longest_length = len(persona) 263 | selected_persona = map(lambda x: f"your persona: {x}.",personas[longest_idx]) 264 | selected_persona = '\n'.join(selected_persona) 265 | train_personas.append(selected_persona) 266 | persona_set = OrderedSet() 267 | for train_persona in train_personas: 268 | persona_set.add(train_persona) 269 | 270 | all_personas = [] 271 | persona_to_idx = {} 272 | for i, persona in enumerate(persona_set): 273 | all_personas.append(persona) 274 | persona_to_idx[persona] = i 275 | 276 | print(f"Total {len(all_personas)} personas in dataset") 277 | 278 | return all_personas, persona_to_idx 279 | 280 | def _get_sorted_persona_pool(self, datatype): 281 | print("[loading sorted persona pool from convai2 training data]") 282 | eval_type = self.opt['eval_type'] 283 | if eval_type == 'convai2': 284 | datapath = make_path(self.opt, 'valid_sorted_50_personas.json') 285 | elif eval_type == 'dnli': 286 | datapath = make_path(self.opt, 'dnli_sorted_50_personas.json') 287 | else: 288 | raise ValueError("eval_set must be one of convai2 and dnli") 289 | 290 | with open(datapath, 'r') as fp: 291 | sorted_personas = json.load(fp) 292 | sorted_personas['idx2persona'] = sorted_personas['train_personas'] 293 | sorted_personas['persona2idx'] = {} 294 | for idx, persona in enumerate(sorted_personas['train_personas']): 295 | sorted_personas['persona2idx'][persona] = idx 296 | 297 | return sorted_personas 298 | 299 | def _load_convai2_data(self, datapath): 300 | """ 301 | Read data in the fbdialog format. 302 | Returns ``(x, y, r, c)`` tuples. 303 | ``x`` represents a query, ``y`` represents the labels, ``r`` represents 304 | any reward, and ``c`` represents any label_candidates. 305 | The example above will be translated into the following tuples: 306 | :: 307 | x: 'Sam went to the kitchen\nPat gave Sam the milk\nWhere is the milk?' 308 | y: ['kitchen'] 309 | r: '1' 310 | c: ['hallway', 'kitchen', 'bathroom'] 311 | new_episode = True (this is the first example in the episode) 312 | :: 313 | x: 'Sam went to the hallway\\nPat went to the bathroom\\nWhere is the 314 | milk?' 315 | y: ['hallway'] 316 | r: '1' 317 | c: ['hallway', 'kitchen', 'bathroom'] 318 | new_episode = False (this is the second example in the episode) 319 | """ 320 | self.cloze = False # Set this to use FbDialogTeacher 321 | convai2_dataloader = FbDeprecatedDialogTeacher.setup_data(self, datapath) 322 | convai2_episodes = [] 323 | for episode in DialogData._read_episode(self, convai2_dataloader): 324 | convai2_episodes.append(episode) 325 | del self.cloze 326 | return convai2_episodes 327 | 328 | def share(self): 329 | shared = super().share() 330 | shared['episodes'] = self.episodes 331 | return shared 332 | 333 | def num_examples(self): 334 | return self.num_exs 335 | 336 | def num_episodes(self): 337 | return len(self.episodes) 338 | 339 | def get(self, episode_idx, entry_idx=None): 340 | return self.episodes[episode_idx][entry_idx] 341 | 342 | 343 | class ContextConsciousDialogueTeacher(SelfConsciousDialogueTeacher): 344 | def _setup_data(self, path, datatype): 345 | # random.seed(self.opt['random_seed']) # Set this for pick same distractor persona 346 | random.seed(46) # Set this for pick same distractor persona 347 | # Data loading with script of ParlAIDialogTeacher 348 | print(f"[Loading ParlAI text data: {path}]") 349 | 350 | # Read data from ConvAI2 351 | convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt') 352 | convai2_episodes = self._load_convai2_data(convai2_datapath) 353 | 354 | if self.opt['eval_type'] == 'convai2': 355 | self.episodes = [] 356 | self.num_exs = 0 357 | eps = [] 358 | with open(path) as read: 359 | for line in read: 360 | msg = str_to_msg(line.rstrip('\n')) 361 | if msg: 362 | self.num_exs += 1 363 | eps.append(msg) 364 | if msg.get('episode_done', False): 365 | self.episodes.append(eps) 366 | eps = [] 367 | if len(eps) > 0: 368 | # add last episode 369 | eps[-1].force_set('episode_done', True) 370 | self.episodes.append(eps) 371 | # Add label candidates and partner's persona 372 | for episode_idx, episode in enumerate(self.episodes): 373 | for turn_idx, turn in enumerate(episode): 374 | convai2_turn = convai2_episodes[episode_idx][turn_idx] 375 | convai2_text = convai2_turn[0] 376 | label_candidates = convai2_turn[3] 377 | 378 | turn['label_candidates'] = label_candidates 379 | if turn_idx == 0: 380 | my_persona, partner_persona, _ = _split_personas_and_context(convai2_text) 381 | turn['partner_persona'] = partner_persona 382 | turn['my_persona'] = my_persona 383 | else: 384 | turn['partner_persona'] = episode[0]['partner_persona'] 385 | turn['my_persona'] = episode[0]['my_persona'] 386 | elif self.opt['eval_type'] == 'dnli': 387 | self.episodes = [] 388 | self.num_exs = 0 389 | for eval_set in ['attributes', 'havenot', 'likedislike']: 390 | datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl') 391 | with open(datapath, 'r') as fp: 392 | for line in fp: 393 | msg = json.loads(line) 394 | msg['eval_set'] = eval_set 395 | msg['episode_done'] = True 396 | 397 | # Make 'text' 398 | persona_lines = [f'your persona: {x[:-2]}.' for x in msg['persona']] 399 | utts = msg['prefix'] 400 | 401 | p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN 402 | lines = persona_lines 403 | # Identify the dialogue lines. It's assumed that p1 goes first. 404 | for i, utt in enumerate(utts): 405 | if i % 2 == 0: 406 | lines.append(f'{p1_token} {utt}') 407 | else: 408 | lines.append(f'{p2_token} {utt}') 409 | text = '\n'.join(lines) 410 | 411 | msg['text'] = text 412 | 413 | # Make 'label_candidates' 414 | cands = msg['candidates'] 415 | msg['label_candidates'] = cands['label'] + cands['neg'][:10] \ 416 | + cands['similar'][:10] + cands['rand'][:10] 417 | 418 | # Remove unused attributes 419 | del msg['persona'] 420 | del msg['prefix'] 421 | del msg['triple'] 422 | del msg['relevant_persona_sentence'] 423 | del msg['candidates'] 424 | 425 | self.episodes.append([msg]) 426 | self.num_exs += 1 427 | 428 | # Get dialogue history pool 429 | context_pool = self._get_context_pool(self.opt) 430 | 431 | # Add distractor history 432 | if self.opt['world_cardinality'] > 0: 433 | for episode in self.episodes: 434 | gt_persona, first_context = _split_persona_and_context(episode[0]['text'], self.opt['eval_type']) 435 | 436 | # Select distractor history 437 | if self.opt['eval_type'] == 'convai2': 438 | num_turn = len(episode) 439 | else: 440 | dialogue = first_context.split('\n') 441 | num_turn = math.ceil(len(dialogue)/2) 442 | if num_turn < min(context_pool.keys()): 443 | # orginal_num_turn = num_turn 444 | num_turn = min(context_pool.keys()) 445 | 446 | context_indices = list(range(len(context_pool[num_turn]))) 447 | 448 | distractor_c_indices = random.sample(context_indices, self.opt['world_cardinality'] - 1) 449 | distractor_contexts = itemgetter(*distractor_c_indices)(context_pool[num_turn]) 450 | 451 | # Make it to 'distractor_text' 452 | if self.opt['eval_type'] == 'convai2': 453 | for turn_idx, turn in enumerate(episode): 454 | turn['distractor_text'] = turn['labels'] + [c[turn_idx] for c in distractor_contexts] 455 | if turn_idx == 0: 456 | turn['my_context'] = turn['labels'] 457 | else: 458 | turn['my_context'] = episode[turn_idx - 1]['my_context'] + turn['labels'] 459 | else: 460 | # DNLI 461 | distractor_text = [episode[0]['text']] 462 | for c in distractor_contexts: 463 | copied_dialogue = copy.deepcopy(dialogue) 464 | for turn_idx, utterance in enumerate(copied_dialogue): 465 | if turn_idx % 2 == 1: 466 | copied_dialogue[turn_idx] = p2_token + c[turn_idx // 2] 467 | distractor_context = '\n'.join([gt_persona] + copied_dialogue) 468 | distractor_text.append(distractor_context) 469 | episode[0]['distractor_text'] = distractor_text 470 | 471 | def _get_context_pool(self, opt): 472 | print("[loading history pool from convai2 training data]") 473 | datapath = make_path(opt, 'train.txt') 474 | episodes = [] 475 | eps = [] 476 | with open(datapath) as read: 477 | for line in read: 478 | msg = str_to_msg(line.rstrip('\n')) 479 | if msg: 480 | eps.append(msg) 481 | if msg.get('episode_done', False): 482 | episodes.append(eps) 483 | eps = [] 484 | if len(eps) > 0: 485 | # add last episode 486 | eps[-1].force_set('episode_done', True) 487 | episodes.append(eps) 488 | 489 | context_pool = defaultdict(list) 490 | for ep in episodes: 491 | context_pool[len(ep)].append([turn['labels'][0] for turn in ep]) 492 | 493 | return dict(context_pool) 494 | 495 | 496 | class DefaultTeacher(SelfConsciousDialogueTeacher): 497 | pass 498 | -------------------------------------------------------------------------------- /tasks/test_persona_map.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/test_persona_map.pkl -------------------------------------------------------------------------------- /tasks/train_persona_map.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/train_persona_map.pkl -------------------------------------------------------------------------------- /tasks/valid_persona_map.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/valid_persona_map.pkl --------------------------------------------------------------------------------