├── __init__.py
├── models
├── __init__.py
└── gce.py
├── version.txt
├── .dockerignore
├── BayLearn_2018_Dialogue.pdf
├── =1.9
├── command_train_woz.sh
├── command_train_mwoz.sh
├── .gitignore
├── .idea
├── vcs.xml
├── misc.xml
├── modules.xml
├── glad.iml
└── workspace.xml
├── command_eval_woz.sh
├── requirements.txt
├── command_eval_mwoz.sh
├── .nfs831404ea9438756800000067
├── README.md
├── Dockerfile
├── utils.py
├── LICENSE
├── evaluate.py
├── preprocess_data.py
├── train.py
├── preprocess_data_mwoz.py
├── dataset.py
└── create_data.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/version.txt:
--------------------------------------------------------------------------------
1 | 0.1
2 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | exp/
2 | data/
3 | Dockerfile
4 | .git/
5 | *.py[cod]
6 |
--------------------------------------------------------------------------------
/BayLearn_2018_Dialogue.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elnaaz/GCE-Model/HEAD/BayLearn_2018_Dialogue.pdf
--------------------------------------------------------------------------------
/=1.9:
--------------------------------------------------------------------------------
1 | Requirement already satisfied: six in /System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python (1.4.1)
2 |
--------------------------------------------------------------------------------
/command_train_woz.sh:
--------------------------------------------------------------------------------
1 |
2 | MODEL=gce_woz
3 | DATA=woz
4 | GPU=0
5 | EPOCH=200
6 |
7 | CUDA_VISIBLE_DEVICES=$GPU python train.py --gpu $GPU -n $MODEL --data $DATA --epoch $EPOCH
8 |
--------------------------------------------------------------------------------
/command_train_mwoz.sh:
--------------------------------------------------------------------------------
1 |
2 | MODEL=gce_multiwoz
3 | DATA=multi_woz
4 | GPU=0
5 | EPOCH=200
6 |
7 | CUDA_VISIBLE_DEVICES=$GPU python train.py --gpu $GPU -n $MODEL --data $DATA --epoch $EPOCH
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | Mar*/
3 | .DS_Store
4 | *.py[cod]
5 | *.json
6 | *.json[~]
7 | *.save
8 | *.log
9 | *.model
10 | *.t7
11 | *.npy
12 | *.flist
13 | *.zip
14 | *.gzip
15 | *.tar
16 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/command_eval_woz.sh:
--------------------------------------------------------------------------------
1 |
2 | DIR=exp/woz/gce
3 | MODEL=gce_woz
4 | DATA=woz
5 | GPU=0
6 | CHECKPOINT=$DIR/$MODEL
7 |
8 | CUDA_VISIBLE_DEVICES=$GPU python evaluate.py --gpu $GPU --data $DATA --split test --dsave $CHECKPOINT
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | protobuf==3.4.0
2 | requests==2.18.4
3 | stanza==0.3
4 | tqdm==4.19.1.post1
5 | vocab==0.0.3
6 | embeddings==0.0.4
7 | http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl
8 | numpy==1.13.1
9 |
--------------------------------------------------------------------------------
/command_eval_mwoz.sh:
--------------------------------------------------------------------------------
1 |
2 | DIR=exp/multi_woz/gce
3 | MODEL=gce_multiwoz
4 | DATA=multi_woz
5 | GPU=0
6 | CHECKPOINT=$DIR/$MODEL
7 |
8 | CUDA_VISIBLE_DEVICES=$GPU python evaluate.py --gpu $GPU --data $DATA --split test --dsave $CHECKPOINT
9 |
--------------------------------------------------------------------------------
/.nfs831404ea9438756800000067:
--------------------------------------------------------------------------------
1 |
2 | MODEL=global_no_rnn_conditioned_v6
3 | GLAD=GLADEncoder_global_no_rnn_conditioned_v6
4 | DATA=woz
5 | GPU=2
6 | EPOCH=2000
7 |
8 | CUDA_VISIBLE_DEVICES=$GPU python train.py --gpu $GPU --encoder $GLAD -n $MODEL --data $DATA --epoch $EPOCH
9 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/glad.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Globally-Conditioned Encoder (GCE) For Neural Dialogue State Tracking
2 |
3 | This repository contains an implementation of the [Toward Scalable Neural Dialogue State Tracking Model](https://arxiv.org/abs/1812.00899).
4 | If you use this in your work, please cite the following
5 |
6 |
7 | ```
8 | @inproceedings{nouri2018gce,
9 | title={ Toward Scalable Neural Dialogue State Tracking },
10 | author={ Nouri, Elnaz and Hosseini-Asl, Ehsan },
11 | booktitle={ NeurIPS 2018, 2nd Conversational AI workshop },
12 | year={ 2018 },
13 | arxiv={ https://arxiv.org/abs/1812.00899 }
14 | }
15 | ```
16 |
17 |
18 | # Install dependencies
19 |
20 | For more details on installation, please check below github repository
21 | ```
22 | https://github.com/salesforce/glad
23 | ```
24 |
25 |
26 |
27 | # Contribution
28 |
29 | Pull requests are welcome!
30 | If you have any questions, please create an issue or contact the corresponding author at `Elnaz.Nouri microsoft com`.
31 |
32 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.0-base-ubuntu16.04
2 |
3 | # install Miniconda
4 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
5 | ENV PATH /opt/conda/bin:$PATH
6 |
7 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \
8 | libglib2.0-0 libxext6 libsm6 libxrender1 \
9 | git mercurial subversion
10 |
11 | RUN wget --quiet https://repo.continuum.io/miniconda/Miniconda3-4.4.10-Linux-x86_64.sh -O ~/miniconda.sh && \
12 | /bin/bash ~/miniconda.sh -b -p /opt/conda && \
13 | rm ~/miniconda.sh && \
14 | /opt/conda/bin/conda clean -tipsy && \
15 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
16 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
17 | echo "conda activate base" >> ~/.bashrc
18 |
19 | # copy GLAD
20 | RUN mkdir -p /opt/glad
21 | WORKDIR /opt/glad
22 |
23 | # install dependencies
24 | COPY requirements.txt .
25 | RUN pip install -r requirements.txt
26 |
27 | # copy source
28 | COPY . .
29 |
30 | # volumes and environment variables
31 | ENV EMBEDDINGS_ROOT /opt/embeddings
32 | RUN mkdir -p /opt/embeddings
33 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from pprint import pformat
5 | from importlib import import_module
6 | from vocab import Vocab
7 | from dataset import Dataset, Ontology
8 | from preprocess_data import dann as dann_woz
9 | from preprocess_data_mwoz import dann as dann_mwoz
10 | import ipdb
11 |
12 | def load_dataset(splits=('train', 'dev', 'test'), data='woz'):
13 | dann = dann_woz if data == 'woz' else dann_mwoz
14 | with open(os.path.join(dann, 'ontology.json')) as f:
15 | ontology = Ontology.from_dict(json.load(f))
16 | with open(os.path.join(dann, 'vocab.json')) as f:
17 | vocab = Vocab.from_dict(json.load(f))
18 | with open(os.path.join(dann, 'emb.json')) as f:
19 | E = json.load(f)
20 | dataset = {}
21 | for split in splits:
22 | with open(os.path.join(dann, '{}.json'.format(split))) as f:
23 | logging.warn('loading split {}'.format(split))
24 | dataset[split] = Dataset.from_dict(json.load(f))
25 |
26 | logging.info('dataset sizes: {}'.format(pformat({k: len(v) for k, v in dataset.items()})))
27 | return dataset, ontology, vocab, E
28 |
29 |
30 | def get_models():
31 | return [m.replace('.py', '') for m in os.listdir('models') if not m.startswith('_') and m != 'model']
32 |
33 |
34 | def load_model(model, *args, **kwargs):
35 | Model = import_module('models.{}'.format(model)).Model
36 | model = Model(*args, **kwargs)
37 | logging.info('loaded model {}'.format(Model))
38 | return model
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, Salesforce
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | from argparse import ArgumentParser, Namespace
5 | from pprint import pprint
6 | from utils import load_dataset, load_model
7 | from models.glad import GLAD_ENCODERS
8 | import ipdb
9 |
10 | if __name__ == '__main__':
11 | parser = ArgumentParser()
12 | parser.add_argument('--dsave', help='save location of model')
13 | parser.add_argument('--split', help='split to evaluate on', default='dev')
14 | parser.add_argument('--gpu', type=int, help='gpu to use', default=None)
15 | parser.add_argument('--fout', help='optional save file to store the predictions')
16 | parser.add_argument('--encoder', help='which encoder to use', default='GLADEncoder', choices=GLAD_ENCODERS)
17 | parser.add_argument('--use_elmo', help='use elmo embeddings', action='store_true')
18 | parser.add_argument('--data', help='dataset', default='woz', choices=['woz', 'multi_woz'])
19 | args = parser.parse_args()
20 |
21 | logging.basicConfig(level=logging.INFO)
22 |
23 | with open(os.path.join(args.dsave, 'config.json')) as f:
24 | args_save = Namespace(**json.load(f))
25 | args_save.gpu = args.gpu
26 | if not hasattr(args_save, 'encoder'):
27 | args_save.encoder = args.encoder
28 | pprint(args_save)
29 |
30 | dataset, ontology, vocab, Eword = load_dataset(data=args.data)
31 |
32 | model = load_model(args_save.model, args.use_elmo, args_save, ontology, vocab)
33 | model.load_best_save(directory=args.dsave)
34 | if args.gpu is not None:
35 | #model.cuda(args.gpu)
36 | model.cuda(0)
37 |
38 | logging.info('Making predictions for {} dialogues and {} turns'.format(len(dataset[args.split]), len(list(dataset[args.split].iter_turns()))))
39 | preds = model.run_pred(dataset[args.split], args_save)
40 | pprint(dataset[args.split].evaluate_preds(preds))
41 |
42 | if args.fout:
43 | with open(args.fout, 'wt') as f:
44 | # predictions is a list of sets, need to convert to list of lists to make it JSON serializable
45 | json.dump([list(p) for p in preds], f, indent=2)
46 |
--------------------------------------------------------------------------------
/preprocess_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import json
4 | import logging
5 | import zipfile
6 | import requests
7 | from tqdm import tqdm
8 | from vocab import Vocab
9 | from embeddings import GloveEmbedding, KazumaCharEmbedding
10 | from dataset import Dataset, Ontology
11 |
12 |
13 | root_dir = os.path.dirname(__file__)
14 | data_dir = os.path.join(root_dir, 'data', 'woz')
15 |
16 |
17 | furl = 'https://mi.eng.cam.ac.uk/~nm480/woz_2.0.zip'
18 | fzip = os.path.join(data_dir, 'woz.zip')
19 |
20 | draw = os.path.join(data_dir, 'raw')
21 | dann = os.path.join(data_dir, 'ann')
22 |
23 | splits = ['dev', 'train', 'test']
24 |
25 |
26 | def download(url, to_file):
27 | r = requests.get(url, stream=True)
28 | with open(to_file, 'wb') as f:
29 | for chunk in r.iter_content(chunk_size=1024):
30 | if chunk:
31 | f.write(chunk)
32 |
33 |
34 | def missing_files(d, files):
35 | return not all([os.path.isfile(os.path.join(d, '{}.json'.format(s))) for s in files])
36 |
37 |
38 | if __name__ == '__main__':
39 | if not os.path.isfile(fzip):
40 | if not os.path.isdir(data_dir):
41 | os.makedirs(data_dir)
42 | logging.warn('Download from {} to {}'.format(furl, fzip))
43 | download(furl, fzip)
44 |
45 | if missing_files(draw, splits):
46 | if not os.path.isdir(draw):
47 | os.makedirs(draw)
48 | with zipfile.ZipFile(fzip) as f:
49 | logging.warn('Extracting {} to {}'.format(fzip, draw))
50 | for split in splits:
51 | with f.open('woz_2.0/woz2_{}.json'.format(split)) as fin, open(os.path.join(draw, '{}.json'.format(split)), 'wb') as fout:
52 | fout.write(fin.read())
53 |
54 | if missing_files(dann, files=splits + ['ontology', 'vocab', 'emb']):
55 | if not os.path.isdir(dann):
56 | os.makedirs(dann)
57 | dataset = {}
58 | ontology = Ontology()
59 | vocab = Vocab()
60 | vocab.word2index(['', ''], train=True)
61 | for s in splits:
62 | fname = '{}.json'.format(s)
63 | logging.warn('Annotating {}'.format(s))
64 | dataset[s] = Dataset.annotate_raw(os.path.join(draw, fname))
65 | dataset[s].numericalize_(vocab)
66 | ontology = ontology + dataset[s].extract_ontology()
67 | with open(os.path.join(dann, fname), 'wt') as f:
68 | json.dump(dataset[s].to_dict(), f)
69 | ontology.numericalize_(vocab)
70 | with open(os.path.join(dann, 'ontology.json'), 'wt') as f:
71 | json.dump(ontology.to_dict(), f)
72 | with open(os.path.join(dann, 'vocab.json'), 'wt') as f:
73 | json.dump(vocab.to_dict(), f)
74 |
75 | logging.warn('Computing word embeddings')
76 | embeddings = [GloveEmbedding(), KazumaCharEmbedding()]
77 | E = []
78 | for w in tqdm(vocab._index2word):
79 | e = []
80 | for emb in embeddings:
81 | e += emb.emb(w, default='zero')
82 | E.append(e)
83 | with open(os.path.join(dann, 'emb.json'), 'wt') as f:
84 | json.dump(E, f)
85 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
3 | from utils import load_dataset, get_models, load_model
4 | import os
5 | import logging
6 | import numpy as np
7 | from pprint import pprint
8 | import torch
9 | from random import seed
10 | # from models.glad import GLAD_ENCODERS
11 | from models.gce import GCEencoder
12 |
13 |
14 | def run(args):
15 | pprint(args)
16 | logging.basicConfig(level=logging.INFO)
17 |
18 | np.random.seed(args.seed)
19 | torch.manual_seed(args.seed)
20 | seed(args.seed)
21 |
22 | dataset, ontology, vocab, Eword = load_dataset(data=args.data)
23 |
24 | model = load_model(args.model, args, ontology, vocab)
25 | model.save_config()
26 | model.load_emb(Eword)
27 |
28 | model = model.to(model.device)
29 | if args.resume:
30 | logging.info('Load best model')
31 | model.load_best_save(directory=args.resume)
32 | logging.info('Starting train')
33 | model.run_train(dataset['train'], dataset['dev'], args)
34 | elif not args.test:
35 | logging.info('Starting train')
36 | model.run_train(dataset['train'], dataset['dev'], args)
37 | else:
38 | model.load_best_save(directory=args.dout)
39 | model = model.to(model.device)
40 | logging.info('Running dev evaluation')
41 | dev_out = model.run_eval(dataset['dev'], args)
42 | pprint(dev_out)
43 |
44 |
45 | def get_args():
46 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
47 | parser.add_argument('--dexp', help='root experiment folder', default='exp')
48 | parser.add_argument('--data', help='dataset', default='woz', choices=['woz', 'multi_woz'])
49 | parser.add_argument('--model', help='which model to use', default='gce', choices=get_models())
50 | # parser.add_argument('--use_elmo', help='use elmo embeddings', action='store_true')
51 | # parser.add_argument('--encoder', help='which encoder to use', default='GLADEncoder', choices=GLAD_ENCODERS)
52 | parser.add_argument('--epoch', help='max epoch to run for', default=50, type=int)
53 | parser.add_argument('--demb', help='word embedding size', default=400, type=int)
54 | parser.add_argument('--dhid', help='hidden state size', default=200, type=int)
55 | parser.add_argument('--batch_size', help='batch size', default=50, type=int)
56 | parser.add_argument('--lr', help='learning rate', default=1e-3, type=float)
57 | parser.add_argument('--stop', help='slot to early stop on', default='joint_goal')
58 | parser.add_argument('--resume', help='save directory to resume from')
59 | parser.add_argument('-n', '--nick', help='nickname for model', default='default')
60 | parser.add_argument('--seed', default=42, help='random seed', type=int)
61 | parser.add_argument('--test', action='store_true', help='run in evaluation only mode')
62 | parser.add_argument('--gpu', type=int, help='which GPU to use')
63 | parser.add_argument('--dropout', nargs='*', help='dropout rates', default=['emb=0.2', 'local=0.2', 'global=0.2'])
64 | args = parser.parse_args()
65 | args.dout = os.path.join(args.dexp, args.data, args.model, args.nick)
66 | args.dropout = {d.split('=')[0]: float(d.split('=')[1]) for d in args.dropout}
67 | if not os.path.isdir(args.dout):
68 | os.makedirs(args.dout)
69 | return args
70 |
71 |
72 | if __name__ == '__main__':
73 | args = get_args()
74 | run(args)
75 |
--------------------------------------------------------------------------------
/preprocess_data_mwoz.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import json
4 | import logging
5 | import zipfile
6 | import requests
7 | from tqdm import tqdm
8 | from vocab import Vocab
9 | from embeddings import GloveEmbedding, KazumaCharEmbedding
10 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
11 | from dataset import Dataset, Ontology
12 |
13 | def get_args():
14 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
15 | parser.add_argument('--data_dir', help='root experiment folder', default='data/multi_woz/')
16 | parser.add_argument('--only_domain', help='', default='')
17 | args = parser.parse_args()
18 | return args
19 |
20 | args = get_args()
21 | data_dir = args.data_dir
22 | only_domain = args.only_domain
23 | # root_dir = os.path.dirname(__file__)
24 | # data_dir = os.path.join(root_dir, 'data', 'woz')
25 |
26 |
27 | # furl = 'https://mi.eng.cam.ac.uk/~nm480/woz_2.0.zip'
28 | # fzip = os.path.join(data_dir, 'woz.zip')
29 |
30 | draw = os.path.join(data_dir, 'raw')
31 | dann = os.path.join(data_dir, only_domain+'ann')
32 |
33 | #dann_domain = os.path.join(dann, only_domain)
34 |
35 | splits = ['dev', 'train', 'test']
36 |
37 | def download(url, to_file):
38 | r = requests.get(url, stream=True)
39 | with open(to_file, 'wb') as f:
40 | for chunk in r.iter_content(chunk_size=1024):
41 | if chunk:
42 | f.write(chunk)
43 |
44 |
45 | def missing_files(d, files):
46 | return not all([os.path.isfile(os.path.join(d, '{}.json'.format(s))) for s in files])
47 |
48 |
49 | if __name__ == '__main__':
50 | # if not os.path.isfile(fzip):
51 | # if not os.path.isdir(data_dir):
52 | # os.makedirs(data_dir)
53 | # logging.warn('Download from {} to {}'.format(furl, fzip))
54 | # download(furl, fzip)
55 |
56 | # if missing_files(draw, splits):
57 | # if not os.path.isdir(draw):
58 | # os.makedirs(draw)
59 | # with zipfile.ZipFile(fzip) as f:
60 | # logging.warn('Extracting {} to {}'.format(fzip, draw))
61 | # for split in splits:
62 | # with f.open('woz_2.0/woz2_{}.json'.format(split)) as fin, open(os.path.join(draw, '{}.json'.format(split)), 'wb') as fout:
63 | # fout.write(fin.read())
64 |
65 | if missing_files(dann, files=splits + ['ontology', 'vocab', 'emb']):
66 | if not os.path.isdir(dann):
67 | os.makedirs(dann)
68 | dataset = {}
69 | ontology = Ontology()
70 | vocab = Vocab()
71 | vocab.word2index(['', ''], train=True)
72 | for s in splits:
73 | fname = '{}.json'.format(s)
74 | logging.warn('Annotating {}'.format(s))
75 | dataset[s] = Dataset.annotate_raw(os.path.join(draw, fname), only_domain)
76 | dataset[s].numericalize_(vocab)
77 | ontology = ontology + dataset[s].extract_ontology()
78 | with open(os.path.join(dann, fname), 'wt') as f:
79 | json.dump(dataset[s].to_dict(), f)
80 | ontology.numericalize_(vocab)
81 | with open(os.path.join(dann, 'ontology.json'), 'wt') as f:
82 | json.dump(ontology.to_dict(), f)
83 | with open(os.path.join(dann, 'vocab.json'), 'wt') as f:
84 | json.dump(vocab.to_dict(), f)
85 |
86 | logging.warn('Computing word embeddings')
87 | embeddings = [GloveEmbedding(), KazumaCharEmbedding()]
88 | E = []
89 | for w in tqdm(vocab._index2word):
90 | e = []
91 | for emb in embeddings:
92 | e += emb.emb(w, default='zero')
93 | E.append(e)
94 | with open(os.path.join(dann, 'emb.json'), 'wt') as f:
95 | json.dump(E, f)
96 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | import numpy as np
4 | from tqdm import tqdm
5 | from stanza.nlp.corenlp import CoreNLPClient
6 |
7 |
8 | client = None
9 |
10 | # do not change
11 | ACTIVATED_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"]
12 |
13 |
14 | def annotate(sent):
15 | # global client
16 | # if client is None:
17 | # client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(','))
18 | # words = []
19 | # for sent in client.annotate(sent).sentences:
20 | # for tok in sent:
21 | # words.append(tok.word)
22 | words = sent.split()
23 | return words
24 |
25 |
26 | class Turn:
27 |
28 | def __init__(self, turn_id, transcript, turn_label, belief_state, system_acts, system_transcript, num=None):
29 | self.id = turn_id
30 | self.transcript = transcript
31 | self.turn_label = turn_label
32 | self.belief_state = belief_state
33 | self.system_acts = system_acts
34 | self.system_transcript = system_transcript
35 | self.num = num or {}
36 |
37 | def to_dict(self):
38 | return {'turn_id': self.id, 'transcript': self.transcript, 'turn_label': self.turn_label, 'belief_state': self.belief_state, 'system_acts': self.system_acts, 'system_transcript': self.system_transcript, 'num': self.num}
39 |
40 | @classmethod
41 | def from_dict(cls, alld, i, d):
42 | # if i!=0:
43 | # d["transcript"] = alld[i-1]["transcript"] + [";"] + d["transcript"]
44 | return cls(**d)
45 |
46 | @classmethod
47 | def annotate_raw(cls, raw, only_domain=""):
48 | system_acts = []
49 | for a in raw['system_acts']:
50 | if isinstance(a, list):
51 | s, v = a
52 | system_acts.append(['inform'] + s.split() + ['='] + v.split())
53 | else:
54 | system_acts.append(['request'] + a.split())
55 | # NOTE: fix inconsistencies in data label
56 | fix = {'centre': 'center', 'areas': 'area', 'phone number': 'number'}
57 | if only_domain!="":
58 | return cls(
59 | turn_id=raw['turn_idx'],
60 | transcript=annotate(raw['transcript']),
61 | system_acts=system_acts,
62 | turn_label=[[fix.get(s.strip(), s.strip()), fix.get(v.strip(), v.strip())] for s, v in raw['turn_label'] if only_domain in s],
63 | belief_state=[bs for bs in raw['belief_state'] if only_domain in bs["slots"][0][0]],
64 | system_transcript=annotate(raw['system_transcript']),
65 | )
66 | else:
67 | return cls(
68 | turn_id=raw['turn_idx'],
69 | transcript=annotate(raw['transcript']),
70 | system_acts=system_acts,
71 | turn_label=[[fix.get(s.strip(), s.strip()), fix.get(v.strip(), v.strip())] for s, v in raw['turn_label'] if s.split("-")[0] in ACTIVATED_DOMAINS],
72 | belief_state=[bs for bs in raw['belief_state'] if bs["slots"][0][0].split("-")[0] in ACTIVATED_DOMAINS], #raw['belief_state'],
73 | system_transcript=annotate(raw['system_transcript']),
74 | )
75 |
76 | def numericalize_(self, vocab):
77 | self.num['transcript'] = vocab.word2index([''] + [w.lower() for w in self.transcript + ['']], train=True)
78 | self.num['system_transcript'] = vocab.word2index([''] + [w.lower() for w in self.system_transcript + ['']], train=True)
79 | self.num['system_acts'] = [vocab.word2index([''] + [w.lower() for w in a] + [''], train=True) for a in self.system_acts + [['']]]
80 |
81 |
82 | class Dialogue:
83 |
84 | def __init__(self, dialogue_id, turns):
85 | self.id = dialogue_id
86 | self.turns = turns
87 |
88 | def __len__(self):
89 | return len(self.turns)
90 |
91 | def to_dict(self):
92 | return {'dialogue_id': self.id, 'turns': [t.to_dict() for t in self.turns]}
93 |
94 | @classmethod
95 | def from_dict(cls, d):
96 | return cls(d['dialogue_id'], [Turn.from_dict(d['turns'], i, t) for i, t in enumerate(d['turns'])])
97 | # return cls(d['dialogue_id'], [Turn.from_dict(t) for t in d['turns']])
98 |
99 | @classmethod
100 | def annotate_raw(cls, raw, only_domain=""):
101 | return cls(raw['dialogue_idx'], [Turn.annotate_raw(t, only_domain) for t in raw['dialogue']])
102 |
103 |
104 | class Dataset:
105 |
106 | def __init__(self, dialogues):
107 | self.dialogues = dialogues
108 |
109 | def __len__(self):
110 | return len(self.dialogues)
111 |
112 | def iter_turns(self):
113 | for d in self.dialogues:
114 | for t in d.turns:
115 | yield t
116 |
117 | def to_dict(self):
118 | return {'dialogues': [d.to_dict() for d in self.dialogues]}
119 |
120 | @classmethod
121 | def from_dict(cls, d):
122 | return cls([Dialogue.from_dict(dd) for dd in d['dialogues']])
123 |
124 | @classmethod
125 | def annotate_raw(cls, fname, only_domain=""):
126 | with open(fname) as f:
127 | data = json.load(f)
128 | if only_domain!="":
129 | return cls([Dialogue.annotate_raw(d, only_domain) for d in tqdm(data) if only_domain in d["domains"]])
130 | else:
131 | return cls([Dialogue.annotate_raw(d) for d in tqdm(data)])
132 |
133 | def numericalize_(self, vocab):
134 | for t in self.iter_turns():
135 | t.numericalize_(vocab)
136 |
137 | def extract_ontology(self):
138 | slots = set()
139 | values = defaultdict(set)
140 | for t in self.iter_turns():
141 | for s, v in t.turn_label:
142 | slots.add(s.lower())
143 | values[s].add(v.lower())
144 | return Ontology(sorted(list(slots)), {k: sorted(list(v)) for k, v in values.items()})
145 |
146 | def batch(self, batch_size, shuffle=False):
147 | turns = list(self.iter_turns())
148 | if shuffle:
149 | np.random.shuffle(turns)
150 | for i in tqdm(range(0, len(turns), batch_size)):
151 | yield turns[i:i+batch_size]
152 |
153 | def evaluate_preds(self, preds, Ontology=[]):
154 | request = []
155 | inform = []
156 | joint_goal = []
157 | slot_acc = []
158 | fix = {'centre': 'center', 'areas': 'area', 'phone number': 'number'}
159 | i = 0
160 | print("len of Ontology", len(Ontology.slots))
161 | for d in self.dialogues:
162 | pred_state = {}
163 | for t in d.turns:
164 | gold_request = set([(s, v) for s, v in t.turn_label if s == 'request'])
165 | gold_inform = set([(s, v) for s, v in t.turn_label if s != 'request'])
166 | pred_request = set([(s, v) for s, v in preds[i] if s == 'request'])
167 | pred_inform = set([(s, v) for s, v in preds[i] if s != 'request'])
168 | request.append(gold_request == pred_request)
169 | inform.append(gold_inform == pred_inform)
170 |
171 | slot_acc.append(self.compute_acc(gold_inform, pred_inform, len(Ontology.slots)))
172 |
173 | gold_recovered = set()
174 | pred_recovered = set()
175 | for s, v in pred_inform:
176 | pred_state[s] = v
177 | for b in t.belief_state:
178 | for s, v in b['slots']:
179 | if b['act'] != 'request':
180 | gold_recovered.add((b['act'], fix.get(s.strip(), s.strip()), fix.get(v.strip(), v.strip())))
181 | for s, v in pred_state.items():
182 | pred_recovered.add(('inform', s, v))
183 | joint_goal.append(gold_recovered == pred_recovered)
184 | i += 1
185 | return {'turn_inform': np.mean(inform), 'turn_request': np.mean(request), 'joint_goal': np.mean(joint_goal), 'slot_acc': np.mean(slot_acc)}
186 |
187 | def compute_acc(self, gold, pred, nb_slot):
188 | miss_gold = 0
189 | miss_slot = []
190 |
191 | # print(gold)
192 | # print(pred)
193 | # print(nb_slot)
194 | for g in gold:
195 | if g not in pred:
196 | miss_gold += 1
197 | miss_slot.append(g[0])
198 | wrong_pred = 0
199 | for p in pred:
200 | if p not in gold and p[0] not in miss_slot:
201 | wrong_pred += 1
202 | ACC_TOTAL = nb_slot
203 | ACC = nb_slot - miss_gold - wrong_pred
204 | ACC = ACC / float(ACC_TOTAL)
205 | return ACC
206 |
207 | def record_preds(self, preds, to_file):
208 | data = self.to_dict()
209 | i = 0
210 | for d in data['dialogues']:
211 | for t in d['turns']:
212 | t['pred'] = sorted(list(preds[i]))
213 | i += 1
214 | with open(to_file, 'wt') as f:
215 | json.dump(data, f)
216 |
217 |
218 | class Ontology:
219 |
220 | def __init__(self, slots=None, values=None, num=None):
221 | self.slots = slots or []
222 | self.values = values or {}
223 | self.num = num or {}
224 |
225 | def __add__(self, another):
226 | new_slots = sorted(list(set(self.slots + another.slots)))
227 | new_values = {s: sorted(list(set(self.values.get(s, []) + another.values.get(s, [])))) for s in new_slots}
228 | return Ontology(new_slots, new_values)
229 |
230 | def __radd__(self, another):
231 | return self if another == 0 else self.__add__(another)
232 |
233 | def to_dict(self):
234 | return {'slots': self.slots, 'values': self.values, 'num': self.num}
235 |
236 | def numericalize_(self, vocab):
237 | self.num = {}
238 | for s, vs in self.values.items():
239 | self.num[s] = [vocab.word2index(annotate('{} = {}'.format(s, v)) + [''], train=True) for v in vs]
240 |
241 | @classmethod
242 | def from_dict(cls, d):
243 | return cls(**d)
244 |
--------------------------------------------------------------------------------
/models/gce.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import optim
4 | from torch.nn import functional as F
5 | import numpy as np
6 | import logging
7 | import os
8 | import re
9 | import json
10 | from collections import defaultdict
11 | from pprint import pformat
12 | import ipdb
13 |
14 |
15 | def position_encoding_init(n_position, d_pos_vec):
16 | ''' Init the sinusoid position encoding table '''
17 |
18 | # keep dim 0 for padding token position encoding zero vector
19 | position_enc = np.array([
20 | [pos / np.power(10000, 2 * (j // 2) / d_pos_vec) for j in range(d_pos_vec)]
21 | if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
22 |
23 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
24 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
25 | return torch.from_numpy(position_enc).type(torch.FloatTensor).cuda()
26 |
27 |
28 | def pad(seqs, emb, device, pad=0):
29 | lens = [len(s) for s in seqs]
30 | max_len = max(lens)
31 | padded = torch.LongTensor([s + (max_len-l) * [pad] for s, l in zip(seqs, lens)])
32 | return emb(padded.to(device)), lens
33 |
34 |
35 | def run_rnn(rnn, inputs, lens):
36 | # sort by lens
37 | order = np.argsort(lens)[::-1].tolist()
38 | reindexed = inputs.index_select(0, inputs.data.new(order).long())
39 | reindexed_lens = [lens[i] for i in order]
40 | packed = nn.utils.rnn.pack_padded_sequence(reindexed, reindexed_lens, batch_first=True)
41 | outputs, _ = rnn(packed)
42 | padded, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=0.)
43 | reverse_order = np.argsort(order).tolist()
44 | recovered = padded.index_select(0, inputs.data.new(reverse_order).long())
45 | return recovered
46 |
47 |
48 | def attend(seq, cond, lens):
49 | """
50 | attend over the sequences `seq` using the condition `cond`.
51 | """
52 | if cond.ndimension() < 3:
53 | scores = cond.unsqueeze(1).expand_as(seq).mul(seq).sum(2)
54 | else:
55 | scores = cond.expand_as(seq).mul(seq).sum(2)
56 | max_len = max(lens)
57 | for i, l in enumerate(lens):
58 | if l < max_len:
59 | scores.data[i, l:] = -np.inf
60 | scores = F.softmax(scores, dim=1)
61 | context = scores.unsqueeze(2).expand_as(seq).mul(seq).sum(1)
62 | return context, scores
63 |
64 |
65 | class FixedEmbedding(nn.Embedding):
66 | """
67 | this is the same as `nn.Embedding` but detaches the result from the graph and has dropout after lookup.
68 | """
69 |
70 | def __init__(self, *args, dropout=0, **kwargs):
71 | super().__init__(*args, **kwargs)
72 | self.dropout = dropout
73 |
74 | def forward(self, *args, **kwargs):
75 | out = super().forward(*args, **kwargs)
76 | out.detach_()
77 | return F.dropout(out, self.dropout, self.training)
78 |
79 |
80 | class SelfAttention_gce(nn.Module):
81 |
82 | def __init__(self, dhid, dropout=0.):
83 | super().__init__()
84 | self.conv = nn.Conv1d(2 * dhid, 1, 5, padding=2)
85 | #self.dropout = nn.Dropout(dropout)
86 | #self.layer_norm = nn.LayerNorm(self.dv)
87 |
88 | def forward(self, inp, lens, cond):
89 | batch, seq_len, d_feat = inp.size()
90 | concat = torch.cat((cond.unsqueeze(0).expand_as(inp), inp), dim=2)
91 | attention = self.conv(concat.transpose(2, 1))
92 | scores = F.softmax(attention, dim=2)
93 | context = scores.bmm(inp)
94 | return context
95 |
96 |
97 | class GCEencoder(nn.Module):
98 | """
99 | the GCE encoder described in https://arxiv.org/abs/1805.09655.
100 | """
101 |
102 | def __init__(self, din, dhid, slots, dropout=None):
103 | super().__init__()
104 | self.dropout = dropout or {}
105 | self.global_rnn = nn.LSTM(2 * din, dhid, bidirectional=True, batch_first=True)
106 | self.global_selfattn = SelfAttention_gce(2 * dhid, dropout=self.dropout.get('selfattn', 0.))
107 | self.slots = slots
108 | self.beta_raw = nn.Parameter(torch.Tensor(len(slots)))
109 | nn.init.uniform_(self.beta_raw, -0.01, 0.01)
110 |
111 | def beta(self, slot):
112 | return F.sigmoid(self.beta_raw[self.slots.index(slot)])
113 |
114 | def forward(self, x, x_len, slot, slot_emb, default_dropout=0.2):
115 | beta = self.beta(slot)
116 | x_new = torch.cat((slot_emb.unsqueeze(0).expand_as(x), x), dim=2)
117 | global_h = run_rnn(self.global_rnn, x_new, x_len)
118 | h = F.dropout(global_h, self.dropout.get('global', default_dropout), self.training) * (1-beta)
119 | c = F.dropout(self.global_selfattn(h, x_len, slot_emb), self.dropout.get('global', default_dropout), self.training) * (1-beta)
120 | return h, c
121 |
122 |
123 | class Model(nn.Module):
124 | """
125 | the GLAD scoring model described in https://arxiv.org/abs/1805.09655.
126 | """
127 |
128 | def __init__(self, args, ontology, vocab):
129 | super().__init__()
130 | self.optimizer = None
131 | self.args = args
132 | self.vocab = vocab
133 | self.ontology = ontology
134 | self.emb_fixed = FixedEmbedding(len(vocab), args.demb, dropout=args.dropout.get('emb', 0.2))
135 | self.encoder = GCEencoder
136 |
137 | self.utt_encoder = self.encoder(args.demb, args.dhid, self.ontology.slots, dropout=args.dropout)
138 | self.act_encoder = self.encoder(args.demb, args.dhid, self.ontology.slots, dropout=args.dropout)
139 | self.ont_encoder = self.encoder(args.demb, args.dhid, self.ontology.slots, dropout=args.dropout)
140 | self.utt_scorer = nn.Linear(2 * args.dhid, 1)
141 | self.score_weight = nn.Parameter(torch.Tensor([0.5]))
142 |
143 | @property
144 | def device(self):
145 | if self.args.gpu is not None and torch.cuda.is_available():
146 | return torch.device('cuda')
147 | else:
148 | return torch.device('cpu')
149 |
150 | def set_optimizer(self):
151 | self.optimizer = optim.Adam(self.parameters(), lr=self.args.lr)
152 |
153 | def load_emb(self, Eword):
154 | new = self.emb_fixed.weight.data.new
155 | self.emb_fixed.weight.data.copy_(new(Eword))
156 |
157 | def forward(self, batch):
158 | # convert to variables and look up embeddings
159 | eos = self.vocab.word2index('')
160 | ontology = {s: pad(v, self.emb_fixed, self.device, pad=eos) for s, v in self.ontology.num.items()}
161 | utterance, utterance_len = pad([e.num['transcript'] for e in batch], self.emb_fixed, self.device, pad=eos)
162 | acts = [pad(e.num['system_acts'], self.emb_fixed, self.device, pad=eos) for e in batch]
163 | ys = {}
164 | for s in self.ontology.slots:
165 | # for each slot, compute the scores for each value
166 |
167 | s_words = s.split()
168 | s_new = s_words[0]
169 | s_emb = self.emb_fixed(torch.cuda.LongTensor([self.vocab.word2index(s_new)]))
170 | H_utt, c_utt = self.utt_encoder(utterance, utterance_len, slot=s, slot_emb=s_emb)
171 | _, C_acts = list(zip(*[self.act_encoder(a, a_len, slot=s, slot_emb=s_emb) for a, a_len in acts]))
172 | _, C_vals = self.ont_encoder(ontology[s][0], ontology[s][1], slot=s, slot_emb=s_emb)
173 |
174 | # compute the utterance score
175 | y_utts = []
176 | q_utts = []
177 | for c_val in C_vals:
178 | c_val = c_val.squeeze(0)
179 | q_utt, _ = attend(H_utt, c_val.unsqueeze(0).expand(len(batch), *c_val.size()), lens=utterance_len)
180 | q_utts.append(q_utt)
181 | y_utts = self.utt_scorer(torch.stack(q_utts, dim=1)).squeeze(2)
182 |
183 | # compute the previous action score
184 | q_acts = []
185 | for i, C_act in enumerate(C_acts):
186 | q_act, _ = attend(C_act.unsqueeze(0), c_utt[i].unsqueeze(0), lens=[C_act.size(0)])
187 | q_acts.append(q_act)
188 |
189 | y_acts = torch.cat(q_acts, dim=0).squeeze().mm(C_vals.squeeze().transpose(0, 1))
190 |
191 | # combine the scores
192 | ys[s] = F.sigmoid(y_utts + self.score_weight * y_acts)
193 |
194 | if self.training:
195 | # create label variable and compute loss
196 | labels = {s: [len(self.ontology.values[s]) * [0] for i in range(len(batch))] for s in self.ontology.slots}
197 | for i, e in enumerate(batch):
198 | for s, v in e.turn_label:
199 | labels[s][i][self.ontology.values[s].index(v)] = 1
200 | labels = {s: torch.Tensor(m).to(self.device) for s, m in labels.items()}
201 |
202 | loss = 0
203 | for s in self.ontology.slots:
204 | loss += F.binary_cross_entropy(ys[s], labels[s])
205 | else:
206 | loss = torch.Tensor([0]).to(self.device)
207 | return loss, {s: v.data.tolist() for s, v in ys.items()}
208 |
209 | def get_train_logger(self):
210 | logger = logging.getLogger('train-{}'.format(self.__class__.__name__))
211 | logger.setLevel(logging.INFO)
212 | formatter = logging.Formatter('%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s')
213 | file_handler = logging.FileHandler(os.path.join(self.args.dout, 'train.log'))
214 | file_handler.setFormatter(formatter)
215 | logger.addHandler(file_handler)
216 | return logger
217 |
218 | def run_train(self, train, dev, args):
219 | track = defaultdict(list)
220 | iteration = 0
221 | best = {}
222 | logger = self.get_train_logger()
223 | if self.optimizer is None:
224 | self.set_optimizer()
225 |
226 | for epoch in range(args.epoch):
227 | logger.info('starting epoch {}'.format(epoch))
228 |
229 | # train and update parameters
230 | self.train()
231 | for batch in train.batch(batch_size=args.batch_size, shuffle=True):
232 | iteration += 1
233 | self.zero_grad()
234 | loss, scores = self.forward(batch)
235 | loss.backward()
236 | self.optimizer.step()
237 | track['loss'].append(loss.item())
238 |
239 | # evalute on train and dev
240 | summary = {'iteration': iteration, 'epoch': epoch}
241 | for k, v in track.items():
242 | summary[k] = sum(v) / len(v)
243 | summary.update({'eval_train_{}'.format(k): v for k, v in self.run_eval(train, args).items()})
244 | summary.update({'eval_dev_{}'.format(k): v for k, v in self.run_eval(dev, args).items()})
245 |
246 | # do early stopping saves
247 | stop_key = 'eval_dev_{}'.format(args.stop)
248 | train_key = 'eval_train_{}'.format(args.stop)
249 | if best.get(stop_key, 0) <= summary[stop_key]:
250 | best_dev = '{:f}'.format(summary[stop_key])
251 | best_train = '{:f}'.format(summary[train_key])
252 | best.update(summary)
253 | self.save(
254 | best,
255 | identifier='epoch={epoch},iter={iteration},train_{key}={train},dev_{key}={dev}'.format(
256 | epoch=epoch, iteration=iteration, train=best_train, dev=best_dev, key=args.stop,
257 | )
258 | )
259 | self.prune_saves()
260 | dev.record_preds(
261 | preds=self.run_pred(dev, self.args),
262 | to_file=os.path.join(self.args.dout, 'dev.pred.json'),
263 | )
264 | summary.update({'best_{}'.format(k): v for k, v in best.items()})
265 | logger.info(pformat(summary))
266 | track.clear()
267 |
268 | def extract_predictions(self, scores, threshold=0.5):
269 | batch_size = len(list(scores.values())[0])
270 | predictions = [set() for i in range(batch_size)]
271 | for s in self.ontology.slots:
272 | for i, p in enumerate(scores[s]):
273 | triggered = [(s, v, p_v) for v, p_v in zip(self.ontology.values[s], p) if p_v > threshold]
274 | if s == 'request':
275 | # we can have multiple requests predictions
276 | predictions[i] |= set([(s, v) for s, v, p_v in triggered])
277 | elif triggered:
278 | # only extract the top inform prediction
279 | sort = sorted(triggered, key=lambda tup: tup[-1], reverse=True)
280 | predictions[i].add((sort[0][0], sort[0][1]))
281 | return predictions
282 |
283 | def run_pred(self, dev, args):
284 | self.eval()
285 | predictions = []
286 | for batch in dev.batch(batch_size=args.batch_size):
287 | loss, scores = self.forward(batch)
288 | predictions += self.extract_predictions(scores)
289 | return predictions
290 |
291 | def run_eval(self, dev, args):
292 | predictions = self.run_pred(dev, args)
293 | return dev.evaluate_preds(predictions)
294 |
295 | def save_config(self):
296 | fname = '{}/config.json'.format(self.args.dout)
297 | with open(fname, 'wt') as f:
298 | logging.info('saving config to {}'.format(fname))
299 | json.dump(vars(self.args), f, indent=2)
300 |
301 | @classmethod
302 | def load_config(cls, fname, ontology, **kwargs):
303 | with open(fname) as f:
304 | logging.info('loading config from {}'.format(fname))
305 | args = object()
306 | for k, v in json.load(f):
307 | setattr(args, k, kwargs.get(k, v))
308 | return cls(args, ontology)
309 |
310 | def save(self, summary, identifier):
311 | fname = '{}/{}.t7'.format(self.args.dout, identifier)
312 | logging.info('saving model to {}'.format(fname))
313 | state = {
314 | 'args': vars(self.args),
315 | 'model': self.state_dict(),
316 | 'summary': summary,
317 | 'optimizer': self.optimizer.state_dict(),
318 | }
319 | torch.save(state, fname)
320 |
321 | def load(self, fname):
322 | logging.info('loading model from {}'.format(fname))
323 | state = torch.load(fname)
324 | self.load_state_dict(state['model'])
325 | self.set_optimizer()
326 | self.optimizer.load_state_dict(state['optimizer'])
327 |
328 | def get_saves(self, directory=None):
329 | if directory is None:
330 | directory = self.args.dout
331 | files = [f for f in os.listdir(directory) if f.endswith('.t7')]
332 | scores = []
333 | for fname in files:
334 | re_str = r'dev_{}=([0-9\.]+)'.format(self.args.stop)
335 | dev_acc = re.findall(re_str, fname)
336 | if dev_acc:
337 | score = float(dev_acc[0].strip('.'))
338 | scores.append((score, os.path.join(directory, fname)))
339 | if not scores:
340 | raise Exception('No files found!')
341 | scores.sort(key=lambda tup: tup[0], reverse=True)
342 | return scores
343 |
344 | def prune_saves(self, n_keep=5):
345 | scores_and_files = self.get_saves()
346 | if len(scores_and_files) > n_keep:
347 | for score, fname in scores_and_files[n_keep:]:
348 | os.remove(fname)
349 |
350 | def load_best_save(self, directory):
351 | if directory is None:
352 | directory = self.args.dout
353 |
354 | scores_and_files = self.get_saves(directory=directory)
355 | if scores_and_files:
356 | assert scores_and_files, 'no saves exist at {}'.format(directory)
357 | score, fname = scores_and_files[0]
358 | self.load(fname)
359 |
--------------------------------------------------------------------------------
/create_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import copy
3 | import json
4 | import os
5 | import re
6 | import shutil
7 | import urllib
8 | from collections import OrderedDict
9 | from io import BytesIO
10 | from zipfile import ZipFile
11 | import difflib
12 | import numpy as np
13 |
14 | np.set_printoptions(precision=3)
15 |
16 | np.random.seed(2)
17 |
18 |
19 | '''
20 | Most of the codes are from https://github.com/budzianowski/multiwoz
21 | '''
22 |
23 |
24 | # GLOBAL VARIABLES
25 | DICT_SIZE = 400
26 | MAX_LENGTH = 50
27 | IGNORE_KEYS_IN_GOAL = ['eod', 'topic', 'messageLen', 'message']
28 |
29 | fin = file('utils/mapping.pair')
30 | replacements = []
31 | for line in fin.readlines():
32 | tok_from, tok_to = line.replace('\n', '').split('\t')
33 | replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
34 |
35 |
36 | def is_ascii(s):
37 | return all(ord(c) < 128 for c in s)
38 |
39 | def insertSpace(token, text):
40 | sidx = 0
41 | while True:
42 | sidx = text.find(token, sidx)
43 | if sidx == -1:
44 | break
45 | if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
46 | re.match('[0-9]', text[sidx + 1]):
47 | sidx += 1
48 | continue
49 | if text[sidx - 1] != ' ':
50 | text = text[:sidx] + ' ' + text[sidx:]
51 | sidx += 1
52 | if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
53 | text = text[:sidx + 1] + ' ' + text[sidx + 1:]
54 | sidx += 1
55 | return text
56 |
57 | def normalize(text, clean_value=True):
58 | # lower case every word
59 | text = text.lower()
60 |
61 | # replace white spaces in front and end
62 | text = re.sub(r'^\s*|\s*$', '', text)
63 |
64 | # hotel domain pfb30
65 | text = re.sub(r"b&b", "bed and breakfast", text)
66 | text = re.sub(r"b and b", "bed and breakfast", text)
67 |
68 | if clean_value:
69 | # normalize phone number
70 | ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text)
71 | if ms:
72 | sidx = 0
73 | for m in ms:
74 | sidx = text.find(m[0], sidx)
75 | if text[sidx - 1] == '(':
76 | sidx -= 1
77 | eidx = text.find(m[-1], sidx) + len(m[-1])
78 | text = text.replace(text[sidx:eidx], ''.join(m))
79 |
80 | # normalize postcode
81 | ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})',
82 | text)
83 | if ms:
84 | sidx = 0
85 | for m in ms:
86 | sidx = text.find(m, sidx)
87 | eidx = sidx + len(m)
88 | text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:]
89 |
90 | # weird unicode bug
91 | text = re.sub(u"(\u2018|\u2019)", "'", text)
92 |
93 | if clean_value:
94 | # replace time and and price
95 | text = re.sub(timepat, ' [value_time] ', text)
96 | text = re.sub(pricepat, ' [value_price] ', text)
97 | #text = re.sub(pricepat2, '[value_price]', text)
98 |
99 | # replace st.
100 | text = text.replace(';', ',')
101 | text = re.sub('$\/', '', text)
102 | text = text.replace('/', ' and ')
103 |
104 | # replace other special characters
105 | text = text.replace('-', ' ')
106 | text = re.sub('[\"\<>@\(\)]', '', text) # remove
107 |
108 | # insert white space before and after tokens:
109 | for token in ['?', '.', ',', '!']:
110 | text = insertSpace(token, text)
111 |
112 | # insert white space for 's
113 | text = insertSpace('\'s', text)
114 |
115 | # replace it's, does't, you'd ... etc
116 | text = re.sub('^\'', '', text)
117 | text = re.sub('\'$', '', text)
118 | text = re.sub('\'\s', ' ', text)
119 | text = re.sub('\s\'', ' ', text)
120 | for fromx, tox in replacements:
121 | text = ' ' + text + ' '
122 | text = text.replace(fromx, tox)[1:-1]
123 |
124 | # remove multiple spaces
125 | text = re.sub(' +', ' ', text)
126 |
127 | # concatenate numbers
128 | tmp = text
129 | tokens = text.split()
130 | i = 1
131 | while i < len(tokens):
132 | if re.match(u'^\d+$', tokens[i]) and \
133 | re.match(u'\d+$', tokens[i - 1]):
134 | tokens[i - 1] += tokens[i]
135 | del tokens[i]
136 | else:
137 | i += 1
138 | text = ' '.join(tokens)
139 |
140 | return text
141 |
142 | def fixDelex(filename, data, data2, idx, idx_acts):
143 | """Given system dialogue acts fix automatic delexicalization."""
144 | try:
145 | turn = data2[filename.strip('.json')][str(idx_acts)]
146 | except:
147 | return data
148 |
149 | if not isinstance(turn, str) and not isinstance(turn, unicode):
150 | for k, act in turn.items():
151 | if 'Attraction' in k:
152 | if 'restaurant_' in data['log'][idx]['text']:
153 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "attraction")
154 | if 'hotel_' in data['log'][idx]['text']:
155 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "attraction")
156 | if 'Hotel' in k:
157 | if 'attraction_' in data['log'][idx]['text']:
158 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "hotel")
159 | if 'restaurant_' in data['log'][idx]['text']:
160 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("restaurant", "hotel")
161 | if 'Restaurant' in k:
162 | if 'attraction_' in data['log'][idx]['text']:
163 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("attraction", "restaurant")
164 | if 'hotel_' in data['log'][idx]['text']:
165 | data['log'][idx]['text'] = data['log'][idx]['text'].replace("hotel", "restaurant")
166 |
167 | return data
168 |
169 |
170 | def getDialogueAct(filename, data, data2, idx, idx_acts):
171 | """Given system dialogue acts fix automatic delexicalization."""
172 | acts = []
173 | try:
174 | turn = data2[filename.strip('.json')][str(idx_acts)]
175 | except:
176 | return acts
177 |
178 | if not isinstance(turn, str) and not isinstance(turn, unicode):
179 | for k in turn.keys():
180 | # temp = [k.split('-')[0].lower(), k.split('-')[1].lower()]
181 | # for a in turn[k]:
182 | # acts.append(temp + [a[0].lower()])
183 |
184 | if k.split('-')[1].lower() == 'request':
185 | for a in turn[k]:
186 | acts.append(a[0].lower())
187 | elif k.split('-')[1].lower() == 'inform':
188 | for a in turn[k]:
189 | acts.append([a[0].lower(), normalize(a[1].lower(), False)])
190 |
191 | return acts
192 |
193 |
194 | def get_summary_bstate(bstate, get_domain=False):
195 | """Based on the mturk annotations we form multi-domain belief state"""
196 | domains = [u'taxi',u'restaurant', u'hospital', u'hotel',u'attraction', u'train', u'police']
197 | summary_bstate = []
198 | summary_bvalue = []
199 | active_domain = []
200 | for domain in domains:
201 | domain_active = False
202 |
203 | booking = []
204 | #print(domain,len(bstate[domain]['book'].keys()))
205 | for slot in sorted(bstate[domain]['book'].keys()):
206 | if slot == 'booked':
207 | if len(bstate[domain]['book']['booked'])!=0:
208 | booking.append(1)
209 | # summary_bvalue.append("book {} {}:{}".format(domain, slot, "Yes"))
210 | else:
211 | booking.append(0)
212 | else:
213 | if bstate[domain]['book'][slot] != "":
214 | booking.append(1)
215 | summary_bvalue.append(["{}-book {}".format(domain, slot.strip().lower()), normalize(bstate[domain]['book'][slot].strip().lower(), False)]) #(["book", domain, slot, bstate[domain]['book'][slot]])
216 | else:
217 | booking.append(0)
218 | if domain == 'train':
219 | if 'people' not in bstate[domain]['book'].keys():
220 | booking.append(0)
221 | if 'ticket' not in bstate[domain]['book'].keys():
222 | booking.append(0)
223 | summary_bstate += booking
224 |
225 | for slot in bstate[domain]['semi']:
226 | slot_enc = [0, 0, 0] # not mentioned, dontcare, filled
227 | if bstate[domain]['semi'][slot] == 'not mentioned':
228 | slot_enc[0] = 1
229 | elif bstate[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", "do not care"]:
230 | slot_enc[1] = 1
231 | summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), "dontcare"]) #(["semi", domain, slot, "dontcare"])
232 | elif bstate[domain]['semi'][slot]:
233 | summary_bvalue.append(["{}-{}".format(domain, slot.strip().lower()), normalize(bstate[domain]['semi'][slot].strip().lower(), False)]) #(["semi", domain, slot, bstate[domain]['semi'][slot]])
234 | if slot_enc != [0, 0, 0]:
235 | domain_active = True
236 | summary_bstate += slot_enc
237 |
238 | # quasi domain-tracker
239 | if domain_active:
240 | summary_bstate += [1]
241 | active_domain.append(domain)
242 | else:
243 | summary_bstate += [0]
244 |
245 | #print(len(summary_bstate))
246 | assert len(summary_bstate) == 94
247 | if get_domain:
248 | return active_domain
249 | else:
250 | return summary_bstate, summary_bvalue
251 |
252 |
253 | def analyze_dialogue(dialogue, maxlen):
254 | """Cleaning procedure for all kinds of errors in text and annotation."""
255 | d = dialogue
256 | # do all the necessary postprocessing
257 | if len(d['log']) % 2 != 0:
258 | #print path
259 | print 'odd # of turns'
260 | return None # odd number of turns, wrong dialogue
261 | d_pp = {}
262 | d_pp['goal'] = d['goal'] # for now we just copy the goal
263 | usr_turns = []
264 | sys_turns = []
265 | # last_bvs = []
266 | for i in range(len(d['log'])):
267 | if len(d['log'][i]['text'].split()) > maxlen:
268 | print 'too long'
269 | return None # too long sentence, wrong dialogue
270 | if i % 2 == 0: # usr turn
271 | text = d['log'][i]['text']
272 | if not is_ascii(text):
273 | print 'not ascii'
274 | return None
275 | usr_turns.append(d['log'][i])
276 | else: # sys turn
277 | text = d['log'][i]['text']
278 | if not is_ascii(text):
279 | print 'not ascii'
280 | return None
281 | belief_summary, belief_value_summary = get_summary_bstate(d['log'][i]['metadata'])
282 | d['log'][i]['belief_summary'] = str(belief_summary)
283 | d['log'][i]['belief_value_summary'] = belief_value_summary
284 | sys_turns.append(d['log'][i])
285 | d_pp['usr_log'] = usr_turns
286 | d_pp['sys_log'] = sys_turns
287 |
288 | return d_pp
289 |
290 |
291 | def get_dial(dialogue):
292 | """Extract a dialogue from the file"""
293 | dial = []
294 | d_orig = analyze_dialogue(dialogue, MAX_LENGTH) # max turn len is 50 words
295 | if d_orig is None:
296 | return None
297 | usr = [t['text'] for t in d_orig['usr_log']]
298 | sys = [t['text'] for t in d_orig['sys_log']]
299 | sys_a = [t['dialogue_acts'] for t in d_orig['sys_log']]
300 | bvs = [t['belief_value_summary'] for t in d_orig['sys_log']]
301 | domain = [t['domain'] for t in d_orig['usr_log']]
302 | for item in zip(usr, sys, sys_a, domain, bvs):
303 | dial.append({'usr':item[0],'sys':item[1], 'sys_a':item[2], 'domain':item[3], 'bvs':item[4]})
304 | return dial
305 |
306 |
307 | def loadData():
308 | data_url = "data/multi-woz/data.json"
309 | dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/280608/MULTIWOZ2.zip?sequence=3&isAllowed=y"
310 | if not os.path.exists("data"):
311 | os.makedirs("data")
312 | os.makedirs("data/multi-woz")
313 |
314 | if not os.path.exists(data_url):
315 | print("Downloading and unzipping the MultiWOZ dataset")
316 | resp = urllib.urlopen(dataset_url)
317 | zip_ref = ZipFile(BytesIO(resp.read()))
318 | zip_ref.extractall("data/multi-woz")
319 | zip_ref.close()
320 | shutil.copy('data/multi-woz/MULTIWOZ2 2/data.json', 'data/multi-woz/')
321 | shutil.copy('data/multi-woz/MULTIWOZ2 2/valListFile.json', 'data/multi-woz/')
322 | shutil.copy('data/multi-woz/MULTIWOZ2 2/testListFile.json', 'data/multi-woz/')
323 | shutil.copy('data/multi-woz/MULTIWOZ2 2/dialogue_acts.json', 'data/multi-woz/')
324 |
325 |
326 | def getDomain(idx, log, domains, last_domain):
327 | if idx == 1:
328 | active_domains = get_summary_bstate(log[idx]["metadata"], True)
329 | crnt_doms = active_domains[0] if len(active_domains)!=0 else domains[0]
330 | return crnt_doms
331 | else:
332 | ds_diff = get_ds_diff(log[idx-2]["metadata"], log[idx]["metadata"])
333 | if len(ds_diff.keys()) == 0: # no clues from dialog states
334 | crnt_doms = last_domain
335 | else:
336 | crnt_doms = ds_diff.keys()
337 | return crnt_doms[0] # How about multiple domains in one sentence senario ?
338 |
339 |
340 | def get_ds_diff(prev_d, crnt_d):
341 | diff = {}
342 | # Sometimes, metadata is an empty dictionary, bug?
343 | if not prev_d or not crnt_d:
344 | return diff
345 |
346 | for ((k1, v1), (k2, v2)) in zip(prev_d.items(), crnt_d.items()):
347 | assert k1 == k2
348 | if v1 != v2: # updated
349 | diff[k2] = v2
350 | return diff
351 |
352 |
353 | def createData():
354 | # download the data
355 | loadData()
356 |
357 | # create dictionary of delexicalied values that then we will search against, order matters here!
358 | # dic = delexicalize.prepareSlotValuesIndependent()
359 | delex_data = {}
360 |
361 | fin1 = file('data/multi-woz/data.json')
362 | data = json.load(fin1)
363 |
364 | fin2 = file('data/multi-woz/dialogue_acts.json')
365 | data2 = json.load(fin2)
366 |
367 | for didx, dialogue_name in enumerate(data):
368 |
369 | dialogue = data[dialogue_name]
370 |
371 | domains = []
372 | for dom_k, dom_v in dialogue['goal'].items():
373 | if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities
374 | domains.append(dom_k)
375 |
376 | idx_acts = 1
377 | last_domain, last_slot_fill = "", []
378 | for idx, turn in enumerate(dialogue['log']):
379 | # normalization, split and delexicalization of the sentence
380 | origin_text = normalize(turn['text'], False)
381 | # origin_text = delexicalize.markEntity(origin_text, dic)
382 | dialogue['log'][idx]['text'] = origin_text
383 |
384 | if idx % 2 == 1: # if it's a system turn
385 |
386 | cur_domain = getDomain(idx, dialogue['log'], domains, last_domain)
387 | last_domain = [cur_domain]
388 |
389 | dialogue['log'][idx - 1]['domain'] = cur_domain
390 | dialogue['log'][idx]['dialogue_acts'] = getDialogueAct(dialogue_name, dialogue, data2, idx, idx_acts)
391 | idx_acts += 1
392 |
393 | # FIXING delexicalization:
394 | dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts)
395 |
396 | delex_data[dialogue_name] = dialogue
397 |
398 | # if didx > 10:
399 | # break
400 |
401 | # with open('data/multi-woz/woz2like_data.json', 'w') as outfile:
402 | # json.dump(delex_data, outfile)
403 |
404 | return delex_data
405 |
406 |
407 | def buildDelexDict(origin_sent, delex_sent):
408 | dictionary = {}
409 | s = difflib.SequenceMatcher(None, delex_sent.split(), origin_sent.split())
410 | bs = s.get_matching_blocks()
411 | for i, b in enumerate(bs):
412 | if i < len(bs)-2:
413 | a_start = b.a + b.size
414 | b_start = b.b + b.size
415 | b_end = bs[i+1].b
416 | dictionary[a_start] = " ".join(origin_sent.split()[b_start:b_end])
417 | return dictionary
418 |
419 |
420 | def divideData(data):
421 | """Given test and validation sets, divide
422 | the data for three different sets"""
423 | testListFile = []
424 | fin = file('data/multi-woz/testListFile.json')
425 | for line in fin:
426 | testListFile.append(line[:-1])
427 | fin.close()
428 |
429 | valListFile = []
430 | fin = file('data/multi-woz/valListFile.json')
431 | for line in fin:
432 | valListFile.append(line[:-1])
433 | fin.close()
434 |
435 | trainListFile = open('data/trainListFile', 'w')
436 |
437 | test_dials = []
438 | val_dials = []
439 | train_dials = []
440 |
441 | # dictionaries
442 | word_freqs_usr = OrderedDict()
443 | word_freqs_sys = OrderedDict()
444 |
445 | count_train, count_val, count_test = 0, 0, 0
446 |
447 | for dialogue_name in data:
448 | # print dialogue_name
449 | dial_item = data[dialogue_name]
450 | domains = []
451 | for dom_k, dom_v in dial_item['goal'].items():
452 | if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities
453 | domains.append(dom_k)
454 |
455 | dial = get_dial(data[dialogue_name])
456 | if dial:
457 | dialogue = {}
458 | dialogue['dialogue_idx'] = dialogue_name
459 | dialogue['domains'] = list(set(domains)) #list(set([d['domain'] for d in dial]))
460 | last_bs = []
461 | dialogue['dialogue'] = []
462 |
463 | for turn_i, turn in enumerate(dial):
464 | # usr, usr_o, sys, sys_o, sys_a, domain
465 | turn_dialog = {}
466 | turn_dialog['system_transcript'] = dial[turn_i-1]['sys'] if turn_i > 0 else ""
467 | turn_dialog['turn_idx'] = turn_i
468 | turn_dialog['belief_state'] = [{"slots": [s], "act": "inform"} for s in turn['bvs']]
469 | turn_dialog['turn_label'] = [bs["slots"][0] for bs in turn_dialog['belief_state'] if bs not in last_bs]
470 | turn_dialog['transcript'] = turn['usr']
471 | turn_dialog['system_acts'] = dial[turn_i-1]['sys_a'] if turn_i > 0 else []
472 | turn_dialog['domain'] = turn['domain']
473 | last_bs = turn_dialog['belief_state']
474 | dialogue['dialogue'].append(turn_dialog)
475 |
476 | if dialogue_name in testListFile:
477 | test_dials.append(dialogue)
478 | count_test += 1
479 | elif dialogue_name in valListFile:
480 | val_dials.append(dialogue)
481 | count_val += 1
482 | else:
483 | trainListFile.write(dialogue_name + '\n')
484 | train_dials.append(dialogue)
485 | count_train += 1
486 |
487 | print("# of dialogues: Train {}, Val {}, Test {}".format(count_train, count_val, count_test))
488 |
489 | # save all dialogues
490 | with open('data/dev_dials.json', 'wb') as f:
491 | json.dump(val_dials, f, indent=4)
492 |
493 | with open('data/test_dials.json', 'wb') as f:
494 | json.dump(test_dials, f, indent=4)
495 |
496 | with open('data/train_dials.json', 'wb') as f:
497 | json.dump(train_dials, f, indent=4)
498 |
499 | # return word_freqs_usr, word_freqs_sys
500 |
501 |
502 | def main():
503 | print('Create WOZ-like dialogues. Get yourself a coffee, this might take a while.')
504 | delex_data = createData()
505 | print('Divide dialogues...')
506 | divideData(delex_data)
507 | # print('Building dictionaries')
508 | # buildDictionaries(word_freqs_usr, word_freqs_sys)
509 |
510 |
511 | if __name__ == "__main__":
512 | main()
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 | annotate
227 | client
228 | Turn
229 | turns
230 | reque
231 | vocab
232 | GLADEncoder_global_no_rnn_conditioned_v5
233 | pad
234 | FixedEmbedding
235 | GLAD_ENCODERS
236 | encoder
237 | position_encoding_init
238 | load_model
239 |
240 |
241 |
242 |
243 |
244 |
245 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 | 1531526790091
329 |
330 |
331 | 1531526790091
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
--------------------------------------------------------------------------------