├── .gitignore ├── LICENSE ├── README.md ├── beam_search.py ├── config.py ├── data_utils.py ├── decoders.py ├── decorators.py ├── encoders.py ├── eval.py ├── eval_f1_acc.py ├── eval_sts.py ├── eval_utils.py ├── generate.py ├── ive.py ├── model_utils.py ├── models.py ├── run-vgvae.sh ├── train.py ├── train_helper.py ├── tree.py └── von_mises_fisher.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mingda Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # syntactic-template-generation 2 | A PyTorch implementation of [Controllable Paraphrase Generation with a Syntactic Exemplar](https://ttic.uchicago.edu/~mchen/papers/mchen+etal.acl19.pdf) 3 | 4 | ## Requirements 5 | 6 | - Python 3.5 7 | - PyTorch >= 1.0 8 | - NLTK 9 | - [tqdm](https://github.com/tqdm/tqdm) 10 | - [py-rouge](https://github.com/Diego999/py-rouge) 11 | - [zss](https://github.com/timtadh/zhang-shasha) 12 | 13 | ## Resource 14 | 15 | - [data and tags](https://drive.google.com/open?id=1HHDlUT_-WpedL6zNYpcN94cLwed_yyrP) 16 | - [evaluation (including multi-bleu, METEOR and a copy of Stanford CoreNLP)](https://drive.google.com/drive/folders/1FJjvMldeZrJnQd-iVXJ3KGFBLEvsndNY?usp=sharing) 17 | - [syntactic evaluation](https://drive.google.com/drive/folders/1oVjn_3xIDZbkRm50fSHDZ5nKZtJ_BFyD?usp=sharing) 18 | - [pretrained model (VGVAE+LC+WN+WPL)](https://drive.google.com/drive/folders/13pii_XG-szMG2KNSuyDn7iPFDyhnXjXm?usp=sharing) 19 | 20 | ``run_vgvae.sh`` is provided as an example for training new models. 21 | 22 | ## Generation 23 | 24 | #### Generate sentences using beam search (and evaluation) 25 | 26 | ``python generate.py -s PATH_TO_MODEL_PICKLE -v VOCAB_PICKLE -i SYNTACTIC_SEMANTIC_TEMPLATES -r REFERENCE_FILE -bs BEAM_SIZE`` 27 | 28 | The argument ``-r`` is optional. When it is specified, the following evaluation script will be executed for reporting BLUE, ROUGE-{1,2,L}, METEOR and Syntactic TED scores. 29 | 30 | ## Evaluation 31 | 32 | #### BLUE, ROUGE, METEOR and Syntactic TED scores 33 | ``python eval.py -i INPUT_FILE -r REFERENCE_FILE `` 34 | 35 | #### Labeled F1 and Tagging accuracy 36 | ``python eval_f1_acc.py -s PATH_TO_MODEL_PICKLE -v VOCAB_PICKLE -d SYNTACTIC_EVAL_DIR`` 37 | 38 | #### STS benchmark 39 | ``python eval_sts.py -s PATH_TO_MODEL_PICKLE -v VOCAB_PICKLE -d PATH_TO_STS`` 40 | 41 | 42 | ## Reference 43 | ``` 44 | @inproceedings{mchen-controllable-19, 45 | author = {Mingda Chen and Qingming Tang and Sam Wiseman and Kevin Gimpel}, 46 | title = {Controllable Paraphrase Generation with a Syntactic Exemplar}, 47 | booktitle = {Proc. of {ACL}}, 48 | year = {2019} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /beam_search.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Modified based on: https://github.com/ottokart/beam_search/blob/master/beam_search.py 5 | 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class Node: 13 | def __init__(self, parent, state, value, cost): 14 | super(Node, self).__init__() 15 | self.value = value 16 | self.parent = parent # parent Node, None for root 17 | self.state = state if state is not None else None # recurrent layer hidden state 18 | self.cum_cost = parent.cum_cost + cost if parent else cost # e.g. -log(p) of sequence up to current node (including) 19 | self.length = 1 if parent is None else parent.length + 1 20 | # self.extras = extras # can hold, for example, attention weights 21 | self._sequence = None 22 | 23 | def to_sequence(self): 24 | # Return sequence of nodes from root to current node. 25 | if not self._sequence: 26 | self._sequence = [] 27 | current_node = self 28 | while current_node: 29 | self._sequence.insert(0, current_node) 30 | current_node = current_node.parent 31 | return self._sequence 32 | 33 | def to_sequence_of_values(self): 34 | return [s.value for s in self.to_sequence()] 35 | 36 | def to_sequence_of_extras(self): 37 | return [s.extras for s in self.to_sequence()] 38 | 39 | 40 | def beam_search(initial_state, generate_function, start_id, end_id, beam_width=4, num_hypotheses=1, max_length=50): 41 | prev_history = [Node(parent=None, state=initial_state, value=start_id, cost=0.0)] 42 | hypotheses = [] 43 | 44 | for _ in range(max_length): 45 | 46 | history = [] 47 | for n in prev_history: 48 | if n.value == end_id or n.length == max_length: 49 | if n.length >= 3: 50 | hypotheses.append(n) 51 | else: 52 | history.append(n) 53 | 54 | if not history or len(hypotheses) >= num_hypotheses: 55 | break 56 | 57 | state_t, p_t = list(zip(*[generate_function(n.state, n.value) for n in history])) 58 | Y_t = [np.argsort(p_t_n)[:beam_width] for p_t_n in p_t] # no point in taking more than fits in the beam 59 | 60 | prev_history = [] 61 | for Y_t_n, p_t_n, state_t_n, n in zip(Y_t, p_t, state_t, history): 62 | Y_nll_t_n = p_t_n[Y_t_n] 63 | 64 | for y_t_n, y_nll_t_n in zip(Y_t_n, Y_nll_t_n): 65 | n_new = Node(parent=n, state=state_t_n, value=y_t_n, cost=y_nll_t_n) 66 | prev_history.append(n_new) 67 | 68 | prev_history = sorted(prev_history, key=lambda n: n.cum_cost)[:beam_width] # may move this into loop to save memory 69 | 70 | hypotheses.sort(key=lambda n: n.cum_cost) 71 | result = [[hypo.to_sequence_of_values(), hypo.cum_cost] for hypo in hypotheses] 72 | return [res for res in result][:num_hypotheses] 73 | 74 | 75 | def get_gen_fn(step_func, yvecs, zvecs): 76 | def generate_function(last_hidden_state, last_word): 77 | with torch.no_grad(): 78 | last_word = torch.LongTensor( 79 | len(yvecs), 1).fill_(last_word).to(yvecs.device) 80 | next_state, next_word_prob, _, _ = \ 81 | step_func(yvecs, zvecs, last_hidden_state, last_word) 82 | next_word_prob = next_word_prob.cpu().numpy() 83 | next_word_prob[next_word_prob < 0] = 0 84 | return next_state, - np.log(next_word_prob[0][0] + 1e-6) 85 | return generate_function 86 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | UNK_IDX = 0 4 | UNK_WORD = "UUUNKKK" 5 | EVAL_YEAR = "2017" 6 | BOS_IDX = 1 7 | EOS_IDX = 2 8 | MAX_GEN_LEN = 40 9 | METEOR_JAR = 'evaluation/meteor-1.5.jar' 10 | METEOR_DATA = 'evaluation/data/paraphrase-en.gz' 11 | MULTI_BLEU_PERL = 'evaluation/multi-bleu.perl' 12 | STANFORD_CORENLP = 'evaluation/stanford-corenlp-full-2018-10-05' 13 | 14 | 15 | def str2bool(v): 16 | return v.lower() in ('yes', 'true', 't', '1', 'y') 17 | 18 | 19 | def get_base_parser(): 20 | parser = argparse.ArgumentParser( 21 | description='Controllable Paraphrase Generation using PyTorch') 22 | parser.register('type', 'bool', str2bool) 23 | 24 | basic_group = parser.add_argument_group('basics') 25 | # Basics 26 | basic_group.add_argument('--debug', type="bool", default=False, 27 | help='activation of debug mode (default: False)') 28 | basic_group.add_argument('--auto_disconnect', type="bool", default=True, 29 | help='for slurm (default: True)') 30 | basic_group.add_argument('--save_prefix', type=str, default="experiments", 31 | help='saving path prefix') 32 | 33 | data_group = parser.add_argument_group('data') 34 | # Data file 35 | data_group.add_argument('--train_path', type=str, default=None, 36 | help='data file') 37 | data_group.add_argument('--train_tag_path', type=str, default=None, 38 | help='data file') 39 | data_group.add_argument('--vocab_file', type=str, default=None, 40 | help='vocabulary file') 41 | data_group.add_argument('--tag_vocab_file', type=str, default=None, 42 | help='tag vocabulary file') 43 | data_group.add_argument('--embed_file', type=str, default=None, 44 | help='pretrained embedding file') 45 | data_group.add_argument('--dev_inp_path', type=str, default=None, 46 | help='data file') 47 | data_group.add_argument('--dev_ref_path', type=str, default=None, 48 | help='data file') 49 | data_group.add_argument('--test_inp_path', type=str, default=None, 50 | help='data file') 51 | data_group.add_argument('--test_ref_path', type=str, default=None, 52 | help='data file') 53 | 54 | config_group = parser.add_argument_group('model_configs') 55 | config_group.add_argument('-lr', '--learning_rate', 56 | dest='lr', 57 | type=float, 58 | default=1e-3, 59 | help='learning rate') 60 | config_group.add_argument('-pratio', '--ploss_ratio', 61 | dest='pratio', 62 | type=float, 63 | default=1.0, 64 | help='ratio of position loss') 65 | config_group.add_argument('-lratio', '--logloss_ratio', 66 | dest='lratio', 67 | type=float, 68 | default=1.0, 69 | help='ratio of log loss') 70 | config_group.add_argument('-plratio', '--para_logloss_ratio', 71 | dest='plratio', 72 | type=float, 73 | default=1.0, 74 | help='ratio of paraphrase log loss') 75 | config_group.add_argument('--eps', 76 | type=float, 77 | default=1e-4, 78 | help='safty for avoiding numerical issues') 79 | config_group.add_argument('-edim', '--embed_dim', 80 | dest='edim', 81 | type=int, default=300, 82 | help='size of embedding') 83 | config_group.add_argument('-wr', '--word_replace', 84 | dest='wr', 85 | type=float, default=0.0, 86 | help='word replace rate') 87 | config_group.add_argument('-dp', '--dropout', 88 | dest='dp', 89 | type=float, default=0.0, 90 | help='dropout probability') 91 | config_group.add_argument('-gclip', '--grad_clip', 92 | dest='gclip', 93 | type=float, default=None, 94 | help='gradient clipping threshold') 95 | 96 | # recurrent neural network detail 97 | config_group.add_argument('-ensize', '--encoder_size', 98 | dest='ensize', 99 | type=int, default=300, 100 | help='encoder hidden size') 101 | config_group.add_argument('-desize', '--decoder_size', 102 | dest='desize', 103 | type=int, default=300, 104 | help='decoder hidden size') 105 | config_group.add_argument('--ysize', 106 | dest='ysize', 107 | type=int, default=100, 108 | help='size of Gaussian') 109 | config_group.add_argument('--zsize', 110 | dest='zsize', 111 | type=int, default=100, 112 | help='size of Gaussian') 113 | 114 | # feedforward neural network 115 | config_group.add_argument('-mhsize', '--mlp_hidden_size', 116 | dest='mhsize', 117 | type=int, default=100, 118 | help='size of hidden size') 119 | config_group.add_argument('-mlplayer', '--mlp_n_layer', 120 | dest='mlplayer', 121 | type=int, default=3, 122 | help='number of layer') 123 | config_group.add_argument('-zmlplayer', '--zmlp_n_layer', 124 | dest='zmlplayer', 125 | type=int, default=3, 126 | help='number of layer') 127 | config_group.add_argument('-ymlplayer', '--ymlp_n_layer', 128 | dest='ymlplayer', 129 | type=int, default=3, 130 | help='number of layer') 131 | 132 | # latent code 133 | config_group.add_argument('-ncode', '--num_code', 134 | dest='ncode', 135 | type=int, default=8, 136 | help='number of latent code') 137 | config_group.add_argument('-nclass', '--num_class', 138 | dest='nclass', 139 | type=int, default=2, 140 | help='size of classes in each latent code') 141 | # optimization 142 | config_group.add_argument('-ps', '--p_scramble', 143 | dest='ps', 144 | type=float, default=0., 145 | help='probability of scrambling') 146 | config_group.add_argument('--l2', type=float, default=0., 147 | help='l2 regularization') 148 | config_group.add_argument('-vmkl', '--max_vmf_kl_temp', 149 | dest='vmkl', type=float, default=1., 150 | help='maximum temperature of kl divergence') 151 | config_group.add_argument('-gmkl', '--max_gauss_kl_temp', 152 | dest='gmkl', type=float, default=1., 153 | help='maximum temperature of kl divergence') 154 | 155 | setup_group = parser.add_argument_group('train_setup') 156 | # train detail 157 | setup_group.add_argument('--save_dir', type=str, default=None, 158 | help='model save path') 159 | basic_group.add_argument('--embed_type', 160 | type=str, default="paragram", 161 | choices=['paragram', 'glove'], 162 | help='types of embedding: paragram, glove') 163 | basic_group.add_argument('--yencoder_type', 164 | type=str, default="word_avg", 165 | help='types of encoder') 166 | basic_group.add_argument('--zencoder_type', 167 | type=str, default="word_avg", 168 | help='types of z encoder') 169 | basic_group.add_argument('--decoder_type', 170 | type=str, default="lstm_z2y", 171 | help='types of decoder') 172 | setup_group.add_argument('--n_epoch', type=int, default=5, 173 | help='number of epochs') 174 | setup_group.add_argument('--batch_size', type=int, default=20, 175 | help='batch size') 176 | setup_group.add_argument('--opt', type=str, default='adam', 177 | choices=['sadam', 'adam', 'sgd', 'rmsprop'], 178 | help='types of optimizer: adam (default), \ 179 | sgd, rmsprop') 180 | setup_group.add_argument('--pre_train_emb', type="bool", default=False, 181 | help='whether to use pretrain embedding') 182 | setup_group.add_argument('--vocab_size', type=int, default=50000, 183 | help='size of vocabulary') 184 | 185 | misc_group = parser.add_argument_group('misc') 186 | # misc 187 | misc_group.add_argument('--print_every', type=int, default=10, 188 | help='print training details after \ 189 | this number of iterations') 190 | misc_group.add_argument('--eval_every', type=int, default=100, 191 | help='evaluate model after \ 192 | this number of iterations') 193 | misc_group.add_argument('--summarize', type="bool", default=False, 194 | help='whether to summarize training stats\ 195 | (default: False)') 196 | return parser 197 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from collections import Counter 7 | 8 | from decorators import auto_init_args, lazy_execute 9 | from config import UNK_IDX, UNK_WORD, BOS_IDX, EOS_IDX 10 | 11 | 12 | class data_holder: 13 | @auto_init_args 14 | def __init__(self, train_data, train_tag, vocab, tag_bucket): 15 | self.inv_vocab = {i: w for w, i in vocab.items()} 16 | 17 | 18 | class data_processor: 19 | @auto_init_args 20 | def __init__(self, train_path, experiment): 21 | self.expe = experiment 22 | 23 | def process(self): 24 | if self.expe.config.pre_train_emb: 25 | fn = "pre_vocab_" + str(self.expe.config.vocab_size) 26 | else: 27 | fn = "vocab_" + str(self.expe.config.vocab_size) 28 | 29 | vocab_file = os.path.join(self.expe.config.vocab_file, fn) 30 | 31 | train_data = self._load_sent( 32 | self.train_path, file_name=self.train_path + ".pkl") 33 | 34 | if self.expe.config.pre_train_emb: 35 | W, vocab = \ 36 | self._build_pretrain_vocab(train_data, file_name=vocab_file) 37 | else: 38 | W, vocab = \ 39 | self._build_vocab(train_data, file_name=vocab_file) 40 | self.expe.log.info("vocab size: {}".format(len(vocab))) 41 | 42 | self.expe.log.info("initializing pos bucketing") 43 | tag_bucket = self._load_tag_bucket( 44 | vocab, self.expe.config.tag_vocab_file, 45 | file_name=vocab_file + "_tag") 46 | 47 | train_tag1 = [] 48 | train_tag2 = [] 49 | with open(self.expe.config.train_tag_path) as fp: 50 | for line in fp: 51 | s1, s2 = line.strip().split("\t") 52 | train_tag1.append(s1.split(" ")) 53 | train_tag2.append(s2.split(" ")) 54 | assert len(train_tag1) == len(train_data[0]) 55 | assert len(train_tag2) == len(train_data[1]) 56 | train_tag = [np.array(train_tag1), np.array(train_tag2)] 57 | 58 | train_data = self._data_to_idx(train_data, vocab) 59 | 60 | def cal_stats(data): 61 | unk_count = 0 62 | total_count = 0 63 | leng = [] 64 | for sent1, sent2 in zip(*data): 65 | leng.append(len(sent1)) 66 | leng.append(len(sent2)) 67 | for w in sent1 + sent2: 68 | if w == UNK_IDX: 69 | unk_count += 1 70 | total_count += 1 71 | return (unk_count, total_count, unk_count / total_count), \ 72 | (len(leng), max(leng), min(leng), sum(leng) / len(leng)) 73 | 74 | train_unk_stats, train_len_stats = cal_stats(train_data) 75 | self.expe.log.info("#train data: {}, max len: {}, " 76 | "min len: {}, avg len: {:.2f}" 77 | .format(*train_len_stats)) 78 | 79 | self.expe.log.info("#unk in train sentences: {}" 80 | .format(train_unk_stats)) 81 | data = data_holder( 82 | train_data=train_data, 83 | train_tag=train_tag, 84 | vocab=vocab, 85 | tag_bucket=tag_bucket) 86 | 87 | return data, W 88 | 89 | @lazy_execute("_load_from_pickle") 90 | def _load_tag_bucket(self, vocab, file_path): 91 | with open(file_path, "rb") as fp: 92 | word2tag = pickle.load(fp) 93 | tag2vocab = {} 94 | for w, tags in word2tag.items(): 95 | if w in vocab: 96 | for tag in tags.most_common(2): 97 | if tag[0] not in tag2vocab: 98 | tag2vocab[tag[0]] = [vocab[w]] 99 | else: 100 | tag2vocab[tag[0]].append(vocab[w]) 101 | self.expe.log.info("#tags: {}".format(len(tag2vocab))) 102 | return tag2vocab 103 | 104 | @lazy_execute("_load_from_pickle") 105 | def _load_sent(self, path): 106 | data_pair1 = [] 107 | data_pair2 = [] 108 | with open(path) as f: 109 | for line in f: 110 | line = line.strip().lower() 111 | if len(line): 112 | line = line.split('\t') 113 | if len(line) == 2: 114 | data_pair1.append(line[0].split(" ")) 115 | data_pair2.append(line[1].split(" ")) 116 | else: 117 | self.expe.log.warning("unexpected data: " + line) 118 | assert len(data_pair1) == len(data_pair2) 119 | return data_pair1, data_pair2 120 | 121 | def _data_to_idx(self, data, vocab): 122 | idx_pair1 = [] 123 | idx_pair2 = [] 124 | for d1, d2 in zip(*data): 125 | s1 = [vocab.get(w, UNK_IDX) for w in d1] 126 | idx_pair1.append(s1) 127 | s2 = [vocab.get(w, UNK_IDX) for w in d2] 128 | idx_pair2.append(s2) 129 | return np.array(idx_pair1), np.array(idx_pair2) 130 | 131 | def _load_paragram_embedding(self, path): 132 | with open(path, encoding="latin-1") as fp: 133 | # word_vectors: word --> vector 134 | word_vectors = {} 135 | for line in fp: 136 | line = line.strip("\n").split(" ") 137 | word_vectors[line[0]] = np.array( 138 | list(map(float, line[1:])), dtype='float32') 139 | vocab_embed = word_vectors.keys() 140 | embed_dim = word_vectors[next(iter(vocab_embed))].shape[0] 141 | return word_vectors, vocab_embed, embed_dim 142 | 143 | def _load_glove_embedding(self, path): 144 | with open(path, 'r', encoding='utf8') as fp: 145 | # word_vectors: word --> vector 146 | word_vectors = {} 147 | for line in fp: 148 | line = line.strip("\n").split(" ") 149 | word_vectors[line[0]] = np.array( 150 | list(map(float, line[1:])), dtype='float32') 151 | vocab_embed = word_vectors.keys() 152 | embed_dim = word_vectors[next(iter(vocab_embed))].shape[0] 153 | 154 | return word_vectors, vocab_embed, embed_dim 155 | 156 | def _create_vocab_from_data(self, data): 157 | vocab = Counter() 158 | for sent1, sent2 in zip(*data): 159 | for w in sent1 + sent2: 160 | vocab[w] += 1 161 | 162 | ls = vocab.most_common(self.expe.config.vocab_size) 163 | self.expe.log.info( 164 | '#Words: %d -> %d' % (len(vocab), len(ls))) 165 | for key in ls[:5]: 166 | self.expe.log.info(key) 167 | self.expe.log.info('...') 168 | for key in ls[-5:]: 169 | self.expe.log.info(key) 170 | vocab = [x[0] for x in ls] 171 | 172 | # 0: unk, 1: bos, 2: eos 173 | vocab = {w: index + 3 for (index, w) in enumerate(vocab)} 174 | vocab[UNK_WORD] = UNK_IDX 175 | vocab[""] = BOS_IDX 176 | vocab[""] = EOS_IDX 177 | 178 | return vocab 179 | 180 | @lazy_execute("_load_from_pickle") 181 | def _build_vocab(self, train_data): 182 | vocab = self._create_vocab_from_data(train_data) 183 | return None, vocab 184 | 185 | @lazy_execute("_load_from_pickle") 186 | def _build_pretrain_vocab(self, train_data): 187 | self.expe.log.info("loading embedding from: {}" 188 | .format(self.expe.config.embed_file)) 189 | if self.expe.config.embed_type.lower() == "glove": 190 | word_vectors, vocab_embed, embed_dim = \ 191 | self._load_glove_embedding(self.expe.config.embed_file) 192 | elif self.expe.config.embed_type.lower() == "paragram": 193 | word_vectors, vocab_embed, embed_dim = \ 194 | self._load_paragram_embedding(self.expe.config.embed_file) 195 | else: 196 | raise NotImplementedError( 197 | "invalid embedding type: {}".format( 198 | self.expe.config.embed_type)) 199 | 200 | vocab = self._create_vocab_from_data(train_data) 201 | 202 | W = np.random.uniform( 203 | -np.sqrt(3.0 / embed_dim), np.sqrt(3.0 / embed_dim), 204 | size=(len(vocab), embed_dim)).astype('float32') 205 | n = 0 206 | for w, i in vocab.items(): 207 | if w in vocab_embed: 208 | W[i, :] = word_vectors[w] 209 | n += 1 210 | self.expe.log.info( 211 | "{}/{} vocabs are initialized with {} embeddings." 212 | .format(n, len(vocab), self.expe.config.embed_type)) 213 | 214 | return W, vocab 215 | 216 | def _load_from_pickle(self, file_name): 217 | self.expe.log.info("loading from {}".format(file_name)) 218 | with open(file_name, "rb") as fp: 219 | data = pickle.load(fp) 220 | return data 221 | 222 | 223 | class minibatcher: 224 | @auto_init_args 225 | def __init__(self, data1, tag1, data2, tag2, tag_bucket, 226 | batch_size, shuffle, p_scramble, 227 | p_replace, *args, **kwargs): 228 | self._reset() 229 | 230 | def __len__(self): 231 | return len(self.idx_pool) 232 | 233 | def _reset(self): 234 | self.pointer = 0 235 | idx_list = np.arange(len(self.data1)) 236 | if self.shuffle: 237 | np.random.shuffle(idx_list) 238 | self.idx_pool = [idx_list[i: i + self.batch_size] 239 | for i in range(0, len(self.data1), self.batch_size)] 240 | 241 | def _replace_word(self, sent, tag): 242 | assert len(sent) == len(tag) 243 | new_sent = [] 244 | for w, t in zip(sent, tag): 245 | if np.random.choice( 246 | [True, False], 247 | p=[self.p_replace, 1 - self.p_replace]).item(): 248 | new_sent.append(np.random.choice(self.tag_bucket[t])) 249 | else: 250 | new_sent.append(w) 251 | return new_sent 252 | 253 | def _pad(self, data1, tag1, data2, tag2): 254 | assert len(data1) == len(data2) 255 | max_len1 = max([len(sent) for sent in data1]) 256 | max_len2 = max([len(sent) for sent in data2]) 257 | 258 | input_data1 = \ 259 | np.zeros((len(data1), max_len1)).astype("float32") 260 | input_repl_data1 = \ 261 | np.zeros((len(data1), max_len1)).astype("float32") 262 | input_mask1 = \ 263 | np.zeros((len(data1), max_len1)).astype("float32") 264 | tgt_data1 = \ 265 | np.zeros((len(data1), max_len1 + 2)).astype("float32") 266 | tgt_mask1 = \ 267 | np.zeros((len(data1), max_len1 + 2)).astype("float32") 268 | 269 | input_data2 = \ 270 | np.zeros((len(data2), max_len2)).astype("float32") 271 | input_repl_data2 = \ 272 | np.zeros((len(data2), max_len2)).astype("float32") 273 | input_mask2 = \ 274 | np.zeros((len(data2), max_len2)).astype("float32") 275 | tgt_data2 = \ 276 | np.zeros((len(data2), max_len2 + 2)).astype("float32") 277 | tgt_mask2 = \ 278 | np.zeros((len(data2), max_len2 + 2)).astype("float32") 279 | 280 | for i, (sent1, t1, sent2, t2) in \ 281 | enumerate(zip(data1, tag1, data2, tag2)): 282 | if np.random.choice( 283 | [True, False], 284 | p=[self.p_scramble, 1 - self.p_scramble]).item(): 285 | sent1 = np.random.permutation(sent1) 286 | sent2 = np.random.permutation(sent2) 287 | 288 | input_data1[i, :len(sent1)] = \ 289 | np.asarray(list(sent1)).astype("float32") 290 | input_mask1[i, :len(sent1)] = 1. 291 | 292 | tgt_data1[i, :len(sent1) + 2] = \ 293 | np.asarray([BOS_IDX] + list(sent1) + [EOS_IDX]).astype("float32") 294 | tgt_mask1[i, :len(sent1) + 2] = 1. 295 | 296 | input_data2[i, :len(sent2)] = \ 297 | np.asarray(list(sent2)).astype("float32") 298 | input_mask2[i, :len(sent2)] = 1. 299 | 300 | tgt_data2[i, :len(sent2) + 2] = \ 301 | np.asarray([BOS_IDX] + list(sent2) + [EOS_IDX]).astype("float32") 302 | tgt_mask2[i, :len(sent2) + 2] = 1. 303 | 304 | input_repl_data1[i, :len(sent1)] = \ 305 | np.asarray(self._replace_word(sent1, t1)).astype("float32") 306 | input_repl_data2[i, :len(sent2)] = \ 307 | np.asarray(self._replace_word(sent2, t2)).astype("float32") 308 | 309 | return [input_data1, input_repl_data1, input_mask1, 310 | input_data2, input_repl_data2, input_mask2, 311 | tgt_data1, tgt_mask1, tgt_data2, tgt_mask2] 312 | 313 | def __iter__(self): 314 | return self 315 | 316 | def __next__(self): 317 | if self.pointer == len(self.idx_pool): 318 | self._reset() 319 | raise StopIteration() 320 | 321 | idx = self.idx_pool[self.pointer] 322 | data1, data2 = self.data1[idx], self.data2[idx] 323 | t1, t2 = self.tag1[idx], self.tag2[idx] 324 | self.pointer += 1 325 | return self._pad(data1, t1, data2, t2) + [idx] 326 | -------------------------------------------------------------------------------- /decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import model_utils 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class decoder_base(nn.Module): 10 | def __init__(self, vocab_size, embed_dim, embed_init, dropout, log): 11 | super(decoder_base, self).__init__() 12 | self.dropout = nn.Dropout(dropout) 13 | self.embed = nn.Embedding(vocab_size, embed_dim) 14 | if embed_init is not None: 15 | self.embed.weight.data.copy_(torch.from_numpy(embed_init)) 16 | log.info( 17 | "{} initialized with pretrained word embedding".format( 18 | type(self))) 19 | 20 | def greedy_decode(self, yvecs, zvecs, max_len): 21 | input_word = torch.ones(len(yvecs), 1).to(yvecs.device) 22 | batch_gen = [] 23 | hidden_state = None 24 | for _ in range(max_len): 25 | hidden_state, _, input_word, _ = \ 26 | self.step(yvecs, zvecs, hidden_state, input_word) 27 | batch_gen.append(input_word.detach().clone().cpu().numpy()) 28 | batch_gen = np.concatenate(batch_gen, 1) 29 | return batch_gen 30 | 31 | 32 | class lstm_z2y(decoder_base): 33 | def __init__(self, vocab_size, embed_dim, embed_init, 34 | ysize, zsize, mlp_hidden_size, 35 | mlp_layer, hidden_size, dropout, 36 | log, *args, **kwargs): 37 | super(lstm_z2y, self).__init__( 38 | vocab_size, embed_dim, embed_init, dropout, log) 39 | self.cell = nn.LSTM( 40 | zsize + embed_dim, hidden_size, 41 | bidirectional=False, batch_first=True) 42 | self.hid2vocab = nn.Linear(hidden_size + ysize, vocab_size) 43 | 44 | def forward(self, yvecs, zvecs, tgts, tgts_mask, 45 | *args, **kwargs): 46 | return self.teacher_force(yvecs, zvecs, tgts, tgts_mask) 47 | 48 | def pred(self, yvecs, zvecs, tgts, tgts_mask): 49 | bs, sl = tgts_mask.size() 50 | tgts_embed = self.dropout(self.embed(tgts.long())) 51 | ex_input_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1) 52 | ex_output_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1) 53 | 54 | input_vecs = torch.cat([tgts_embed, ex_input_vecs], -1) 55 | ori_output_seq, _ = model_utils.get_rnn_vecs( 56 | input_vecs, tgts_mask, self.cell, bidir=False, initial_state=None) 57 | output_seq = torch.cat([ori_output_seq, ex_output_vecs], -1) 58 | # batch size x seq len x vocab size 59 | pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :] 60 | return pred, input_vecs 61 | 62 | def teacher_force(self, yvecs, zvecs, tgts, tgts_mask): 63 | pred, pvecs = self.pred(yvecs, zvecs, tgts, tgts_mask) 64 | batch_size, seq_len, vocab_size = pred.size() 65 | 66 | pred = pred.contiguous().view(batch_size * seq_len, vocab_size) 67 | logloss = F.cross_entropy( 68 | pred, tgts[:, 1:].contiguous().view(-1).long(), reduction="none") 69 | 70 | logloss = (logloss.view(batch_size, seq_len) * 71 | tgts_mask[:, 1:]).sum(-1) / tgts_mask[:, 1:].sum(-1) 72 | return logloss.mean(), pvecs 73 | 74 | def step(self, yvecs, zvecs, last_hidden_state, last_output): 75 | input_embed = self.embed(last_output.long()) 76 | input_vec = torch.cat([input_embed, zvecs.unsqueeze(1)], -1) 77 | curr_output, hidden_state = self.cell(input_vec, last_hidden_state) 78 | curr_output = torch.cat([curr_output, yvecs.unsqueeze(1)], -1) 79 | curr_prob = F.softmax(self.hid2vocab(curr_output), -1) 80 | curr_prob[:, :, 0] = -10 81 | output_prob, input_word = curr_prob.max(-1) 82 | return hidden_state, curr_prob, input_word, output_prob 83 | 84 | 85 | class lstm_y2z(decoder_base): 86 | def __init__(self, vocab_size, embed_dim, embed_init, 87 | ysize, zsize, mlp_hidden_size, 88 | mlp_layer, hidden_size, dropout, 89 | log, *args, **kwargs): 90 | super(lstm_y2z, self).__init__( 91 | vocab_size, embed_dim, embed_init, dropout, log) 92 | self.cell = nn.LSTM( 93 | ysize + embed_dim, hidden_size, 94 | bidirectional=False, batch_first=True) 95 | self.hid2vocab = nn.Linear(hidden_size + zsize, vocab_size) 96 | 97 | def forward(self, yvecs, zvecs, tgts, tgts_mask, 98 | *args, **kwargs): 99 | return self.teacher_force(yvecs, zvecs, tgts, tgts_mask) 100 | 101 | def pred(self, yvecs, zvecs, tgts, tgts_mask): 102 | bs, sl = tgts_mask.size() 103 | tgts_embed = self.dropout(self.embed(tgts.long())) 104 | ex_input_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1) 105 | ex_output_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1) 106 | 107 | input_vecs = torch.cat([tgts_embed, ex_input_vecs], -1) 108 | ori_output_seq, _ = model_utils.get_rnn_vecs( 109 | input_vecs, tgts_mask, self.cell, bidir=False, initial_state=None) 110 | output_seq = torch.cat([ori_output_seq, ex_output_vecs], -1) 111 | # batch size x seq len x vocab size 112 | pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :] 113 | return pred, torch.cat([tgts_embed, ex_output_vecs], -1) 114 | 115 | def teacher_force(self, yvecs, zvecs, tgts, tgts_mask): 116 | pred, pvecs = self.pred(yvecs, zvecs, tgts, tgts_mask) 117 | batch_size, seq_len, vocab_size = pred.size() 118 | 119 | pred = pred.contiguous().view(batch_size * seq_len, vocab_size) 120 | logloss = F.cross_entropy( 121 | pred, tgts[:, 1:].contiguous().view(-1).long(), reduction="none") 122 | 123 | logloss = (logloss.view(batch_size, seq_len) * 124 | tgts_mask[:, 1:]).sum(-1) / tgts_mask[:, 1:].sum(-1) 125 | return logloss.mean(), pvecs 126 | 127 | def step(self, yvecs, zvecs, last_hidden_state, last_output): 128 | input_embed = self.embed(last_output.long()) 129 | input_vec = torch.cat([input_embed, yvecs.unsqueeze(1)], -1) 130 | curr_output, hidden_state = self.cell(input_vec, last_hidden_state) 131 | curr_output = torch.cat([curr_output, zvecs.unsqueeze(1)], -1) 132 | curr_prob = F.softmax(self.hid2vocab(curr_output), -1) 133 | curr_prob[:, :, 0] = -10 134 | output_prob, input_word = curr_prob.max(-1) 135 | return hidden_state, curr_prob, input_word, output_prob 136 | 137 | 138 | class lstm_yz(decoder_base): 139 | def __init__(self, vocab_size, embed_dim, embed_init, 140 | ysize, zsize, mlp_hidden_size, 141 | mlp_layer, hidden_size, dropout, 142 | log, *args, **kwargs): 143 | super(lstm_yz, self).__init__( 144 | vocab_size, embed_dim, embed_init, dropout, log) 145 | self.cell = nn.LSTM( 146 | zsize + ysize + embed_dim, hidden_size, 147 | bidirectional=False, batch_first=True) 148 | self.hid2vocab = nn.Linear(hidden_size, vocab_size) 149 | 150 | def forward(self, yvecs, zvecs, tgts, tgts_mask, 151 | *args, **kwargs): 152 | return self.teacher_force(yvecs, zvecs, tgts, tgts_mask) 153 | 154 | def pred(self, yvecs, zvecs, tgts, tgts_mask): 155 | bs, sl = tgts_mask.size() 156 | tgts_embed = self.dropout(self.embed(tgts.long())) 157 | ex_input_vecs = zvecs.unsqueeze(1).expand(-1, sl, -1) 158 | ex_input2_vecs = yvecs.unsqueeze(1).expand(-1, sl, -1) 159 | 160 | input_vecs = torch.cat([tgts_embed, ex_input_vecs, ex_input2_vecs], -1) 161 | ori_output_seq, _ = model_utils.get_rnn_vecs( 162 | input_vecs, tgts_mask, self.cell, bidir=False, initial_state=None) 163 | output_seq = ori_output_seq 164 | # batch size x seq len x vocab size 165 | pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :] 166 | return pred, torch.cat([tgts_embed, ex_input_vecs], -1) 167 | 168 | def teacher_force(self, yvecs, zvecs, tgts, tgts_mask): 169 | pred, pvecs = self.pred(yvecs, zvecs, tgts, tgts_mask) 170 | batch_size, seq_len, vocab_size = pred.size() 171 | 172 | pred = pred.contiguous().view(batch_size * seq_len, vocab_size) 173 | logloss = F.cross_entropy( 174 | pred, tgts[:, 1:].contiguous().view(-1).long(), reduction="none") 175 | 176 | logloss = (logloss.view(batch_size, seq_len) * 177 | tgts_mask[:, 1:]).sum(-1) / tgts_mask[:, 1:].sum(-1) 178 | return logloss.mean(), pvecs 179 | 180 | def step(self, yvecs, zvecs, last_hidden_state, last_output): 181 | input_embed = self.embed(last_output.long()) 182 | input_vec = torch.cat( 183 | [input_embed, zvecs.unsqueeze(1), yvecs.unsqueeze(1)], -1) 184 | curr_output, hidden_state = self.cell(input_vec, last_hidden_state) 185 | curr_prob = F.softmax(self.hid2vocab(curr_output), -1) 186 | curr_prob[:, :, 0] = -10 187 | output_prob, input_word = curr_prob.max(-1) 188 | return hidden_state, curr_prob, input_word, output_prob 189 | 190 | 191 | class yz_lstm(decoder_base): 192 | def __init__(self, vocab_size, embed_dim, embed_init, 193 | ysize, zsize, mlp_hidden_size, 194 | mlp_layer, hidden_size, dropout, 195 | log, *args, **kwargs): 196 | super(yz_lstm, self).__init__( 197 | vocab_size, embed_dim, embed_init, dropout, log) 198 | self.cell = nn.LSTM( 199 | embed_dim, hidden_size, 200 | bidirectional=False, batch_first=True) 201 | self.latent2init = nn.Linear(ysize + zsize, hidden_size * 2) 202 | self.hid2vocab = nn.Linear(hidden_size, vocab_size) 203 | 204 | def forward(self, yvecs, zvecs, tgts, tgts_mask, 205 | *args, **kwargs): 206 | return self.teacher_force(yvecs, zvecs, tgts, tgts_mask) 207 | 208 | def pred(self, yvecs, zvecs, tgts, tgts_mask): 209 | bs, sl = tgts_mask.size() 210 | tgts_embed = self.dropout(self.embed(tgts.long())) 211 | init_vecs = self.latent2init(torch.cat([yvecs, zvecs], -1)) 212 | 213 | if isinstance(self.cell, nn.LSTM): 214 | init_vecs = tuple([h.unsqueeze(0).contiguous() for h in 215 | torch.chunk(init_vecs, 2, -1)]) 216 | 217 | input_vecs = tgts_embed 218 | ori_output_seq, _ = model_utils.get_rnn_vecs( 219 | input_vecs, tgts_mask, self.cell, bidir=False, initial_state=init_vecs) 220 | output_seq = ori_output_seq 221 | # batch size x seq len x vocab size 222 | pred = self.hid2vocab(self.dropout(output_seq))[:, :-1, :] 223 | return pred, torch.cat( 224 | [tgts_embed, zvecs.unsqueeze(1).expand(-1, sl, -1)], -1) 225 | 226 | def teacher_force(self, yvecs, zvecs, tgts, tgts_mask): 227 | pred, pvecs = self.pred(yvecs, zvecs, tgts, tgts_mask) 228 | batch_size, seq_len, vocab_size = pred.size() 229 | 230 | pred = pred.contiguous().view(batch_size * seq_len, vocab_size) 231 | logloss = F.cross_entropy( 232 | pred, tgts[:, 1:].contiguous().view(-1).long(), reduction="none") 233 | 234 | logloss = (logloss.view(batch_size, seq_len) * 235 | tgts_mask[:, 1:]).sum(-1) / tgts_mask[:, 1:].sum(-1) 236 | return logloss.mean(), pvecs 237 | 238 | def step(self, yvecs, zvecs, last_hidden_state, last_output): 239 | input_embed = self.embed(last_output.long()) 240 | input_vec = input_embed 241 | curr_output, hidden_state = self.cell(input_vec, last_hidden_state) 242 | curr_prob = F.softmax(self.hid2vocab(curr_output), -1) 243 | curr_prob[:, :, 0] = -10 244 | output_prob, input_word = curr_prob.max(-1) 245 | return hidden_state, curr_prob, input_word, output_prob 246 | 247 | def greedy_decode(self, yvecs, zvecs, max_len): 248 | input_word = torch.ones(len(yvecs), 1).to(yvecs.device) 249 | batch_gen = [] 250 | init_vecs = self.latent2init(torch.cat([yvecs, zvecs], -1)) 251 | 252 | if isinstance(self.cell, nn.LSTM): 253 | init_vecs = tuple([h.unsqueeze(0).contiguous() for h in 254 | torch.chunk(init_vecs, 2, -1)]) 255 | hidden_state = init_vecs 256 | for _ in range(max_len): 257 | hidden_state, _, input_word, _ = \ 258 | self.step(None, None, hidden_state, input_word) 259 | batch_gen.append(input_word.detach().clone().cpu().numpy()) 260 | batch_gen = np.concatenate(batch_gen, 1) 261 | return batch_gen 262 | -------------------------------------------------------------------------------- /decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pickle 3 | import os 4 | 5 | 6 | def auto_init_args(init): 7 | def new_init(self, *args, **kwargs): 8 | arg_dict = inspect.signature(init).parameters 9 | arg_names = list(arg_dict.keys())[1:] # skip self 10 | proc_names = set() 11 | for name, arg in zip(arg_names, args): 12 | setattr(self, name, arg) 13 | proc_names.add(name) 14 | for name, arg in kwargs.items(): 15 | setattr(self, name, arg) 16 | proc_names.add(name) 17 | remain_names = set(arg_names) - proc_names 18 | if len(remain_names): 19 | for name in remain_names: 20 | setattr(self, name, arg_dict[name].default) 21 | init(self, *args, **kwargs) 22 | 23 | return new_init 24 | 25 | 26 | def auto_init_pytorch(init): 27 | def new_init(self, *args, **kwargs): 28 | init(self, *args, **kwargs) 29 | self.opt = self.init_optimizer( 30 | self.expe.config.opt, 31 | self.expe.config.lr, 32 | self.expe.config.l2) 33 | 34 | if not self.expe.config.resume: 35 | self.to(self.device) 36 | self.expe.log.info( 37 | "transferred model to {}".format(self.device)) 38 | 39 | return new_init 40 | 41 | 42 | class lazy_execute: 43 | @auto_init_args 44 | def __init__(self, func_name): 45 | pass 46 | 47 | def __call__(self, fn): 48 | func_name = self.func_name 49 | 50 | def new_fn(self, *args, **kwargs): 51 | file_name = kwargs.pop('file_name') 52 | if os.path.isfile(file_name): 53 | return getattr(self, func_name)(file_name) 54 | else: 55 | data = fn(self, *args, **kwargs) 56 | 57 | self.expe.log.info("saving to {}" 58 | .format(file_name)) 59 | with open(file_name, "wb+") as fp: 60 | pickle.dump(data, fp, protocol=-1) 61 | return data 62 | return new_fn 63 | -------------------------------------------------------------------------------- /encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import model_utils 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class encoder_base(nn.Module): 8 | def __init__(self, vocab_size, embed_dim, embed_init, dropout, log, 9 | *args, **kwargs): 10 | super(encoder_base, self).__init__() 11 | self.dropout = nn.Dropout(dropout) 12 | self.embed = nn.Embedding(vocab_size, embed_dim) 13 | if embed_init is not None: 14 | self.embed.weight.data.copy_(torch.from_numpy(embed_init)) 15 | log.info( 16 | "{} initialized with pretrained word embedding".format( 17 | type(self))) 18 | 19 | 20 | class word_avg(encoder_base): 21 | def __init__(self, vocab_size, embed_dim, embed_init, dropout, log, 22 | *args, **kwargs): 23 | super(word_avg, self).__init__( 24 | vocab_size, embed_dim, embed_init, dropout, log) 25 | 26 | def forward(self, inputs, mask): 27 | input_vecs = self.dropout(self.embed(inputs.long())) 28 | sum_vecs = (input_vecs * mask.unsqueeze(-1)).sum(1) 29 | avg_vecs = sum_vecs / mask.sum(1, keepdim=True) 30 | return input_vecs, avg_vecs 31 | 32 | 33 | class bilstm(encoder_base): 34 | def __init__(self, vocab_size, embed_dim, embed_init, hidden_size, 35 | dropout, log, *args, **kwargs): 36 | super(bilstm, self).__init__( 37 | vocab_size, embed_dim, embed_init, dropout, log) 38 | self.lstm = nn.LSTM( 39 | embed_dim, hidden_size, bidirectional=True, batch_first=True) 40 | 41 | def forward(self, inputs, mask, temp=None): 42 | input_vecs = self.dropout(self.embed(inputs.long())) 43 | outputs, _ = model_utils.get_rnn_vecs( 44 | input_vecs, mask, self.lstm, bidir=True) 45 | outputs = self.dropout(outputs) * mask.unsqueeze(-1) 46 | sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True) 47 | return input_vecs, sent_vec 48 | 49 | 50 | class bilstm_lc(encoder_base): 51 | def __init__(self, vocab_size, embed_dim, embed_init, hidden_size, 52 | dropout, nclass, ncode, mlp_hidden_size, mlp_layer, log, 53 | *args, **kwargs): 54 | super(bilstm_lc, self).__init__( 55 | vocab_size, embed_dim, embed_init, dropout, log) 56 | self.lstm = nn.LSTM( 57 | embed_dim // ncode * ncode, hidden_size, 58 | bidirectional=True, batch_first=True) 59 | self.lc = nn.ModuleList( 60 | [model_utils.get_mlp( 61 | embed_dim, hidden_size, nclass, mlp_layer, dropout) 62 | for _ in range(ncode)]) 63 | self.lc_embed = nn.ModuleList( 64 | [nn.Embedding(nclass, embed_dim // ncode) for _ in range(ncode)]) 65 | 66 | def forward(self, inputs, mask, temp=None): 67 | input_vecs = self.dropout(self.embed(inputs.long())) 68 | lc_vecs = [] 69 | for proj, emb in zip(self.lc, self.lc_embed): 70 | prob = F.softmax(proj(input_vecs), -1) 71 | lc_vecs.append(torch.matmul(prob, emb.weight)) 72 | input_vecs = self.dropout(torch.cat(lc_vecs, -1)) 73 | outputs, _ = model_utils.get_rnn_vecs( 74 | input_vecs, mask, self.lstm, bidir=True) 75 | outputs = self.dropout(outputs) * mask.unsqueeze(-1) 76 | sent_vec = outputs.sum(1) / mask.sum(1, keepdim=True) 77 | return input_vecs, sent_vec 78 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import rouge 3 | 4 | from eval_utils import Meteor, stanford_parsetree_extractor, \ 5 | compute_tree_edit_distance 6 | from train_helper import run_multi_bleu 7 | from tqdm import tqdm 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--input_file', '-i', type=str) 12 | parser.add_argument('--ref_file', '-r', type=str) 13 | args = parser.parse_args() 14 | 15 | n_ref_line = len(list(open(args.ref_file))) 16 | n_inp_line = len(list(open(args.input_file))) 17 | print("#lines - ref: {}, inp: {}".format(n_ref_line, n_inp_line)) 18 | assert n_inp_line == n_ref_line, \ 19 | "#ref {} != #inp {}".format(n_ref_line, n_inp_line) 20 | 21 | bleu_score = run_multi_bleu(args.input_file, args.ref_file) 22 | print("bleu", bleu_score) 23 | spe = stanford_parsetree_extractor() 24 | input_parses = spe.run(args.input_file) 25 | ref_parses = spe.run(args.ref_file) 26 | spe.cleanup() 27 | assert len(input_parses) == n_inp_line 28 | assert len(ref_parses) == n_inp_line 29 | 30 | all_meteor = [] 31 | all_ted = [] 32 | all_rouge1 = [] 33 | all_rouge2 = [] 34 | all_rougel = [] 35 | preds = [] 36 | 37 | rouge_eval = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], 38 | max_n=2, 39 | limit_length=True, 40 | length_limit=100, 41 | length_limit_type='words', 42 | apply_avg=False, 43 | apply_best=False, 44 | alpha=0.5, # Default F1_score 45 | weight_factor=1.2, 46 | stemming=True) 47 | meteor = Meteor() 48 | pbar = tqdm(zip(open(args.input_file), 49 | open(args.ref_file), 50 | input_parses, 51 | ref_parses)) 52 | 53 | for input_line, ref_line, input_parse, ref_parse in pbar: 54 | ted = compute_tree_edit_distance(input_parse, ref_parse) 55 | ms = meteor._score(input_line.strip(), [ref_line.strip()]) 56 | rs = rouge_eval.get_scores([input_line.strip()], [ref_line.strip()]) 57 | 58 | all_rouge1.append(rs['rouge-1'][0]['f'][0]) 59 | all_rouge2.append(rs['rouge-2'][0]['f'][0]) 60 | all_rougel.append(rs['rouge-l'][0]['f'][0]) 61 | all_meteor.append(ms) 62 | all_ted.append(ted) 63 | pbar.set_description( 64 | "bleu: {:.3f}, rouge-1: {:.3f}, rouge-2: {:.3f}, " 65 | "rouge-l: {:.3f}, meteor: {:.3f}, syntax-TED: {:.3f}".format( 66 | bleu_score, 67 | sum(all_rouge1) / len(all_rouge1) * 100, 68 | sum(all_rouge2) / len(all_rouge1) * 100, 69 | sum(all_rougel) / len(all_rouge1) * 100, 70 | sum(all_meteor) / len(all_meteor) * 100, 71 | sum(all_ted) / len(all_ted))) 72 | 73 | print( 74 | "bleu: {:.3f}, rouge-1: {:.3f}, rouge-2: {:.3f}, " 75 | "rouge-l: {:.3f}, meteor: {:.3f}, syntax-TED: {:.3f}".format( 76 | bleu_score, 77 | sum(all_rouge1) / len(all_rouge1) * 100, 78 | sum(all_rouge2) / len(all_rouge1) * 100, 79 | sum(all_rougel) / len(all_rouge1) * 100, 80 | sum(all_meteor) / len(all_meteor) * 100, 81 | sum(all_ted) / len(all_ted))) 82 | -------------------------------------------------------------------------------- /eval_f1_acc.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import collections 4 | 5 | import tree 6 | import torch 7 | import models 8 | import data_utils 9 | import train_helper 10 | 11 | import numpy as np 12 | 13 | from tqdm import tqdm 14 | 15 | 16 | MAX_LEN = 30 17 | batch_size = 1000 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--save_file', '-s', type=str) 22 | parser.add_argument('--vocab_file', '-v', type=str) 23 | parser.add_argument('--data_dir', '-d', type=str) 24 | args = parser.parse_args() 25 | 26 | 27 | def _brackets_helper(node, i, result): 28 | i0 = i 29 | if len(node.children) > 0: 30 | for child in node.children: 31 | i = _brackets_helper(child, i, result) 32 | j0 = i 33 | if len(node.children[0].children) > 0: # don't count preterminals 34 | result[node.label, i0, j0] += 1 35 | else: 36 | j0 = i0 + 1 37 | return j0 38 | 39 | 40 | def brackets(t): 41 | result = collections.defaultdict(int) 42 | _brackets_helper(t.root, 0, result) 43 | return result 44 | 45 | 46 | def cosine_similarity(v1, v2): 47 | prod = (v1 * v2).sum(-1) 48 | v1_norm = (v1 ** 2).sum(-1) ** 0.5 49 | v2_norm = (v2 ** 2).sum(-1) ** 0.5 50 | return prod / (v1_norm * v2_norm) 51 | 52 | 53 | save_dict = torch.load( 54 | args.save_file, 55 | map_location=lambda storage, 56 | loc: storage) 57 | 58 | config = save_dict['config'] 59 | checkpoint = save_dict['state_dict'] 60 | config.debug = True 61 | 62 | with open(args.vocab_file, "rb") as fp: 63 | W, vocab = pickle.load(fp) 64 | 65 | with train_helper.experiment(config, config.save_prefix) as e: 66 | e.log.info("vocab loaded from: {}".format(args.vocab_file)) 67 | model = models.vgvae( 68 | vocab_size=len(vocab), 69 | embed_dim=e.config.edim if W is None else W.shape[1], 70 | embed_init=W, 71 | experiment=e) 72 | model.eval() 73 | model.load(checkpointed_state_dict=checkpoint) 74 | e.log.info(model) 75 | 76 | def encode(d): 77 | global vocab, batch_size 78 | new_d = [[vocab.get(w, 0) for w in s.split(" ")] for s in d] 79 | all_y_vecs = [] 80 | all_z_vecs = [] 81 | 82 | for s1, _, m1, s2, _, m2, _, _, _, _, _ in \ 83 | tqdm(data_utils.minibatcher( 84 | data1=np.array(new_d), 85 | tag1=np.array(new_d), 86 | data2=np.array(new_d), 87 | tag2=np.array(new_d), 88 | tag_bucket=None, 89 | batch_size=batch_size, 90 | p_replace=0., 91 | shuffle=False, 92 | p_scramble=0.)): 93 | with torch.no_grad(): 94 | semantics, semantics_mask, synatx, syntax_mask = \ 95 | model.to_tensors(s1, m1, s2, m2) 96 | _, yvecs = model.yencode(semantics, semantics_mask) 97 | _, zvecs = model.zencode(synatx, syntax_mask) 98 | 99 | ymean = model.mean1(yvecs) 100 | ymean = ymean / ymean.norm(dim=-1, keepdim=True) 101 | zmean = model.mean2(zvecs) 102 | 103 | all_y_vecs.append(ymean.cpu().numpy()) 104 | all_z_vecs.append(zmean.cpu().numpy()) 105 | return np.concatenate(all_y_vecs), np.concatenate(all_z_vecs) 106 | 107 | y_tot_pred = {i: [] for i in range(1, MAX_LEN)} 108 | tot_label = {i: [] for i in range(1, MAX_LEN)} 109 | y_results_len = {i: {"match_count": 0, 110 | "parse_count": 0, 111 | "gold_count": 0, 112 | "best_f1": []} 113 | for i in range(1, MAX_LEN)} 114 | 115 | z_tot_pred = {i: [] for i in range(1, MAX_LEN)} 116 | z_results_len = {i: {"match_count": 0, 117 | "parse_count": 0, 118 | "gold_count": 0, 119 | "best_f1": []} 120 | for i in range(1, MAX_LEN)} 121 | 122 | tag2num = {} 123 | for i in range(1, MAX_LEN): 124 | e.log.info("*" * 25 + " Length: {} ".format(i) + "*" * 25) 125 | cand_sents = [] 126 | test_sents = [] 127 | cand_pos = [] 128 | cand_parse = [] 129 | test_pos = [] 130 | test_parse = [] 131 | with open(args.data_dir + "/{}_candidates.txt".format(i)) as cf, \ 132 | open(args.data_dir + "/{}_test.txt".format(i)) as tf: 133 | for line in cf: 134 | sent, pos, parse = line.strip().split("\t") 135 | cand_pos.append(pos.strip()) 136 | cand_sents.append(sent.strip()) 137 | cand_parse.append(parse.strip()) 138 | for line in tf: 139 | sent, pos, parse = line.strip().split("\t") 140 | test_pos.append(pos.strip()) 141 | test_sents.append(sent.strip()) 142 | test_parse.append(parse.strip()) 143 | y_test_vecs, z_test_vecs = encode(test_sents) 144 | y_cand_vecs, z_cand_vecs = encode(cand_sents) 145 | e.log.info("#query: {}, #candidate: {}" 146 | .format(len(test_pos), len(cand_pos))) 147 | pbar = tqdm(zip(test_pos, test_parse, y_test_vecs, z_test_vecs)) 148 | for curr_label, gold_parse, y_test_vec, z_test_vec in pbar: 149 | gold = tree.Tree.from_str(gold_parse) 150 | gold_brackets = brackets(gold) 151 | y_results_len[i]["gold_count"] += sum(gold_brackets.values()) 152 | z_results_len[i]["gold_count"] += sum(gold_brackets.values()) 153 | idx = cosine_similarity( 154 | y_test_vec[None, :], y_cand_vecs).argmax(-1) 155 | y_best_pred = cand_pos[idx] 156 | y_best_parse = cand_parse[idx] 157 | idx = cosine_similarity( 158 | z_test_vec[None, :], z_cand_vecs).argmax(-1) 159 | z_best_pred = cand_pos[idx] 160 | z_best_parse = cand_parse[idx] 161 | 162 | curr_label_ = [] 163 | for t in curr_label.strip().split(" "): 164 | if t not in tag2num: 165 | tag2num[t] = len(tag2num) 166 | curr_label_.append(tag2num[t]) 167 | tot_label[i].extend(curr_label_) 168 | y_best_pred_ = [] 169 | for t in y_best_pred.strip().split(" "): 170 | if t not in tag2num: 171 | tag2num[t] = len(tag2num) 172 | y_best_pred_.append(tag2num[t]) 173 | y_tot_pred[i].extend(y_best_pred_) 174 | z_best_pred_ = [] 175 | for t in z_best_pred.strip().split(" "): 176 | if t not in tag2num: 177 | tag2num[t] = len(tag2num) 178 | z_best_pred_.append(tag2num[t]) 179 | z_tot_pred[i].extend(z_best_pred_) 180 | 181 | parse = tree.Tree.from_str(y_best_parse) 182 | parse_brackets = brackets(parse) 183 | delta_parse_count = sum(parse_brackets.values()) 184 | 185 | delta_match_count = 0 186 | for bracket, count in parse_brackets.items(): 187 | delta_match_count += min(count, gold_brackets[bracket]) 188 | y_curr_f1 = (2. / (y_results_len[i]["gold_count"] / 189 | float(y_results_len[i]["match_count"] + delta_match_count) + 190 | (y_results_len[i]["parse_count"] + delta_parse_count) / 191 | float(y_results_len[i]["match_count"] + delta_match_count))) \ 192 | if float(y_results_len[i]["match_count"] + delta_match_count) else 0 193 | 194 | y_results_len[i]["match_count"] += delta_match_count 195 | y_results_len[i]["parse_count"] += delta_parse_count 196 | 197 | parse = tree.Tree.from_str(z_best_parse) 198 | parse_brackets = brackets(parse) 199 | delta_parse_count = sum(parse_brackets.values()) 200 | 201 | delta_match_count = 0 202 | for bracket, count in parse_brackets.items(): 203 | delta_match_count += min(count, gold_brackets[bracket]) 204 | z_curr_f1 = (2. / (z_results_len[i]["gold_count"] / 205 | float(z_results_len[i]["match_count"] + delta_match_count) + 206 | (z_results_len[i]["parse_count"] + delta_parse_count) / 207 | float(z_results_len[i]["match_count"] + delta_match_count))) \ 208 | if float(z_results_len[i]["match_count"] + delta_match_count) else 0 209 | 210 | z_results_len[i]["match_count"] += delta_match_count 211 | z_results_len[i]["parse_count"] += delta_parse_count 212 | 213 | pbar.set_description( 214 | "y - curr acc: {:.4f}, f1: {:.4f}, " 215 | "z - curr acc: {:.4f}, f1: {:.4f}".format( 216 | (np.array(y_tot_pred[i]) == np.array(tot_label[i]) 217 | .astype("float32")).mean(), y_curr_f1, 218 | (np.array(z_tot_pred[i]) == np.array(tot_label[i]) 219 | .astype("float32")).mean(), z_curr_f1)) 220 | pbar.close() 221 | 222 | e.log.info( 223 | "y - curr acc: {:.4f}, f1: {:.4f}, " 224 | "z - curr acc: {:.4f}, f1: {:.4f}".format( 225 | (np.array(y_tot_pred[i]) == np.array(tot_label[i]) 226 | .astype("float32")).mean(), y_curr_f1, 227 | (np.array(z_tot_pred[i]) == np.array(tot_label[i]) 228 | .astype("float32")).mean(), z_curr_f1)) 229 | 230 | e.log.info("*" * 25 + " EVAL y " + "*" * 25) 231 | e.log.info("*" * 25 + " POS Acc " + "*" * 25) 232 | for i in range(1, MAX_LEN): 233 | e.log.info("length: {}, acc: {:.4f}" 234 | .format(i, (np.array(y_tot_pred[i]) == 235 | np.array(tot_label[i]).astype("float32")).mean())) 236 | 237 | e.log.info("*" * 25 + " Labeled F1 " + "*" * 25) 238 | tot_match = tot_parse = tot_gold = 0 239 | for lens, d in sorted(y_results_len.items()): 240 | e.log.info("length: {}, F1: {:.4f}" 241 | .format(lens, (2. / (d["gold_count"] / float(d["match_count"]) + 242 | d["parse_count"] / float(d["match_count"]))))) 243 | tot_match += d["match_count"] 244 | tot_parse += d["parse_count"] 245 | tot_gold += d["gold_count"] 246 | 247 | all_pred = sum(y_tot_pred.values(), []) 248 | all_label = sum(tot_label.values(), []) 249 | e.log.info("POS Acc: {:.4f}".format( 250 | (np.array(all_pred) == np.array(all_label)).astype("float32").mean())) 251 | 252 | e.log.info( 253 | "Labeled F1: {:.4f}".format( 254 | 2. / (tot_gold / float(tot_match) + tot_parse / float(tot_match)))) 255 | 256 | e.log.info("*" * 25 + " EVAL z " + "*" * 25) 257 | e.log.info("*" * 25 + " POS Acc " + "*" * 25) 258 | 259 | for i in range(1, MAX_LEN): 260 | e.log.info("length: {}, acc: {:.4f}" 261 | .format(i, (np.array(z_tot_pred[i]) == 262 | np.array(tot_label[i]) 263 | .astype("float32")).mean())) 264 | 265 | e.log.info("*" * 25 + " Labeled F1 " + "*" * 25) 266 | tot_match = tot_parse = tot_gold = 0 267 | for lens, d in sorted(z_results_len.items()): 268 | e.log.info("length: {}, F1: {:.4f}" 269 | .format(lens, (2. / (d["gold_count"] / float(d["match_count"]) + 270 | d["parse_count"] / float(d["match_count"]))))) 271 | tot_match += d["match_count"] 272 | tot_parse += d["parse_count"] 273 | tot_gold += d["gold_count"] 274 | 275 | all_pred = sum(z_tot_pred.values(), []) 276 | all_label = sum(tot_label.values(), []) 277 | e.log.info("POS Acc: {:.4f}".format( 278 | (np.array(all_pred) == np.array(all_label)).astype("float32").mean())) 279 | 280 | e.log.info( 281 | "Labeled F1: {:.4f}".format(2. / (tot_gold / float(tot_match) + 282 | tot_parse / float(tot_match)))) 283 | -------------------------------------------------------------------------------- /eval_sts.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import pickle 3 | import argparse 4 | 5 | import torch 6 | import models 7 | import data_utils 8 | import train_helper 9 | 10 | import numpy as np 11 | 12 | from tqdm import tqdm 13 | from scipy.stats import pearsonr 14 | 15 | 16 | MAX_LEN = 30 17 | batch_size = 1000 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--save_file', '-s', type=str) 22 | parser.add_argument('--vocab_file', '-v', type=str) 23 | parser.add_argument('--sts_file', '-d', type=str, default="sts-test.csv") 24 | args = parser.parse_args() 25 | 26 | 27 | def cosine_similarity(v1, v2): 28 | prod = (v1 * v2).sum(-1) 29 | v1_norm = (v1 ** 2).sum(-1) ** 0.5 30 | v2_norm = (v2 ** 2).sum(-1) ** 0.5 31 | return prod / (v1_norm * v2_norm) 32 | 33 | 34 | save_dict = torch.load( 35 | args.save_file, 36 | map_location=lambda storage, 37 | loc: storage) 38 | 39 | config = save_dict['config'] 40 | checkpoint = save_dict['state_dict'] 41 | config.debug = True 42 | 43 | with open(args.vocab_file, "rb") as fp: 44 | W, vocab = pickle.load(fp) 45 | 46 | sent1 = [] 47 | sent2 = [] 48 | gold_score = [] 49 | with open(args.sts_file) as fp: 50 | for i, line in enumerate(fp): 51 | d = line.strip().split("\t") 52 | score, s1, s2 = d[4], d[5], d[6] 53 | sent1.append(nltk.word_tokenize(s1.lower())) 54 | sent2.append(nltk.word_tokenize(s2.lower())) 55 | gold_score.append(float(score)) 56 | 57 | with train_helper.experiment(config, config.save_prefix) as e: 58 | e.log.info("vocab loaded from: {}".format(args.vocab_file)) 59 | e.log.info("data loaded from: {}".format(args.sts_file)) 60 | model = models.vgvae( 61 | vocab_size=len(vocab), 62 | embed_dim=e.config.edim if W is None else W.shape[1], 63 | embed_init=W, 64 | experiment=e) 65 | model.eval() 66 | model.load(checkpointed_state_dict=checkpoint) 67 | e.log.info(model) 68 | 69 | def encode(d): 70 | global vocab 71 | new_d = [[vocab.get(w, 0) for w in s] for s in d] 72 | all_y_vecs = [] 73 | all_z_vecs = [] 74 | 75 | for s1, _, m1, s2, _, m2, _, _, _, _, _ in \ 76 | tqdm(data_utils.minibatcher( 77 | data1=np.array(new_d), 78 | tag1=np.array(new_d), 79 | data2=np.array(new_d), 80 | tag2=np.array(new_d), 81 | tag_bucket=None, 82 | batch_size=100, 83 | p_replace=0., 84 | shuffle=False, 85 | p_scramble=0.)): 86 | with torch.no_grad(): 87 | semantics, semantics_mask, synatx, syntax_mask = \ 88 | model.to_tensors(s1, m1, s2, m2) 89 | _, yvecs = model.yencode(semantics, semantics_mask) 90 | _, zvecs = model.zencode(synatx, syntax_mask) 91 | 92 | ymean = model.mean1(yvecs) 93 | ymean = ymean / ymean.norm(dim=-1, keepdim=True) 94 | zmean = model.mean2(zvecs) 95 | 96 | all_y_vecs.append(ymean.cpu().numpy()) 97 | all_z_vecs.append(zmean.cpu().numpy()) 98 | return np.concatenate(all_y_vecs), np.concatenate(all_z_vecs) 99 | 100 | s1y, s1z = encode(sent1) 101 | s2y, s2z = encode(sent2) 102 | yscore = pearsonr(cosine_similarity(s1y, s2y), gold_score)[0] 103 | zscore = pearsonr(cosine_similarity(s1z, s2z), gold_score)[0] 104 | e.log.info("y score: {:.4f}, z score: {:.4f}".format(yscore, zscore)) 105 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import re 6 | import subprocess 7 | import threading 8 | import tempfile 9 | 10 | from config import METEOR_JAR, METEOR_DATA, STANFORD_CORENLP 11 | from nltk.tree import Tree 12 | from zss import simple_distance, Node 13 | 14 | 15 | def enc(s): 16 | return s.encode('utf-8') 17 | 18 | 19 | def dec(s): 20 | return s.decode('utf-8') 21 | 22 | 23 | class Meteor: 24 | def __init__(self): 25 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, 26 | '-', '-', '-stdio', '-l', 'en', '-norm', '-a', 27 | METEOR_DATA] 28 | self.meteor_p = subprocess.Popen( 29 | self.meteor_cmd, 30 | cwd=os.path.dirname(os.path.abspath(__file__)), 31 | stdin=subprocess.PIPE, 32 | stdout=subprocess.PIPE, 33 | stderr=subprocess.PIPE) 34 | # Used to guarantee thread safety 35 | self.lock = threading.Lock() 36 | 37 | def compute_score(self, gts, res): 38 | assert(gts.keys() == res.keys()) 39 | imgIds = gts.keys() 40 | scores = [] 41 | 42 | eval_line = 'EVAL' 43 | self.lock.acquire() 44 | for i in imgIds: 45 | assert(len(res[i]) == 1) 46 | stat = self._stat(res[i][0], gts[i]) 47 | eval_line += ' ||| {}'.format(stat) 48 | 49 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 50 | self.meteor_p.stdin.flush() 51 | for i in range(0, len(imgIds)): 52 | scores.append(dec(float(self.meteor_p.stdout.readline().strip()))) 53 | score = float(dec(self.meteor_p.stdout.readline().strip())) 54 | self.lock.release() 55 | 56 | return score, scores 57 | 58 | def _stat(self, hypothesis_str, reference_list): 59 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 60 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 61 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 62 | self.meteor_p.stdin.write(enc(score_line + "\n")) 63 | self.meteor_p.stdin.flush() 64 | return dec(self.meteor_p.stdout.readline()).strip() 65 | 66 | def _score(self, hypothesis_str, reference_list): 67 | # self.lock.acquire() 68 | with self.lock: 69 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 70 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 71 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 72 | self.meteor_p.stdin.write(enc(score_line + "\n")) 73 | self.meteor_p.stdin.flush() 74 | stats = dec(self.meteor_p.stdout.readline().strip()) 75 | eval_line = 'EVAL ||| {}'.format(stats) 76 | # EVAL ||| stats 77 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 78 | self.meteor_p.stdin.flush() 79 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 80 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 81 | # thanks for Andrej for pointing this out 82 | score = float(dec(self.meteor_p.stdout.readline().strip())) 83 | # self.lock.release() 84 | return score 85 | 86 | def __del__(self): 87 | self.lock.acquire() 88 | self.meteor_p.stdin.close() 89 | self.meteor_p.kill() 90 | self.meteor_p.wait() 91 | self.lock.release() 92 | 93 | 94 | def deleaf(parse_string): 95 | tree = Tree.fromstring(parse_string.strip(), read_leaf=lambda s: "") 96 | for sub in tree.subtrees(): 97 | for n, child in enumerate(sub): 98 | if isinstance(child, str): 99 | continue 100 | if len(list(child.subtrees(filter=lambda x: x.label() == '-NONE-'))) == len(child.leaves()): 101 | del sub[n] 102 | oneline = tree.pformat(margin=10000, parens=[" ( ", " ) "]) 103 | oneline = re.sub(' +', ' ', oneline) 104 | return oneline 105 | 106 | 107 | def extract_parses(fname): 108 | # extract parses from corenlp output 109 | # based on https://github.com/miyyer/scpn/blob/master/read_paranmt_parses.py 110 | with open(fname, 'r', encoding='utf-8') as f: 111 | 112 | count = 0 113 | sentences = [] 114 | data = {'tokens': [], 'pos': [], 'parse': '', 'deps': []} 115 | for idx, line in enumerate(f): 116 | if idx <= 1: 117 | continue 118 | if line.startswith('Sentence #'): 119 | new_sent = True 120 | new_pos = False 121 | new_parse = False 122 | new_deps = False 123 | if idx == 2: 124 | continue 125 | 126 | sentences.append(data) 127 | count += 1 128 | 129 | data = {'tokens': [], 'pos': [], 'parse': '', 'deps': []} 130 | 131 | # read original sentence 132 | elif new_sent: 133 | new_sent = False 134 | new_pos = True 135 | 136 | elif new_pos and line.startswith("Tokens"): 137 | continue 138 | 139 | # read POS tags 140 | elif new_pos and line.startswith('[Text='): 141 | line = line.strip().split() 142 | w = line[0].split('[Text=')[-1] 143 | pos = line[-1].split('PartOfSpeech=')[-1][:-1] 144 | data['tokens'].append(w) 145 | data['pos'].append(pos) 146 | 147 | # start reading const parses 148 | elif (new_pos or new_parse) and len(line.strip()): 149 | if line.startswith("Constituency parse"): 150 | continue 151 | new_pos = False 152 | new_parse = True 153 | data['parse'] += ' ' + line.strip() 154 | 155 | # start reading deps 156 | elif (new_parse and line.strip() == "") or \ 157 | line.startswith("Dependency Parse"): 158 | new_parse = False 159 | new_deps = True 160 | 161 | elif new_deps and len(line.strip()): 162 | line = line.strip()[:-1].split('(', 1) 163 | rel = line[0] 164 | x1, x2 = line[1].split(', ') 165 | x1 = x1.replace("'", "") 166 | x2 = x2.replace("'", "") 167 | x1 = int(x1.rsplit('-', 1)[-1]) 168 | x2 = int(x2.rsplit('-', 1)[-1]) 169 | data['deps'].append((rel, x1 - 1, x2 - 1)) 170 | 171 | else: 172 | new_deps = False 173 | 174 | sentences.append(data) 175 | 176 | return sentences 177 | 178 | 179 | class stanford_parsetree_extractor: 180 | def __init__(self): 181 | self.stanford_corenlp_path = os.path.join(STANFORD_CORENLP, "*") 182 | print("standford corenlp path:", self.stanford_corenlp_path) 183 | self.output_dir = tempfile.TemporaryDirectory() 184 | self.cmd = ['java', '-cp', self.stanford_corenlp_path, 185 | '-Xmx2G', 'edu.stanford.nlp.pipeline.StanfordCoreNLP', 186 | '-annotators', 'tokenize,ssplit,pos,parse', 187 | '-ssplit.eolonly', '-outputFormat', 'text', 188 | '-outputDirectory', self.output_dir.name, 189 | '-file', None] 190 | 191 | def run(self, file): 192 | print("parsing file:", file) 193 | self.cmd[-1] = file 194 | out = subprocess.run( 195 | self.cmd, 196 | cwd=os.path.dirname(os.path.abspath(__file__)), 197 | stdout=subprocess.PIPE, 198 | stderr=subprocess.PIPE) 199 | print(out) 200 | parsed_file = \ 201 | os.path.join( 202 | self.output_dir.name, 203 | os.path.split(file)[1] + ".out") 204 | return [deleaf(e['parse']).strip() for e in extract_parses(parsed_file)] 205 | 206 | def cleanup(self): 207 | self.output_dir.cleanup() 208 | 209 | 210 | def build_tree(s): 211 | old_t = Tree.fromstring(s) 212 | new_t = Node("S") 213 | 214 | def create_tree(curr_t, t): 215 | if t.label() and t.label() != "S": 216 | new_t = Node(t.label()) 217 | curr_t.addkid(new_t) 218 | else: 219 | new_t = curr_t 220 | for i in t: 221 | if isinstance(i, Tree): 222 | create_tree(new_t, i) 223 | create_tree(new_t, old_t) 224 | return new_t 225 | 226 | 227 | def strdist(a, b): 228 | if a == b: 229 | return 0 230 | else: 231 | return 1 232 | 233 | 234 | def compute_tree_edit_distance(pred_parse, ref_parse): 235 | return simple_distance( 236 | build_tree(ref_parse), build_tree(pred_parse), label_dist=strdist) 237 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import argparse 5 | import tempfile 6 | import subprocess 7 | 8 | import torch 9 | import models 10 | import data_utils 11 | import train_helper 12 | 13 | import numpy as np 14 | 15 | from beam_search import beam_search, get_gen_fn 16 | from config import BOS_IDX, EOS_IDX 17 | from tqdm import tqdm 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--save_file', '-s', type=str) 22 | parser.add_argument('--vocab_file', '-v', type=str) 23 | parser.add_argument('--input_file', '-i', type=str) 24 | parser.add_argument('--ref_file', '-r', type=str) 25 | parser.add_argument('--beam_size', '-bs', type=int, default=10) 26 | args = parser.parse_args() 27 | 28 | 29 | save_dict = torch.load( 30 | args.save_file, 31 | map_location=lambda storage, 32 | loc: storage) 33 | 34 | config = save_dict['config'] 35 | checkpoint = save_dict['state_dict'] 36 | config.debug = True 37 | 38 | with open(args.vocab_file, "rb") as fp: 39 | W, vocab = pickle.load(fp) 40 | inv_vocab = {i: w for w, i in vocab.items()} 41 | 42 | if config.decoder_type == "lstm": 43 | config.decoder_type = "lstm_z2y" 44 | config.ncode = None 45 | config.nclass = None 46 | 47 | with train_helper.experiment(config, config.save_prefix) as e: 48 | e.log.info("vocab loaded from: {}".format(args.vocab_file)) 49 | model = models.vgvae( 50 | vocab_size=len(vocab), 51 | embed_dim=e.config.edim if W is None else W.shape[1], 52 | embed_init=W, 53 | experiment=e) 54 | model.load(checkpointed_state_dict=checkpoint) 55 | e.log.info(model) 56 | 57 | semantics_input = [] 58 | syntax_input = [] 59 | e.log.info("loading from: {}".format(args.input_file)) 60 | with open(args.input_file) as fp: 61 | for line in fp: 62 | seman_in, syn_in = line.strip().split("\t") 63 | semantics_input.append( 64 | [vocab.get(w.lower(), 0) for w in 65 | seman_in.strip().split(" ")]) 66 | syntax_input.append([vocab.get(w.lower(), 0) for w in 67 | syn_in.strip().split(" ")]) 68 | e.log.info("#evaluation data: {}, {}".format( 69 | len(semantics_input), 70 | len(syntax_input))) 71 | 72 | tf = tempfile.NamedTemporaryFile(mode='w+', delete=False) 73 | e.log.info('generation saving to {}'.format(tf.name)) 74 | e.log.info('beam size: {}'.format(args.beam_size)) 75 | for s1, _, m1, s2, _, m2, _, _, _, _, _ in \ 76 | tqdm(data_utils.minibatcher( 77 | data1=np.array(semantics_input), 78 | tag1=np.array(semantics_input), 79 | data2=np.array(syntax_input), 80 | tag2=np.array(syntax_input), 81 | tag_bucket=None, 82 | batch_size=1, 83 | p_replace=0., 84 | shuffle=False, 85 | p_scramble=0.)): 86 | with torch.no_grad(): 87 | model.eval() 88 | semantics, semantics_mask, synatx, syntax_mask = \ 89 | model.to_tensors(s1, m1, s2, m2) 90 | _, yvecs = model.yencode(semantics, semantics_mask) 91 | _, zvecs = model.zencode(synatx, syntax_mask) 92 | 93 | ymean = model.mean1(yvecs) 94 | ymean = ymean / ymean.norm(dim=-1, keepdim=True) 95 | zmean = model.mean2(zvecs) 96 | 97 | generate_function = get_gen_fn(model.decode.step, ymean, zmean) 98 | initial_state = None 99 | if e.config.decoder_type.lower() == "yz_lstm": 100 | init_vecs = model.decode.latent2init( 101 | torch.cat([ymean, zmean], -1)) 102 | initial_state = tuple([h.unsqueeze(0).contiguous() for h in 103 | torch.chunk(init_vecs, 2, -1)]) 104 | gen = beam_search( 105 | initial_state=initial_state, 106 | generate_function=generate_function, 107 | start_id=BOS_IDX, 108 | end_id=EOS_IDX, 109 | beam_width=args.beam_size)[0][0] 110 | 111 | curr_gen = [] 112 | for i in gen[1:]: 113 | if i == EOS_IDX: 114 | break 115 | curr_gen.append(inv_vocab[int(i)]) 116 | 117 | tf.write(" ".join(curr_gen)) 118 | tf.write("\n") 119 | tf.flush() 120 | if args.ref_file is not None: 121 | e.log.info('running eval.py using reference file {}' 122 | .format(args.ref_file)) 123 | subprocess.run( 124 | ['python', 'eval.py', '-i', tf.name, '-r', args.ref_file], 125 | cwd=os.path.dirname(os.path.abspath(__file__)), 126 | stdout=sys.stderr, 127 | stderr=sys.stdout) 128 | -------------------------------------------------------------------------------- /ive.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import scipy.special 5 | from numbers import Number 6 | 7 | 8 | class IveFunction(torch.autograd.Function): 9 | 10 | @staticmethod 11 | def forward(self, v, z): 12 | 13 | assert isinstance(v, Number), 'v must be a scalar' 14 | 15 | self.save_for_backward(z) 16 | self.v = v 17 | z_cpu = z.detach().cpu().numpy() 18 | 19 | if np.isclose(v, 0): 20 | output = scipy.special.i0e(z_cpu, dtype=np.double) 21 | elif np.isclose(v, 1): 22 | output = scipy.special.i1e(z_cpu, dtype=np.double) 23 | else: # v > 0 24 | output = scipy.special.ive(v, z_cpu, dtype=np.double) 25 | # else: 26 | # print(v, type(v), np.isclose(v, 0)) 27 | # raise RuntimeError('v must be >= 0, it is {}'.format(v)) 28 | 29 | return torch.tensor(output).to(z.device) 30 | 31 | @staticmethod 32 | def backward(self, grad_output): 33 | z = self.saved_tensors[-1].double() 34 | return None, (grad_output.double() * 35 | (ive(self.v - 1, z) - ive(self.v, z) * 36 | (self.v + z) / z)).float() 37 | 38 | 39 | class Ive(torch.nn.Module): 40 | 41 | def __init__(self, v): 42 | super(Ive, self).__init__() 43 | self.v = v 44 | 45 | def forward(self, z): 46 | return ive(self.v, z) 47 | 48 | 49 | ive = IveFunction.apply 50 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_mlp(input_size, hidden_size, output_size, n_layer, dropout): 6 | if n_layer == 0: 7 | proj = nn.Sequential( 8 | nn.Linear(input_size, output_size)) 9 | else: 10 | proj = nn.Sequential( 11 | nn.Linear(input_size, hidden_size), 12 | nn.ReLU(), 13 | nn.Dropout(dropout)) 14 | 15 | for i in range(n_layer - 1): 16 | proj.add_module( 17 | str(len(proj)), 18 | nn.Linear(hidden_size, hidden_size)) 19 | proj.add_module(str(len(proj)), nn.ReLU()) 20 | proj.add_module(str(len(proj)), nn.Dropout(dropout)) 21 | 22 | proj.add_module( 23 | str(len(proj)), 24 | nn.Linear(hidden_size, output_size)) 25 | return proj 26 | 27 | 28 | def get_rnn_vecs( 29 | inputs, 30 | mask, 31 | cell, 32 | bidir=False, 33 | initial_state=None, 34 | get_last=False, 35 | filter_zero_length=False): 36 | """ 37 | Args: 38 | inputs: batch_size x seq_len x n_feat 39 | mask: batch_size x seq_len 40 | initial_state: batch_size x num_layers x hidden_size 41 | cell: GRU/LSTM/RNN 42 | """ 43 | if bidir: 44 | seq_lengths = torch.sum(mask, dim=-1) 45 | sorted_len, sorted_idx = seq_lengths.sort(0, descending=True) 46 | sorted_inputs = inputs[sorted_idx.long()] 47 | sorted_len = sorted_len.long().cpu().numpy() 48 | if filter_zero_length: 49 | sorted_len[sorted_len == 0] = 1 50 | packed_seq = torch.nn.utils.rnn.pack_padded_sequence( 51 | sorted_inputs, sorted_len, batch_first=True) 52 | if initial_state is not None: 53 | if isinstance(cell, torch.nn.LSTM): 54 | initial_state = \ 55 | (initial_state[0].index_select(1, sorted_idx.long()), 56 | initial_state[1].index_select(1, sorted_idx.long())) 57 | else: 58 | initial_state = \ 59 | initial_state.index_select(1, sorted_idx.long()) 60 | out, hid = cell(packed_seq, hx=initial_state) 61 | unpacked, unpacked_len = \ 62 | torch.nn.utils.rnn.pad_packed_sequence( 63 | out, batch_first=True) 64 | _, original_idx = sorted_idx.sort(0, descending=False) 65 | output_seq = unpacked[original_idx.long()] 66 | if get_last: 67 | if isinstance(hid, tuple): 68 | if bidir: 69 | hid = tuple([torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2) 70 | for h in hid]) 71 | hid = \ 72 | (hid[0].index_select(1, original_idx.long()), 73 | hid[1].index_select(1, original_idx.long())) 74 | else: 75 | if bidir: 76 | hid = torch.cat( 77 | [hid[0:hid.size(0):2], hid[1:hid.size(0):2]], 2) 78 | hid = hid.index_select(1, original_idx.long()) 79 | else: 80 | output_seq, hid = cell(inputs, hx=initial_state) 81 | return output_seq, hid 82 | 83 | 84 | def gauss_kl_div(mean, logvar, eps=1e-8): 85 | """KL(p||N(0,1)) 86 | args: 87 | mean: batch size x * x dimension 88 | logvar: batch size x * x dimension 89 | 90 | return: 91 | KL divergence: batch size x * 92 | """ 93 | return -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).sum(-1) 94 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import model_utils 4 | import encoders 5 | import decoders 6 | 7 | import numpy as np 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | from von_mises_fisher import VonMisesFisher 12 | from decorators import auto_init_args, auto_init_pytorch 13 | 14 | MAX_LEN = 32 15 | 16 | 17 | class base(nn.Module): 18 | def __init__(self, vocab_size, embed_dim, embed_init, experiment): 19 | super(base, self).__init__() 20 | self.expe = experiment 21 | self.eps = self.expe.config.eps 22 | if torch.cuda.is_available(): 23 | self.device = torch.device('cuda') 24 | else: 25 | self.device = torch.device('cpu') 26 | 27 | def pos_loss(self, mask, vecs, func): 28 | batch_size, seq_len = mask.size() 29 | # batch size x seq len x MAX LEN 30 | logits = func(vecs) 31 | if MAX_LEN - seq_len: 32 | padded = torch.zeros(batch_size, MAX_LEN - seq_len).to(mask.device) 33 | new_mask = 1 - torch.cat([mask, padded], -1) 34 | else: 35 | new_mask = 1 - mask 36 | new_mask = new_mask.unsqueeze(1).expand_as(logits) 37 | logits.masked_fill_(new_mask.byte(), -float('inf')) 38 | loss = F.softmax(logits, -1)[:, np.arange(int(seq_len)), 39 | np.arange(int(seq_len))] 40 | loss = -(loss + self.eps).log() * mask 41 | 42 | loss = loss.sum(-1) / mask.sum(1) 43 | return loss.mean() 44 | 45 | def sample_gaussian(self, mean, logvar): 46 | sample = mean + torch.exp(0.5 * logvar) * \ 47 | logvar.new_empty(logvar.size()).normal_() 48 | return sample 49 | 50 | def to_tensor(self, inputs): 51 | if torch.is_tensor(inputs): 52 | return inputs.clone().detach().to(self.device) 53 | else: 54 | return torch.tensor(inputs, device=self.device) 55 | 56 | def to_tensors(self, *inputs): 57 | return [self.to_tensor(inputs_) if inputs_ is not None and inputs_.size 58 | else None for inputs_ in inputs] 59 | 60 | def optimize(self, loss): 61 | self.opt.zero_grad() 62 | loss.backward() 63 | if self.expe.config.gclip is not None: 64 | torch.nn.utils.clip_grad_norm( 65 | self.parameters(), self.expe.config.gclip) 66 | self.opt.step() 67 | 68 | def init_optimizer(self, opt_type, learning_rate, weight_decay): 69 | if opt_type.lower() == "adam": 70 | optimizer = torch.optim.Adam 71 | elif opt_type.lower() == "rmsprop": 72 | optimizer = torch.optim.RMSprop 73 | elif opt_type.lower() == "sgd": 74 | optimizer = torch.optim.SGD 75 | else: 76 | raise NotImplementedError("invalid optimizer: {}".format(opt_type)) 77 | 78 | opt = optimizer( 79 | params=filter( 80 | lambda p: p.requires_grad, self.parameters() 81 | ), 82 | weight_decay=weight_decay, 83 | lr=learning_rate) 84 | 85 | return opt 86 | 87 | def save(self, dev_bleu, dev_stats, test_bleu, test_stats, 88 | epoch, iteration=None, name="best"): 89 | save_path = os.path.join(self.expe.experiment_dir, name + ".ckpt") 90 | checkpoint = { 91 | "dev_bleu": dev_bleu, 92 | "dev_stats": dev_stats, 93 | "test_bleu": test_bleu, 94 | "test_stats": test_stats, 95 | "epoch": epoch, 96 | "iteration": iteration, 97 | "state_dict": self.state_dict(), 98 | "opt_state_dict": self.opt.state_dict(), 99 | "config": self.expe.config 100 | } 101 | torch.save(checkpoint, save_path) 102 | self.expe.log.info("model saved to {}".format(save_path)) 103 | 104 | def load(self, checkpointed_state_dict=None, name="best"): 105 | if checkpointed_state_dict is None: 106 | save_path = os.path.join(self.expe.experiment_dir, name + ".ckpt") 107 | checkpoint = torch.load(save_path, 108 | map_location=lambda storage, 109 | loc: storage) 110 | self.load_state_dict(checkpoint['state_dict']) 111 | self.opt.load_state_dict(checkpoint.get("opt_state_dict")) 112 | self.expe.log.info("model loaded from {}".format(save_path)) 113 | self.to(self.device) 114 | for state in self.opt.state.values(): 115 | for k, v in state.items(): 116 | if isinstance(v, torch.Tensor): 117 | state[k] = v.to(self.device) 118 | self.expe.log.info("transferred model to {}".format(self.device)) 119 | return checkpoint.get('epoch', 0), \ 120 | checkpoint.get('iteration', 0), \ 121 | checkpoint.get('dev_bleu', 0), \ 122 | checkpoint.get('test_bleu', 0) 123 | else: 124 | self.load_state_dict(checkpointed_state_dict) 125 | self.expe.log.info("model loaded from checkpoint.") 126 | self.to(self.device) 127 | self.expe.log.info("transferred model to {}".format(self.device)) 128 | 129 | 130 | class vgvae(base): 131 | @auto_init_pytorch 132 | @auto_init_args 133 | def __init__(self, vocab_size, embed_dim, embed_init, experiment): 134 | super(vgvae, self).__init__( 135 | vocab_size, embed_dim, embed_init, experiment) 136 | self.yencode = getattr(encoders, self.expe.config.yencoder_type)( 137 | embed_dim=embed_dim, 138 | embed_init=embed_init, 139 | hidden_size=self.expe.config.ensize, 140 | vocab_size=vocab_size, 141 | dropout=self.expe.config.dp, 142 | log=experiment.log) 143 | 144 | self.zencode = getattr(encoders, self.expe.config.zencoder_type)( 145 | embed_dim=embed_dim, 146 | embed_init=embed_init, 147 | hidden_size=self.expe.config.ensize, 148 | vocab_size=vocab_size, 149 | dropout=self.expe.config.dp, 150 | mlp_hidden_size=self.expe.config.mhsize, 151 | mlp_layer=self.expe.config.mlplayer, 152 | ncode=self.expe.config.ncode, 153 | nclass=self.expe.config.nclass, 154 | log=experiment.log) 155 | 156 | if "lstm" in self.expe.config.yencoder_type.lower(): 157 | y_out_size = 2 * self.expe.config.ensize 158 | elif self.expe.config.yencoder_type.lower() == "word_avg": 159 | y_out_size = embed_dim 160 | 161 | if "lstm" in self.expe.config.zencoder_type.lower(): 162 | z_out_size = 2 * self.expe.config.ensize 163 | elif self.expe.config.zencoder_type.lower() == "word_avg": 164 | z_out_size = embed_dim 165 | 166 | self.mean1 = model_utils.get_mlp( 167 | input_size=y_out_size, 168 | hidden_size=self.expe.config.mhsize, 169 | output_size=self.expe.config.ysize, 170 | n_layer=self.expe.config.ymlplayer, 171 | dropout=self.expe.config.dp) 172 | 173 | self.logvar1 = model_utils.get_mlp( 174 | input_size=y_out_size, 175 | hidden_size=self.expe.config.mhsize, 176 | output_size=1, 177 | n_layer=self.expe.config.ymlplayer, 178 | dropout=self.expe.config.dp) 179 | 180 | self.mean2 = model_utils.get_mlp( 181 | input_size=z_out_size, 182 | hidden_size=self.expe.config.mhsize, 183 | output_size=self.expe.config.zsize, 184 | n_layer=self.expe.config.zmlplayer, 185 | dropout=self.expe.config.dp) 186 | 187 | self.logvar2 = model_utils.get_mlp( 188 | input_size=z_out_size, 189 | hidden_size=self.expe.config.mhsize, 190 | output_size=self.expe.config.zsize, 191 | n_layer=self.expe.config.zmlplayer, 192 | dropout=self.expe.config.dp) 193 | 194 | self.decode = getattr(decoders, self.expe.config.decoder_type)( 195 | embed_init=embed_init, 196 | embed_dim=embed_dim, 197 | ysize=self.expe.config.ysize, 198 | zsize=self.expe.config.zsize, 199 | mlp_hidden_size=self.expe.config.mhsize, 200 | mlp_layer=self.expe.config.mlplayer, 201 | hidden_size=self.expe.config.desize, 202 | dropout=self.expe.config.dp, 203 | vocab_size=vocab_size, 204 | log=experiment.log) 205 | 206 | if "lc" in self.expe.config.zencoder_type.lower(): 207 | enc_embed_dim = embed_dim // self.expe.config.ncode *\ 208 | self.expe.config.ncode 209 | else: 210 | enc_embed_dim = embed_dim 211 | 212 | self.enc_pos_decode = model_utils.get_mlp( 213 | input_size=self.expe.config.zsize + enc_embed_dim, 214 | hidden_size=self.expe.config.mhsize, 215 | n_layer=self.expe.config.mlplayer, 216 | output_size=MAX_LEN, 217 | dropout=self.expe.config.dp) 218 | 219 | self.dec_pos_decode = model_utils.get_mlp( 220 | input_size=self.expe.config.zsize + embed_dim, 221 | hidden_size=self.expe.config.mhsize, 222 | n_layer=self.expe.config.mlplayer, 223 | output_size=MAX_LEN, 224 | dropout=self.expe.config.dp) 225 | 226 | def sent2param(self, sent, sent_repl, mask): 227 | yembed, yvecs = self.yencode(sent, mask) 228 | zembed, zvecs = self.zencode(sent_repl, mask) 229 | 230 | mean = self.mean1(yvecs) 231 | mean = mean / mean.norm(dim=-1, keepdim=True) 232 | logvar = self.logvar1(yvecs) 233 | var = F.softplus(logvar) + 1 234 | 235 | mean2 = self.mean2(zvecs) 236 | logvar2 = self.logvar2(zvecs) 237 | 238 | return zembed, mean, var, mean2, logvar2 239 | 240 | def forward(self, sent1, sent_repl1, mask1, sent2, sent_repl2, 241 | mask2, tgt1, tgt_mask1, tgt2, tgt_mask2, vtemp, gtemp): 242 | self.train() 243 | sent1, sent_repl1, mask1, sent2, sent_repl2, mask2, tgt1, \ 244 | tgt_mask1, tgt2, tgt_mask2 = \ 245 | self.to_tensors(sent1, sent_repl1, mask1, sent2, sent_repl2, 246 | mask2, tgt1, tgt_mask1, tgt2, tgt_mask2) 247 | 248 | s1_zembed, sent1_mean, sent1_var, sent1_mean2, sent1_logvar2 = \ 249 | self.sent2param(sent1, sent_repl1, mask1) 250 | s2_zembed, sent2_mean, sent2_var, sent2_mean2, sent2_logvar2 = \ 251 | self.sent2param(sent2, sent_repl2, mask2) 252 | 253 | sent1_dist = VonMisesFisher(sent1_mean, sent1_var) 254 | sent2_dist = VonMisesFisher(sent2_mean, sent2_var) 255 | 256 | sent1_syntax = self.sample_gaussian(sent1_mean2, sent1_logvar2) 257 | sent2_syntax = self.sample_gaussian(sent2_mean2, sent2_logvar2) 258 | 259 | sent1_semantic = sent1_dist.rsample() 260 | sent2_semantic = sent2_dist.rsample() 261 | 262 | logloss1, s1_decs = self.decode( 263 | sent1_semantic, sent1_syntax, tgt1, tgt_mask1) 264 | logloss2, s2_decs = self.decode( 265 | sent2_semantic, sent2_syntax, tgt2, tgt_mask2) 266 | 267 | logloss3, s3_decs = self.decode( 268 | sent2_semantic, sent1_syntax, tgt1, tgt_mask1) 269 | logloss4, s4_decs = self.decode( 270 | sent1_semantic, sent2_syntax, tgt2, tgt_mask2) 271 | 272 | if self.expe.config.pratio: 273 | s1_vecs = torch.cat( 274 | [s1_zembed, 275 | sent1_syntax.unsqueeze(1) 276 | .expand(-1, s1_zembed.size(1), -1)], 277 | -1) 278 | s2_vecs = torch.cat( 279 | [s2_zembed, 280 | sent2_syntax.unsqueeze(1) 281 | .expand(-1, s2_zembed.size(1), -1)], 282 | -1) 283 | ploss1 = self.pos_loss( 284 | mask1, s1_vecs, self.enc_pos_decode) 285 | ploss2 = self.pos_loss( 286 | mask2, s2_vecs, self.enc_pos_decode) 287 | ploss3 = self.pos_loss( 288 | tgt_mask1, s3_decs, self.dec_pos_decode) 289 | ploss4 = self.pos_loss( 290 | tgt_mask2, s4_decs, self.dec_pos_decode) 291 | ploss = ploss1 + ploss2 + ploss3 + ploss4 292 | else: 293 | ploss = torch.zeros_like(logloss1) 294 | 295 | sent1_kl = model_utils.gauss_kl_div( 296 | sent1_mean2, sent1_logvar2, 297 | eps=self.eps).mean() 298 | sent2_kl = model_utils.gauss_kl_div( 299 | sent2_mean2, sent2_logvar2, 300 | eps=self.eps).mean() 301 | 302 | vkl = sent1_dist.kl_div().mean() + sent2_dist.kl_div().mean() 303 | 304 | gkl = sent1_kl + sent2_kl 305 | 306 | rec_logloss = logloss1 + logloss2 307 | 308 | para_logloss = logloss3 + logloss4 309 | 310 | loss = self.expe.config.lratio * rec_logloss + \ 311 | self.expe.config.plratio * para_logloss + \ 312 | vtemp * vkl + gtemp * gkl + \ 313 | self.expe.config.pratio * ploss 314 | 315 | return loss, vkl, gkl, rec_logloss, para_logloss, ploss 316 | 317 | def greedy_decode(self, semantics, semantics_mask, 318 | synatx, syntax_mask, max_len): 319 | self.eval() 320 | synatx, syntax_mask, semantics, semantics_mask = \ 321 | self.to_tensors(synatx, syntax_mask, semantics, semantics_mask) 322 | _, yvecs = self.yencode(semantics, semantics_mask) 323 | _, zvecs = self.zencode(synatx, syntax_mask) 324 | 325 | ymean = self.mean1(yvecs) 326 | ymean = ymean / ymean.norm(dim=-1, keepdim=True) 327 | 328 | zmean = self.mean2(zvecs) 329 | 330 | return self.decode.greedy_decode(ymean, zmean, max_len) 331 | -------------------------------------------------------------------------------- /run-vgvae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --debug 0 \ 5 | --auto_disconnect 1 \ 6 | --save_prefix vgvae-exp \ 7 | --decoder_type lstm_z2y \ 8 | --yencoder_type word_avg \ 9 | --zencoder_type bilstm \ 10 | --n_epoch 20 \ 11 | --train_path train.txt \ 12 | --train_tag_path train.tag \ 13 | --tag_vocab_file word2tag.pkl \ 14 | --embed_file glove.6B.100d.txt \ 15 | --embed_type glove \ 16 | --dev_inp_path dev_input.txt \ 17 | --dev_ref_path dev_ref.txt \ 18 | --test_inp_path test_input.txt \ 19 | --test_ref_path test_ref.txt \ 20 | --pre_train_emb 1 \ 21 | --vocab_file vocab \ 22 | --vocab_size 50000 \ 23 | --batch_size 30 \ 24 | --dropout 0.0 \ 25 | --l2 0.0 \ 26 | --word_replace 0.0 \ 27 | --max_vmf_kl_temp 1e-4 \ 28 | --max_gauss_kl_temp 1e-3 \ 29 | --zmlp_n_layer 2 \ 30 | --ymlp_n_layer 2 \ 31 | --mlp_n_layer 3 \ 32 | --para_logloss_ratio 1.0 \ 33 | --ploss_ratio 1.0 \ 34 | --mlp_hidden_size 100 \ 35 | --ysize 50 \ 36 | --zsize 50 \ 37 | --embed_dim 100 \ 38 | --encoder_size 100 \ 39 | --decoder_size 100 \ 40 | --p_scramble 0.0 \ 41 | --print_every 500 \ 42 | --eval_every 5000 \ 43 | --summarize 1 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import train_helper 2 | import data_utils 3 | import config 4 | 5 | import models 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | best_dev_bleu = test_bleu = 0 10 | 11 | 12 | def run(e): 13 | global best_dev_bleu, test_bleu 14 | 15 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 16 | dp = data_utils.data_processor( 17 | train_path=e.config.train_path, 18 | experiment=e) 19 | data, W = dp.process() 20 | 21 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 22 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 23 | 24 | model = models.vgvae( 25 | vocab_size=len(data.vocab), 26 | embed_dim=e.config.edim if W is None else W.shape[1], 27 | embed_init=W, 28 | experiment=e) 29 | 30 | start_epoch = true_it = 0 31 | best_dev_stats = test_stats = None 32 | if e.config.resume: 33 | start_epoch, _, best_dev_bleu, test_bleu = \ 34 | model.load(name="latest") 35 | e.log.info( 36 | "resumed from previous checkpoint: start epoch: {}, " 37 | "iteration: {}, best dev bleu: {:.3f}, test bleu: {:.3f}, " 38 | .format(start_epoch, true_it, best_dev_bleu, test_bleu)) 39 | 40 | e.log.info(model) 41 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 42 | 43 | if e.config.summarize: 44 | writer = SummaryWriter(e.experiment_dir) 45 | 46 | train_batch = data_utils.minibatcher( 47 | data1=data.train_data[0], 48 | tag1=data.train_tag[0], 49 | data2=data.train_data[1], 50 | tag2=data.train_tag[1], 51 | tag_bucket=data.tag_bucket, 52 | vocab_size=len(data.vocab), 53 | batch_size=e.config.batch_size, 54 | shuffle=True, 55 | p_replace=e.config.wr, 56 | p_scramble=e.config.ps) 57 | 58 | dev_eval = train_helper.evaluator( 59 | e.config.dev_inp_path, e.config.dev_ref_path, 60 | model, data.vocab, data.inv_vocab, e) 61 | test_eval = train_helper.evaluator( 62 | e.config.test_inp_path, e.config.test_ref_path, 63 | model, data.vocab, data.inv_vocab, e) 64 | 65 | e.log.info("Training start ...") 66 | train_stats = train_helper.tracker(["loss", "vmf_kl", "gauss_kl", 67 | "rec_logloss", "para_logloss", 68 | "wploss"]) 69 | 70 | for epoch in range(start_epoch, e.config.n_epoch): 71 | for it, (s1, sr1, m1, s2, sr2, m2, t1, tm1, t2, tm2, _) in \ 72 | enumerate(train_batch): 73 | true_it = it + 1 + epoch * len(train_batch) 74 | 75 | loss, kl, kl2, rec_logloss, para_logloss, wploss = \ 76 | model(s1, sr1, m1, s2, sr2, m2, t1, tm1, 77 | t2, tm2, e.config.vmkl, e.config.gmkl) 78 | model.optimize(loss) 79 | train_stats.update( 80 | {"loss": loss, "vmf_kl": kl, "gauss_kl": kl2, 81 | "para_logloss": para_logloss, "rec_logloss": rec_logloss, 82 | "wploss": wploss}, 83 | len(s1)) 84 | 85 | if (true_it + 1) % e.config.print_every == 0 or \ 86 | (true_it + 1) % len(train_batch) == 0: 87 | summarization = train_stats.summarize( 88 | "epoch: {}, it: {} (max: {}), kl_temp(v|g): {:.2E}|{:.2E}" 89 | .format(epoch, it, len(train_batch), 90 | e.config.vmkl, e.config.gmkl)) 91 | e.log.info(summarization) 92 | if e.config.summarize: 93 | for name, value in train_stats.stats.items(): 94 | writer.add_scalar( 95 | "train/" + name, value, true_it) 96 | train_stats.reset() 97 | 98 | if (true_it + 1) % e.config.eval_every == 0 or \ 99 | (true_it + 1) % len(train_batch) == 0: 100 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 101 | 102 | dev_stats, dev_bleu = dev_eval.evaluate("gen_dev") 103 | 104 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 105 | 106 | if e.config.summarize: 107 | for name, value in dev_stats.items(): 108 | writer.add_scalar( 109 | "dev/" + name, value, true_it) 110 | 111 | if best_dev_bleu < dev_bleu: 112 | best_dev_bleu = dev_bleu 113 | best_dev_stats = dev_stats 114 | 115 | e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) 116 | 117 | test_stats, test_bleu = test_eval.evaluate("gen_test") 118 | 119 | e.log.info("*" * 25 + " TEST SET EVALUATION " + "*" * 25) 120 | 121 | model.save( 122 | dev_bleu=best_dev_bleu, 123 | dev_stats=best_dev_stats, 124 | test_bleu=test_bleu, 125 | test_stats=test_stats, 126 | iteration=true_it, 127 | epoch=epoch) 128 | 129 | if e.config.summarize: 130 | for name, value in test_stats.items(): 131 | writer.add_scalar( 132 | "test/" + name, value, true_it) 133 | 134 | e.log.info("best dev bleu: {:.4f}, test bleu: {:.4f}" 135 | .format(best_dev_bleu, test_bleu)) 136 | 137 | model.save( 138 | dev_bleu=best_dev_bleu, 139 | dev_stats=best_dev_stats, 140 | test_bleu=test_bleu, 141 | test_stats=test_stats, 142 | iteration=true_it, 143 | epoch=epoch + 1, 144 | name="latest") 145 | 146 | time_per_epoch = (e.elapsed_time / (epoch - start_epoch + 1)) 147 | time_in_need = time_per_epoch * (e.config.n_epoch - epoch - 1) 148 | e.log.info("elapsed time: {:.2f}(h), " 149 | "time per epoch: {:.2f}(h), " 150 | "time needed to finish: {:.2f}(h)" 151 | .format(e.elapsed_time, time_per_epoch, time_in_need)) 152 | 153 | if time_per_epoch + e.elapsed_time > 3.7 and e.config.auto_disconnect: 154 | exit(1) 155 | 156 | test_gen_stats, test_res = test_eval.evaluate("gen_test") 157 | 158 | 159 | if __name__ == '__main__': 160 | 161 | args = config.get_base_parser().parse_args() 162 | 163 | def exit_handler(*args): 164 | print(args) 165 | print("best dev bleu: {:.4f}, test bleu: {:.4f}" 166 | .format(best_dev_bleu, test_bleu)) 167 | exit() 168 | 169 | train_helper.register_exit_handler(exit_handler) 170 | 171 | with train_helper.experiment(args, args.save_prefix) as e: 172 | 173 | e.log.info("*" * 25 + " ARGS " + "*" * 25) 174 | e.log.info(args) 175 | e.log.info("*" * 25 + " ARGS " + "*" * 25) 176 | 177 | run(e) 178 | -------------------------------------------------------------------------------- /train_helper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import argparse 4 | import os 5 | import torch 6 | import data_utils 7 | import rouge 8 | import signal 9 | import subprocess 10 | 11 | import numpy as np 12 | 13 | from config import get_base_parser, MULTI_BLEU_PERL, \ 14 | EOS_IDX, MAX_GEN_LEN 15 | from decorators import auto_init_args 16 | from eval_utils import Meteor 17 | 18 | 19 | def register_exit_handler(exit_handler): 20 | import atexit 21 | 22 | atexit.register(exit_handler) 23 | signal.signal(signal.SIGTERM, exit_handler) 24 | signal.signal(signal.SIGINT, exit_handler) 25 | 26 | 27 | def run_multi_bleu(input_file, reference_file): 28 | bleu_output = subprocess.check_output( 29 | "./{} -lc {} < {}".format(MULTI_BLEU_PERL, reference_file, input_file), 30 | stderr=subprocess.STDOUT, shell=True).decode('utf-8') 31 | bleu = float( 32 | bleu_output.strip().split("\n")[-1] 33 | .split(",")[0].split("=")[1][1:]) 34 | return bleu 35 | 36 | 37 | def get_kl_anneal_function(anneal_function, max_value, slope): 38 | if anneal_function.lower() == 'exp': 39 | return lambda step, curr_value: \ 40 | min(max_value, float(1 / (1 + np.exp(-slope * step + 100)))) 41 | elif anneal_function.lower() == 'linear': 42 | return lambda step, curr_value: \ 43 | min(max_value, curr_value + slope * max_value / (step + 100)) 44 | elif anneal_function.lower() == 'linear2': 45 | return lambda step, curr_value: min(max_value, slope * step) 46 | else: 47 | raise ValueError("invalid anneal function: {}".format(anneal_function)) 48 | 49 | 50 | class tracker: 51 | @auto_init_args 52 | def __init__(self, names): 53 | assert len(names) > 0 54 | self.reset() 55 | 56 | def __getitem__(self, name): 57 | return self.values.get(name, 0) / self.counter if self.counter else 0 58 | 59 | def __len__(self): 60 | return len(self.names) 61 | 62 | def reset(self): 63 | self.values = dict({name: 0. for name in self.names}) 64 | self.counter = 0 65 | self.create_time = time.time() 66 | 67 | def update(self, named_values, count): 68 | """ 69 | named_values: dictionary with each item as name: value 70 | """ 71 | self.counter += count 72 | for name, value in named_values.items(): 73 | self.values[name] += value.item() * count 74 | 75 | def summarize(self, output=""): 76 | if output: 77 | output += ", " 78 | for name in self.names: 79 | output += "{}: {:.3f}, ".format( 80 | name, self.values[name] / self.counter if self.counter else 0) 81 | output += "elapsed time: {:.1f}(s)".format( 82 | time.time() - self.create_time) 83 | return output 84 | 85 | @property 86 | def stats(self): 87 | return {n: v / self.counter if self.counter else 0 88 | for n, v in self.values.items()} 89 | 90 | 91 | class experiment: 92 | @auto_init_args 93 | def __init__(self, config, experiments_prefix, logfile_name="log"): 94 | """Create a new Experiment instance. 95 | 96 | Modified based on: https://github.com/ex4sperans/mag 97 | 98 | Args: 99 | logfile_name: str, naming for log file. This can be useful to 100 | separate logs for different runs on the same experiment 101 | experiments_prefix: str, a prefix to the path where 102 | experiment will be saved 103 | """ 104 | 105 | # get all defaults 106 | all_defaults = {} 107 | for key in vars(config): 108 | all_defaults[key] = get_base_parser().get_default(key) 109 | 110 | self.default_config = all_defaults 111 | 112 | config.resume = False 113 | if not config.debug: 114 | if os.path.isdir(self.experiment_dir): 115 | print("log exists: {}".format(self.experiment_dir)) 116 | config.resume = True 117 | 118 | print(config) 119 | self._makedir() 120 | 121 | self._make_misc_dir() 122 | 123 | def _makedir(self): 124 | os.makedirs(self.experiment_dir, exist_ok=True) 125 | 126 | def _make_misc_dir(self): 127 | os.makedirs(self.config.vocab_file, exist_ok=True) 128 | 129 | @property 130 | def experiment_dir(self): 131 | if self.config.debug: 132 | return "./" 133 | else: 134 | # get namespace for each group of args 135 | arg_g = dict() 136 | for group in get_base_parser()._action_groups: 137 | group_d = {a.dest: self.default_config.get(a.dest, None) 138 | for a in group._group_actions} 139 | arg_g[group.title] = argparse.Namespace(**group_d) 140 | 141 | # skip default value 142 | identifier = "" 143 | for key, value in sorted(vars(arg_g["model_configs"]).items()): 144 | if getattr(self.config, key) != value: 145 | identifier += key + str(getattr(self.config, key)) 146 | return os.path.join(self.experiments_prefix, identifier) 147 | 148 | @property 149 | def log_file(self): 150 | return os.path.join(self.experiment_dir, self.logfile_name) 151 | 152 | def register_directory(self, dirname): 153 | directory = os.path.join(self.experiment_dir, dirname) 154 | os.makedirs(directory, exist_ok=True) 155 | setattr(self, dirname, directory) 156 | 157 | def _register_existing_directories(self): 158 | for item in os.listdir(self.experiment_dir): 159 | fullpath = os.path.join(self.experiment_dir, item) 160 | if os.path.isdir(fullpath): 161 | setattr(self, item, fullpath) 162 | 163 | def __enter__(self): 164 | 165 | if self.config.debug: 166 | logging.basicConfig( 167 | level=logging.DEBUG, 168 | format='%(asctime)s %(levelname)s: %(message)s', 169 | datefmt='%m-%d %H:%M') 170 | else: 171 | print("log saving to", self.log_file) 172 | logging.basicConfig( 173 | filename=self.log_file, 174 | filemode='a+', level=logging.INFO, 175 | format='%(asctime)s %(levelname)s: %(message)s', 176 | datefmt='%m-%d %H:%M') 177 | 178 | self.log = logging.getLogger() 179 | self.start_time = time.time() 180 | return self 181 | 182 | def __exit__(self, *args): 183 | logging.shutdown() 184 | 185 | @property 186 | def elapsed_time(self): 187 | return (time.time() - self.start_time) / 3600 188 | 189 | 190 | class evaluator: 191 | @auto_init_args 192 | def __init__(self, inp_path, ref_path, model, vocab, inv_vocab, experiment): 193 | self.expe = experiment 194 | self.ref_path = ref_path 195 | self.semantics_input = [] 196 | self.syntax_input = [] 197 | self.references = [] 198 | self.expe.log.info("loading eval input from: {}".format(inp_path)) 199 | with open(inp_path) as fp: 200 | for line in fp: 201 | seman_in, syn_in = line.strip().split("\t") 202 | self.semantics_input.append( 203 | [vocab.get(w.lower(), 0) for w in 204 | seman_in.strip().split(" ")]) 205 | self.syntax_input.append([vocab.get(w.lower(), 0) for w in 206 | syn_in.strip().split(" ")]) 207 | 208 | self.expe.log.info("loading reference from: {}".format(ref_path)) 209 | with open(ref_path) as fp: 210 | for line in fp: 211 | self.references.append(line.strip().lower()) 212 | self.expe.log.info("#data: {}, {}, {}".format( 213 | len(self.semantics_input), 214 | len(self.syntax_input), len(self.references))) 215 | 216 | def evaluate(self, gen_fn): 217 | self.model.eval() 218 | meteor = Meteor() 219 | rouge_eval = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], 220 | max_n=2, 221 | limit_length=True, 222 | length_limit=MAX_GEN_LEN, 223 | length_limit_type='words', 224 | apply_avg=False, 225 | apply_best=False, 226 | alpha=0.5, # Default F1_score 227 | weight_factor=1.2, 228 | stemming=True) 229 | stats = {"bleu": 0, "rouge1": 0, 230 | "rouge2": 0, "rougel": 0, "meteor": 0} 231 | all_gen = [] 232 | for s1, _, m1, s2, _, m2, _, _, _, _, _ in \ 233 | data_utils.minibatcher( 234 | data1=np.array(self.semantics_input), 235 | tag1=np.array(self.semantics_input), 236 | data2=np.array(self.syntax_input), 237 | tag2=np.array(self.syntax_input), 238 | tag_bucket=None, 239 | batch_size=100, 240 | p_replace=0., 241 | shuffle=False, 242 | p_scramble=0.): 243 | with torch.no_grad(): 244 | batch_gen = self.model.greedy_decode( 245 | s1, m1, s2, m2, MAX_GEN_LEN) 246 | for gen in batch_gen: 247 | curr_gen = [] 248 | for i in gen: 249 | if i == EOS_IDX: 250 | break 251 | curr_gen.append(self.inv_vocab[int(i)]) 252 | all_gen.append(" ".join(curr_gen)) 253 | assert len(all_gen) == len(self.references), \ 254 | "{} != {}".format(len(all_gen), len(self.references)) 255 | fn = os.path.join(self.expe.experiment_dir, gen_fn + ".txt") 256 | with open(fn, "w+") as fp: 257 | for hyp, ref in zip(all_gen, self.references): 258 | fp.write(hyp) 259 | fp.write("\n") 260 | try: 261 | stats['meteor'] += meteor._score(hyp, [ref]) 262 | except ValueError: 263 | stats['meteor'] += 0 264 | rs = rouge_eval.get_scores([hyp], [ref]) 265 | stats['rouge1'] += rs['rouge-1'][0]['f'][0] 266 | stats['rouge2'] += rs['rouge-2'][0]['f'][0] 267 | stats['rougel'] += rs['rouge-l'][0]['f'][0] 268 | 269 | stats['bleu'] = run_multi_bleu(fn, self.ref_path) 270 | self.expe.log.info("generated sentences saved to: {}".format(fn)) 271 | 272 | stats['meteor'] = stats['meteor'] / len(all_gen) * 100 273 | stats['rouge1'] = stats['rouge1'] / len(all_gen) * 100 274 | stats['rouge2'] = stats['rouge2'] / len(all_gen) * 100 275 | stats['rougel'] = stats['rougel'] / len(all_gen) * 100 276 | 277 | self.expe.log.info( 278 | "#Data: {}, bleu: {:.3f}, meteor: {:.3f}, " 279 | "rouge-1: {:.3f}, rouge-2: {:.3f}, rouge-l: {:.3f}" 280 | .format(len(all_gen), stats['bleu'], stats['meteor'], 281 | stats['rouge1'], stats['rouge2'], stats['rougel'])) 282 | return stats, stats['bleu'] 283 | -------------------------------------------------------------------------------- /tree.py: -------------------------------------------------------------------------------- 1 | # tree.py 2 | # David Chiang 3 | 4 | # Tom Dong 5 | 6 | import re 7 | 8 | class RootDeletedException(Exception): 9 | pass 10 | 11 | 12 | class Node(object): 13 | def __init__(self, label, children): 14 | self.label = label 15 | self.children = children 16 | for (i, child) in enumerate(self.children): 17 | if child.parent is not None: 18 | child.detach() 19 | child.parent = self 20 | child.order = i 21 | self.parent = None 22 | self.order = 0 23 | 24 | def __str__(self): 25 | return self.label 26 | 27 | def _subtree_str(self): 28 | if len(self.children) != 0: 29 | return "(%s %s)" % (self.label, " ".join(child._subtree_str() for child in self.children)) 30 | else: 31 | s = '%s' % self.label 32 | #s = s.replace("(", "-LRB-") 33 | #s = s.replace(")", "-RRB-") 34 | return s 35 | 36 | def insert_child(self, i, child): 37 | if child.parent is not None: 38 | child.detach() 39 | child.parent = self 40 | self.children[i:i] = [child] 41 | for j in range(i, len(self.children)): 42 | self.children[j].order = j 43 | 44 | def append_child(self, child): 45 | if child.parent is not None: 46 | child.detach() 47 | child.parent = self 48 | self.children.append(child) 49 | child.order = len(self.children) - 1 50 | 51 | def delete_child(self, i): 52 | self.children[i].parent = None 53 | self.children[i].order = 0 54 | self.children[i:i + 1] = [] 55 | for j in range(i, len(self.children)): 56 | self.children[j].order = j 57 | 58 | def detach(self): 59 | if self.parent is None: 60 | raise RootDeletedException() 61 | self.parent.delete_child(self.order) 62 | 63 | def delete_clean(self): 64 | "Cleans up childless ancestors" 65 | parent = self.parent 66 | self.detach() 67 | if len(parent.children) == 0: 68 | parent.delete_clean() 69 | 70 | def bottom_up(self): 71 | for child in self.children: 72 | for node in child.bottom_up(): 73 | yield node 74 | yield self 75 | 76 | def top_down(self): 77 | yield self 78 | for child in self.children: 79 | for node in child.top_down(): 80 | yield node 81 | 82 | def leaves(self): 83 | if len(self.children) == 0: 84 | yield self 85 | else: 86 | for child in self.children: 87 | for leaf in child.leaves(): 88 | yield leaf 89 | 90 | 91 | class Tree(object): 92 | def __init__(self, root): 93 | self.root = root 94 | 95 | def __str__(self): 96 | return self.root._subtree_str() 97 | 98 | interior_node = re.compile(r"\s*\(([^\s)]*)") 99 | close_brace = re.compile(r"\s*\)") 100 | leaf_node = re.compile(r'\s*([^\s)]+)') 101 | 102 | @staticmethod 103 | def _scan_tree(s): 104 | result = Tree.interior_node.match(s) 105 | if result is not None: 106 | label = result.group(1) 107 | pos = result.end() 108 | children = [] 109 | (child, length) = Tree._scan_tree(s[pos:]) 110 | while child is not None: 111 | children.append(child) 112 | pos += length 113 | (child, length) = Tree._scan_tree(s[pos:]) 114 | result = Tree.close_brace.match(s[pos:]) 115 | if result is not None: 116 | pos += result.end() 117 | return Node(label, children), pos 118 | else: 119 | return None, 0 120 | else: 121 | result = Tree.leaf_node.match(s) 122 | if result is not None: 123 | pos = result.end() 124 | label = result.group(1) 125 | #label = label.replace("-LRB-", "(") 126 | #label = label.replace("-RRB-", ")") 127 | return Node(label, []), pos 128 | else: 129 | return None, 0 130 | 131 | @staticmethod 132 | def from_str(s): 133 | s = s.strip() 134 | (tree, n) = Tree._scan_tree(s) 135 | return Tree(tree) 136 | 137 | def bottom_up(self): 138 | """ Traverse the nodes of the tree bottom-up. """ 139 | return self.root.bottom_up() 140 | 141 | def top_down(self): 142 | """ Traverse the nodes of the tree top-down. """ 143 | return self.root.top_down() 144 | 145 | def leaves(self): 146 | """ Traverse the leaf nodes of the tree. """ 147 | return self.root.leaves() 148 | 149 | def remove_empty(self): 150 | """ Remove empty nodes. """ 151 | nodes = list(self.bottom_up()) 152 | for node in nodes: 153 | if node.label == '-NONE-': 154 | try: 155 | node.delete_clean() 156 | except RootDeletedException: 157 | self.root = None 158 | 159 | def remove_unit(self): 160 | """ Remove unary nodes by fusing them with their parents. """ 161 | nodes = list(self.bottom_up()) 162 | for node in nodes: 163 | if len(node.children) == 1: 164 | child = node.children[0] 165 | if len(child.children) > 0: 166 | node.label = "%s_%s" % (node.label, child.label) 167 | child.detach() 168 | for grandchild in list(child.children): 169 | node.append_child(grandchild) 170 | 171 | def restore_unit(self): 172 | """ Restore the unary nodes that were removed by remove_unit(). """ 173 | 174 | def visit(node): 175 | children = [visit(child) for child in node.children] 176 | labels = node.label.split('_') 177 | node = Node(labels[-1], children) 178 | for label in reversed(labels[:-1]): 179 | node = Node(label, [node]) 180 | return node 181 | 182 | self.root = visit(self.root) 183 | 184 | def binarize_right(self): 185 | """ Binarize into a right-branching structure. """ 186 | nodes = list(self.bottom_up()) 187 | for node in nodes: 188 | if len(node.children) > 2: 189 | # create a right-branching structure 190 | children = list(node.children) 191 | children.reverse() 192 | vlabel = node.label + "*" 193 | prev = children[0] 194 | for child in children[1:-1]: 195 | prev = Node(vlabel, [child, prev]) 196 | node.append_child(prev) 197 | 198 | def binarize_left(self): 199 | """ Binarize into a left-branching structure. """ 200 | nodes = list(self.bottom_up()) 201 | for node in nodes: 202 | if len(node.children) > 2: 203 | vlabel = node.label + "*" 204 | children = list(node.children) 205 | prev = children[0] 206 | for child in children[1:-1]: 207 | prev = Node(vlabel, [prev, child]) 208 | node.insert_child(0, prev) 209 | 210 | def binarize(self): 211 | """ Binarize into a left-branching or right-branching structure 212 | using linguistically motivated heuristics. Currently, the heuristic 213 | is extremely simple: SQ is right-branching, everything else is left-branching. """ 214 | nodes = list(self.bottom_up()) 215 | for node in nodes: 216 | if len(node.children) > 2: 217 | if node.label in ['SQ']: 218 | # create a right-branching structure 219 | children = list(node.children) 220 | children.reverse() 221 | vlabel = node.label + "*" 222 | prev = children[0] 223 | for child in children[1:-1]: 224 | prev = Node(vlabel, [child, prev]) 225 | node.append_child(prev) 226 | else: 227 | # create a left-branching structure 228 | vlabel = node.label + "*" 229 | children = list(node.children) 230 | prev = children[0] 231 | for child in children[1:-1]: 232 | prev = Node(vlabel, [prev, child]) 233 | node.insert_child(0, prev) 234 | 235 | def unbinarize(self): 236 | """ Undo binarization by removing any nodes ending with *. """ 237 | 238 | def visit(node): 239 | children = sum([visit(child) for child in node.children], []) 240 | if node.label.endswith('*'): 241 | return children 242 | else: 243 | return [Node(node.label, children)] 244 | 245 | roots = visit(self.root) 246 | assert len(roots) == 1 247 | self.root = roots[0] 248 | 249 | def normalize(self): 250 | self.binarize() 251 | self.remove_unit() 252 | 253 | def unnormalize(self): 254 | self.restore_unit() 255 | self.unbinarize() 256 | 257 | @staticmethod 258 | def _add_tag(node, tag): 259 | if '##' in node.label: 260 | node.label += (',%s' % tag) 261 | else: 262 | node.label += ('##%s' % tag) 263 | 264 | @staticmethod 265 | def _remove_tags(node): 266 | i = node.label.find('##') 267 | if i > 0: 268 | node.label = node.label[:i] 269 | 270 | if __name__ == "__main__": 271 | import sys 272 | 273 | for line in sys.stdin: 274 | t = Tree.from_str(line) 275 | print(t) 276 | 277 | -------------------------------------------------------------------------------- /von_mises_fisher.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import warnings 4 | 5 | import numpy as np 6 | 7 | from ive import ive 8 | from scipy import special as sp 9 | 10 | 11 | class HypersphericalUniform(torch.nn.Module): 12 | @property 13 | def dim(self): 14 | return self._dim 15 | 16 | def __init__(self, dim): 17 | super(HypersphericalUniform, self).__init__() 18 | self._dim = dim 19 | 20 | def entropy(self): 21 | return self.__log_surface_area() 22 | 23 | def __log_surface_area(self): 24 | output = math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - \ 25 | sp.loggamma((self._dim + 1) / 2).real 26 | return output 27 | 28 | 29 | class VonMisesFisher(torch.nn.Module): 30 | 31 | @property 32 | def mean(self): 33 | return self.loc * (ive(self.__m / 2, self.scale) / 34 | ive(self.__m / 2 - 1, self.scale)) 35 | 36 | @property 37 | def stddev(self): 38 | return self.scale 39 | 40 | def __init__(self, loc, scale): 41 | super(VonMisesFisher, self).__init__() 42 | self.loc = loc 43 | self.scale = scale 44 | self.__m = int(loc.size()[-1]) 45 | self.__e1 = torch.tensor([1.] + [0] * (loc.size()[-1] - 1), 46 | device=loc.device) 47 | 48 | def rsample(self, shape=[]): 49 | 50 | w = self.__sample_w3(shape=shape) \ 51 | if self.__m == 3 else self.__sample_w_rej(shape=shape) 52 | 53 | v = np.random.normal(0, 1, size=shape + list(self.loc.size())) 54 | v = np.swapaxes(v, 0, -1)[1:] 55 | v = np.swapaxes(v, 0, -1) 56 | 57 | v = torch.from_numpy(v).float().to(self.loc.device) 58 | 59 | v = v / v.norm(dim=-1, keepdim=True) 60 | 61 | x = torch.cat((w, torch.sqrt(1 - (w ** 2)) * v), -1) 62 | z = self.__householder_rotation(x) 63 | 64 | return z 65 | 66 | def __sample_w3(self, shape): 67 | shape = shape + list(self.scale.size()) 68 | u = torch.from_numpy( 69 | np.random.uniform(low=0, high=1, size=shape)).float()\ 70 | .to(self.loc.device) 71 | 72 | self.__w = torch.stack( 73 | [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0) 74 | self.__w = VonMisesFisher.logsumexp(self.__w, dim=0) 75 | self.__w = 1 + self.__w / self.scale 76 | return self.__w 77 | 78 | def __sample_w_rej(self, shape): 79 | c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2) 80 | b_true = (-2 * self.scale + c) / (self.__m - 1) 81 | 82 | # using Taylor approximation with a smooth swift from 10 < scale < 11 83 | # to avoid numerical errors for large scale 84 | b_app = (self.__m - 1) / (4 * self.scale) 85 | s = torch.min(torch.max( 86 | torch.zeros_like(self.scale), self.scale - 10)[0], 87 | torch.ones_like(self.scale))[0] 88 | b = b_app * s + b_true * (1 - s) 89 | 90 | a = (self.__m - 1 + 2 * self.scale + c) / 4 91 | d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1) 92 | 93 | self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape) 94 | return self.__w 95 | 96 | def __while_loop(self, b, a, d, shape): 97 | 98 | b, a, d = [e.repeat(*shape, *([1] * len(self.scale.size()))) for e in (b, a, d)] 99 | w, e, bool_mask = torch.zeros_like(b), \ 100 | torch.zeros_like(b), (torch.ones_like(b) == 1) 101 | 102 | shape = shape + list(self.scale.size()) 103 | max_try = 50 * np.array(shape).prod() 104 | n_try = 0 # will give up after given patiance 105 | while bool_mask.sum().item() != 0: 106 | if n_try > max_try: 107 | warnings.warn("Maximum iterations for rejection sampling exceeded!") 108 | break 109 | e_ = torch.from_numpy( 110 | np.random.beta((self.__m - 1) / 2, 111 | (self.__m - 1) / 2, 112 | size=shape[:-1]).reshape(shape)).float()\ 113 | .to(self.loc.device) 114 | u = torch.from_numpy(np.random.uniform(0, 1, size=shape)).float()\ 115 | .to(self.loc.device) 116 | 117 | w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_) 118 | t = (2 * a * b) / (1 - (1 - b) * e_) 119 | 120 | accept = ((self.__m - 1) * t.log() - t + d) > torch.log(u) 121 | reject = (1 - accept.long()).byte() 122 | 123 | accept_mask = (bool_mask * accept).detach() 124 | if accept_mask.sum().item(): 125 | w[accept_mask] = w_[accept_mask] 126 | e[accept_mask] = e_[accept_mask] 127 | 128 | bool_mask[accept_mask] = reject[accept_mask] 129 | n_try += 1 130 | 131 | return e, w 132 | 133 | def __householder_rotation(self, x): 134 | u = (self.__e1 - self.loc) 135 | u = u / (u.norm(dim=-1, keepdim=True) + 1e-5) 136 | z = x - 2 * (x * u).sum(-1, keepdim=True) * u 137 | return z 138 | 139 | def entropy(self): 140 | output = - self.scale.double() * \ 141 | ive(self.__m / 2, self.scale) / \ 142 | ive((self.__m / 2) - 1, self.scale) 143 | 144 | return output.view(*(output.size()[:-1])) + self._log_normalization() 145 | 146 | def log_prob(self, x): 147 | return self._log_unnormalized_prob(x) - self._log_normalization() 148 | 149 | def _log_unnormalized_prob(self, x): 150 | output = self.scale * (self.loc * x).sum(-1, keepdim=True) 151 | 152 | return output.view(*(output.size()[:-1])) 153 | 154 | def _log_normalization(self): 155 | output = - ((self.__m / 2 - 1) * torch.log(self.scale.double()) - 156 | (self.__m / 2) * math.log(2 * math.pi) - 157 | (self.scale.double() + torch.log( 158 | ive(self.__m / 2 - 1, self.scale)))) 159 | 160 | return output.view(*(output.size()[:-1])) 161 | 162 | @staticmethod 163 | def logsumexp(inputs, dim=None, keepdim=False): 164 | """Numerically stable logsumexp. 165 | 166 | Args: 167 | inputs: A Variable with any shape. 168 | dim: An integer. 169 | keepdim: A boolean. 170 | 171 | Returns: 172 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). 173 | """ 174 | # For a 1-D array x (any array along a single dimension), 175 | # log sum exp(x) = s + log sum exp(x - s) 176 | # with s = max(x) being a common choice. 177 | if dim is None: 178 | inputs = inputs.view(-1) 179 | dim = 0 180 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 181 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 182 | if not keepdim: 183 | outputs = outputs.squeeze(dim) 184 | return outputs 185 | 186 | def kl_div(self): 187 | # output = self.scale.double() * \ 188 | # ive(self.__m / 2, self.scale) / \ 189 | # ive((self.__m / 2) - 1, self.scale) + \ 190 | # ((self.__m / 2 - 1) * torch.log(self.scale.double()) - 191 | # (self.__m / 2) * math.log(2 * math.pi) - 192 | # torch.log(ive(self.__m / 2 - 1, self.scale))) + \ 193 | # self.__m / 2 * math.log(math.pi) + math.log(2) - \ 194 | # sp.loggamma(self.__m / 2).real 195 | # return output.float() 196 | return - self.entropy().float() + \ 197 | HypersphericalUniform(self.__m - 1).entropy() 198 | --------------------------------------------------------------------------------