├── .gitignore
├── LICENSE
├── README.md
├── agents
├── __init__.py
├── history.py
├── modules.py
└── selfconscious_blender.py
├── assets
└── figure1.png
├── environment.yml
├── eval_dnli.py
├── eval_personachat.py
├── modules
├── __init__.py
└── dnli_bert.py
└── tasks
├── __init__.py
├── build.py
├── teachers.py
├── test_persona_map.pkl
├── train_persona_map.pkl
└── valid_persona_map.pkl
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .cenv
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Hyunwoo Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pragmatic Self-Consciousness
for Improving Persona Consistency in Dialogues
2 |
3 | 
4 |
5 | **Official PyTorch implementation of our EMNLP paper:**
6 | [Hyunwoo Kim](https://hyunw.kim), [Byeongchang Kim](https://bckim92.github.io), and [Gunhee Kim](https://vision.snu.ac.kr/gunhee). Will I Sound Like Me? Improving Persona Consistency in Dialogues through Pragmatic Self-Consciousness. _EMNLP_, 2020 [[Paper]](https://arxiv.org/abs/2004.05816)
7 |
8 | * **TL;DR**: Inspired by social cognition and pragmatics, we model _public self-consciousness_ in existing dialogue agents with an imaginary listener to improve consistency. Compared to previous works, our method does not require additional consistency-related labels or training.
9 |
10 | Earlier version of this work was also accepted at ICLR 2020 [Bridging AI and Cognitive Science (BAICS) workshop](https://baicsworkshop.github.io/) as an oral presentation.
11 |
12 |
13 | ## Reference
14 |
15 | If you use the materials in this repository as part of any published research, we ask you to cite the following [paper](https://arxiv.org/abs/2004.05816):
16 |
17 | ```bibtex
18 | @inproceedings{Kim:2020:selfc,
19 | title={Will I Sound Like Me? Improving Persona Consistency in Dialogues through Pragmatic Self-Consciousness},
20 | author={Kim, Hyunwoo and Kim, Byeongchang and Kim, Gunhee},
21 | booktitle={EMNLP},
22 | year=2020
23 | }
24 | ```
25 |
26 | ### Have any question?
27 | Please contact [Hyunwoo Kim](https://hyunw.kim) at hyunw.kim@vl.snu.ac.kr.
28 |
29 | ## Implementation
30 |
31 | ### System Requirements
32 |
33 | * Python 3.6.8
34 | * Pytorch 1.6.0
35 | * CUDA 10.1 supported GPU with at least 12GB memory
36 | * See [environment.yml](https://github.com/skywalker023/pragmatic-consistency/blob/master/environment.yml) for details
37 |
38 | ### Environment setup
39 |
40 | Our code is built on the [ParlAI](https://parl.ai/) framework.
41 | We recommend you create a conda environment as follows
42 |
43 | ```bash
44 | conda env create -f environment.yml
45 | ```
46 |
47 | and activate it with
48 |
49 | ```bash
50 | conda activate pragmatic-consistency
51 | ```
52 |
53 | ## Running Experiments
54 |
55 | ### Self-conscious Blender for its persona
56 |
57 | #### Dialogue NLI
58 |
59 | ```bash
60 | python eval_dnli.py --conscious-target self -t tasks.teachers:SelfConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --fp16 false
61 | ```
62 |
63 | #### PersonaChat
64 |
65 | ```bash
66 | python eval_personachat.py --conscious-target self -t tasks.teachers:SelfConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --batchsize 48 --fp16 false
67 | ```
68 |
69 | ### Self-conscious Blender for its context
70 |
71 | #### Dialogue NLI
72 |
73 | ```bash
74 | python eval_dnli.py --conscious-target context -t tasks.teachers:ContextConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --fp16 false
75 | ```
76 |
77 | #### PersonaChat
78 |
79 | ```bash
80 | python eval_personachat.py --conscious-target context -t tasks.teachers:ContextConsciousDialogueTeacher --model agents.selfconscious_blender:SelfConsciousBlenderAgent --batchsize 48 --fp16 false
81 | ```
82 |
83 | 💡 In case you want to run the evaluation with vanilla Blender as is, set the `--conscious-target` to `none`.
84 |
85 |
86 | ## Acknowledgements
87 |
88 | We would like to thank [Reuben Cohn-Gordon](https://reubencohngordon.com/), [Sean Welleck](https://cs.nyu.edu/~welleck/), [Junhyug Noh](https://junhyug.github.io/) and [Jiwan Chung](https://vl.snu.ac.kr/people/jiwanchung.html) for their valuable comments. We also thank the anonymous reviewers for their thoughtful suggestions on this work.
89 |
90 | This research was supported by Brain Research Program by National Research Foundation of Korea (NRF) (2017M3C7A1047860), Institute of Information \& communications Technology Planning \& Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2017-0-01772, Video Turing Test, No. 2019-0-01082, SW StarLab), and Creative Pioneering Researchers Program through Seoul National University.
91 |
92 |
93 | ## License
94 |
95 | This repository is MIT licensed. See the [LICENSE](https://github.com/skywalker023/pragmatic-consistency/blob/master/LICENSE) file for details.
96 |
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/agents/__init__.py
--------------------------------------------------------------------------------
/agents/history.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | This file is derived from parlai/core/seq2seq/seq2seq.py.
9 | In particular, it's derived from an older version that inherits from TorchAgent rather
10 | than TorchGeneratorAgent.
11 | It should be possible to refactor this file to be comparable to the current
12 | parlai/core/seq2seq/seq2seq.py, i.e. inherit from TorchGeneratorAgent - this would
13 | probably reduce the amount of boilerplate in this file.
14 | However, for simplicity and to keep things as similar as possible to the version used
15 | for the paper, we have kept this file mostly the same.
16 | """
17 |
18 | from parlai.core.torch_agent import Batch, History, TorchAgent
19 | from parlai.core.torch_generator_agent import TorchGeneratorAgent
20 | from parlai.utils.torch import padded_tensor, argsort
21 | # from .base_controllable_seq2seq import BaseControllableSeq2seqAgent
22 | # from .util import ConvAI2History
23 | # from .controls import get_ctrl_vec
24 |
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 |
29 | from collections import defaultdict, namedtuple, Counter, deque
30 | from operator import attrgetter
31 |
32 | import os
33 | import math
34 | import json
35 | import tempfile
36 | import copy
37 | from itertools import chain
38 |
39 |
40 | def list_to_matrix(l, n):
41 | return [l[i:i+n] for i in range(0, len(l), n)]
42 |
43 |
44 | class SelfConsciousHistory(History):
45 | def __init__(self, *args, **kwargs):
46 | super().__init__(*args, **kwargs)
47 | opt = args[0]
48 | if opt['eval_type'] == 'convai2':
49 | self.add_person_tokens = True
50 | elif opt['eval_type'] == 'dnli':
51 | self.add_person_tokens = False
52 | else:
53 | raise ValueError
54 |
55 | self.world_cardinality = opt.get('world_cardinality', 5)
56 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)]
57 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)]
58 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)]
59 | # Will be used for TransferTransfo
60 | self.history_token_type_ids = []
61 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)]
62 |
63 | def reset(self):
64 | """Clear the history"""
65 | super().reset()
66 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)]
67 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)]
68 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)]
69 | self.history_token_type_ids = []
70 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)]
71 |
72 | def _update_distractor_strings(self, text, idx):
73 | history_strings = self.history_distractor_strings[idx]
74 | if self.size > 0:
75 | while len(history_strings) >= self.size:
76 | history_strings.pop(0)
77 | history_strings.append(text)
78 |
79 | def _update_distractor_raw_strings(self, text, idx):
80 | history_raw_strings = self.history_distractor_raw_strings[idx]
81 | if self.size > 0:
82 | while len(history_raw_strings) >= self.size:
83 | history_raw_strings.pop(0)
84 | history_raw_strings.append(text)
85 |
86 | def _update_distractor_vecs(self, text, idx):
87 | history_vecs = self.history_distractor_vecs[idx]
88 | if self.size > 0:
89 | while len(history_vecs) >= self.size:
90 | history_vecs.pop(0)
91 | history_vecs.append(self.parse(text))
92 |
93 | def _update_token_type_ids(self, text, idx):
94 | pass
95 |
96 | def add_reply_to_distractors(self, model_reply):
97 |
98 | # Update model's response to the history
99 | if model_reply is not None:
100 | for idx in range(self.world_cardinality):
101 | self._update_distractor_raw_strings(model_reply, idx)
102 | # this is causing the repetition of p2 token.
103 | # need to do this only once. not every loop
104 | if self.add_person_tokens and idx == 0:
105 | model_reply = self._add_person_tokens(model_reply, self.p2_token)
106 | self._update_distractor_strings(model_reply, idx)
107 | self._update_distractor_vecs(model_reply, idx)
108 |
109 | # def update_history(self, obs, add_next=None):
110 | def update_history(self, obs, temp_history=None):
111 | """
112 | Update the history with the given observation.
113 | :param add_next:
114 | string to append to history prior to updating it with the
115 | observation
116 | """
117 | # super().update_history(obs, add_next)
118 | super().update_history(obs, temp_history=temp_history)
119 |
120 | # Update previous turn's my response
121 | # if add_next is not None:
122 | # for idx in range(self.world_cardinality):
123 | # self._update_distractor_raw_strings(add_next, idx)
124 | # # this is causing the repetition of p2 token.
125 | # # need to do this only once. not every loop
126 | # if self.add_person_tokens and idx == 0:
127 | # add_next = self._add_person_tokens(add_next, self.p2_token)
128 | # self._update_distractor_strings(add_next, idx)
129 | # self._update_distractor_vecs(add_next, idx)
130 |
131 | # Update current turn's opponent's response
132 | if 'distractor_text' in obs:
133 | assert len(obs['distractor_text']) == self.world_cardinality, \
134 | f"Numer of distractor_text must be eqaul to world_cardinality. ({len(obs['distractor_text'])} vs {self.world_cardinality})"
135 | for idx, distractor_text in enumerate(obs['distractor_text']):
136 | if self.split_on_newln:
137 | next_texts = distractor_text.split('\n')
138 | else:
139 | next_texts = [distractor_text]
140 | for text in next_texts:
141 | self._update_distractor_raw_strings(text, idx)
142 | if self.add_person_tokens:
143 | text = self._add_person_tokens(
144 | distractor_text, self.p1_token, self.add_p1_after_newln
145 | )
146 | self._update_distractor_strings(text, idx)
147 | self._update_distractor_vecs(text, idx)
148 |
149 | def get_history_distractor_str(self):
150 | """Return the list of string version of the distractor histories."""
151 | if len(self.history_distractor_strings[0]) > 0:
152 | return [
153 | self.delimiter.join(history_strings)
154 | for history_strings in self.history_distractor_strings
155 | ]
156 | return None
157 |
158 | def get_history_distractor_vec(self):
159 | """Return a vectorized version of the distractor histories."""
160 | if len(self.history_distractor_vecs[0]) == 0:
161 | return None
162 |
163 | histories = []
164 | for idx in range(self.world_cardinality):
165 | history_vecs = self.history_distractor_vecs[idx]
166 |
167 | # if self.vec_type == 'deque':
168 | # history = deque(maxlen=self.max_len)
169 | # for vec in history_vecs[:-1]:
170 | # history.extend(vec)
171 | # history.extend(self.delimiter_tok)
172 | # history.extend(history_vecs[-1])
173 | # else:
174 | # vec type is a list
175 | history = []
176 | for vec in history_vecs[:-1]:
177 | history += vec
178 | history += self.delimiter_tok
179 | history += history_vecs[-1]
180 |
181 | histories.append(history)
182 | return histories
183 |
184 | def get_token_type_ids(self):
185 | """
186 | Return a vectorized version of the token_type_ids and
187 | distractor_token_type_ids
188 | """
189 | pass
190 |
191 |
192 | class ContextConsciousHistory(History):
193 | def __init__(self, *args, **kwargs):
194 | super().__init__(*args, **kwargs)
195 | opt = args[0]
196 | if opt['eval_type'] == 'convai2':
197 | self.add_person_tokens = True
198 | elif opt['eval_type'] == 'dnli':
199 | self.add_person_tokens = False
200 | else:
201 | raise ValueError
202 |
203 | self.world_cardinality = opt.get('world_cardinality', 5)
204 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)]
205 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)]
206 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)]
207 | # Will be used for TransferTransfo
208 | self.history_token_type_ids = []
209 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)]
210 | self.eval_type = opt.get('eval_type')
211 |
212 | def reset(self):
213 | """Clear the history"""
214 | super().reset()
215 | self.history_distractor_strings = [[] for _ in range(self.world_cardinality)]
216 | self.history_distractor_raw_strings = [[] for _ in range(self.world_cardinality)]
217 | self.history_distractor_vecs = [[] for _ in range(self.world_cardinality)]
218 | self.history_token_type_ids = []
219 | self.history_distractor_token_type_ids = [[] for _ in range(self.world_cardinality)]
220 |
221 | def _update_distractor_strings(self, text, idx):
222 | history_strings = self.history_distractor_strings[idx]
223 | if self.size > 0:
224 | while len(history_strings) >= self.size:
225 | history_strings.pop(0)
226 | history_strings.append(text)
227 |
228 | def _update_distractor_raw_strings(self, text, idx):
229 | history_raw_strings = self.history_distractor_raw_strings[idx]
230 | if self.size > 0:
231 | while len(history_raw_strings) >= self.size:
232 | history_raw_strings.pop(0)
233 | history_raw_strings.append(text)
234 |
235 | def _update_distractor_vecs(self, text, idx):
236 | history_vecs = self.history_distractor_vecs[idx]
237 | if self.size > 0:
238 | while len(history_vecs) >= self.size:
239 | history_vecs.pop(0)
240 | history_vecs.append(self.parse(text))
241 |
242 | def _update_token_type_ids(self, text, idx):
243 | pass
244 |
245 | def add_reply_to_distractors(self, model_reply, obs=None):
246 |
247 | # Update model's response along with distractor responses to the history
248 | if model_reply is not None and 'distractor_text' in obs:
249 | distractor_responses = obs['distractor_text']
250 | assert len(obs['distractor_text']) == self.world_cardinality
251 |
252 | for idx in range(self.world_cardinality):
253 | self._update_distractor_raw_strings(distractor_responses[idx], idx)
254 | if self.add_person_tokens:
255 | distractor_responses[idx] = self._add_person_tokens(distractor_responses[idx], self.p2_token)
256 | self._update_distractor_strings(distractor_responses[idx], idx)
257 | self._update_distractor_vecs(distractor_responses[idx], idx)
258 |
259 | # def update_history(self, obs, add_next=None):
260 | def update_history(self, obs, temp_history=None):
261 | """
262 | Update the history with the given observation.
263 | :param add_next:
264 | string to append to history prior to updating it with the
265 | observation
266 | """
267 | super().update_history(obs, temp_history=temp_history)
268 |
269 | # Update current turn's opponent's response
270 | if self.eval_type == 'convai2':
271 | if 'text' in obs:
272 | for idx in range(self.world_cardinality):
273 | if self.split_on_newln:
274 | next_texts = obs['text'].split('\n')
275 | else:
276 | next_texts = [obs['text']]
277 | for text in next_texts:
278 | self._update_distractor_raw_strings(text, idx)
279 | if self.add_person_tokens:
280 | text = self._add_person_tokens(
281 | obs['text'], self.p1_token, self.add_p1_after_newln
282 | )
283 | self._update_distractor_strings(text, idx)
284 | self._update_distractor_vecs(text, idx)
285 | else:
286 | if 'distractor_text' in obs:
287 | distractor_texts = obs['distractor_text']
288 | for idx, distractor in enumerate(distractor_texts):
289 | self._update_distractor_raw_strings(distractor, idx)
290 | self._update_distractor_strings(distractor, idx)
291 | self._update_distractor_vecs(distractor, idx)
292 |
293 | def get_history_distractor_str(self):
294 | """Return the list of string version of the distractor histories."""
295 | if len(self.history_distractor_strings[0]) > 0:
296 | return [
297 | self.delimiter.join(history_strings)
298 | for history_strings in self.history_distractor_strings
299 | ]
300 | return None
301 |
302 | def get_history_distractor_vec(self):
303 | """Return a vectorized version of the distractor histories."""
304 | if len(self.history_distractor_vecs[0]) == 0:
305 | return None
306 |
307 | histories = []
308 | for idx in range(self.world_cardinality):
309 | history_vecs = self.history_distractor_vecs[idx]
310 |
311 | # if self.vec_type == 'deque':
312 | # history = deque(maxlen=self.max_len)
313 | # for vec in history_vecs[:-1]:
314 | # history.extend(vec)
315 | # history.extend(self.delimiter_tok)
316 | # history.extend(history_vecs[-1])
317 | # else:
318 | # vec type is a list
319 | history = []
320 | for vec in history_vecs[:-1]:
321 | history += vec
322 | history += self.delimiter_tok
323 | history += history_vecs[-1]
324 |
325 | histories.append(history)
326 | return histories
327 |
328 | def get_token_type_ids(self):
329 | """
330 | Return a vectorized version of the token_type_ids and
331 | distractor_token_type_ids
332 | """
333 | pass
334 |
--------------------------------------------------------------------------------
/agents/modules.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | Implements NN code for transformers.
8 |
9 | Original paper: https://arxiv.org/abs/1706.03762. (Vaswani, 2017). The
10 | `Annotated Transformer` (Rush, 2018) is an excellent reading guide which explains
11 | much of the mechanics of the Transformer model
12 | (http://nlp.seas.harvard.edu/2018/04/03/attention.html).
13 |
14 | This module also supports special segments (ala BERT;
15 | https://arxiv.org/abs/1810.04805), and a few different variations seen in the
16 | literature (BERT and XLM; https://arxiv.org/abs/1901.07291).
17 | """
18 |
19 | import math
20 | from typing import Dict, Tuple, Optional
21 |
22 | import numpy as np
23 | import torch
24 | import torch.cuda
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | from parlai.core.torch_generator_agent import TorchGeneratorModel
29 | from parlai.agents.transformer.modules import TransformerGeneratorModel
30 |
31 |
32 | class SelfConsciousTransformerModel(TransformerGeneratorModel):
33 | """
34 | Implements a full transformer generator model, with pragmatic self-consciousness.
35 | """
36 |
37 | def __init__(self, opt, dictionary):
38 | super().__init__(opt, dictionary)
39 |
40 | self.alpha = 0.0 if opt['conscious_target'] == 'none' else opt['alpha']
41 | self.beta = opt['beta']
42 | self.world_cardinality = opt['world_cardinality']
43 | self.worldprior = opt['worldprior']
44 | self.target_persona = 0
45 | self.fp16 = opt['fp16']
46 |
47 | def _initialize_worldpriors(self, bsz, seqlen):
48 | """
49 | initialize the world prior with a uniform distribution
50 | """
51 | cardinality = self.world_cardinality
52 | torch_dtype=torch.half if self.fp16 else torch.float
53 | ones = torch.ones(1, seqlen, cardinality, dtype=torch_dtype, requires_grad=False).cuda()
54 | uniform_world_prior = torch.log(ones / cardinality)
55 | world_priors = uniform_world_prior.repeat(bsz, 1, 1).detach()
56 |
57 | return world_priors
58 |
59 | def _pragmatic_reasoning(self, s0_t, worldprior):
60 | """
61 | run pragmatic reasoning with the base speaker and its imaginary listener
62 | """
63 |
64 | vocab_size = self.embeddings.num_embeddings
65 |
66 | # log-scale
67 | log_score = nn.functional.log_softmax(s0_t, dim=2)
68 | log_score = log_score.squeeze() # (bpsz, vocab)
69 |
70 | # (bsz, world_cardinality, vocab)
71 | log_score = log_score.view(-1, self.world_cardinality, vocab_size)
72 |
73 | # S_0 for L_1
74 | _literal_speaker = log_score.clone()
75 | _literal_speaker, _literal_s_next_token_idxs = torch.max(_literal_speaker, dim=-1, keepdim=True)
76 |
77 | # S_0 for the actual given persona (bsz, vocab)
78 | speaker_prior = log_score.select(1, self.target_persona) # target persona is always index 0
79 |
80 | # S_0 for L_0
81 | # (bsz, vocab, world_cardinality)
82 | log_score = log_score.transpose(dim0=1, dim1=2).contiguous()
83 | log_score = log_score * self.beta
84 |
85 | # L_0 \propto S_0 * p(i)
86 | # worldprior should be broadcasted to all the tokens
87 | # (bsz, vocab, world_cardinality)
88 | listener_posterior = (log_score + worldprior) - torch.logsumexp(log_score + worldprior, 2, keepdim=True)
89 |
90 | # (bsz, vocab)
91 | listener_score = listener_posterior.select(2, self.target_persona) # target persona is always index 0
92 | listener_score = listener_score * self.alpha
93 |
94 | speaker_posterior = (listener_score + speaker_prior) - torch.logsumexp(listener_score + speaker_prior, 1, keepdim=True)
95 |
96 | # need to unsqueeze in the dimension 1
97 | speaker_posterior = speaker_posterior.unsqueeze(1) # (bsz, 1, vocab)
98 |
99 | # L_0 for L_1
100 | _literal_listener = listener_posterior.transpose(dim0=1, dim1=2).contiguous()
101 | _literal_listener = torch.gather(_literal_listener, -1, _literal_s_next_token_idxs)
102 |
103 | pragmatic_listener = (_literal_speaker + _literal_listener) - torch.logsumexp(_literal_speaker + _literal_listener, 1, keepdim=True)
104 | pragmatic_listener = pragmatic_listener.squeeze()
105 |
106 | return speaker_posterior, listener_posterior, pragmatic_listener
107 |
108 | def selfconscious_decode(self, encoder_states, maxlen):
109 | """
110 | greedy decoding with pragmatic self-consciousness
111 | """
112 | bpsz = encoder_states[0].size(0)
113 | bsz = bpsz // self.world_cardinality
114 |
115 | inputs_t = self.START.detach().expand(bpsz, 1)
116 | worldpriors = self._initialize_worldpriors(bsz, maxlen).detach()
117 |
118 | s1_scores = []
119 | incr_state = None
120 |
121 | for t in range(maxlen):
122 | worldprior_t = worldpriors.select(1, t).unsqueeze(1)
123 |
124 | latent, incr_state = self.decoder(inputs_t, encoder_states, incr_state)
125 | _logits = self.output(latent)
126 | # only get the last timestep's logit
127 | s0_t = _logits.select(dim=1, index=-1).unsqueeze(1) # logits shape: (bpsz, 1, vocab)
128 |
129 | # s1_t: (bsz, 1, vocab)
130 | # listener_posterior: (bsz, vocab, world_cardinality)
131 | s1_t, l0_t, l1_t = self._pragmatic_reasoning(s0_t, worldprior_t)
132 | s1_scores.append(s1_t)
133 |
134 | next_token = s1_t.max(2)[1].clone().detach() # next input is current predicted output idx
135 |
136 | idx_for_tile = torch.arange(bsz).repeat(self.world_cardinality, 1).transpose(0, 1).reshape(-1).cuda()
137 | inputs_next_t = torch.index_select(next_token, 0, idx_for_tile)
138 | next_token = next_token.unsqueeze(2)
139 | tiled_next_token = next_token.repeat(1, 1, self.world_cardinality)
140 |
141 | if self.worldprior != 'uniform':
142 | # (bsz, vocab, world_cardinality) -> (bsz, 1, world_cardinality)
143 | updated_world_prior = torch.gather(l0_t, 1, tiled_next_token).clone().detach()
144 | if t + 1 < maxlen:
145 | if self.worldprior == 'L0':
146 | worldpriors[:, t + 1, :] = updated_world_prior.squeeze()
147 | elif self.worldprior == 'L1':
148 | worldpriors[:, t + 1, :] = l1_t
149 | else:
150 | raise NotImplementedError
151 |
152 | # update inputs for next timestep
153 | inputs_t = torch.cat((inputs_t, inputs_next_t), dim=1)
154 |
155 | s1_scores = torch.cat(s1_scores, dim=1) # (bsz, seqlen, vocab)
156 | _, preds = s1_scores.max(dim=2)
157 |
158 | return preds, s1_scores
159 |
160 | def selfconscious_decode_forced(self, encoder_states, ys):
161 | """
162 | faster teacher-forced decoding with pragmatic self-consciousness
163 | """
164 |
165 | bsz = ys.size(0)
166 | seqlen = ys.size(1)
167 | self.longest_label = max(self.longest_label, seqlen)
168 | emb_size = self.encoder.embedding_size
169 | enc_outputs = encoder_states[0].view(bsz * self.world_cardinality, -1, emb_size).contiguous()
170 | enc_outputs_mask = encoder_states[1].view(bsz * self.world_cardinality, -1).contiguous()
171 | enc_states = (enc_outputs, enc_outputs_mask)
172 | bpsz = enc_outputs.size(0)
173 |
174 | # tile ys as much as the world_cardinality
175 | idx_for_tile = torch.arange(bsz).repeat(self.world_cardinality, 1).transpose(0, 1).reshape(-1).cuda()
176 | tiled_ys = torch.index_select(ys, 0, idx_for_tile)
177 |
178 | inputs = tiled_ys.narrow(1, 0, seqlen - 1)
179 | inputs = torch.cat([self.START.detach().expand(bpsz, 1), inputs], 1)
180 | worldpriors = self._initialize_worldpriors(bsz, seqlen).detach()
181 | s1_scores = []
182 |
183 | latent, _ = self.decoder(inputs, enc_states)
184 | base_speaker = self.output(latent)
185 |
186 | for t in range(seqlen):
187 |
188 | s0_t = base_speaker.select(dim=1, index=t).unsqueeze(1) # s0_t: (bpsz, 1, vocab)
189 | worldprior_t = worldpriors.select(dim=1, index=t).unsqueeze(1)
190 |
191 | # s1_t: (bsz, 1, vocab)
192 | # l0_t: (bsz, vocab, world_cardinality)
193 | s1_t, l0_t, l1_t = self._pragmatic_reasoning(s0_t, worldprior_t)
194 | s1_scores.append(s1_t)
195 |
196 | # Update world_prior with listener posterior
197 | if t + 1 < seqlen:
198 | next_tokens = inputs.select(1, t + 1).view(-1, 1) # (bpsz, 1): the next tokens for each bpsz instance
199 | next_tokens = next_tokens.unsqueeze(2)
200 | # [0, 1*world_cardinality, 2*wc, 3*wc, ..., bpsz - 1wc] -> to get the ground-truth personas
201 | target_persona_idxs = torch.arange(bsz).cuda() * (self.world_cardinality)
202 |
203 | # we only need the next token of the ground-truth persona
204 | next_token = torch.index_select(next_tokens, 0, target_persona_idxs) # (bsz, 1, 1)
205 | tiled_next_token = next_token.repeat(1, 1, self.world_cardinality) # (bsz, 1, world_cardinality)
206 |
207 | if self.worldprior != 'uniform':
208 | # (bsz, vocab, world_cardinality) -> (bsz, 1, world_cardinality)
209 | updated_world_prior = torch.gather(l0_t, 1, tiled_next_token).clone().detach()
210 | if self.worldprior == 'L0':
211 | worldpriors[:, t + 1, :] = updated_world_prior.squeeze()
212 | elif self.worldprior == 'L1':
213 | worldpriors[:, t + 1, :] = l1_t
214 | else:
215 | raise NotImplementedError
216 |
217 | s1_scores = torch.cat(s1_scores, 1) # (bsz, seqlen, vocab)
218 | _, preds = s1_scores.max(dim=2)
219 |
220 | return s1_scores, preds
221 |
--------------------------------------------------------------------------------
/agents/selfconscious_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from itertools import chain
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | from parlai.core.opt import Opt
10 | from parlai.core.message import Message
11 | from parlai.core.torch_agent import Batch, Output
12 | from parlai.core.torch_generator_agent import PPLMetric
13 | from parlai.core.metrics import SumMetric, AverageMetric
14 | from parlai.utils.torch import padded_tensor
15 | from parlai.utils.misc import warn_once
16 | from parlai.agents.transformer.transformer import (
17 | TransformerGeneratorAgent,
18 | add_common_cmdline_args
19 | )
20 |
21 | from agents.modules import SelfConsciousTransformerModel
22 | from modules.dnli_bert import DnliBert
23 | from agents.history import SelfConsciousHistory, ContextConsciousHistory
24 |
25 |
26 | def list_to_matrix(l, n):
27 | return [l[i:i+n] for i in range(0, len(l), n)]
28 |
29 |
30 | class SelfConsciousBlenderAgent(TransformerGeneratorAgent):
31 | """
32 | Implementation of the Self-Conscious Blender Agent.
33 | """
34 |
35 | @classmethod
36 | def add_cmdline_args(cls, argparser):
37 | """
38 | Add command-line arguments specifically for this agent.
39 | """
40 | agent = argparser.add_argument_group('Self-conscious Blender Arguments')
41 | agent.add_argument(
42 | '--conscious-target',
43 | type=str,
44 | choices=['none', 'self', 'context'],
45 | default='self',
46 | help='The target which the agent will be concerned about.',
47 | )
48 | agent.add_argument(
49 | '-a',
50 | '--alpha',
51 | type=float,
52 | default=0,
53 | help='Rationality parameter for S_1(speaker_1)',
54 | )
55 | agent.add_argument(
56 | '-b',
57 | '--beta',
58 | type=float,
59 | default=1,
60 | help='Rationality parameter for Listener',
61 | )
62 | agent.add_argument(
63 | '--world_cardinality',
64 | type=int,
65 | default=3,
66 | help='Cardinality of world I:= Number of persona to use RSA model (including GT)',
67 | )
68 | agent.add_argument(
69 | '--worldprior',
70 | type=str,
71 | choices=['uniform', 'L0', 'L1'],
72 | default='L0',
73 | help='Update world prior with a `uniform` distribution or `L0` or `L1`.',
74 | )
75 | agent.add_argument(
76 | '--use_dnli',
77 | type=bool,
78 | default=True,
79 | help='Whether to use dnli model to measure consistency-score in Convai2 or rerank candidates in DNLI'
80 | )
81 | add_common_cmdline_args(agent)
82 | cls.dictionary_class().add_cmdline_args(argparser)
83 |
84 | super(SelfConsciousBlenderAgent, cls).add_cmdline_args(argparser)
85 | return agent
86 |
87 | def __init__(self, opt: Opt, shared=None):
88 |
89 | self.task = str.lower(opt['task'].split(':')[-1])
90 |
91 | if opt['conscious_target'] != 'none':
92 | assert opt['conscious_target'] in self.task, \
93 | "conscious_target (`" + opt['conscious_target'] + "`) must match task type (`" + self.task + "`)"
94 |
95 | SEED = 46
96 | random.seed(SEED)
97 | np.random.seed(SEED)
98 | os.environ['PYTHONHASHSEED'] = str(SEED)
99 | torch.random.manual_seed(SEED)
100 | torch.cuda.manual_seed(SEED)
101 | torch.manual_seed(SEED)
102 | torch.cuda.manual_seed_all(SEED)
103 | torch.backends.cudnn.deterministic = True
104 | torch.backends.cudnn.benchmark = False
105 |
106 | # For public self-consciousness
107 | self.target_persona = opt.get('target_persona', 0)
108 | self.conscious_target = opt.get('conscious_target', 'self')
109 | self.world_cardinality = opt.get('world_cardinality', 3)
110 | self.alpha = 0.0 if self.conscious_target == 'none' else opt.get('alpha', 2.0)
111 | self.beta = opt.get('beta', 1.0)
112 | self.worldprior = opt.get('worldprior', 'L0')
113 |
114 | self.eval_type = opt.get('eval_type')
115 | # self.rank_candidates = opt.get('rank_candidates', True)
116 | self.multigpu = (
117 | opt.get('multigpu', False) and self.use_cuda and (opt.get('batchsize') > 1)
118 | )
119 |
120 | init_model, is_finetune = self._get_init_model(opt, shared)
121 | super().__init__(opt, shared)
122 |
123 | # Implementation is based on beam_size 1
124 | self.beam_size = 1
125 | warn_once(f'This implementation is assumed to have beam-size 1.')
126 |
127 | # Always rank candidates for the ranking metrics
128 | self.rank_candidates = True
129 | warn_once(f'rank-candidates is always True for ranking metrics.')
130 |
131 | if opt['use_dnli']:
132 | if not shared:
133 | self.dnli_model = DnliBert(opt, use_cuda=self.use_cuda)
134 | else:
135 | self.dnli_model = shared['dnli_model']
136 | else:
137 | self.dnli_model = None
138 |
139 | self.id = 'SelfConsciousBlender'
140 |
141 | self.reset()
142 |
143 | def build_model(self, states=None):
144 | """
145 | Build and return model.
146 | """
147 | model = SelfConsciousTransformerModel(self.opt, self.dict)
148 | if self.opt['embedding_type'] != 'random':
149 | self._copy_embeddings(
150 | model.encoder.embeddings.weight, self.opt['embedding_type']
151 | )
152 | return model
153 |
154 | def history_class(self):
155 | return ContextConsciousHistory if 'context' in self.task else SelfConsciousHistory
156 |
157 | def _model_input(self, batch):
158 | """
159 | Override from TorchGeneratorAgent
160 | passes (batch.text_vec,) to TorchGeneratorAgent._encoder_input()
161 | TGA._encoder_input() directly passes the result of TGA._model_input()
162 | change batch.text_vec to batch.distractor_text_vec for pragmatic decoding
163 | """
164 | bsz = batch.text_vec.size(0)
165 | distractor_text_vec = batch.distractor_text_vec.view(bsz * self.world_cardinality, -1).contiguous()
166 | return (distractor_text_vec,)
167 |
168 | def selfconscious_greedy_generate(self, batch, maxlen):
169 | """
170 | Greedy decoding with Public Self-Consciousness
171 | """
172 |
173 | bsz = batch.text_vec.size(0)
174 | world_cardinality = self.world_cardinality
175 | embedding_size = self.opt.get('embedding_size')
176 | encoder_states = self.model.encoder(*self._encoder_input(batch))
177 |
178 | preds, scores = self.model.selfconscious_decode(encoder_states, maxlen)
179 |
180 | return preds, scores
181 |
182 | def rank(self, batch):
183 | """
184 | Rank candidates by PPL score
185 | """
186 | bsz = batch.text_vec.size(0)
187 | world_cardinality = self.world_cardinality
188 | embedding_size = self.opt.get('embedding_size')
189 | ranked_candidates = []
190 | cand_ordering = []
191 | encoder_states = self.model.encoder(*self._encoder_input(batch))
192 | batch_dim = encoder_states[0].size(0) # two possibilities: batchsize or batchsize * world_cardinality
193 |
194 | if bsz != batch_dim:
195 | enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous()
196 | enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous()
197 | encoder_states = (enc_output, enc_output_mask)
198 |
199 | for i in range(bsz):
200 | num_cands = len(batch.candidate_vecs[i])
201 | cands, _ = self._pad_tensor(batch.candidate_vecs[i])
202 | # get [i]th state from encoder_states #num_cands time.
203 | # because we need same encoder_states for each candidate
204 | enc = self.model.reorder_encoder_states(encoder_states, [i] * num_cands)
205 |
206 | # enc: (num_cands, world_cardinality, seqlen, emb_size)
207 | # scores: (num_cands, max_len, vocab_size)
208 | scores, _ = self.model.selfconscious_decode_forced(enc, cands)
209 |
210 | cand_losses = F.cross_entropy(
211 | scores.view(num_cands * cands.size(1), -1),
212 | cands.view(-1),
213 | reduction='none',
214 | ).view(num_cands, cands.size(1))
215 | # now cand_losses is cands x seqlen size, but we still need to
216 | # check padding and such
217 | mask = (cands != self.NULL_IDX)
218 | mask = mask.half() if self.fp16 else mask.float()
219 | cand_scores = (-cand_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
220 |
221 | if self.dnli_model is not None and self.eval_type == 'dnli':
222 | cand_scores = torch.unsqueeze(cand_scores, 0)
223 | cand_scores = self.dnli_model.rerank_candidates([batch.observations[i]], cand_scores)
224 | cand_scores = torch.squeeze(cand_scores)
225 |
226 | _, ordering = cand_scores.sort(descending=True)
227 | ranked_candidates.append([batch.candidates[i][o] for o in ordering])
228 | cand_ordering.append(ordering)
229 |
230 | return ranked_candidates, cand_ordering
231 |
232 | def compute_loss(self, batch, return_output=False):
233 | """
234 | Override from TorchGeneratorAgent
235 | Compute and return the loss for the given batch.
236 |
237 | Easily overridable for customized loss functions.
238 |
239 | If return_output is True, the full output from the call to self.model()
240 | is also returned, via a (loss, model_output) pair.
241 | """
242 | if batch.label_vec is None:
243 | raise ValueError('Cannot compute loss without a label.')
244 |
245 | bsz = batch.text_vec.size(0)
246 | world_cardinality = self.world_cardinality
247 | embedding_size = self.opt.get('embedding_size')
248 | encoder_states = self.model.encoder(*self._encoder_input(batch))
249 |
250 | enc_output = encoder_states[0].view(bsz, world_cardinality, -1, embedding_size).contiguous()
251 | enc_output_mask = encoder_states[1].view(bsz, world_cardinality, -1).contiguous()
252 | encoder_states = (enc_output, enc_output_mask)
253 |
254 | scores, preds = self.model.selfconscious_decode_forced(encoder_states, batch.label_vec)
255 | model_output = (scores, preds, encoder_states)
256 |
257 | score_view = scores.view(-1, scores.size(-1))
258 | loss = self.criterion(score_view, batch.label_vec.view(-1))
259 | loss = loss.view(scores.shape[:-1]).sum(dim=1)
260 | # save loss to metrics
261 | notnull = batch.label_vec.ne(self.NULL_IDX)
262 | target_tokens = notnull.long().sum(dim=-1)
263 | correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
264 |
265 | self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
266 | self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
267 | self.record_local_metric(
268 | 'token_acc', AverageMetric.many(correct, target_tokens)
269 | )
270 |
271 | # actually do backwards loss
272 | loss = loss.sum()
273 | loss /= target_tokens.sum() # average loss per token
274 |
275 | if return_output:
276 | return (loss, model_output)
277 | else:
278 | return loss
279 |
280 | def _eval_convai2_step(self, batch):
281 | """Evaluate a single batch of examples."""
282 |
283 | assert self.alpha >= 0
284 | if batch.distractor_text_vec is None:
285 | return None
286 |
287 | self.model.eval()
288 |
289 | # 1. Generation
290 | assert self.beam_size is 1
291 | maxlen = self.label_truncate or 256
292 | if not self.skip_generation:
293 | preds, scores = self.selfconscious_greedy_generate(batch, maxlen)
294 | else:
295 | preds = None
296 |
297 | # 2. Compute PPL with teacher-forced generation
298 | # calculate loss on targets with teacher forcing
299 | loss, model_output = self.compute_loss(batch, return_output=True)
300 | token_losses = self._construct_token_losses(
301 | batch.label_vec, model_output
302 | )
303 |
304 | # 3. Rank candidates by computing PPL for each candidates
305 | if self.rank_candidates:
306 | ranked_cands, ordering = self.rank(batch)
307 | else:
308 | ranked_cands = None
309 |
310 | # 4. Compute consistency score
311 | additional_metrics = [{'c_score': 0.0} for _ in range(len(batch.observations))]
312 | output_texts = [self._v2t(p) for p in preds] if preds is not None else None
313 | if not self.skip_generation:
314 | if self.opt['use_dnli']:
315 | c_scores = []
316 | for text, obs in zip(output_texts, batch.observations):
317 | if 'context' in self.task:
318 | c_score = self.dnli_model.compute_consistency_scores(text, obs['my_context'])
319 | else:
320 | persona_strings = obs['my_persona'].split('\n')
321 | c_score = self.dnli_model.compute_consistency_scores(text, persona_strings)
322 |
323 | c_scores.append(c_score)
324 |
325 | for idx, c_score in enumerate(c_scores):
326 | additional_metrics[idx]['c_score'] = c_score
327 |
328 | return Output(output_texts, ranked_cands, token_losses=token_losses, metrics=additional_metrics)
329 |
330 | def _eval_dnli_step(self, batch):
331 | """Evaluate a single batch of examples."""
332 |
333 | assert self.alpha >= 0
334 |
335 | self.model.eval()
336 | ranked_cands, ordering = self.rank(batch)
337 |
338 | bsz = len(ranked_cands)
339 | dnli_metrics = []
340 | for batch_idx in range(bsz):
341 | dnli_score = {'contradict@1': 0, 'entail@1': 0, 'neutral@1': 0}
342 | top1_idx = ordering[batch_idx][0].item()
343 | if top1_idx == 0:
344 | pass
345 | # dnli_metrics['dnli_hit@1'] += 1
346 | elif top1_idx > 0 and top1_idx < 11:
347 | dnli_score['contradict@1'] += 1
348 | elif top1_idx >= 11 and top1_idx < 21:
349 | dnli_score['entail@1'] += 1
350 | else:
351 | dnli_score['neutral@1'] += 1
352 | dnli_metrics.append(dnli_score)
353 |
354 | return Output(text_candidates=ranked_cands, metrics=dnli_metrics)
355 |
356 | def eval_step(self, batch):
357 |
358 | if self.opt['eval_type'] == 'convai2':
359 | return self._eval_convai2_step(batch)
360 | elif self.opt['eval_type'] == 'dnli':
361 | return self._eval_dnli_step(batch)
362 | else:
363 | raise NotImplementedError
364 |
365 | def self_observe(self, self_message: Message):
366 | """
367 | Override from TorchAgent
368 | Update the model's reply or label to the history of distractor-fields in History class
369 | """
370 | episode_done = self.observation['episode_done']
371 | use_reply = self.opt.get('use_reply', 'label')
372 |
373 | # actually ingest the label
374 | if use_reply == 'none':
375 | # we're not including our own responses anyway.
376 | reply = None
377 | elif use_reply == 'label':
378 | # first look for the true label
379 | label_key = (
380 | 'labels'
381 | if 'labels' in self.observation
382 | else 'eval_labels'
383 | if 'eval_labels' in self.observation
384 | else None
385 | )
386 | if label_key is not None:
387 | lbls = self.observation[label_key]
388 | reply = lbls[0] if len(lbls) == 1 else self.random.choice(lbls)
389 | else:
390 | # otherwise, we use the last output the model generated
391 | if self_message is not None:
392 | reply = self_message['text']
393 | else:
394 | reply = None
395 |
396 | super().self_observe(self_message)
397 |
398 | if episode_done:
399 | return None
400 |
401 | if reply is not None:
402 | if 'context' in self.task:
403 | self.history.add_reply_to_distractors(reply, self.observation)
404 | else:
405 | self.history.add_reply_to_distractors(reply)
406 |
407 | return reply
408 |
409 | def _ordered_cand_scores_to_cand_text(self, ordered_cand_preds, cand_inds, candidates):
410 | cand_replies = [None] * len(candidates)
411 |
412 | for idx, order in enumerate(ordered_cand_preds): # batch_idx, sorted cand_idx
413 | batch_idx = cand_inds[idx]
414 | # get the original sentences from candidates by order
415 | cand_replies[batch_idx] = [candidates[batch_idx][i] for i in order]
416 |
417 | return cand_replies
418 |
419 | def _build_candidates_tensor(self, batch):
420 | if not batch.candidates:
421 | return None, None
422 |
423 | cand_inds = [i for i in range(len(batch.candidates)) if batch.candidates[i]]
424 | cands = [batch.candidate_vecs[i] for i in cand_inds]
425 |
426 | # get the length of the longest candidate in the batch
427 | max_cand_len = max(
428 | [max([cand.size(0) for cand in cands_i]) for cands_i in cands]
429 | )
430 |
431 | for i, c in enumerate(cands): # make each instance in batch.cands to a padded tensor
432 | cands[i] = padded_tensor(c, use_cuda=self.use_cuda,
433 | max_len=max_cand_len,
434 | fp16friendly=self.fp16)[0].unsqueeze(0)
435 |
436 | # (batchsize, num_cands, max_len + a) +a due to fp16
437 | cands = torch.cat(cands, 0)
438 |
439 | return cands, cand_inds
440 |
441 | def vectorize(self, obs, history, **kwargs):
442 | """
443 | Override from TorchAgent
444 | Vectorize the texts in observation
445 | """
446 | super().vectorize(obs, history, **kwargs) # candidate vecs are vectorized here
447 | if not self.is_training:
448 | self._set_distractor_text_vec(obs, history, kwargs['text_truncate'])
449 | return obs
450 |
451 | def _set_text_vec(self, obs, history, truncate):
452 | """
453 | Override from TorchAgent for DNLI evaluation
454 | This will be called in super().vectorize()
455 | """
456 | # WARNING: self.is_training is always False in here
457 | is_training = False if 'eval_labels' in obs else True
458 |
459 | if is_training or self.opt['eval_type'] == 'convai2':
460 | return super()._set_text_vec(obs, history, truncate)
461 | elif self.opt['eval_type'] == 'dnli':
462 | if 'text' not in obs:
463 | return obs
464 |
465 | # Vectorize the text
466 | if 'text_vec' not in obs:
467 | obs['full_text'] = obs['text']
468 | vec = self.dict.txt2vec(obs['full_text'])
469 | obs['text_vec'] = vec
470 |
471 | # check truncation
472 | if obs.get('text_vec') is not None:
473 | truncated_vec = self._check_truncate(obs['text_vec'], truncate, True)
474 | obs.force_set('text_vec', torch.LongTensor(truncated_vec))
475 | return obs
476 | else:
477 | raise NotImplementedError
478 |
479 | def _set_distractor_text_vec(self, obs, history, truncate):
480 | """
481 | Set 'distractor_text' and 'distractor_text_vec' field in the observation
482 | """
483 | if 'distractor_text' not in obs:
484 | return obs
485 |
486 | if 'distractor_text_vec' not in obs:
487 | # distractor_text is in the SelfConsciousHistory class
488 | distractor_string = history.get_history_distractor_str()
489 |
490 | if distractor_string is None:
491 | return obs
492 |
493 | # Set 'full_distractor_text'
494 | obs['full_distractor_text'] = distractor_string
495 | # distractor_text_vec is also in the SelfConsciousHistory class
496 | # they are already vectorized at SelfConsciousHistory.update_history()
497 | if distractor_string:
498 | obs['distractor_text_vec'] = history.get_history_distractor_vec()
499 |
500 | # Check truncation
501 | if obs.get('distractor_text_vec') is not None:
502 | truncated_vec = [
503 | torch.LongTensor(self._check_truncate(text_vec, truncate, True))
504 | for text_vec in obs['distractor_text_vec']
505 | ]
506 | obs.force_set('distractor_text_vec', truncated_vec)
507 | return obs
508 |
509 | def batchify(self, *args, **kwargs):
510 | """
511 | Override from TorchAgent
512 | Additionally batchify the distractor_text_vec and add it to batch
513 | """
514 | kwargs['sort'] = True # need sort for pack_padded()
515 | batch = super().batchify(*args, **kwargs)
516 | sort = False # we must not sort after super().batchify()
517 |
518 | exs = batch.observations
519 | d_text_vec, d_lens = None, None
520 | if any('distractor_text_vec' in ex for ex in exs):
521 | # Pad distractor vectors
522 | _d_text_vec = [ex.get('distractor_text_vec', self.EMPTY) for ex in exs]
523 | _d_text_vec_flattened = list(chain(*_d_text_vec))
524 | d_text_vec, d_lens = self._pad_tensor(_d_text_vec_flattened)
525 |
526 | # Reshape to (batch_size, world_cardinality, max_length)
527 | bsz = len(exs)
528 | d_text_vec = d_text_vec.view(bsz, self.world_cardinality, -1)
529 | d_lens = list_to_matrix(d_lens, self.world_cardinality)
530 |
531 | batch = Batch(
532 | distractor_text_vec=d_text_vec,
533 | distractor_text_lengths=d_lens,
534 | **dict(batch)
535 | )
536 |
537 | return batch
538 |
539 | def share(self):
540 | shared = super().share()
541 | if self.opt['use_dnli']:
542 | shared['dnli_model'] = self.dnli_model
543 | return shared
544 |
--------------------------------------------------------------------------------
/assets/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/assets/figure1.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: pragmatic-consistency
2 | channels:
3 | - defaults
4 | dependencies:
5 | - cudatoolkit=10.1
6 | - pytorch::pytorch=1.6.0
7 | - python=3.6.8
8 | - pip=20.1.1
9 | - pip:
10 | # ParlAI (use commit of 1st of October, 2020)
11 | - git+https://github.com/facebookresearch/ParlAI.git@9cd6b6c0e70c72a24e959e4a328cb4093eb7f3de
12 | - torchtext==0.7.0
13 | - spacy==2.3.2
14 | - pytorch-transformers==1.2.0
15 | # For logging
16 | - tqdm
17 | - better-exceptions
18 | # For linting
19 | - pylint
20 | - pycodestyle
21 | - mypy
22 | # For markdown preview
23 | - grip
24 | # etc
25 | - ruamel.yaml
26 | - more_itertools
27 | - isort
28 | - pudb
29 | - jupyter
30 | - orderedset
31 |
--------------------------------------------------------------------------------
/eval_dnli.py:
--------------------------------------------------------------------------------
1 | from parlai.scripts.eval_model import eval_model
2 | from parlai.scripts.eval_model import setup_args as parlai_setupargs
3 |
4 |
5 | def setup_args():
6 | parser = parlai_setupargs()
7 | parser.set_defaults(
8 | model_file='zoo:blender/blender_90M/model',
9 | eval_type='dnli',
10 | metrics='contradict@1,entail@1,neutral@1',
11 | alpha=8,
12 | beta=1,
13 | use_dnli=False
14 | )
15 | return parser
16 |
17 |
18 | if __name__ == '__main__':
19 | parser = setup_args()
20 | opt = parser.parse_args()
21 | eval_model(opt)
22 |
--------------------------------------------------------------------------------
/eval_personachat.py:
--------------------------------------------------------------------------------
1 | from parlai.scripts.eval_model import eval_model
2 | from parlai.scripts.eval_model import setup_args as parlai_setupargs
3 |
4 |
5 | def setup_args():
6 | parser = parlai_setupargs()
7 | parser.set_defaults(
8 | model_file='zoo:blender/blender_90M/model',
9 | eval_type='convai2',
10 | metrics='token_acc,ppl,loss,c_scores,f1',
11 | alpha=2,
12 | beta=0.5
13 | )
14 | return parser
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = setup_args()
19 | opt = parser.parse_args()
20 | eval_model(opt)
21 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/modules/__init__.py
--------------------------------------------------------------------------------
/modules/dnli_bert.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from pytorch_transformers import (
8 | BertForSequenceClassification,
9 | BertTokenizer
10 | )
11 | from parlai.core.build_data import download_from_google_drive
12 |
13 |
14 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
15 | """Truncates a sequence pair in place to the maximum length."""
16 | # This is a simple heuristic which will always truncate the longer sequence
17 | # one token at a time. This makes more sense than truncating an equal percent
18 | # of tokens from each, since if one sequence is very short then each token
19 | # that's truncated likely contains more information than a longer sequence.
20 | while True:
21 | total_length = len(tokens_a) + len(tokens_b)
22 | if total_length <= max_length:
23 | break
24 | if len(tokens_a) > len(tokens_b):
25 | tokens_a.pop()
26 | else:
27 | tokens_b.pop()
28 |
29 |
30 | class DnliBert(object):
31 | def __init__(self,
32 | opt,
33 | dnli_lambda=1.0,
34 | dnli_k=10,
35 | max_seq_length=128,
36 | use_cuda=True):
37 | self.opt = opt
38 | self.dnli_lambda = dnli_lambda
39 | self.dnli_k = dnli_k
40 | self.max_seq_length = max_seq_length
41 | self.use_cuda = use_cuda
42 | self.mapper = {0: "contradiction",
43 | 1: "entailment",
44 | 2: "neutral"}
45 |
46 | dnli_model, dnli_tokenizer = self._load_dnli_model()
47 | self.dnli_model = dnli_model
48 | self.dnli_tokenizer = dnli_tokenizer
49 |
50 | def _load_dnli_model(self):
51 | # Download pretrained weight
52 | dnli_model_fname = os.path.join(self.opt['datapath'], 'dnli_model.bin')
53 | if not os.path.exists(dnli_model_fname):
54 | print(f"[ Download pretrained dnli model params to {dnli_model_fname}]")
55 | download_from_google_drive(
56 | '1Qawz1pMcV0aGLVYzOgpHPgG5vLSKPOJ1',
57 | dnli_model_fname
58 | )
59 |
60 | # Load pretrained weight
61 | print(f"[ Load pretrained dnli model from {dnli_model_fname}]")
62 | model_state_dict = torch.load(dnli_model_fname)
63 | dnli_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict, num_labels=3)
64 | if self.use_cuda:
65 | dnli_model.cuda()
66 | dnli_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
67 |
68 | return dnli_model, dnli_tokenizer
69 |
70 | def rerank_candidates(self, observations, cand_scores):
71 | sorted_cand_values, sorted_cand_indices = cand_scores.sort(1, descending=True)
72 |
73 | for batch_idx, obs in enumerate(observations):
74 | full_text = obs['full_text']
75 | personas = []
76 | for text in full_text.split('\n'):
77 | if 'your persona:' in text:
78 | personas.append(text.replace('your persona:', ''))
79 | else:
80 | break
81 | candidates = obs['label_candidates']
82 |
83 | tok_candidates = [self.dnli_tokenizer.tokenize(sent) for sent in candidates]
84 | tok_personas = [self.dnli_tokenizer.tokenize(sent) for sent in personas]
85 |
86 | dnli_scores = self._compute_dnli_scores(tok_candidates, tok_personas)
87 | s_1 = sorted_cand_values[batch_idx, 0]
88 | s_k = sorted_cand_values[batch_idx, self.dnli_k - 1]
89 |
90 | _lambda = self.dnli_lambda
91 | cand_scores[batch_idx] = cand_scores[batch_idx] - _lambda * (s_1 - s_k) * dnli_scores
92 |
93 | return cand_scores
94 |
95 | def compute_consistency_scores(self, pred, personas):
96 | """
97 | preds, and personas must be list of string
98 | """
99 | max_seq_length = self.max_seq_length
100 |
101 | pred_tokenized = self.dnli_tokenizer.tokenize(pred)
102 | personas_tokenized = [self.dnli_tokenizer.tokenize(sent.replace('your persona:', '')) for sent in personas]
103 |
104 | all_input_ids = []
105 | all_input_mask = []
106 | all_segment_ids = []
107 | for idx, persona_tokenized in enumerate(personas_tokenized):
108 | _pred_tokenized = deepcopy(pred_tokenized)
109 | _persona_tokenized = deepcopy(persona_tokenized)
110 | _truncate_seq_pair(_pred_tokenized, _persona_tokenized, max_seq_length - 3)
111 |
112 | tokens = ["[CLS]"] + _pred_tokenized + ["[SEP]"]
113 | segment_ids = [0] * len(tokens)
114 | tokens += _persona_tokenized + ["[SEP]"]
115 | segment_ids += [1] * (len(_persona_tokenized) + 1)
116 |
117 | input_ids = self.dnli_tokenizer.convert_tokens_to_ids(tokens)
118 | input_mask = [1] * len(input_ids)
119 | padding = [0] * (max_seq_length - len(input_ids))
120 | input_ids += padding
121 | input_mask += padding
122 | segment_ids += padding
123 |
124 | all_input_ids.append(input_ids)
125 | all_input_mask.append(input_mask)
126 | all_segment_ids.append(segment_ids)
127 |
128 | # Convert inputs to tensors
129 | all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
130 | all_input_mask = torch.tensor(all_input_mask, dtype=torch.long)
131 | all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long)
132 | if self.use_cuda:
133 | all_input_ids = all_input_ids.cuda()
134 | all_input_mask = all_input_mask.cuda()
135 | all_segment_ids = all_segment_ids.cuda()
136 |
137 | # Inference
138 | self.dnli_model.eval()
139 | with torch.no_grad():
140 | logits = self.dnli_model(all_input_ids, all_segment_ids, all_input_mask)
141 | probs = F.softmax(logits[0], dim=1)
142 |
143 | probs = probs.detach().cpu().numpy()
144 | idx_max = np.argmax(probs, axis=1)
145 | val_max = np.max(probs, axis=1)
146 |
147 | consistency_score = 0.0
148 | for pred_idx in idx_max:
149 | if pred_idx == 0: # contradict
150 | consistency_score -= 1.0
151 | elif pred_idx == 1: # entailment
152 | consistency_score += 1.0
153 | elif pred_idx == 2: # neutral
154 | consistency_score += 0.0
155 |
156 | return consistency_score
157 |
158 | def _compute_dnli_scores(self, tok_candidates, tok_personas):
159 | max_seq_length = self.max_seq_length
160 |
161 | dnli_scores = []
162 | for cand_idx, tok_candidate in enumerate(tok_candidates):
163 | all_input_ids = []
164 | all_input_mask = []
165 | all_segment_ids = []
166 | for tok_persona in tok_personas:
167 | # Prepare inputs
168 | # [CLS] candidates [SEP] persona [SEP]
169 | _tok_candidate = deepcopy(tok_candidate)
170 | _tok_persona = deepcopy(tok_persona)
171 | # Account for [CLS], [SEP], [SEP] with "- 3"
172 | _truncate_seq_pair(_tok_candidate, _tok_persona, max_seq_length - 3)
173 |
174 | # Make inputs
175 | tokens = ["[CLS]"] + _tok_candidate + ["[SEP]"]
176 | segment_ids = [0] * len(tokens)
177 | tokens += _tok_persona + ["[SEP]"]
178 | segment_ids += [1] * (len(_tok_persona) + 1)
179 |
180 | input_ids = self.dnli_tokenizer.convert_tokens_to_ids(tokens)
181 | input_mask = [1] * len(input_ids)
182 | padding = [0] * (max_seq_length - len(input_ids))
183 | input_ids += padding
184 | input_mask += padding
185 | segment_ids += padding
186 |
187 | all_input_ids.append(input_ids)
188 | all_input_mask.append(input_mask)
189 | all_segment_ids.append(segment_ids)
190 |
191 | # Convert inputs to tensors
192 | all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
193 | all_input_mask = torch.tensor(all_input_mask, dtype=torch.long)
194 | all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long)
195 | if self.use_cuda:
196 | all_input_ids = all_input_ids.cuda()
197 | all_input_mask = all_input_mask.cuda()
198 | all_segment_ids = all_segment_ids.cuda()
199 |
200 | # Inference
201 | self.dnli_model.eval()
202 | with torch.no_grad():
203 | logits = self.dnli_model(all_input_ids, all_segment_ids, all_input_mask)
204 | probs = F.softmax(logits[0], dim=1)
205 |
206 | probs = probs.detach().cpu().numpy()
207 | idx_max = np.argmax(probs, axis=1)
208 | val_max = np.max(probs, axis=1)
209 | dnli_score = np.max((idx_max == 0) * val_max)
210 | dnli_scores.append(dnli_score)
211 | dnli_scores = torch.tensor(dnli_scores, dtype=torch.float)
212 | if self.use_cuda:
213 | dnli_scores = dnli_scores.cuda()
214 | return dnli_scores
215 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/__init__.py
--------------------------------------------------------------------------------
/tasks/build.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import parlai.core.params as params
9 | import parlai.core.build_data as build_data
10 |
11 |
12 | FOLDER_NAME = 'self_conscious_dialogue'
13 |
14 |
15 | def build(opt):
16 | dpath = os.path.join(opt['datapath'], FOLDER_NAME)
17 | # version 1.0: initial release
18 | version = '1.0'
19 |
20 | # check whether data had been previously built
21 | if not build_data.built(dpath, version_string=version):
22 | print('[building data: ' + dpath + ']')
23 |
24 | # make a clean directory if needed
25 | if build_data.built(dpath):
26 | # if an older version exists, remove those outdated files.
27 | build_data.remove_dir(dpath)
28 | build_data.make_dir(dpath)
29 |
30 | #########################
31 | # ConvAI2 (PersonaChat)
32 | #########################
33 | fname = 'data_v1.tar.gz'
34 | url = 'https://parl.ai/downloads/controllable_dialogue/' + fname
35 | build_data.download(url, dpath, fname)
36 | build_data.untar(dpath, fname)
37 |
38 | fname = 'convai2_fix_723.tgz'
39 | url = 'http://parl.ai/downloads/convai2/' + fname
40 | build_data.download(url, dpath, fname)
41 | build_data.untar(dpath, fname)
42 |
43 | #########################
44 | # Dialogue NLI
45 | #########################
46 | fname = 'dialogue_nli.zip'
47 | gd_id = '1WtbXCv3vPB5ql6w0FVDmAEMmWadbrCuG'
48 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname))
49 | build_data.untar(dpath, fname)
50 |
51 | fname = 'dialogue_nli_evaluation.zip'
52 | gd_id = '1sllq30KMJzEVQ4C0-a9ShSLSPIZc3iMi'
53 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname))
54 | build_data.untar(dpath, fname)
55 |
56 | #########################
57 | # Distractor personas
58 | #########################
59 | fname = 'train_sorted_50_personas.json'
60 | gd_id = '1SGFdJqyNYeepKFqwMLv4Ym717QQTtpi8'
61 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname))
62 | fname = 'valid_sorted_50_personas.json'
63 | gd_id = '1A7oVKmjJ1EZTh6-3Gio4XQo81QgnTGGi'
64 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname))
65 | fname = 'dnli_sorted_50_personas.json'
66 | gd_id = '1wlIkVcBZoGQd3rbI7XWNhuq4rvw9FyoP'
67 | build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname))
68 |
69 | print("Data has been placed in " + dpath)
70 |
71 | build_data.mark_done(dpath, version)
72 |
73 |
74 | def make_path(opt, fname):
75 | return os.path.join(opt['datapath'], FOLDER_NAME, fname)
76 |
77 |
78 | if __name__ == '__main__':
79 | opt = params.ParlaiParser().parse_args(print_args=False)
80 | build(opt)
81 |
--------------------------------------------------------------------------------
/tasks/teachers.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import math
4 | import json
5 | import random
6 | from operator import itemgetter
7 | from orderedset import OrderedSet
8 | from collections import defaultdict
9 | import pickle
10 |
11 | from parlai.utils.misc import warn_once, str_to_msg
12 | from parlai.core.message import Message
13 | from parlai.core.torch_agent import TorchAgent
14 | from parlai.core.teachers import (
15 | ParlAIDialogTeacher,
16 | FixedDialogTeacher,
17 | FbDeprecatedDialogTeacher,
18 | DialogData
19 | )
20 |
21 | from .build import build, make_path
22 |
23 | __PATH__ = os.path.abspath(os.path.dirname(__file__))
24 |
25 |
26 | def _path(opt):
27 | build(opt)
28 | datatype = opt['datatype'].split(':')[0]
29 | if datatype == 'test':
30 | warn_once("WARNING: Test set not included. Setting datatype to valid.")
31 | datatype = 'valid'
32 | return make_path(opt, datatype + '.txt'), datatype
33 |
34 |
35 | def _split_persona_and_context(text, eval_type='convai2'):
36 | if 'your persona:' not in text:
37 | return None, text
38 | else:
39 | if eval_type == 'convai2':
40 | texts = text.split('\n')
41 | return '\n'.join(texts[:-1]), texts[-1]
42 | elif eval_type =='dnli':
43 | texts = text.split('\n')
44 | last_idx = 0
45 | for idx, text in enumerate(texts):
46 | if 'your persona:' in text:
47 | last_idx = idx
48 | persona_texts = texts[:last_idx+1]
49 | context_texts = texts[last_idx+1:]
50 | return '\n'.join(persona_texts), '\n'.join(context_texts)
51 |
52 |
53 | def _split_personas_and_context(text):
54 | if 'your persona:' not in text:
55 | return text, text, text
56 | else:
57 | your_personas = []
58 | partner_personas = []
59 | context = []
60 | texts = text.split('\n')
61 | for text in texts:
62 | if text.startswith('your persona:'):
63 | your_personas.append(text)
64 | elif text.startswith("partner's persona:"):
65 | partner_personas.append(text)
66 | else:
67 | context.append(text)
68 |
69 | return '\n'.join(your_personas), '\n'.join(partner_personas), context
70 |
71 |
72 | class SelfConsciousDialogueTeacher(FixedDialogTeacher):
73 | """
74 | Teacher (i.e. input data supplier) for the Self-conscious Agent.
75 | SelfConsciousDialogueTeacher (SCDT) supplies data input
76 | along with the distractors to the Self-conscious Agent.
77 | """
78 | def __init__(self, opt, shared=None):
79 | super().__init__(opt, shared)
80 | self.opt = opt
81 |
82 | datapath, datatype = _path(opt)
83 |
84 | if not shared:
85 | self.episodes = []
86 | self.num_exs = 0
87 | self._setup_data(datapath, datatype)
88 | else:
89 | self.episodes = shared['episodes']
90 | self.num_exs = sum(len(e) for e in self.episodes)
91 | self.id = 'self_conscious_dialogue'
92 | self.reset()
93 |
94 | @staticmethod
95 | def add_cmdline_args(argparser):
96 | agent = argparser.add_argument_group('Self Conscious Dialogue Teacher arguments')
97 | agent.add_argument(
98 | '--eval-type',
99 | type=str,
100 | choices=['convai2', 'dnli'],
101 | default='dnli',
102 | help='Which validation data to use',
103 | )
104 |
105 | def _setup_data(self, path, datatype):
106 |
107 | random.seed(46)
108 |
109 | # Data loading with script of ParlAIDialogTeacher
110 | print(f"[Loading ParlAI text data: {path}]")
111 |
112 | # Read data from ConvAI2
113 | convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt')
114 | convai2_episodes = self._load_convai2_data(convai2_datapath)
115 |
116 | # Get persona pool
117 | all_personas, persona_to_idx = self._get_persona_pool(self.opt)
118 | sorted_personas = self._get_sorted_persona_pool(datatype)
119 |
120 |
121 | if self.opt['eval_type'] == 'convai2':
122 | self.episodes = []
123 | self.num_exs = 0
124 | eps = []
125 | with open(path) as read:
126 | for line in read:
127 | msg = str_to_msg(line.rstrip('\n'))
128 | if msg:
129 | self.num_exs += 1
130 | eps.append(msg)
131 | if msg.get('episode_done', False):
132 | self.episodes.append(eps)
133 | eps = []
134 | if len(eps) > 0:
135 | # add last episode
136 | eps[-1].force_set('episode_done', True)
137 | self.episodes.append(eps)
138 | # Add label candidates and partner's persona
139 | for episode_idx, episode in enumerate(self.episodes):
140 | for turn_idx, turn in enumerate(episode):
141 | convai2_turn = convai2_episodes[episode_idx][turn_idx]
142 | convai2_text = convai2_turn[0]
143 | label_candidates = convai2_turn[3]
144 |
145 | turn['label_candidates'] = label_candidates
146 | if turn_idx == 0:
147 | my_persona, partner_persona, _ = _split_personas_and_context(convai2_text)
148 | turn['partner_persona'] = partner_persona
149 | turn['my_persona'] = my_persona
150 | else:
151 | turn['partner_persona'] = episode[0]['partner_persona']
152 | turn['my_persona'] = episode[0]['my_persona']
153 | elif self.opt['eval_type'] == 'dnli':
154 | self.episodes = []
155 | self.num_exs = 0
156 | for eval_set in ['attributes', 'havenot', 'likedislike']:
157 | datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl')
158 | with open(datapath, 'r') as fp:
159 | for line in fp:
160 | msg = json.loads(line)
161 | msg['eval_set'] = eval_set
162 | msg['episode_done'] = True
163 |
164 | # Make 'text'
165 | persona_lines = [f'your persona: {x[:-2]}.' for x in msg['persona']]
166 | utts = msg['prefix']
167 |
168 | p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN
169 | lines = persona_lines
170 | # Identify the dialogue lines. It's assumed that p1 goes first.
171 | for i, utt in enumerate(utts):
172 | if i % 2 == 0:
173 | lines.append(f'{p1_token} {utt}')
174 | else:
175 | lines.append(f'{p2_token} {utt}')
176 | text = '\n'.join(lines)
177 |
178 | msg['text'] = text
179 |
180 | # Make 'label_candidates'
181 | cands = msg['candidates']
182 | msg['label_candidates'] = cands['label'] + cands['neg'][:10] \
183 | + cands['similar'][:10] + cands['rand'][:10]
184 |
185 | # Remove unused attributes
186 | del msg['persona']
187 | del msg['prefix']
188 | del msg['triple']
189 | del msg['relevant_persona_sentence']
190 | del msg['candidates']
191 |
192 | self.episodes.append([msg])
193 | self.num_exs += 1
194 |
195 | # Add distractor personas
196 | if self.opt['world_cardinality'] > 0:
197 | num_all_personas = len(all_personas)
198 | persona_indices = list(range(num_all_personas))
199 | world_cardinality = self.opt['world_cardinality']
200 | for episode in self.episodes:
201 | gt_persona, first_context = _split_persona_and_context(episode[0]['text'], self.opt['eval_type'])
202 | gt_persona_idx = persona_to_idx.get(gt_persona, -1)
203 |
204 | # Choose random distractor personas
205 | distractor_indices = random.sample(persona_indices, world_cardinality - 1)
206 | while gt_persona_idx in distractor_indices:
207 | # Resample if gt_persona is sampled
208 | distractor_indices = random.sample(persona_indices, world_cardinality - 1)
209 | distractor_personas = itemgetter(*distractor_indices)(all_personas)
210 | distractor_personas = list(distractor_personas)
211 |
212 | # Make it to 'distractor_text'
213 | for turn_idx, turn in enumerate(episode):
214 | if turn_idx == 0:
215 | turn['distractor_text'] = [
216 | '\n'.join([persona, first_context])
217 | for persona in [gt_persona] + distractor_personas
218 | ]
219 | else:
220 | turn['distractor_text'] = [turn['text']] * world_cardinality
221 |
222 | def _get_persona_pool(self, opt, remove_duplicate=True):
223 | print("[loading persona pool from convai2 training data]")
224 | # Get episodes from training dataset
225 | datapath = make_path(opt, 'train.txt')
226 | episodes = []
227 | eps = []
228 | with open(datapath) as read:
229 | for line in read:
230 | msg = str_to_msg(line.rstrip('\n'))
231 | if msg:
232 | # self.num_exs += 1
233 | eps.append(msg)
234 | if msg.get('episode_done', False):
235 | episodes.append(eps)
236 | eps = []
237 | if len(eps) > 0:
238 | # add last episode
239 | eps[-1].force_set('episode_done', True)
240 | episodes.append(eps)
241 |
242 | # Extract personas from episodes
243 | persona_set = OrderedSet()
244 | for episode in episodes:
245 | first_turn = episode[0]
246 | text = first_turn['text']
247 | persona, _ = _split_persona_and_context(text)
248 | persona_set.add(persona)
249 |
250 | # Remove duplicate
251 | if remove_duplicate:
252 | train_persona_fname = os.path.join(__PATH__, 'train_persona_map.pkl')
253 | with open(train_persona_fname, 'rb') as fp:
254 | _train_personas = pickle.load(fp)
255 | train_personas = []
256 | for personas in _train_personas.values():
257 | longest_idx = 0
258 | longest_length = -1
259 | for idx, persona in enumerate(personas):
260 | if len(persona) > longest_length:
261 | longest_idx = idx
262 | longest_length = len(persona)
263 | selected_persona = map(lambda x: f"your persona: {x}.",personas[longest_idx])
264 | selected_persona = '\n'.join(selected_persona)
265 | train_personas.append(selected_persona)
266 | persona_set = OrderedSet()
267 | for train_persona in train_personas:
268 | persona_set.add(train_persona)
269 |
270 | all_personas = []
271 | persona_to_idx = {}
272 | for i, persona in enumerate(persona_set):
273 | all_personas.append(persona)
274 | persona_to_idx[persona] = i
275 |
276 | print(f"Total {len(all_personas)} personas in dataset")
277 |
278 | return all_personas, persona_to_idx
279 |
280 | def _get_sorted_persona_pool(self, datatype):
281 | print("[loading sorted persona pool from convai2 training data]")
282 | eval_type = self.opt['eval_type']
283 | if eval_type == 'convai2':
284 | datapath = make_path(self.opt, 'valid_sorted_50_personas.json')
285 | elif eval_type == 'dnli':
286 | datapath = make_path(self.opt, 'dnli_sorted_50_personas.json')
287 | else:
288 | raise ValueError("eval_set must be one of convai2 and dnli")
289 |
290 | with open(datapath, 'r') as fp:
291 | sorted_personas = json.load(fp)
292 | sorted_personas['idx2persona'] = sorted_personas['train_personas']
293 | sorted_personas['persona2idx'] = {}
294 | for idx, persona in enumerate(sorted_personas['train_personas']):
295 | sorted_personas['persona2idx'][persona] = idx
296 |
297 | return sorted_personas
298 |
299 | def _load_convai2_data(self, datapath):
300 | """
301 | Read data in the fbdialog format.
302 | Returns ``(x, y, r, c)`` tuples.
303 | ``x`` represents a query, ``y`` represents the labels, ``r`` represents
304 | any reward, and ``c`` represents any label_candidates.
305 | The example above will be translated into the following tuples:
306 | ::
307 | x: 'Sam went to the kitchen\nPat gave Sam the milk\nWhere is the milk?'
308 | y: ['kitchen']
309 | r: '1'
310 | c: ['hallway', 'kitchen', 'bathroom']
311 | new_episode = True (this is the first example in the episode)
312 | ::
313 | x: 'Sam went to the hallway\\nPat went to the bathroom\\nWhere is the
314 | milk?'
315 | y: ['hallway']
316 | r: '1'
317 | c: ['hallway', 'kitchen', 'bathroom']
318 | new_episode = False (this is the second example in the episode)
319 | """
320 | self.cloze = False # Set this to use FbDialogTeacher
321 | convai2_dataloader = FbDeprecatedDialogTeacher.setup_data(self, datapath)
322 | convai2_episodes = []
323 | for episode in DialogData._read_episode(self, convai2_dataloader):
324 | convai2_episodes.append(episode)
325 | del self.cloze
326 | return convai2_episodes
327 |
328 | def share(self):
329 | shared = super().share()
330 | shared['episodes'] = self.episodes
331 | return shared
332 |
333 | def num_examples(self):
334 | return self.num_exs
335 |
336 | def num_episodes(self):
337 | return len(self.episodes)
338 |
339 | def get(self, episode_idx, entry_idx=None):
340 | return self.episodes[episode_idx][entry_idx]
341 |
342 |
343 | class ContextConsciousDialogueTeacher(SelfConsciousDialogueTeacher):
344 | def _setup_data(self, path, datatype):
345 | # random.seed(self.opt['random_seed']) # Set this for pick same distractor persona
346 | random.seed(46) # Set this for pick same distractor persona
347 | # Data loading with script of ParlAIDialogTeacher
348 | print(f"[Loading ParlAI text data: {path}]")
349 |
350 | # Read data from ConvAI2
351 | convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt')
352 | convai2_episodes = self._load_convai2_data(convai2_datapath)
353 |
354 | if self.opt['eval_type'] == 'convai2':
355 | self.episodes = []
356 | self.num_exs = 0
357 | eps = []
358 | with open(path) as read:
359 | for line in read:
360 | msg = str_to_msg(line.rstrip('\n'))
361 | if msg:
362 | self.num_exs += 1
363 | eps.append(msg)
364 | if msg.get('episode_done', False):
365 | self.episodes.append(eps)
366 | eps = []
367 | if len(eps) > 0:
368 | # add last episode
369 | eps[-1].force_set('episode_done', True)
370 | self.episodes.append(eps)
371 | # Add label candidates and partner's persona
372 | for episode_idx, episode in enumerate(self.episodes):
373 | for turn_idx, turn in enumerate(episode):
374 | convai2_turn = convai2_episodes[episode_idx][turn_idx]
375 | convai2_text = convai2_turn[0]
376 | label_candidates = convai2_turn[3]
377 |
378 | turn['label_candidates'] = label_candidates
379 | if turn_idx == 0:
380 | my_persona, partner_persona, _ = _split_personas_and_context(convai2_text)
381 | turn['partner_persona'] = partner_persona
382 | turn['my_persona'] = my_persona
383 | else:
384 | turn['partner_persona'] = episode[0]['partner_persona']
385 | turn['my_persona'] = episode[0]['my_persona']
386 | elif self.opt['eval_type'] == 'dnli':
387 | self.episodes = []
388 | self.num_exs = 0
389 | for eval_set in ['attributes', 'havenot', 'likedislike']:
390 | datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl')
391 | with open(datapath, 'r') as fp:
392 | for line in fp:
393 | msg = json.loads(line)
394 | msg['eval_set'] = eval_set
395 | msg['episode_done'] = True
396 |
397 | # Make 'text'
398 | persona_lines = [f'your persona: {x[:-2]}.' for x in msg['persona']]
399 | utts = msg['prefix']
400 |
401 | p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN
402 | lines = persona_lines
403 | # Identify the dialogue lines. It's assumed that p1 goes first.
404 | for i, utt in enumerate(utts):
405 | if i % 2 == 0:
406 | lines.append(f'{p1_token} {utt}')
407 | else:
408 | lines.append(f'{p2_token} {utt}')
409 | text = '\n'.join(lines)
410 |
411 | msg['text'] = text
412 |
413 | # Make 'label_candidates'
414 | cands = msg['candidates']
415 | msg['label_candidates'] = cands['label'] + cands['neg'][:10] \
416 | + cands['similar'][:10] + cands['rand'][:10]
417 |
418 | # Remove unused attributes
419 | del msg['persona']
420 | del msg['prefix']
421 | del msg['triple']
422 | del msg['relevant_persona_sentence']
423 | del msg['candidates']
424 |
425 | self.episodes.append([msg])
426 | self.num_exs += 1
427 |
428 | # Get dialogue history pool
429 | context_pool = self._get_context_pool(self.opt)
430 |
431 | # Add distractor history
432 | if self.opt['world_cardinality'] > 0:
433 | for episode in self.episodes:
434 | gt_persona, first_context = _split_persona_and_context(episode[0]['text'], self.opt['eval_type'])
435 |
436 | # Select distractor history
437 | if self.opt['eval_type'] == 'convai2':
438 | num_turn = len(episode)
439 | else:
440 | dialogue = first_context.split('\n')
441 | num_turn = math.ceil(len(dialogue)/2)
442 | if num_turn < min(context_pool.keys()):
443 | # orginal_num_turn = num_turn
444 | num_turn = min(context_pool.keys())
445 |
446 | context_indices = list(range(len(context_pool[num_turn])))
447 |
448 | distractor_c_indices = random.sample(context_indices, self.opt['world_cardinality'] - 1)
449 | distractor_contexts = itemgetter(*distractor_c_indices)(context_pool[num_turn])
450 |
451 | # Make it to 'distractor_text'
452 | if self.opt['eval_type'] == 'convai2':
453 | for turn_idx, turn in enumerate(episode):
454 | turn['distractor_text'] = turn['labels'] + [c[turn_idx] for c in distractor_contexts]
455 | if turn_idx == 0:
456 | turn['my_context'] = turn['labels']
457 | else:
458 | turn['my_context'] = episode[turn_idx - 1]['my_context'] + turn['labels']
459 | else:
460 | # DNLI
461 | distractor_text = [episode[0]['text']]
462 | for c in distractor_contexts:
463 | copied_dialogue = copy.deepcopy(dialogue)
464 | for turn_idx, utterance in enumerate(copied_dialogue):
465 | if turn_idx % 2 == 1:
466 | copied_dialogue[turn_idx] = p2_token + c[turn_idx // 2]
467 | distractor_context = '\n'.join([gt_persona] + copied_dialogue)
468 | distractor_text.append(distractor_context)
469 | episode[0]['distractor_text'] = distractor_text
470 |
471 | def _get_context_pool(self, opt):
472 | print("[loading history pool from convai2 training data]")
473 | datapath = make_path(opt, 'train.txt')
474 | episodes = []
475 | eps = []
476 | with open(datapath) as read:
477 | for line in read:
478 | msg = str_to_msg(line.rstrip('\n'))
479 | if msg:
480 | eps.append(msg)
481 | if msg.get('episode_done', False):
482 | episodes.append(eps)
483 | eps = []
484 | if len(eps) > 0:
485 | # add last episode
486 | eps[-1].force_set('episode_done', True)
487 | episodes.append(eps)
488 |
489 | context_pool = defaultdict(list)
490 | for ep in episodes:
491 | context_pool[len(ep)].append([turn['labels'][0] for turn in ep])
492 |
493 | return dict(context_pool)
494 |
495 |
496 | class DefaultTeacher(SelfConsciousDialogueTeacher):
497 | pass
498 |
--------------------------------------------------------------------------------
/tasks/test_persona_map.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/test_persona_map.pkl
--------------------------------------------------------------------------------
/tasks/train_persona_map.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/train_persona_map.pkl
--------------------------------------------------------------------------------
/tasks/valid_persona_map.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalker023/pragmatic-consistency/0a9165914a2009f3d3de5b4c4f37bcd745dc2886/tasks/valid_persona_map.pkl
--------------------------------------------------------------------------------