├── __init__.py ├── models ├── __init__.py └── gce.py ├── version.txt ├── .dockerignore ├── BayLearn_2018_Dialogue.pdf ├── =1.9 ├── command_train_woz.sh ├── command_train_mwoz.sh ├── .gitignore ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── glad.iml └── workspace.xml ├── command_eval_woz.sh ├── requirements.txt ├── command_eval_mwoz.sh ├── .nfs831404ea9438756800000067 ├── README.md ├── Dockerfile ├── utils.py ├── LICENSE ├── evaluate.py ├── preprocess_data.py ├── train.py ├── preprocess_data_mwoz.py ├── dataset.py └── create_data.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.1 2 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | exp/ 2 | data/ 3 | Dockerfile 4 | .git/ 5 | *.py[cod] 6 | -------------------------------------------------------------------------------- /BayLearn_2018_Dialogue.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elnaaz/GCE-Model/HEAD/BayLearn_2018_Dialogue.pdf -------------------------------------------------------------------------------- /=1.9: -------------------------------------------------------------------------------- 1 | Requirement already satisfied: six in /System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python (1.4.1) 2 | -------------------------------------------------------------------------------- /command_train_woz.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL=gce_woz 3 | DATA=woz 4 | GPU=0 5 | EPOCH=200 6 | 7 | CUDA_VISIBLE_DEVICES=$GPU python train.py --gpu $GPU -n $MODEL --data $DATA --epoch $EPOCH 8 | -------------------------------------------------------------------------------- /command_train_mwoz.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL=gce_multiwoz 3 | DATA=multi_woz 4 | GPU=0 5 | EPOCH=200 6 | 7 | CUDA_VISIBLE_DEVICES=$GPU python train.py --gpu $GPU -n $MODEL --data $DATA --epoch $EPOCH 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | Mar*/ 3 | .DS_Store 4 | *.py[cod] 5 | *.json 6 | *.json[~] 7 | *.save 8 | *.log 9 | *.model 10 | *.t7 11 | *.npy 12 | *.flist 13 | *.zip 14 | *.gzip 15 | *.tar 16 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /command_eval_woz.sh: -------------------------------------------------------------------------------- 1 | 2 | DIR=exp/woz/gce 3 | MODEL=gce_woz 4 | DATA=woz 5 | GPU=0 6 | CHECKPOINT=$DIR/$MODEL 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU python evaluate.py --gpu $GPU --data $DATA --split test --dsave $CHECKPOINT 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf==3.4.0 2 | requests==2.18.4 3 | stanza==0.3 4 | tqdm==4.19.1.post1 5 | vocab==0.0.3 6 | embeddings==0.0.4 7 | http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl 8 | numpy==1.13.1 9 | -------------------------------------------------------------------------------- /command_eval_mwoz.sh: -------------------------------------------------------------------------------- 1 | 2 | DIR=exp/multi_woz/gce 3 | MODEL=gce_multiwoz 4 | DATA=multi_woz 5 | GPU=0 6 | CHECKPOINT=$DIR/$MODEL 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU python evaluate.py --gpu $GPU --data $DATA --split test --dsave $CHECKPOINT 9 | -------------------------------------------------------------------------------- /.nfs831404ea9438756800000067: -------------------------------------------------------------------------------- 1 | 2 | MODEL=global_no_rnn_conditioned_v6 3 | GLAD=GLADEncoder_global_no_rnn_conditioned_v6 4 | DATA=woz 5 | GPU=2 6 | EPOCH=2000 7 | 8 | CUDA_VISIBLE_DEVICES=$GPU python train.py --gpu $GPU --encoder $GLAD -n $MODEL --data $DATA --epoch $EPOCH 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/glad.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Globally-Conditioned Encoder (GCE) For Neural Dialogue State Tracking 2 | 3 | This repository contains an implementation of the [Toward Scalable Neural Dialogue State Tracking Model](https://arxiv.org/abs/1812.00899). 4 | If you use this in your work, please cite the following 5 | 6 | 7 | ``` 8 | @inproceedings{nouri2018gce, 9 | title={ Toward Scalable Neural Dialogue State Tracking }, 10 | author={ Nouri, Elnaz and Hosseini-Asl, Ehsan }, 11 | booktitle={ NeurIPS 2018, 2nd Conversational AI workshop }, 12 | year={ 2018 }, 13 | arxiv={ https://arxiv.org/abs/1812.00899 } 14 | } 15 | ``` 16 | 17 | 18 | # Install dependencies 19 | 20 | For more details on installation, please check below github repository 21 | ``` 22 | https://github.com/salesforce/glad 23 | ``` 24 | 25 | 26 | 27 | # Contribution 28 | 29 | Pull requests are welcome! 30 | If you have any questions, please create an issue or contact the corresponding author at `Elnaz.Nouri microsoft com`. 31 | 32 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-base-ubuntu16.04 2 | 3 | # install Miniconda 4 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 5 | ENV PATH /opt/conda/bin:$PATH 6 | 7 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 8 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 9 | git mercurial subversion 10 | 11 | RUN wget --quiet https://repo.continuum.io/miniconda/Miniconda3-4.4.10-Linux-x86_64.sh -O ~/miniconda.sh && \ 12 | /bin/bash ~/miniconda.sh -b -p /opt/conda && \ 13 | rm ~/miniconda.sh && \ 14 | /opt/conda/bin/conda clean -tipsy && \ 15 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 16 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 17 | echo "conda activate base" >> ~/.bashrc 18 | 19 | # copy GLAD 20 | RUN mkdir -p /opt/glad 21 | WORKDIR /opt/glad 22 | 23 | # install dependencies 24 | COPY requirements.txt . 25 | RUN pip install -r requirements.txt 26 | 27 | # copy source 28 | COPY . . 29 | 30 | # volumes and environment variables 31 | ENV EMBEDDINGS_ROOT /opt/embeddings 32 | RUN mkdir -p /opt/embeddings 33 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pprint import pformat 5 | from importlib import import_module 6 | from vocab import Vocab 7 | from dataset import Dataset, Ontology 8 | from preprocess_data import dann as dann_woz 9 | from preprocess_data_mwoz import dann as dann_mwoz 10 | import ipdb 11 | 12 | def load_dataset(splits=('train', 'dev', 'test'), data='woz'): 13 | dann = dann_woz if data == 'woz' else dann_mwoz 14 | with open(os.path.join(dann, 'ontology.json')) as f: 15 | ontology = Ontology.from_dict(json.load(f)) 16 | with open(os.path.join(dann, 'vocab.json')) as f: 17 | vocab = Vocab.from_dict(json.load(f)) 18 | with open(os.path.join(dann, 'emb.json')) as f: 19 | E = json.load(f) 20 | dataset = {} 21 | for split in splits: 22 | with open(os.path.join(dann, '{}.json'.format(split))) as f: 23 | logging.warn('loading split {}'.format(split)) 24 | dataset[split] = Dataset.from_dict(json.load(f)) 25 | 26 | logging.info('dataset sizes: {}'.format(pformat({k: len(v) for k, v in dataset.items()}))) 27 | return dataset, ontology, vocab, E 28 | 29 | 30 | def get_models(): 31 | return [m.replace('.py', '') for m in os.listdir('models') if not m.startswith('_') and m != 'model'] 32 | 33 | 34 | def load_model(model, *args, **kwargs): 35 | Model = import_module('models.{}'.format(model)).Model 36 | model = Model(*args, **kwargs) 37 | logging.info('loaded model {}'.format(Model)) 38 | return model 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Salesforce 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from argparse import ArgumentParser, Namespace 5 | from pprint import pprint 6 | from utils import load_dataset, load_model 7 | from models.glad import GLAD_ENCODERS 8 | import ipdb 9 | 10 | if __name__ == '__main__': 11 | parser = ArgumentParser() 12 | parser.add_argument('--dsave', help='save location of model') 13 | parser.add_argument('--split', help='split to evaluate on', default='dev') 14 | parser.add_argument('--gpu', type=int, help='gpu to use', default=None) 15 | parser.add_argument('--fout', help='optional save file to store the predictions') 16 | parser.add_argument('--encoder', help='which encoder to use', default='GLADEncoder', choices=GLAD_ENCODERS) 17 | parser.add_argument('--use_elmo', help='use elmo embeddings', action='store_true') 18 | parser.add_argument('--data', help='dataset', default='woz', choices=['woz', 'multi_woz']) 19 | args = parser.parse_args() 20 | 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | with open(os.path.join(args.dsave, 'config.json')) as f: 24 | args_save = Namespace(**json.load(f)) 25 | args_save.gpu = args.gpu 26 | if not hasattr(args_save, 'encoder'): 27 | args_save.encoder = args.encoder 28 | pprint(args_save) 29 | 30 | dataset, ontology, vocab, Eword = load_dataset(data=args.data) 31 | 32 | model = load_model(args_save.model, args.use_elmo, args_save, ontology, vocab) 33 | model.load_best_save(directory=args.dsave) 34 | if args.gpu is not None: 35 | #model.cuda(args.gpu) 36 | model.cuda(0) 37 | 38 | logging.info('Making predictions for {} dialogues and {} turns'.format(len(dataset[args.split]), len(list(dataset[args.split].iter_turns())))) 39 | preds = model.run_pred(dataset[args.split], args_save) 40 | pprint(dataset[args.split].evaluate_preds(preds)) 41 | 42 | if args.fout: 43 | with open(args.fout, 'wt') as f: 44 | # predictions is a list of sets, need to convert to list of lists to make it JSON serializable 45 | json.dump([list(p) for p in preds], f, indent=2) 46 | -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import json 4 | import logging 5 | import zipfile 6 | import requests 7 | from tqdm import tqdm 8 | from vocab import Vocab 9 | from embeddings import GloveEmbedding, KazumaCharEmbedding 10 | from dataset import Dataset, Ontology 11 | 12 | 13 | root_dir = os.path.dirname(__file__) 14 | data_dir = os.path.join(root_dir, 'data', 'woz') 15 | 16 | 17 | furl = 'https://mi.eng.cam.ac.uk/~nm480/woz_2.0.zip' 18 | fzip = os.path.join(data_dir, 'woz.zip') 19 | 20 | draw = os.path.join(data_dir, 'raw') 21 | dann = os.path.join(data_dir, 'ann') 22 | 23 | splits = ['dev', 'train', 'test'] 24 | 25 | 26 | def download(url, to_file): 27 | r = requests.get(url, stream=True) 28 | with open(to_file, 'wb') as f: 29 | for chunk in r.iter_content(chunk_size=1024): 30 | if chunk: 31 | f.write(chunk) 32 | 33 | 34 | def missing_files(d, files): 35 | return not all([os.path.isfile(os.path.join(d, '{}.json'.format(s))) for s in files]) 36 | 37 | 38 | if __name__ == '__main__': 39 | if not os.path.isfile(fzip): 40 | if not os.path.isdir(data_dir): 41 | os.makedirs(data_dir) 42 | logging.warn('Download from {} to {}'.format(furl, fzip)) 43 | download(furl, fzip) 44 | 45 | if missing_files(draw, splits): 46 | if not os.path.isdir(draw): 47 | os.makedirs(draw) 48 | with zipfile.ZipFile(fzip) as f: 49 | logging.warn('Extracting {} to {}'.format(fzip, draw)) 50 | for split in splits: 51 | with f.open('woz_2.0/woz2_{}.json'.format(split)) as fin, open(os.path.join(draw, '{}.json'.format(split)), 'wb') as fout: 52 | fout.write(fin.read()) 53 | 54 | if missing_files(dann, files=splits + ['ontology', 'vocab', 'emb']): 55 | if not os.path.isdir(dann): 56 | os.makedirs(dann) 57 | dataset = {} 58 | ontology = Ontology() 59 | vocab = Vocab() 60 | vocab.word2index(['', ''], train=True) 61 | for s in splits: 62 | fname = '{}.json'.format(s) 63 | logging.warn('Annotating {}'.format(s)) 64 | dataset[s] = Dataset.annotate_raw(os.path.join(draw, fname)) 65 | dataset[s].numericalize_(vocab) 66 | ontology = ontology + dataset[s].extract_ontology() 67 | with open(os.path.join(dann, fname), 'wt') as f: 68 | json.dump(dataset[s].to_dict(), f) 69 | ontology.numericalize_(vocab) 70 | with open(os.path.join(dann, 'ontology.json'), 'wt') as f: 71 | json.dump(ontology.to_dict(), f) 72 | with open(os.path.join(dann, 'vocab.json'), 'wt') as f: 73 | json.dump(vocab.to_dict(), f) 74 | 75 | logging.warn('Computing word embeddings') 76 | embeddings = [GloveEmbedding(), KazumaCharEmbedding()] 77 | E = [] 78 | for w in tqdm(vocab._index2word): 79 | e = [] 80 | for emb in embeddings: 81 | e += emb.emb(w, default='zero') 82 | E.append(e) 83 | with open(os.path.join(dann, 'emb.json'), 'wt') as f: 84 | json.dump(E, f) 85 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 3 | from utils import load_dataset, get_models, load_model 4 | import os 5 | import logging 6 | import numpy as np 7 | from pprint import pprint 8 | import torch 9 | from random import seed 10 | # from models.glad import GLAD_ENCODERS 11 | from models.gce import GCEencoder 12 | 13 | 14 | def run(args): 15 | pprint(args) 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | np.random.seed(args.seed) 19 | torch.manual_seed(args.seed) 20 | seed(args.seed) 21 | 22 | dataset, ontology, vocab, Eword = load_dataset(data=args.data) 23 | 24 | model = load_model(args.model, args, ontology, vocab) 25 | model.save_config() 26 | model.load_emb(Eword) 27 | 28 | model = model.to(model.device) 29 | if args.resume: 30 | logging.info('Load best model') 31 | model.load_best_save(directory=args.resume) 32 | logging.info('Starting train') 33 | model.run_train(dataset['train'], dataset['dev'], args) 34 | elif not args.test: 35 | logging.info('Starting train') 36 | model.run_train(dataset['train'], dataset['dev'], args) 37 | else: 38 | model.load_best_save(directory=args.dout) 39 | model = model.to(model.device) 40 | logging.info('Running dev evaluation') 41 | dev_out = model.run_eval(dataset['dev'], args) 42 | pprint(dev_out) 43 | 44 | 45 | def get_args(): 46 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 47 | parser.add_argument('--dexp', help='root experiment folder', default='exp') 48 | parser.add_argument('--data', help='dataset', default='woz', choices=['woz', 'multi_woz']) 49 | parser.add_argument('--model', help='which model to use', default='gce', choices=get_models()) 50 | # parser.add_argument('--use_elmo', help='use elmo embeddings', action='store_true') 51 | # parser.add_argument('--encoder', help='which encoder to use', default='GLADEncoder', choices=GLAD_ENCODERS) 52 | parser.add_argument('--epoch', help='max epoch to run for', default=50, type=int) 53 | parser.add_argument('--demb', help='word embedding size', default=400, type=int) 54 | parser.add_argument('--dhid', help='hidden state size', default=200, type=int) 55 | parser.add_argument('--batch_size', help='batch size', default=50, type=int) 56 | parser.add_argument('--lr', help='learning rate', default=1e-3, type=float) 57 | parser.add_argument('--stop', help='slot to early stop on', default='joint_goal') 58 | parser.add_argument('--resume', help='save directory to resume from') 59 | parser.add_argument('-n', '--nick', help='nickname for model', default='default') 60 | parser.add_argument('--seed', default=42, help='random seed', type=int) 61 | parser.add_argument('--test', action='store_true', help='run in evaluation only mode') 62 | parser.add_argument('--gpu', type=int, help='which GPU to use') 63 | parser.add_argument('--dropout', nargs='*', help='dropout rates', default=['emb=0.2', 'local=0.2', 'global=0.2']) 64 | args = parser.parse_args() 65 | args.dout = os.path.join(args.dexp, args.data, args.model, args.nick) 66 | args.dropout = {d.split('=')[0]: float(d.split('=')[1]) for d in args.dropout} 67 | if not os.path.isdir(args.dout): 68 | os.makedirs(args.dout) 69 | return args 70 | 71 | 72 | if __name__ == '__main__': 73 | args = get_args() 74 | run(args) 75 | -------------------------------------------------------------------------------- /preprocess_data_mwoz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import json 4 | import logging 5 | import zipfile 6 | import requests 7 | from tqdm import tqdm 8 | from vocab import Vocab 9 | from embeddings import GloveEmbedding, KazumaCharEmbedding 10 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 11 | from dataset import Dataset, Ontology 12 | 13 | def get_args(): 14 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('--data_dir', help='root experiment folder', default='data/multi_woz/') 16 | parser.add_argument('--only_domain', help='', default='') 17 | args = parser.parse_args() 18 | return args 19 | 20 | args = get_args() 21 | data_dir = args.data_dir 22 | only_domain = args.only_domain 23 | # root_dir = os.path.dirname(__file__) 24 | # data_dir = os.path.join(root_dir, 'data', 'woz') 25 | 26 | 27 | # furl = 'https://mi.eng.cam.ac.uk/~nm480/woz_2.0.zip' 28 | # fzip = os.path.join(data_dir, 'woz.zip') 29 | 30 | draw = os.path.join(data_dir, 'raw') 31 | dann = os.path.join(data_dir, only_domain+'ann') 32 | 33 | #dann_domain = os.path.join(dann, only_domain) 34 | 35 | splits = ['dev', 'train', 'test'] 36 | 37 | def download(url, to_file): 38 | r = requests.get(url, stream=True) 39 | with open(to_file, 'wb') as f: 40 | for chunk in r.iter_content(chunk_size=1024): 41 | if chunk: 42 | f.write(chunk) 43 | 44 | 45 | def missing_files(d, files): 46 | return not all([os.path.isfile(os.path.join(d, '{}.json'.format(s))) for s in files]) 47 | 48 | 49 | if __name__ == '__main__': 50 | # if not os.path.isfile(fzip): 51 | # if not os.path.isdir(data_dir): 52 | # os.makedirs(data_dir) 53 | # logging.warn('Download from {} to {}'.format(furl, fzip)) 54 | # download(furl, fzip) 55 | 56 | # if missing_files(draw, splits): 57 | # if not os.path.isdir(draw): 58 | # os.makedirs(draw) 59 | # with zipfile.ZipFile(fzip) as f: 60 | # logging.warn('Extracting {} to {}'.format(fzip, draw)) 61 | # for split in splits: 62 | # with f.open('woz_2.0/woz2_{}.json'.format(split)) as fin, open(os.path.join(draw, '{}.json'.format(split)), 'wb') as fout: 63 | # fout.write(fin.read()) 64 | 65 | if missing_files(dann, files=splits + ['ontology', 'vocab', 'emb']): 66 | if not os.path.isdir(dann): 67 | os.makedirs(dann) 68 | dataset = {} 69 | ontology = Ontology() 70 | vocab = Vocab() 71 | vocab.word2index(['', ''], train=True) 72 | for s in splits: 73 | fname = '{}.json'.format(s) 74 | logging.warn('Annotating {}'.format(s)) 75 | dataset[s] = Dataset.annotate_raw(os.path.join(draw, fname), only_domain) 76 | dataset[s].numericalize_(vocab) 77 | ontology = ontology + dataset[s].extract_ontology() 78 | with open(os.path.join(dann, fname), 'wt') as f: 79 | json.dump(dataset[s].to_dict(), f) 80 | ontology.numericalize_(vocab) 81 | with open(os.path.join(dann, 'ontology.json'), 'wt') as f: 82 | json.dump(ontology.to_dict(), f) 83 | with open(os.path.join(dann, 'vocab.json'), 'wt') as f: 84 | json.dump(vocab.to_dict(), f) 85 | 86 | logging.warn('Computing word embeddings') 87 | embeddings = [GloveEmbedding(), KazumaCharEmbedding()] 88 | E = [] 89 | for w in tqdm(vocab._index2word): 90 | e = [] 91 | for emb in embeddings: 92 | e += emb.emb(w, default='zero') 93 | E.append(e) 94 | with open(os.path.join(dann, 'emb.json'), 'wt') as f: 95 | json.dump(E, f) 96 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import numpy as np 4 | from tqdm import tqdm 5 | from stanza.nlp.corenlp import CoreNLPClient 6 | 7 | 8 | client = None 9 | 10 | # do not change 11 | ACTIVATED_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] 12 | 13 | 14 | def annotate(sent): 15 | # global client 16 | # if client is None: 17 | # client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(',')) 18 | # words = [] 19 | # for sent in client.annotate(sent).sentences: 20 | # for tok in sent: 21 | # words.append(tok.word) 22 | words = sent.split() 23 | return words 24 | 25 | 26 | class Turn: 27 | 28 | def __init__(self, turn_id, transcript, turn_label, belief_state, system_acts, system_transcript, num=None): 29 | self.id = turn_id 30 | self.transcript = transcript 31 | self.turn_label = turn_label 32 | self.belief_state = belief_state 33 | self.system_acts = system_acts 34 | self.system_transcript = system_transcript 35 | self.num = num or {} 36 | 37 | def to_dict(self): 38 | return {'turn_id': self.id, 'transcript': self.transcript, 'turn_label': self.turn_label, 'belief_state': self.belief_state, 'system_acts': self.system_acts, 'system_transcript': self.system_transcript, 'num': self.num} 39 | 40 | @classmethod 41 | def from_dict(cls, alld, i, d): 42 | # if i!=0: 43 | # d["transcript"] = alld[i-1]["transcript"] + [";"] + d["transcript"] 44 | return cls(**d) 45 | 46 | @classmethod 47 | def annotate_raw(cls, raw, only_domain=""): 48 | system_acts = [] 49 | for a in raw['system_acts']: 50 | if isinstance(a, list): 51 | s, v = a 52 | system_acts.append(['inform'] + s.split() + ['='] + v.split()) 53 | else: 54 | system_acts.append(['request'] + a.split()) 55 | # NOTE: fix inconsistencies in data label 56 | fix = {'centre': 'center', 'areas': 'area', 'phone number': 'number'} 57 | if only_domain!="": 58 | return cls( 59 | turn_id=raw['turn_idx'], 60 | transcript=annotate(raw['transcript']), 61 | system_acts=system_acts, 62 | turn_label=[[fix.get(s.strip(), s.strip()), fix.get(v.strip(), v.strip())] for s, v in raw['turn_label'] if only_domain in s], 63 | belief_state=[bs for bs in raw['belief_state'] if only_domain in bs["slots"][0][0]], 64 | system_transcript=annotate(raw['system_transcript']), 65 | ) 66 | else: 67 | return cls( 68 | turn_id=raw['turn_idx'], 69 | transcript=annotate(raw['transcript']), 70 | system_acts=system_acts, 71 | turn_label=[[fix.get(s.strip(), s.strip()), fix.get(v.strip(), v.strip())] for s, v in raw['turn_label'] if s.split("-")[0] in ACTIVATED_DOMAINS], 72 | belief_state=[bs for bs in raw['belief_state'] if bs["slots"][0][0].split("-")[0] in ACTIVATED_DOMAINS], #raw['belief_state'], 73 | system_transcript=annotate(raw['system_transcript']), 74 | ) 75 | 76 | def numericalize_(self, vocab): 77 | self.num['transcript'] = vocab.word2index([''] + [w.lower() for w in self.transcript + ['']], train=True) 78 | self.num['system_transcript'] = vocab.word2index([''] + [w.lower() for w in self.system_transcript + ['']], train=True) 79 | self.num['system_acts'] = [vocab.word2index([''] + [w.lower() for w in a] + [''], train=True) for a in self.system_acts + [['']]] 80 | 81 | 82 | class Dialogue: 83 | 84 | def __init__(self, dialogue_id, turns): 85 | self.id = dialogue_id 86 | self.turns = turns 87 | 88 | def __len__(self): 89 | return len(self.turns) 90 | 91 | def to_dict(self): 92 | return {'dialogue_id': self.id, 'turns': [t.to_dict() for t in self.turns]} 93 | 94 | @classmethod 95 | def from_dict(cls, d): 96 | return cls(d['dialogue_id'], [Turn.from_dict(d['turns'], i, t) for i, t in enumerate(d['turns'])]) 97 | # return cls(d['dialogue_id'], [Turn.from_dict(t) for t in d['turns']]) 98 | 99 | @classmethod 100 | def annotate_raw(cls, raw, only_domain=""): 101 | return cls(raw['dialogue_idx'], [Turn.annotate_raw(t, only_domain) for t in raw['dialogue']]) 102 | 103 | 104 | class Dataset: 105 | 106 | def __init__(self, dialogues): 107 | self.dialogues = dialogues 108 | 109 | def __len__(self): 110 | return len(self.dialogues) 111 | 112 | def iter_turns(self): 113 | for d in self.dialogues: 114 | for t in d.turns: 115 | yield t 116 | 117 | def to_dict(self): 118 | return {'dialogues': [d.to_dict() for d in self.dialogues]} 119 | 120 | @classmethod 121 | def from_dict(cls, d): 122 | return cls([Dialogue.from_dict(dd) for dd in d['dialogues']]) 123 | 124 | @classmethod 125 | def annotate_raw(cls, fname, only_domain=""): 126 | with open(fname) as f: 127 | data = json.load(f) 128 | if only_domain!="": 129 | return cls([Dialogue.annotate_raw(d, only_domain) for d in tqdm(data) if only_domain in d["domains"]]) 130 | else: 131 | return cls([Dialogue.annotate_raw(d) for d in tqdm(data)]) 132 | 133 | def numericalize_(self, vocab): 134 | for t in self.iter_turns(): 135 | t.numericalize_(vocab) 136 | 137 | def extract_ontology(self): 138 | slots = set() 139 | values = defaultdict(set) 140 | for t in self.iter_turns(): 141 | for s, v in t.turn_label: 142 | slots.add(s.lower()) 143 | values[s].add(v.lower()) 144 | return Ontology(sorted(list(slots)), {k: sorted(list(v)) for k, v in values.items()}) 145 | 146 | def batch(self, batch_size, shuffle=False): 147 | turns = list(self.iter_turns()) 148 | if shuffle: 149 | np.random.shuffle(turns) 150 | for i in tqdm(range(0, len(turns), batch_size)): 151 | yield turns[i:i+batch_size] 152 | 153 | def evaluate_preds(self, preds, Ontology=[]): 154 | request = [] 155 | inform = [] 156 | joint_goal = [] 157 | slot_acc = [] 158 | fix = {'centre': 'center', 'areas': 'area', 'phone number': 'number'} 159 | i = 0 160 | print("len of Ontology", len(Ontology.slots)) 161 | for d in self.dialogues: 162 | pred_state = {} 163 | for t in d.turns: 164 | gold_request = set([(s, v) for s, v in t.turn_label if s == 'request']) 165 | gold_inform = set([(s, v) for s, v in t.turn_label if s != 'request']) 166 | pred_request = set([(s, v) for s, v in preds[i] if s == 'request']) 167 | pred_inform = set([(s, v) for s, v in preds[i] if s != 'request']) 168 | request.append(gold_request == pred_request) 169 | inform.append(gold_inform == pred_inform) 170 | 171 | slot_acc.append(self.compute_acc(gold_inform, pred_inform, len(Ontology.slots))) 172 | 173 | gold_recovered = set() 174 | pred_recovered = set() 175 | for s, v in pred_inform: 176 | pred_state[s] = v 177 | for b in t.belief_state: 178 | for s, v in b['slots']: 179 | if b['act'] != 'request': 180 | gold_recovered.add((b['act'], fix.get(s.strip(), s.strip()), fix.get(v.strip(), v.strip()))) 181 | for s, v in pred_state.items(): 182 | pred_recovered.add(('inform', s, v)) 183 | joint_goal.append(gold_recovered == pred_recovered) 184 | i += 1 185 | return {'turn_inform': np.mean(inform), 'turn_request': np.mean(request), 'joint_goal': np.mean(joint_goal), 'slot_acc': np.mean(slot_acc)} 186 | 187 | def compute_acc(self, gold, pred, nb_slot): 188 | miss_gold = 0 189 | miss_slot = [] 190 | 191 | # print(gold) 192 | # print(pred) 193 | # print(nb_slot) 194 | for g in gold: 195 | if g not in pred: 196 | miss_gold += 1 197 | miss_slot.append(g[0]) 198 | wrong_pred = 0 199 | for p in pred: 200 | if p not in gold and p[0] not in miss_slot: 201 | wrong_pred += 1 202 | ACC_TOTAL = nb_slot 203 | ACC = nb_slot - miss_gold - wrong_pred 204 | ACC = ACC / float(ACC_TOTAL) 205 | return ACC 206 | 207 | def record_preds(self, preds, to_file): 208 | data = self.to_dict() 209 | i = 0 210 | for d in data['dialogues']: 211 | for t in d['turns']: 212 | t['pred'] = sorted(list(preds[i])) 213 | i += 1 214 | with open(to_file, 'wt') as f: 215 | json.dump(data, f) 216 | 217 | 218 | class Ontology: 219 | 220 | def __init__(self, slots=None, values=None, num=None): 221 | self.slots = slots or [] 222 | self.values = values or {} 223 | self.num = num or {} 224 | 225 | def __add__(self, another): 226 | new_slots = sorted(list(set(self.slots + another.slots))) 227 | new_values = {s: sorted(list(set(self.values.get(s, []) + another.values.get(s, [])))) for s in new_slots} 228 | return Ontology(new_slots, new_values) 229 | 230 | def __radd__(self, another): 231 | return self if another == 0 else self.__add__(another) 232 | 233 | def to_dict(self): 234 | return {'slots': self.slots, 'values': self.values, 'num': self.num} 235 | 236 | def numericalize_(self, vocab): 237 | self.num = {} 238 | for s, vs in self.values.items(): 239 | self.num[s] = [vocab.word2index(annotate('{} = {}'.format(s, v)) + [''], train=True) for v in vs] 240 | 241 | @classmethod 242 | def from_dict(cls, d): 243 | return cls(**d) 244 | -------------------------------------------------------------------------------- /models/gce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import logging 7 | import os 8 | import re 9 | import json 10 | from collections import defaultdict 11 | from pprint import pformat 12 | import ipdb 13 | 14 | 15 | def position_encoding_init(n_position, d_pos_vec): 16 | ''' Init the sinusoid position encoding table ''' 17 | 18 | # keep dim 0 for padding token position encoding zero vector 19 | position_enc = np.array([ 20 | [pos / np.power(10000, 2 * (j // 2) / d_pos_vec) for j in range(d_pos_vec)] 21 | if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) 22 | 23 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i 24 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 25 | return torch.from_numpy(position_enc).type(torch.FloatTensor).cuda() 26 | 27 | 28 | def pad(seqs, emb, device, pad=0): 29 | lens = [len(s) for s in seqs] 30 | max_len = max(lens) 31 | padded = torch.LongTensor([s + (max_len-l) * [pad] for s, l in zip(seqs, lens)]) 32 | return emb(padded.to(device)), lens 33 | 34 | 35 | def run_rnn(rnn, inputs, lens): 36 | # sort by lens 37 | order = np.argsort(lens)[::-1].tolist() 38 | reindexed = inputs.index_select(0, inputs.data.new(order).long()) 39 | reindexed_lens = [lens[i] for i in order] 40 | packed = nn.utils.rnn.pack_padded_sequence(reindexed, reindexed_lens, batch_first=True) 41 | outputs, _ = rnn(packed) 42 | padded, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=0.) 43 | reverse_order = np.argsort(order).tolist() 44 | recovered = padded.index_select(0, inputs.data.new(reverse_order).long()) 45 | return recovered 46 | 47 | 48 | def attend(seq, cond, lens): 49 | """ 50 | attend over the sequences `seq` using the condition `cond`. 51 | """ 52 | if cond.ndimension() < 3: 53 | scores = cond.unsqueeze(1).expand_as(seq).mul(seq).sum(2) 54 | else: 55 | scores = cond.expand_as(seq).mul(seq).sum(2) 56 | max_len = max(lens) 57 | for i, l in enumerate(lens): 58 | if l < max_len: 59 | scores.data[i, l:] = -np.inf 60 | scores = F.softmax(scores, dim=1) 61 | context = scores.unsqueeze(2).expand_as(seq).mul(seq).sum(1) 62 | return context, scores 63 | 64 | 65 | class FixedEmbedding(nn.Embedding): 66 | """ 67 | this is the same as `nn.Embedding` but detaches the result from the graph and has dropout after lookup. 68 | """ 69 | 70 | def __init__(self, *args, dropout=0, **kwargs): 71 | super().__init__(*args, **kwargs) 72 | self.dropout = dropout 73 | 74 | def forward(self, *args, **kwargs): 75 | out = super().forward(*args, **kwargs) 76 | out.detach_() 77 | return F.dropout(out, self.dropout, self.training) 78 | 79 | 80 | class SelfAttention_gce(nn.Module): 81 | 82 | def __init__(self, dhid, dropout=0.): 83 | super().__init__() 84 | self.conv = nn.Conv1d(2 * dhid, 1, 5, padding=2) 85 | #self.dropout = nn.Dropout(dropout) 86 | #self.layer_norm = nn.LayerNorm(self.dv) 87 | 88 | def forward(self, inp, lens, cond): 89 | batch, seq_len, d_feat = inp.size() 90 | concat = torch.cat((cond.unsqueeze(0).expand_as(inp), inp), dim=2) 91 | attention = self.conv(concat.transpose(2, 1)) 92 | scores = F.softmax(attention, dim=2) 93 | context = scores.bmm(inp) 94 | return context 95 | 96 | 97 | class GCEencoder(nn.Module): 98 | """ 99 | the GCE encoder described in https://arxiv.org/abs/1805.09655. 100 | """ 101 | 102 | def __init__(self, din, dhid, slots, dropout=None): 103 | super().__init__() 104 | self.dropout = dropout or {} 105 | self.global_rnn = nn.LSTM(2 * din, dhid, bidirectional=True, batch_first=True) 106 | self.global_selfattn = SelfAttention_gce(2 * dhid, dropout=self.dropout.get('selfattn', 0.)) 107 | self.slots = slots 108 | self.beta_raw = nn.Parameter(torch.Tensor(len(slots))) 109 | nn.init.uniform_(self.beta_raw, -0.01, 0.01) 110 | 111 | def beta(self, slot): 112 | return F.sigmoid(self.beta_raw[self.slots.index(slot)]) 113 | 114 | def forward(self, x, x_len, slot, slot_emb, default_dropout=0.2): 115 | beta = self.beta(slot) 116 | x_new = torch.cat((slot_emb.unsqueeze(0).expand_as(x), x), dim=2) 117 | global_h = run_rnn(self.global_rnn, x_new, x_len) 118 | h = F.dropout(global_h, self.dropout.get('global', default_dropout), self.training) * (1-beta) 119 | c = F.dropout(self.global_selfattn(h, x_len, slot_emb), self.dropout.get('global', default_dropout), self.training) * (1-beta) 120 | return h, c 121 | 122 | 123 | class Model(nn.Module): 124 | """ 125 | the GLAD scoring model described in https://arxiv.org/abs/1805.09655. 126 | """ 127 | 128 | def __init__(self, args, ontology, vocab): 129 | super().__init__() 130 | self.optimizer = None 131 | self.args = args 132 | self.vocab = vocab 133 | self.ontology = ontology 134 | self.emb_fixed = FixedEmbedding(len(vocab), args.demb, dropout=args.dropout.get('emb', 0.2)) 135 | self.encoder = GCEencoder 136 | 137 | self.utt_encoder = self.encoder(args.demb, args.dhid, self.ontology.slots, dropout=args.dropout) 138 | self.act_encoder = self.encoder(args.demb, args.dhid, self.ontology.slots, dropout=args.dropout) 139 | self.ont_encoder = self.encoder(args.demb, args.dhid, self.ontology.slots, dropout=args.dropout) 140 | self.utt_scorer = nn.Linear(2 * args.dhid, 1) 141 | self.score_weight = nn.Parameter(torch.Tensor([0.5])) 142 | 143 | @property 144 | def device(self): 145 | if self.args.gpu is not None and torch.cuda.is_available(): 146 | return torch.device('cuda') 147 | else: 148 | return torch.device('cpu') 149 | 150 | def set_optimizer(self): 151 | self.optimizer = optim.Adam(self.parameters(), lr=self.args.lr) 152 | 153 | def load_emb(self, Eword): 154 | new = self.emb_fixed.weight.data.new 155 | self.emb_fixed.weight.data.copy_(new(Eword)) 156 | 157 | def forward(self, batch): 158 | # convert to variables and look up embeddings 159 | eos = self.vocab.word2index('') 160 | ontology = {s: pad(v, self.emb_fixed, self.device, pad=eos) for s, v in self.ontology.num.items()} 161 | utterance, utterance_len = pad([e.num['transcript'] for e in batch], self.emb_fixed, self.device, pad=eos) 162 | acts = [pad(e.num['system_acts'], self.emb_fixed, self.device, pad=eos) for e in batch] 163 | ys = {} 164 | for s in self.ontology.slots: 165 | # for each slot, compute the scores for each value 166 | 167 | s_words = s.split() 168 | s_new = s_words[0] 169 | s_emb = self.emb_fixed(torch.cuda.LongTensor([self.vocab.word2index(s_new)])) 170 | H_utt, c_utt = self.utt_encoder(utterance, utterance_len, slot=s, slot_emb=s_emb) 171 | _, C_acts = list(zip(*[self.act_encoder(a, a_len, slot=s, slot_emb=s_emb) for a, a_len in acts])) 172 | _, C_vals = self.ont_encoder(ontology[s][0], ontology[s][1], slot=s, slot_emb=s_emb) 173 | 174 | # compute the utterance score 175 | y_utts = [] 176 | q_utts = [] 177 | for c_val in C_vals: 178 | c_val = c_val.squeeze(0) 179 | q_utt, _ = attend(H_utt, c_val.unsqueeze(0).expand(len(batch), *c_val.size()), lens=utterance_len) 180 | q_utts.append(q_utt) 181 | y_utts = self.utt_scorer(torch.stack(q_utts, dim=1)).squeeze(2) 182 | 183 | # compute the previous action score 184 | q_acts = [] 185 | for i, C_act in enumerate(C_acts): 186 | q_act, _ = attend(C_act.unsqueeze(0), c_utt[i].unsqueeze(0), lens=[C_act.size(0)]) 187 | q_acts.append(q_act) 188 | 189 | y_acts = torch.cat(q_acts, dim=0).squeeze().mm(C_vals.squeeze().transpose(0, 1)) 190 | 191 | # combine the scores 192 | ys[s] = F.sigmoid(y_utts + self.score_weight * y_acts) 193 | 194 | if self.training: 195 | # create label variable and compute loss 196 | labels = {s: [len(self.ontology.values[s]) * [0] for i in range(len(batch))] for s in self.ontology.slots} 197 | for i, e in enumerate(batch): 198 | for s, v in e.turn_label: 199 | labels[s][i][self.ontology.values[s].index(v)] = 1 200 | labels = {s: torch.Tensor(m).to(self.device) for s, m in labels.items()} 201 | 202 | loss = 0 203 | for s in self.ontology.slots: 204 | loss += F.binary_cross_entropy(ys[s], labels[s]) 205 | else: 206 | loss = torch.Tensor([0]).to(self.device) 207 | return loss, {s: v.data.tolist() for s, v in ys.items()} 208 | 209 | def get_train_logger(self): 210 | logger = logging.getLogger('train-{}'.format(self.__class__.__name__)) 211 | logger.setLevel(logging.INFO) 212 | formatter = logging.Formatter('%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s') 213 | file_handler = logging.FileHandler(os.path.join(self.args.dout, 'train.log')) 214 | file_handler.setFormatter(formatter) 215 | logger.addHandler(file_handler) 216 | return logger 217 | 218 | def run_train(self, train, dev, args): 219 | track = defaultdict(list) 220 | iteration = 0 221 | best = {} 222 | logger = self.get_train_logger() 223 | if self.optimizer is None: 224 | self.set_optimizer() 225 | 226 | for epoch in range(args.epoch): 227 | logger.info('starting epoch {}'.format(epoch)) 228 | 229 | # train and update parameters 230 | self.train() 231 | for batch in train.batch(batch_size=args.batch_size, shuffle=True): 232 | iteration += 1 233 | self.zero_grad() 234 | loss, scores = self.forward(batch) 235 | loss.backward() 236 | self.optimizer.step() 237 | track['loss'].append(loss.item()) 238 | 239 | # evalute on train and dev 240 | summary = {'iteration': iteration, 'epoch': epoch} 241 | for k, v in track.items(): 242 | summary[k] = sum(v) / len(v) 243 | summary.update({'eval_train_{}'.format(k): v for k, v in self.run_eval(train, args).items()}) 244 | summary.update({'eval_dev_{}'.format(k): v for k, v in self.run_eval(dev, args).items()}) 245 | 246 | # do early stopping saves 247 | stop_key = 'eval_dev_{}'.format(args.stop) 248 | train_key = 'eval_train_{}'.format(args.stop) 249 | if best.get(stop_key, 0) <= summary[stop_key]: 250 | best_dev = '{:f}'.format(summary[stop_key]) 251 | best_train = '{:f}'.format(summary[train_key]) 252 | best.update(summary) 253 | self.save( 254 | best, 255 | identifier='epoch={epoch},iter={iteration},train_{key}={train},dev_{key}={dev}'.format( 256 | epoch=epoch, iteration=iteration, train=best_train, dev=best_dev, key=args.stop, 257 | ) 258 | ) 259 | self.prune_saves() 260 | dev.record_preds( 261 | preds=self.run_pred(dev, self.args), 262 | to_file=os.path.join(self.args.dout, 'dev.pred.json'), 263 | ) 264 | summary.update({'best_{}'.format(k): v for k, v in best.items()}) 265 | logger.info(pformat(summary)) 266 | track.clear() 267 | 268 | def extract_predictions(self, scores, threshold=0.5): 269 | batch_size = len(list(scores.values())[0]) 270 | predictions = [set() for i in range(batch_size)] 271 | for s in self.ontology.slots: 272 | for i, p in enumerate(scores[s]): 273 | triggered = [(s, v, p_v) for v, p_v in zip(self.ontology.values[s], p) if p_v > threshold] 274 | if s == 'request': 275 | # we can have multiple requests predictions 276 | predictions[i] |= set([(s, v) for s, v, p_v in triggered]) 277 | elif triggered: 278 | # only extract the top inform prediction 279 | sort = sorted(triggered, key=lambda tup: tup[-1], reverse=True) 280 | predictions[i].add((sort[0][0], sort[0][1])) 281 | return predictions 282 | 283 | def run_pred(self, dev, args): 284 | self.eval() 285 | predictions = [] 286 | for batch in dev.batch(batch_size=args.batch_size): 287 | loss, scores = self.forward(batch) 288 | predictions += self.extract_predictions(scores) 289 | return predictions 290 | 291 | def run_eval(self, dev, args): 292 | predictions = self.run_pred(dev, args) 293 | return dev.evaluate_preds(predictions) 294 | 295 | def save_config(self): 296 | fname = '{}/config.json'.format(self.args.dout) 297 | with open(fname, 'wt') as f: 298 | logging.info('saving config to {}'.format(fname)) 299 | json.dump(vars(self.args), f, indent=2) 300 | 301 | @classmethod 302 | def load_config(cls, fname, ontology, **kwargs): 303 | with open(fname) as f: 304 | logging.info('loading config from {}'.format(fname)) 305 | args = object() 306 | for k, v in json.load(f): 307 | setattr(args, k, kwargs.get(k, v)) 308 | return cls(args, ontology) 309 | 310 | def save(self, summary, identifier): 311 | fname = '{}/{}.t7'.format(self.args.dout, identifier) 312 | logging.info('saving model to {}'.format(fname)) 313 | state = { 314 | 'args': vars(self.args), 315 | 'model': self.state_dict(), 316 | 'summary': summary, 317 | 'optimizer': self.optimizer.state_dict(), 318 | } 319 | torch.save(state, fname) 320 | 321 | def load(self, fname): 322 | logging.info('loading model from {}'.format(fname)) 323 | state = torch.load(fname) 324 | self.load_state_dict(state['model']) 325 | self.set_optimizer() 326 | self.optimizer.load_state_dict(state['optimizer']) 327 | 328 | def get_saves(self, directory=None): 329 | if directory is None: 330 | directory = self.args.dout 331 | files = [f for f in os.listdir(directory) if f.endswith('.t7')] 332 | scores = [] 333 | for fname in files: 334 | re_str = r'dev_{}=([0-9\.]+)'.format(self.args.stop) 335 | dev_acc = re.findall(re_str, fname) 336 | if dev_acc: 337 | score = float(dev_acc[0].strip('.')) 338 | scores.append((score, os.path.join(directory, fname))) 339 | if not scores: 340 | raise Exception('No files found!') 341 | scores.sort(key=lambda tup: tup[0], reverse=True) 342 | return scores 343 | 344 | def prune_saves(self, n_keep=5): 345 | scores_and_files = self.get_saves() 346 | if len(scores_and_files) > n_keep: 347 | for score, fname in scores_and_files[n_keep:]: 348 | os.remove(fname) 349 | 350 | def load_best_save(self, directory): 351 | if directory is None: 352 | directory = self.args.dout 353 | 354 | scores_and_files = self.get_saves(directory=directory) 355 | if scores_and_files: 356 | assert scores_and_files, 'no saves exist at {}'.format(directory) 357 | score, fname = scores_and_files[0] 358 | self.load(fname) 359 | -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | import json 4 | import os 5 | import re 6 | import shutil 7 | import urllib 8 | from collections import OrderedDict 9 | from io import BytesIO 10 | from zipfile import ZipFile 11 | import difflib 12 | import numpy as np 13 | 14 | np.set_printoptions(precision=3) 15 | 16 | np.random.seed(2) 17 | 18 | 19 | ''' 20 | Most of the codes are from https://github.com/budzianowski/multiwoz 21 | ''' 22 | 23 | 24 | # GLOBAL VARIABLES 25 | DICT_SIZE = 400 26 | MAX_LENGTH = 50 27 | IGNORE_KEYS_IN_GOAL = ['eod', 'topic', 'messageLen', 'message'] 28 | 29 | fin = file('utils/mapping.pair') 30 | replacements = [] 31 | for line in fin.readlines(): 32 | tok_from, tok_to = line.replace('\n', '').split('\t') 33 | replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) 34 | 35 | 36 | def is_ascii(s): 37 | return all(ord(c) < 128 for c in s) 38 | 39 | def insertSpace(token, text): 40 | sidx = 0 41 | while True: 42 | sidx = text.find(token, sidx) 43 | if sidx == -1: 44 | break 45 | if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ 46 | re.match('[0-9]', text[sidx + 1]): 47 | sidx += 1 48 | continue 49 | if text[sidx - 1] != ' ': 50 | text = text[:sidx] + ' ' + text[sidx:] 51 | sidx += 1 52 | if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': 53 | text = text[:sidx + 1] + ' ' + text[sidx + 1:] 54 | sidx += 1 55 | return text 56 | 57 | def normalize(text, clean_value=True): 58 | # lower case every word 59 | text = text.lower() 60 | 61 | # replace white spaces in front and end 62 | text = re.sub(r'^\s*|\s*$', '', text) 63 | 64 | # hotel domain pfb30 65 | text = re.sub(r"b&b", "bed and breakfast", text) 66 | text = re.sub(r"b and b", "bed and breakfast", text) 67 | 68 | if clean_value: 69 | # normalize phone number 70 | ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) 71 | if ms: 72 | sidx = 0 73 | for m in ms: 74 | sidx = text.find(m[0], sidx) 75 | if text[sidx - 1] == '(': 76 | sidx -= 1 77 | eidx = text.find(m[-1], sidx) + len(m[-1]) 78 | text = text.replace(text[sidx:eidx], ''.join(m)) 79 | 80 | # normalize postcode 81 | ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', 82 | text) 83 | if ms: 84 | sidx = 0 85 | for m in ms: 86 | sidx = text.find(m, sidx) 87 | eidx = sidx + len(m) 88 | text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] 89 | 90 | # weird unicode bug 91 | text = re.sub(u"(\u2018|\u2019)", "'", text) 92 | 93 | if clean_value: 94 | # replace time and and price 95 | text = re.sub(timepat, ' [value_time] ', text) 96 | text = re.sub(pricepat, ' [value_price] ', text) 97 | #text = re.sub(pricepat2, '[value_price]', text) 98 | 99 | # replace st. 100 | text = text.replace(';', ',') 101 | text = re.sub('$\/', '', text) 102 | text = text.replace('/', ' and ') 103 | 104 | # replace other special characters 105 | text = text.replace('-', ' ') 106 | text = re.sub('[\"\<>@\(\)]', '', text) # remove 107 | 108 | # insert white space before and after tokens: 109 | for token in ['?', '.', ',', '!']: 110 | text = insertSpace(token, text) 111 | 112 | # insert white space for 's 113 | text = insertSpace('\'s', text) 114 | 115 | # replace it's, does't, you'd ... etc 116 | text = re.sub('^\'', '', text) 117 | text = re.sub('\'$', '', text) 118 | text = re.sub('\'\s', ' ', text) 119 | text = re.sub('\s\'', ' ', text) 120 | for fromx, tox in replacements: 121 | text = ' ' + text + ' ' 122 | text = text.replace(fromx, tox)[1:-1] 123 | 124 | # remove multiple spaces 125 | text = re.sub(' +', ' ', text) 126 | 127 | # concatenate numbers 128 | tmp = text 129 | tokens = text.split() 130 | i = 1 131 | while i < len(tokens): 132 | if re.match(u'^\d+$', tokens[i]) and \ 133 | re.match(u'\d+$', tokens[i - 1]): 134 | tokens[i - 1] += tokens[i] 135 | del tokens[i] 136 | else: 137 | i += 1 138 | text = ' '.join(tokens) 139 | 140 | return text 141 | 142 | def fixDelex(filename, data, data2, idx, idx_acts): 143 | """Given system dialogue acts fix automatic delexicalization.""" 144 | try: 145 | turn = data2[filename.strip('.json')][str(idx_acts)] 146 | except: 147 | return data 148 | 149 | if not isinstance(turn, str) and not isinstance(turn, unicode): 150 | for k, act in turn.items(): 151 | if 'Attraction' in k: 152 | if 'restaurant_' in data['log'][idx]['text']: 153 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "attraction") 154 | if 'hotel_' in data['log'][idx]['text']: 155 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "attraction") 156 | if 'Hotel' in k: 157 | if 'attraction_' in data['log'][idx]['text']: 158 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "hotel") 159 | if 'restaurant_' in data['log'][idx]['text']: 160 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "hotel") 161 | if 'Restaurant' in k: 162 | if 'attraction_' in data['log'][idx]['text']: 163 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "restaurant") 164 | if 'hotel_' in data['log'][idx]['text']: 165 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "restaurant") 166 | 167 | return data 168 | 169 | 170 | def getDialogueAct(filename, data, data2, idx, idx_acts): 171 | """Given system dialogue acts fix automatic delexicalization.""" 172 | acts = [] 173 | try: 174 | turn = data2[filename.strip('.json')][str(idx_acts)] 175 | except: 176 | return acts 177 | 178 | if not isinstance(turn, str) and not isinstance(turn, unicode): 179 | for k in turn.keys(): 180 | # temp = [k.split('-')[0].lower(), k.split('-')[1].lower()] 181 | # for a in turn[k]: 182 | # acts.append(temp + [a[0].lower()]) 183 | 184 | if k.split('-')[1].lower() == 'request': 185 | for a in turn[k]: 186 | acts.append(a[0].lower()) 187 | elif k.split('-')[1].lower() == 'inform': 188 | for a in turn[k]: 189 | acts.append([a[0].lower(), normalize(a[1].lower(), False)]) 190 | 191 | return acts 192 | 193 | 194 | def get_summary_bstate(bstate, get_domain=False): 195 | """Based on the mturk annotations we form multi-domain belief state""" 196 | domains = [u'taxi',u'restaurant', u'hospital', u'hotel',u'attraction', u'train', u'police'] 197 | summary_bstate = [] 198 | summary_bvalue = [] 199 | active_domain = [] 200 | for domain in domains: 201 | domain_active = False 202 | 203 | booking = [] 204 | #print(domain,len(bstate[domain]['book'].keys())) 205 | for slot in sorted(bstate[domain]['book'].keys()): 206 | if slot == 'booked': 207 | if len(bstate[domain]['book']['booked'])!=0: 208 | booking.append(1) 209 | # summary_bvalue.append("book {} {}:{}".format(domain, slot, "Yes")) 210 | else: 211 | booking.append(0) 212 | else: 213 | if bstate[domain]['book'][slot] != "": 214 | booking.append(1) 215 | summary_bvalue.append(["{}-book {}".format(domain, slot.strip().lower()), normalize(bstate[domain]['book'][slot].strip().lower(), False)]) #(["book", domain, slot, bstate[domain]['book'][slot]]) 216 | else: 217 | booking.append(0) 218 | if domain == 'train': 219 | if 'people' not in bstate[domain]['book'].keys(): 220 | booking.append(0) 221 | if 'ticket' not in bstate[domain]['book'].keys(): 222 | booking.append(0) 223 | summary_bstate += booking 224 | 225 | for slot in bstate[domain]['semi']: 226 | slot_enc = [0, 0, 0] # not mentioned, dontcare, filled 227 | if bstate[domain]['semi'][slot] == 'not mentioned': 228 | slot_enc[0] = 1 229 | elif bstate[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", "do not care"]: 230 | slot_enc[1] = 1 231 | summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), "dontcare"]) #(["semi", domain, slot, "dontcare"]) 232 | elif bstate[domain]['semi'][slot]: 233 | summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), normalize(bstate[domain]['semi'][slot].strip().lower(), False)]) #(["semi", domain, slot, bstate[domain]['semi'][slot]]) 234 | if slot_enc != [0, 0, 0]: 235 | domain_active = True 236 | summary_bstate += slot_enc 237 | 238 | # quasi domain-tracker 239 | if domain_active: 240 | summary_bstate += [1] 241 | active_domain.append(domain) 242 | else: 243 | summary_bstate += [0] 244 | 245 | #print(len(summary_bstate)) 246 | assert len(summary_bstate) == 94 247 | if get_domain: 248 | return active_domain 249 | else: 250 | return summary_bstate, summary_bvalue 251 | 252 | 253 | def analyze_dialogue(dialogue, maxlen): 254 | """Cleaning procedure for all kinds of errors in text and annotation.""" 255 | d = dialogue 256 | # do all the necessary postprocessing 257 | if len(d['log']) % 2 != 0: 258 | #print path 259 | print 'odd # of turns' 260 | return None # odd number of turns, wrong dialogue 261 | d_pp = {} 262 | d_pp['goal'] = d['goal'] # for now we just copy the goal 263 | usr_turns = [] 264 | sys_turns = [] 265 | # last_bvs = [] 266 | for i in range(len(d['log'])): 267 | if len(d['log'][i]['text'].split()) > maxlen: 268 | print 'too long' 269 | return None # too long sentence, wrong dialogue 270 | if i % 2 == 0: # usr turn 271 | text = d['log'][i]['text'] 272 | if not is_ascii(text): 273 | print 'not ascii' 274 | return None 275 | usr_turns.append(d['log'][i]) 276 | else: # sys turn 277 | text = d['log'][i]['text'] 278 | if not is_ascii(text): 279 | print 'not ascii' 280 | return None 281 | belief_summary, belief_value_summary = get_summary_bstate(d['log'][i]['metadata']) 282 | d['log'][i]['belief_summary'] = str(belief_summary) 283 | d['log'][i]['belief_value_summary'] = belief_value_summary 284 | sys_turns.append(d['log'][i]) 285 | d_pp['usr_log'] = usr_turns 286 | d_pp['sys_log'] = sys_turns 287 | 288 | return d_pp 289 | 290 | 291 | def get_dial(dialogue): 292 | """Extract a dialogue from the file""" 293 | dial = [] 294 | d_orig = analyze_dialogue(dialogue, MAX_LENGTH) # max turn len is 50 words 295 | if d_orig is None: 296 | return None 297 | usr = [t['text'] for t in d_orig['usr_log']] 298 | sys = [t['text'] for t in d_orig['sys_log']] 299 | sys_a = [t['dialogue_acts'] for t in d_orig['sys_log']] 300 | bvs = [t['belief_value_summary'] for t in d_orig['sys_log']] 301 | domain = [t['domain'] for t in d_orig['usr_log']] 302 | for item in zip(usr, sys, sys_a, domain, bvs): 303 | dial.append({'usr':item[0],'sys':item[1], 'sys_a':item[2], 'domain':item[3], 'bvs':item[4]}) 304 | return dial 305 | 306 | 307 | def loadData(): 308 | data_url = "data/multi-woz/data.json" 309 | dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/280608/MULTIWOZ2.zip?sequence=3&isAllowed=y" 310 | if not os.path.exists("data"): 311 | os.makedirs("data") 312 | os.makedirs("data/multi-woz") 313 | 314 | if not os.path.exists(data_url): 315 | print("Downloading and unzipping the MultiWOZ dataset") 316 | resp = urllib.urlopen(dataset_url) 317 | zip_ref = ZipFile(BytesIO(resp.read())) 318 | zip_ref.extractall("data/multi-woz") 319 | zip_ref.close() 320 | shutil.copy('data/multi-woz/MULTIWOZ2 2/data.json', 'data/multi-woz/') 321 | shutil.copy('data/multi-woz/MULTIWOZ2 2/valListFile.json', 'data/multi-woz/') 322 | shutil.copy('data/multi-woz/MULTIWOZ2 2/testListFile.json', 'data/multi-woz/') 323 | shutil.copy('data/multi-woz/MULTIWOZ2 2/dialogue_acts.json', 'data/multi-woz/') 324 | 325 | 326 | def getDomain(idx, log, domains, last_domain): 327 | if idx == 1: 328 | active_domains = get_summary_bstate(log[idx]["metadata"], True) 329 | crnt_doms = active_domains[0] if len(active_domains)!=0 else domains[0] 330 | return crnt_doms 331 | else: 332 | ds_diff = get_ds_diff(log[idx-2]["metadata"], log[idx]["metadata"]) 333 | if len(ds_diff.keys()) == 0: # no clues from dialog states 334 | crnt_doms = last_domain 335 | else: 336 | crnt_doms = ds_diff.keys() 337 | return crnt_doms[0] # How about multiple domains in one sentence senario ? 338 | 339 | 340 | def get_ds_diff(prev_d, crnt_d): 341 | diff = {} 342 | # Sometimes, metadata is an empty dictionary, bug? 343 | if not prev_d or not crnt_d: 344 | return diff 345 | 346 | for ((k1, v1), (k2, v2)) in zip(prev_d.items(), crnt_d.items()): 347 | assert k1 == k2 348 | if v1 != v2: # updated 349 | diff[k2] = v2 350 | return diff 351 | 352 | 353 | def createData(): 354 | # download the data 355 | loadData() 356 | 357 | # create dictionary of delexicalied values that then we will search against, order matters here! 358 | # dic = delexicalize.prepareSlotValuesIndependent() 359 | delex_data = {} 360 | 361 | fin1 = file('data/multi-woz/data.json') 362 | data = json.load(fin1) 363 | 364 | fin2 = file('data/multi-woz/dialogue_acts.json') 365 | data2 = json.load(fin2) 366 | 367 | for didx, dialogue_name in enumerate(data): 368 | 369 | dialogue = data[dialogue_name] 370 | 371 | domains = [] 372 | for dom_k, dom_v in dialogue['goal'].items(): 373 | if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities 374 | domains.append(dom_k) 375 | 376 | idx_acts = 1 377 | last_domain, last_slot_fill = "", [] 378 | for idx, turn in enumerate(dialogue['log']): 379 | # normalization, split and delexicalization of the sentence 380 | origin_text = normalize(turn['text'], False) 381 | # origin_text = delexicalize.markEntity(origin_text, dic) 382 | dialogue['log'][idx]['text'] = origin_text 383 | 384 | if idx % 2 == 1: # if it's a system turn 385 | 386 | cur_domain = getDomain(idx, dialogue['log'], domains, last_domain) 387 | last_domain = [cur_domain] 388 | 389 | dialogue['log'][idx - 1]['domain'] = cur_domain 390 | dialogue['log'][idx]['dialogue_acts'] = getDialogueAct(dialogue_name, dialogue, data2, idx, idx_acts) 391 | idx_acts += 1 392 | 393 | # FIXING delexicalization: 394 | dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts) 395 | 396 | delex_data[dialogue_name] = dialogue 397 | 398 | # if didx > 10: 399 | # break 400 | 401 | # with open('data/multi-woz/woz2like_data.json', 'w') as outfile: 402 | # json.dump(delex_data, outfile) 403 | 404 | return delex_data 405 | 406 | 407 | def buildDelexDict(origin_sent, delex_sent): 408 | dictionary = {} 409 | s = difflib.SequenceMatcher(None, delex_sent.split(), origin_sent.split()) 410 | bs = s.get_matching_blocks() 411 | for i, b in enumerate(bs): 412 | if i < len(bs)-2: 413 | a_start = b.a + b.size 414 | b_start = b.b + b.size 415 | b_end = bs[i+1].b 416 | dictionary[a_start] = " ".join(origin_sent.split()[b_start:b_end]) 417 | return dictionary 418 | 419 | 420 | def divideData(data): 421 | """Given test and validation sets, divide 422 | the data for three different sets""" 423 | testListFile = [] 424 | fin = file('data/multi-woz/testListFile.json') 425 | for line in fin: 426 | testListFile.append(line[:-1]) 427 | fin.close() 428 | 429 | valListFile = [] 430 | fin = file('data/multi-woz/valListFile.json') 431 | for line in fin: 432 | valListFile.append(line[:-1]) 433 | fin.close() 434 | 435 | trainListFile = open('data/trainListFile', 'w') 436 | 437 | test_dials = [] 438 | val_dials = [] 439 | train_dials = [] 440 | 441 | # dictionaries 442 | word_freqs_usr = OrderedDict() 443 | word_freqs_sys = OrderedDict() 444 | 445 | count_train, count_val, count_test = 0, 0, 0 446 | 447 | for dialogue_name in data: 448 | # print dialogue_name 449 | dial_item = data[dialogue_name] 450 | domains = [] 451 | for dom_k, dom_v in dial_item['goal'].items(): 452 | if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities 453 | domains.append(dom_k) 454 | 455 | dial = get_dial(data[dialogue_name]) 456 | if dial: 457 | dialogue = {} 458 | dialogue['dialogue_idx'] = dialogue_name 459 | dialogue['domains'] = list(set(domains)) #list(set([d['domain'] for d in dial])) 460 | last_bs = [] 461 | dialogue['dialogue'] = [] 462 | 463 | for turn_i, turn in enumerate(dial): 464 | # usr, usr_o, sys, sys_o, sys_a, domain 465 | turn_dialog = {} 466 | turn_dialog['system_transcript'] = dial[turn_i-1]['sys'] if turn_i > 0 else "" 467 | turn_dialog['turn_idx'] = turn_i 468 | turn_dialog['belief_state'] = [{"slots": [s], "act": "inform"} for s in turn['bvs']] 469 | turn_dialog['turn_label'] = [bs["slots"][0] for bs in turn_dialog['belief_state'] if bs not in last_bs] 470 | turn_dialog['transcript'] = turn['usr'] 471 | turn_dialog['system_acts'] = dial[turn_i-1]['sys_a'] if turn_i > 0 else [] 472 | turn_dialog['domain'] = turn['domain'] 473 | last_bs = turn_dialog['belief_state'] 474 | dialogue['dialogue'].append(turn_dialog) 475 | 476 | if dialogue_name in testListFile: 477 | test_dials.append(dialogue) 478 | count_test += 1 479 | elif dialogue_name in valListFile: 480 | val_dials.append(dialogue) 481 | count_val += 1 482 | else: 483 | trainListFile.write(dialogue_name + '\n') 484 | train_dials.append(dialogue) 485 | count_train += 1 486 | 487 | print("# of dialogues: Train {}, Val {}, Test {}".format(count_train, count_val, count_test)) 488 | 489 | # save all dialogues 490 | with open('data/dev_dials.json', 'wb') as f: 491 | json.dump(val_dials, f, indent=4) 492 | 493 | with open('data/test_dials.json', 'wb') as f: 494 | json.dump(test_dials, f, indent=4) 495 | 496 | with open('data/train_dials.json', 'wb') as f: 497 | json.dump(train_dials, f, indent=4) 498 | 499 | # return word_freqs_usr, word_freqs_sys 500 | 501 | 502 | def main(): 503 | print('Create WOZ-like dialogues. Get yourself a coffee, this might take a while.') 504 | delex_data = createData() 505 | print('Divide dialogues...') 506 | divideData(delex_data) 507 | # print('Building dictionaries') 508 | # buildDictionaries(word_freqs_usr, word_freqs_sys) 509 | 510 | 511 | if __name__ == "__main__": 512 | main() -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | annotate 227 | client 228 | Turn 229 | turns 230 | reque 231 | vocab 232 | GLADEncoder_global_no_rnn_conditioned_v5 233 | pad 234 | FixedEmbedding 235 | GLAD_ENCODERS 236 | encoder 237 | position_encoding_init 238 | load_model 239 | 240 | 241 | 242 | 244 | 245 | 259 | 260 | 261 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 |