├── WSD ├── utils │ ├── __init__.py │ ├── wordnet_reader_example.py │ ├── glove.py │ ├── path.py │ ├── losses.py │ └── store_result.py └── BiLSTM │ ├── __init__.py │ └── config.py ├── image ├── model.pdf └── model.png ├── inf.sh ├── Pun_Generation └── code │ ├── utils │ ├── misc_utils_test.py │ ├── vocab_utils_test.py │ ├── evaluation_utils_test.py │ ├── standard_hparams_utils.py │ ├── vocab_utils.py │ ├── scripts │ │ ├── bleu.py │ │ └── rouge.py │ ├── common_test_utils.py │ ├── nmt_utils.py │ ├── misc_utils.py │ ├── evaluation_utils.py │ └── iterator_utils_test.py │ ├── dealt.py │ ├── concatenate.py │ ├── attention_model.py │ ├── inference.py │ └── gnmt_model.py ├── Pun_Generation_Forward └── code │ ├── utils │ ├── misc_utils_test.py │ ├── vocab_utils_test.py │ ├── evaluation_utils_test.py │ ├── standard_hparams_utils.py │ ├── vocab_utils.py │ ├── scripts │ │ ├── bleu.py │ │ └── rouge.py │ ├── common_test_utils.py │ ├── nmt_utils.py │ ├── misc_utils.py │ └── evaluation_utils.py │ ├── attention_model.py │ ├── inference.py │ └── gnmt_model.py ├── README.md └── train.sh /WSD/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /image/model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lishunyao97/Pun-GAN/HEAD/image/model.pdf -------------------------------------------------------------------------------- /image/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lishunyao97/Pun-GAN/HEAD/image/model.png -------------------------------------------------------------------------------- /WSD/BiLSTM/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @version: python2.7 4 | @author: luofuli 5 | @time: 2018/5/3 13:13 6 | """ 7 | 8 | if __name__ == "__main__": 9 | pass -------------------------------------------------------------------------------- /WSD/utils/wordnet_reader_example.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from nltk.corpus.reader.wordnet import WordNetCorpusReader 3 | 4 | wn = WordNetCorpusReader(YOUR_WORDNET_PATH, '.*') # 这种方式就会有函数补全 5 | print('wordnet version %s: %s' % (wn.get_version(), YOUR_WORDNET_PATH)) 6 | 7 | print'get gloss from sensekey......' 8 | key = 'dance%1:04:00::' 9 | lemma = wn.lemma_from_key(key) 10 | synset = lemma.synset() 11 | print synset.definition() 12 | -------------------------------------------------------------------------------- /inf.sh: -------------------------------------------------------------------------------- 1 | python -u ./Pun_Generation/code/nmt.py \ 2 | --infer_batch_size=64 \ 3 | --vocab_prefix=./Pun_Generation/data/1backward/vocab \ 4 | --src=in \ 5 | --tgt=out \ 6 | --out_dir=./Pun_Generation/code/backward_model_path \ 7 | --train_prefix=./Pun_Generation/data/1backward/train \ 8 | --dev_prefix=./Pun_Generation/data/1backward/dev \ 9 | --test_prefix=./Pun_Generation/data/1backward/test \ 10 | --inference_input_file=./Pun_Generation/data/sample_2548 \ 11 | --inference_output_file=./Pun_Generation/code/backward_model_path/first_part_file \ 12 | --beam_width=10 \ 13 | --num_translations_per_input=1 > ./Pun_Generation/code/output_infer.txt 14 | 15 | python ./Pun_Generation/code/dealt.py 2548 16 | 17 | python -u ./Pun_Generation_Forward/code/nmt.py \ 18 | --infer_batch_size=64 \ 19 | --vocab_prefix=./Pun_Generation_Forward/data/2forward/vocab \ 20 | --src=in \ 21 | --tgt=out \ 22 | --out_dir=./Pun_Generation_Forward/code/forward_model_path \ 23 | --train_prefix=./Pun_Generation_Forward/data/2forward/train \ 24 | --dev_prefix=./Pun_Generation_Forward/data/2forward/dev \ 25 | --test_prefix=./Pun_Generation_Forward/data/2forward/test \ 26 | --inference_input_file=./Pun_Generation/code/backward_model_path/dealt_first_part_file \ 27 | --inference_output_file=./Pun_Generation_Forward/code/forward_model_path/second_part_file \ 28 | --beam_width=10 \ 29 | --num_translations_per_input=1 > ./Pun_Generation_Forward/code/output_infer.txt 30 | 31 | python ./Pun_Generation/code/concatenate.py 32 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/misc_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import misc_utils 25 | 26 | 27 | class MiscUtilsTest(tf.test.TestCase): 28 | 29 | def testFormatBpeText(self): 30 | bpe_line = ( 31 | b"En@@ ough to make already reluc@@ tant men hesitate to take screening" 32 | b" tests ." 33 | ) 34 | expected_result = ( 35 | b"Enough to make already reluctant men hesitate to take screening tests" 36 | b" ." 37 | ) 38 | self.assertEqual(expected_result, 39 | misc_utils.format_bpe_text(bpe_line.split(b" "))) 40 | 41 | def testFormatSPMText(self): 42 | spm_line = u"\u2581This \u2581is \u2581a \u2581 te st .".encode("utf-8") 43 | expected_result = "This is a test." 44 | self.assertEqual(expected_result, 45 | misc_utils.format_spm_text(spm_line.split(b" "))) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/misc_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import misc_utils 25 | 26 | 27 | class MiscUtilsTest(tf.test.TestCase): 28 | 29 | def testFormatBpeText(self): 30 | bpe_line = ( 31 | b"En@@ ough to make already reluc@@ tant men hesitate to take screening" 32 | b" tests ." 33 | ) 34 | expected_result = ( 35 | b"Enough to make already reluctant men hesitate to take screening tests" 36 | b" ." 37 | ) 38 | self.assertEqual(expected_result, 39 | misc_utils.format_bpe_text(bpe_line.split(b" "))) 40 | 41 | def testFormatSPMText(self): 42 | spm_line = u"\u2581This \u2581is \u2581a \u2581 te st .".encode("utf-8") 43 | expected_result = "This is a test." 44 | self.assertEqual(expected_result, 45 | misc_utils.format_spm_text(spm_line.split(b" "))) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /Pun_Generation/code/dealt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | PUNGAN_ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 3 | target_words = [] 4 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation/data/sample_'+sys.argv[1]) as f: 5 | for line in f: 6 | target_words.append(line.strip()) 7 | backward = [] 8 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation/code/backward_model_path/first_part_file') as f: 9 | for line in f: 10 | backward.append(line.strip()) 11 | backward_split = [backward[i:i + 32] for i in range(0, len(backward), 32)] 12 | 13 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation/code/backward_model_path/dealt_first_part_file','w') as fw: 14 | remain = [] 15 | if len(backward_split)%2 == 0 and len(backward_split[-1]) == 32: 16 | unit = len(backward_split)/2 17 | elif len(backward_split)%2 == 1: 18 | unit = len(backward_split)/2 19 | remain = backward_split[-1] 20 | elif len(backward_split)%2 == 0: 21 | unit = len(backward_split)/2 - 1 22 | remain = backward_split[-2] + backward_split[-1] 23 | for i in range(unit): 24 | for j in range(32): 25 | l1 = backward[i * 64 + j].split() 26 | l1.reverse() 27 | l1.append(target_words[i * 64 + j * 2]) 28 | fw.write(' '.join(l1[1:])+'\n') 29 | l2 = backward[i * 64 + 32 + j].split() 30 | l2.reverse() 31 | l2.append(target_words[i * 64 + j * 2 + 1]) 32 | fw.write(' '.join(l2[1:])+'\n') 33 | if remain: 34 | id = unit * 64 35 | for sent in remain[:len(remain)/2]: 36 | l1 = sent.split() 37 | l1.reverse() 38 | l1.append(target_words[id]) 39 | fw.write(' '.join(l1[1:])+'\n') 40 | l2 = sent.split() 41 | l2.reverse() 42 | l2.append(target_words[id + 1]) 43 | fw.write(' '.join(l2[1:])+'\n') 44 | id += 2 45 | -------------------------------------------------------------------------------- /WSD/utils/glove.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import numpy as np 3 | from utils import path 4 | _path = path.WSD_path() 5 | glove_dir = _path.GLOVE_DIR 6 | 7 | 8 | def load_glove(dim,size): 9 | if size=='6B': 10 | path = glove_dir + 'glove.6B/glove.6B.' + str(dim) + 'd.txt' 11 | elif size=='42B' and dim==300: 12 | path = glove_dir+'glove.42B.300d.txt' 13 | elif size=='840B' and dim==300: 14 | path = glove_dir+'glove.840B.300d.txt' 15 | else: 16 | print(u'没有满足要求的glove model') 17 | exit(-3) 18 | wordvecs = {} 19 | with open(path, 'r') as file: 20 | lines = file.readlines() 21 | for line in lines: 22 | tokens = line.split(' ') 23 | vec = np.array(tokens[1:], dtype=np.float32) 24 | wordvecs[tokens[0]] = vec 25 | 26 | return wordvecs 27 | 28 | 29 | def fill_with_gloves(word_to_id, emb_size, vocab_size, wordvecs=None): 30 | if not wordvecs: 31 | wordvecs = load_glove(emb_size,vocab_size) 32 | 33 | n_words = len(word_to_id) 34 | res = np.zeros([n_words, emb_size], dtype=np.float32) 35 | n_not_found = 0 36 | words_notin = set() 37 | for word, id in word_to_id.iteritems(): 38 | if '#' in word: 39 | word = word.split('#')[0] ## 去掉pos 40 | 41 | if '-' in word: 42 | words = word.split('-') 43 | elif '_' in word: 44 | words = word.split('_') 45 | else: 46 | words = [word] 47 | 48 | vecs = [] 49 | for w in words: 50 | if w in wordvecs: 51 | vecs.append(wordvecs[w]) # 如果是连词,就拆分成几个词,然后求加和 52 | if vecs != []: 53 | res[id, :] = np.mean(np.array(vecs), 0) 54 | else: 55 | words_notin.add(word) 56 | n_not_found += 1 57 | res[id, :] = np.random.normal(0.0, 0.1, emb_size) 58 | print 'n words not found in glove word vectors: ' + str(n_not_found) 59 | open('../tmp/word_not_in_glove.txt','w').write((u'\n'.join(words_notin)).encode('utf-8')) 60 | 61 | return res 62 | 63 | 64 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/vocab_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from ..utils import vocab_utils 27 | 28 | 29 | class VocabUtilsTest(tf.test.TestCase): 30 | 31 | def testCheckVocab(self): 32 | # Create a vocab file 33 | vocab_dir = os.path.join(tf.test.get_temp_dir(), "vocab_dir") 34 | os.makedirs(vocab_dir) 35 | vocab_file = os.path.join(vocab_dir, "vocab_file") 36 | vocab = ["a", "b", "c"] 37 | with codecs.getwriter("utf-8")(tf.gfile.GFile(vocab_file, "wb")) as f: 38 | for word in vocab: 39 | f.write("%s\n" % word) 40 | 41 | # Call vocab_utils 42 | out_dir = os.path.join(tf.test.get_temp_dir(), "out_dir") 43 | os.makedirs(out_dir) 44 | vocab_size, new_vocab_file = vocab_utils.check_vocab( 45 | vocab_file, out_dir) 46 | 47 | # Assert: we expect the code to add , , and 48 | # create a new vocab file 49 | self.assertEqual(len(vocab) + 3, vocab_size) 50 | self.assertEqual(os.path.join(out_dir, "vocab_file"), new_vocab_file) 51 | new_vocab, _ = vocab_utils.load_vocab(new_vocab_file) 52 | self.assertEqual( 53 | [vocab_utils.UNK, vocab_utils.SOS, vocab_utils.EOS] + vocab, new_vocab) 54 | 55 | 56 | if __name__ == "__main__": 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/vocab_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for vocab_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from ..utils import vocab_utils 27 | 28 | 29 | class VocabUtilsTest(tf.test.TestCase): 30 | 31 | def testCheckVocab(self): 32 | # Create a vocab file 33 | vocab_dir = os.path.join(tf.test.get_temp_dir(), "vocab_dir") 34 | os.makedirs(vocab_dir) 35 | vocab_file = os.path.join(vocab_dir, "vocab_file") 36 | vocab = ["a", "b", "c"] 37 | with codecs.getwriter("utf-8")(tf.gfile.GFile(vocab_file, "wb")) as f: 38 | for word in vocab: 39 | f.write("%s\n" % word) 40 | 41 | # Call vocab_utils 42 | out_dir = os.path.join(tf.test.get_temp_dir(), "out_dir") 43 | os.makedirs(out_dir) 44 | vocab_size, new_vocab_file = vocab_utils.check_vocab( 45 | vocab_file, out_dir) 46 | 47 | # Assert: we expect the code to add , , and 48 | # create a new vocab file 49 | self.assertEqual(len(vocab) + 3, vocab_size) 50 | self.assertEqual(os.path.join(out_dir, "vocab_file"), new_vocab_file) 51 | new_vocab, _ = vocab_utils.load_vocab(new_vocab_file) 52 | self.assertEqual( 53 | [vocab_utils.UNK, vocab_utils.SOS, vocab_utils.EOS] + vocab, new_vocab) 54 | 55 | 56 | if __name__ == "__main__": 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pun-GAN: Generative Adversarial Network for Pun Generation 2 | 3 | This repo contains code for the following paper. 4 | 5 | "Pun-GAN: Generative Adversarial Network for Pun Generation". Fuli Luo, Shunyao Li, Pengcheng Yang, Lei Li, Baobao Chang, Zhifang Sui and Xu SUN. EMNLP 2019. 6 | 7 | In this paper, we focus on the task of generating a pun sentence given a pair of word senses. A major challenge for pun generation is the lack of large-scale pun corpus to guide the supervised learning. To remedy this, we propose an adversarial generative network for pun generation (Pun-GAN). It consists of a generator to produce pun sentences, and a discriminator to distinguish between the generated pun sentences and the real sentences with specific word senses. The output of the discriminator is then used as a reward to train the generator via reinforcement learning, encouraging it to produce pun sentences which can support two word senses simultaneously. 8 | 9 | ![model](./image/model.png) 10 | 11 | ## Quick Start 12 | 13 | 1. Pretrain pun generation model, which can be divided into two parts: backward and forward. 14 | 15 | ```bash 16 | # pretrain backward model 17 | cd ./Pun_Generation/code 18 | 19 | python -u nmt.py 20 | --infer_batch_size=64 21 | --out_dir=backward_model_path 22 | --sampling_temperature=0 23 | --pretrain=1 > output_backward.txt 24 | ``` 25 | 26 | ```bash 27 | # pretrain forward model 28 | cd ./Pun_Generation_Forward/code 29 | 30 | python -u nmt.py 31 | --infer_batch_size=64 32 | --out_dir=forward_model_path 33 | --sampling_temperature=0 34 | --pretrain=1 > output_forward.txt 35 | ``` 36 | 37 | 2. Pretain word sense disambiguation(WSD) model. 38 | 39 | ```bash 40 | cd ./WSD/BiLSTM 41 | 42 | python train.py 43 | ``` 44 | 45 | 3. Train Pun-GAN. 46 | 47 | ```bash 48 | sh train.sh 49 | ``` 50 | 51 | 4. Inference. 52 | 53 | ```bash 54 | sh inf.sh 55 | ``` 56 | 57 | ## Data Format 58 | 59 | Sense pairs are required for pun generation. We prepare senses by keys in WordNet and store them in /Pun_Generation/data/samples. 60 | 61 | ``` 62 | rich%3:00:00:: 63 | rich%5:00:00:unwholesome:00 64 | pump%1:06:01:: 65 | pump%2:32:00:: 66 | cleanly%4:02:00:: 67 | cleanly%4:02:02:: 68 | umbrella%1:06:00:: 69 | umbrella%1:04:01:: 70 | revealing%5:00:00:informative:00 71 | reveal%2:39:00:: 72 | partial%5:00:00:inclined:02 73 | partial%5:00:00:incomplete:00 74 | ``` 75 | 76 | ## Dependencies 77 | 78 | ``` 79 | python2.7 80 | tensorflow_gpu==1.4.1 81 | numpy==1.14.2 82 | nltk==3.2.5 83 | ``` 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /Pun_Generation/code/concatenate.py: -------------------------------------------------------------------------------- 1 | import random 2 | PUNGAN_ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 3 | source_sents = [] 4 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation/code/backward_model_path/dealt_first_part_file') as f: 5 | for line in f: 6 | source_sents.append(line.strip()) 7 | forward = [] 8 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation_Forward/code/forward_model_path/second_part_file') as f: 9 | for line in f: 10 | forward.append(line.strip()) 11 | forward_split = [forward[i:i + 32] for i in range(0, len(forward), 32)] 12 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation/code/backward_model_path/concatenate_file','w') as fw: 13 | remain = [] 14 | if len(forward_split)%2 == 0 and len(forward_split[-1]) == 32: 15 | unit = len(forward_split)/2 16 | elif len(forward_split)%2 == 1: 17 | unit = len(forward_split)/2 18 | remain = forward_split[-1] 19 | elif len(forward_split)%2 == 0: 20 | unit = len(forward_split)/2 - 1 21 | remain = forward_split[-2] + forward_split[-1] 22 | for i in range(unit): 23 | for j in range(32): 24 | l1 = forward[i * 64 + j].split() 25 | l1 = source_sents[i * 64 + j * 2].split() + l1 26 | fw.write(' '.join(l1[:-1])+'\n') 27 | l2 = forward[i * 64 + 32 + j].split() 28 | l2 = source_sents[i * 64 + j * 2 + 1].split() + l2 29 | fw.write(' '.join(l2[:-1])+'\n') 30 | if remain: 31 | id = unit * 64 32 | for sent in remain[:len(remain)/2]: 33 | l1 = sent.split() 34 | l1 = source_sents[id].split() + l1 35 | fw.write(' '.join(l1[:-1])+'\n') 36 | l2 = sent.split() 37 | l2 = source_sents[id + 1].split() + l2 38 | fw.write(' '.join(l2[:-1])+'\n') 39 | id += 2 40 | # randomly choose 10 candidate sentences for human evaluation 41 | result = [] 42 | chosen = [] 43 | with open(PUNGAN_ROOT_PATH + '/Pun_Generation/code/backward_model_path/concatenate_file') as f: 44 | for line in f: 45 | result.append(line) 46 | result_split = [result[i:i + 2] for i in range(0, len(result), 2)] 47 | id_list = [] 48 | for i in range(50): 49 | id = random.randint(0, len(result_split)-1) 50 | while id in id_list: 51 | id = random.randint(0, len(result_split)-1) 52 | id_list.append(id) 53 | # print('len(result_split)', len(result_split)) 54 | # print('id', id) 55 | chosen.extend(result_split[id]) 56 | with open(PUNGAN_ROOT_PATH + '/inf_human.txt', 'w') as fw: 57 | for i in chosen: 58 | fw.write(i) 59 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/evaluation_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for evaluation_utils.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | 26 | 27 | class EvaluationUtilsTest(tf.test.TestCase): 28 | 29 | def testEvaluate(self): 30 | output = "nmt/testdata/deen_output" 31 | ref_bpe = "nmt/testdata/deen_ref_bpe" 32 | ref_spm = "nmt/testdata/deen_ref_spm" 33 | 34 | expected_bleu_score = 22.5855084573 35 | expected_rouge_score = 50.8429782599 36 | 37 | bpe_bleu_score = evaluation_utils.evaluate( 38 | ref_bpe, output, "bleu", "bpe") 39 | bpe_rouge_score = evaluation_utils.evaluate( 40 | ref_bpe, output, "rouge", "bpe") 41 | 42 | self.assertAlmostEqual(expected_bleu_score, bpe_bleu_score) 43 | self.assertAlmostEqual(expected_rouge_score, bpe_rouge_score) 44 | 45 | spm_bleu_score = evaluation_utils.evaluate( 46 | ref_spm, output, "bleu", "spm") 47 | spm_rouge_score = evaluation_utils.evaluate( 48 | ref_spm, output, "rouge", "spm") 49 | 50 | self.assertAlmostEqual(expected_rouge_score, spm_rouge_score) 51 | self.assertAlmostEqual(expected_bleu_score, spm_bleu_score) 52 | 53 | def testAccuracy(self): 54 | pred_output = "nmt/testdata/pred_output" 55 | label_ref = "nmt/testdata/label_ref" 56 | 57 | expected_accuracy_score = 60.00 58 | 59 | accuracy_score = evaluation_utils.evaluate( 60 | label_ref, pred_output, "accuracy") 61 | self.assertAlmostEqual(expected_accuracy_score, accuracy_score) 62 | 63 | def testWordAccuracy(self): 64 | pred_output = "nmt/testdata/pred_output" 65 | label_ref = "nmt/testdata/label_ref" 66 | 67 | expected_word_accuracy_score = 60.00 68 | 69 | word_accuracy_score = evaluation_utils.evaluate( 70 | label_ref, pred_output, "word_accuracy") 71 | self.assertAlmostEqual(expected_word_accuracy_score, word_accuracy_score) 72 | 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/evaluation_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for evaluation_utils.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | 26 | 27 | class EvaluationUtilsTest(tf.test.TestCase): 28 | 29 | def testEvaluate(self): 30 | output = "nmt/testdata/deen_output" 31 | ref_bpe = "nmt/testdata/deen_ref_bpe" 32 | ref_spm = "nmt/testdata/deen_ref_spm" 33 | 34 | expected_bleu_score = 22.5855084573 35 | expected_rouge_score = 50.8429782599 36 | 37 | bpe_bleu_score = evaluation_utils.evaluate( 38 | ref_bpe, output, "bleu", "bpe") 39 | bpe_rouge_score = evaluation_utils.evaluate( 40 | ref_bpe, output, "rouge", "bpe") 41 | 42 | self.assertAlmostEqual(expected_bleu_score, bpe_bleu_score) 43 | self.assertAlmostEqual(expected_rouge_score, bpe_rouge_score) 44 | 45 | spm_bleu_score = evaluation_utils.evaluate( 46 | ref_spm, output, "bleu", "spm") 47 | spm_rouge_score = evaluation_utils.evaluate( 48 | ref_spm, output, "rouge", "spm") 49 | 50 | self.assertAlmostEqual(expected_rouge_score, spm_rouge_score) 51 | self.assertAlmostEqual(expected_bleu_score, spm_bleu_score) 52 | 53 | def testAccuracy(self): 54 | pred_output = "nmt/testdata/pred_output" 55 | label_ref = "nmt/testdata/label_ref" 56 | 57 | expected_accuracy_score = 60.00 58 | 59 | accuracy_score = evaluation_utils.evaluate( 60 | label_ref, pred_output, "accuracy") 61 | self.assertAlmostEqual(expected_accuracy_score, accuracy_score) 62 | 63 | def testWordAccuracy(self): 64 | pred_output = "nmt/testdata/pred_output" 65 | label_ref = "nmt/testdata/label_ref" 66 | 67 | expected_word_accuracy_score = 60.00 68 | 69 | word_accuracy_score = evaluation_utils.evaluate( 70 | label_ref, pred_output, "word_accuracy") 71 | self.assertAlmostEqual(expected_word_accuracy_score, word_accuracy_score) 72 | 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/standard_hparams_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """standard hparams utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def create_standard_hparams(): 26 | return tf.contrib.training.HParams( 27 | # Data 28 | src="", 29 | tgt="", 30 | train_prefix="", 31 | dev_prefix="", 32 | test_prefix="", 33 | vocab_prefix="", 34 | embed_prefix="", 35 | out_dir="", 36 | 37 | # Networks 38 | num_units=512, 39 | num_layers=2, 40 | num_encoder_layers=2, 41 | num_decoder_layers=2, 42 | dropout=0.2, 43 | unit_type="lstm", 44 | encoder_type="bi", 45 | residual=False, 46 | time_major=True, 47 | num_embeddings_partitions=0, 48 | 49 | # Attention mechanisms 50 | attention="scaled_luong", 51 | attention_architecture="standard", 52 | output_attention=True, 53 | pass_hidden_state=True, 54 | 55 | # Train 56 | optimizer="sgd", 57 | batch_size=128, 58 | init_op="uniform", 59 | init_weight=0.1, 60 | max_gradient_norm=5.0, 61 | learning_rate=1.0, 62 | warmup_steps=0, 63 | warmup_scheme="t2t", 64 | decay_scheme="luong234", 65 | colocate_gradients_with_ops=True, 66 | num_train_steps=12000, 67 | 68 | # Data constraints 69 | num_buckets=5, 70 | max_train=0, 71 | src_max_len=50, 72 | tgt_max_len=50, 73 | src_max_len_infer=0, 74 | tgt_max_len_infer=0, 75 | 76 | # Data format 77 | sos="", 78 | eos="", 79 | subword_option="", 80 | check_special_token=True, 81 | 82 | # Misc 83 | forget_bias=1.0, 84 | num_gpus=1, 85 | epoch_step=0, # record where we were within an epoch. 86 | steps_per_stats=100, 87 | steps_per_external_eval=0, 88 | share_vocab=False, 89 | metrics=["bleu"], 90 | log_device_placement=False, 91 | random_seed=None, 92 | # only enable beam search during inference when beam_width > 0. 93 | beam_width=0, 94 | length_penalty_weight=0.0, 95 | override_loaded_hparams=True, 96 | num_keep_ckpts=5, 97 | avg_ckpts=False, 98 | 99 | # For inference 100 | inference_indices=None, 101 | infer_batch_size=32, 102 | sampling_temperature=0.0, 103 | num_translations_per_input=1, 104 | ) 105 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/standard_hparams_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """standard hparams utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def create_standard_hparams(): 26 | return tf.contrib.training.HParams( 27 | # Data 28 | src="", 29 | tgt="", 30 | train_prefix="", 31 | dev_prefix="", 32 | test_prefix="", 33 | vocab_prefix="", 34 | embed_prefix="", 35 | out_dir="", 36 | 37 | # Networks 38 | num_units=512, 39 | num_layers=2, 40 | num_encoder_layers=2, 41 | num_decoder_layers=2, 42 | dropout=0.2, 43 | unit_type="lstm", 44 | encoder_type="bi", 45 | residual=False, 46 | time_major=True, 47 | num_embeddings_partitions=0, 48 | 49 | # Attention mechanisms 50 | attention="scaled_luong", 51 | attention_architecture="standard", 52 | output_attention=True, 53 | pass_hidden_state=True, 54 | 55 | # Train 56 | optimizer="sgd", 57 | batch_size=128, 58 | init_op="uniform", 59 | init_weight=0.1, 60 | max_gradient_norm=5.0, 61 | learning_rate=1.0, 62 | warmup_steps=0, 63 | warmup_scheme="t2t", 64 | decay_scheme="luong234", 65 | colocate_gradients_with_ops=True, 66 | num_train_steps=12000, 67 | 68 | # Data constraints 69 | num_buckets=5, 70 | max_train=0, 71 | src_max_len=50, 72 | tgt_max_len=50, 73 | src_max_len_infer=0, 74 | tgt_max_len_infer=0, 75 | 76 | # Data format 77 | sos="", 78 | eos="", 79 | subword_option="", 80 | check_special_token=True, 81 | 82 | # Misc 83 | forget_bias=1.0, 84 | num_gpus=1, 85 | epoch_step=0, # record where we were within an epoch. 86 | steps_per_stats=100, 87 | steps_per_external_eval=0, 88 | share_vocab=False, 89 | metrics=["bleu"], 90 | log_device_placement=False, 91 | random_seed=None, 92 | # only enable beam search during inference when beam_width > 0. 93 | beam_width=0, 94 | length_penalty_weight=0.0, 95 | override_loaded_hparams=True, 96 | num_keep_ckpts=5, 97 | avg_ckpts=False, 98 | 99 | # For inference 100 | inference_indices=None, 101 | infer_batch_size=32, 102 | sampling_temperature=0.0, 103 | num_translations_per_input=1, 104 | ) 105 | -------------------------------------------------------------------------------- /WSD/utils/path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | PUNGAN_ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | 6 | class WSD_path(object): 7 | def __init__(self): 8 | ## Lexical Example task dataset 9 | self.LS_DATASET = ['senseval2_LS', 'senseval3_LS'] 10 | ## ALL-words task dataset 11 | self.ALL_WORDS_TRAIN_DATASET = ['semcor', 'semcor+omsti'] 12 | self.ALL_WORDS_TEST_DATASET = ['ALL', 'senseval2', 'senseval3', 'semeval2007', 'semeval2013', 'semeval2015'] 13 | self.ALL_WORDS_VAL_DATASET = 'semeval2007' 14 | 15 | ## Lexical Example task path 16 | self.LS_BASE_PATH =_LS_BASE_PATH= PUNGAN_ROOT_PATH + '/WSD/data/Lexical_Sample_WSD/' 17 | self.LS_OLD_TRAIN_PATH = _LS_BASE_PATH + '{}/train.xml' 18 | self.LS_TRAIN_PATH = _LS_BASE_PATH + '{}/train.new.xml' 19 | self.LS_VAL_OLD_PATH = _LS_BASE_PATH + '{}/test.haskey.xml' 20 | self.LS_VAL_PATH = _LS_BASE_PATH + '{}/test.haskey.new.xml' 21 | self.LS_TEST_PATH = _LS_BASE_PATH + '{}/test.xml' 22 | self.LS_DIC_PATH = _LS_BASE_PATH + '{}/dictionary.new.xml' 23 | self.LS_OLD_DIC_PATH = _LS_BASE_PATH + '{}/dictionary.xml' 24 | self.LS_TEST_KEY_OLD_PATH = _LS_BASE_PATH + '{}/test.key' 25 | self.LS_TEST_KEY_PATH = _LS_BASE_PATH + '{}/test.new.key' 26 | self.LS_SENSEMAP_PATH = _LS_BASE_PATH + '{}/sensemap.txt' 27 | 28 | ## ALL-words task path 29 | self.ALL_WORDS_BASE_PATH = _ALL_WORDS_BASE_PATH = PUNGAN_ROOT_PATH + '/WSD/data/All_Words_WSD/' 30 | # path for all-words train 31 | self.ALL_WORDS_TRAIN_PATH = _ALL_WORDS_BASE_PATH + 'Training_Corpora/{0}/{0}.data.xml' 32 | self.ALL_WORDS_TRAIN_KEY_PATH = _ALL_WORDS_BASE_PATH + 'Training_Corpora/{0}/{0}.gold.key.txt' 33 | self.ALL_WORDS_DIC_PATH = _ALL_WORDS_BASE_PATH + 'Training_Corpora/{0}/{0}.dict.xml' 34 | # path for all_words test 35 | self.ALL_WORDS_TEST_PATH = _ALL_WORDS_BASE_PATH + 'Evaluation_Datasets/{0}/{0}.data.xml' 36 | self.ALL_WORDS_TEST_KEY_PATH = _ALL_WORDS_BASE_PATH + 'Evaluation_Datasets/{0}/{0}.gold.key.txt' 37 | self.ALL_WORDS_TEST_KEY_WPATH = _ALL_WORDS_BASE_PATH + 'Evaluation_Datasets/{0}/{0}.gold.key.withPos.txt' 38 | # MFS / FS result 39 | self.BASE_OTHER_SYSTEM_PATH = _ALL_WORDS_BASE_PATH + 'Output_Systems_ALL/' 40 | self.MFS_PATH = _ALL_WORDS_BASE_PATH + 'Output_Systems_ALL/MFS_{0}.key' 41 | self.WNFS_PATH = _ALL_WORDS_BASE_PATH + 'Output_Systems_ALL/WNFirstsense.key' 42 | 43 | self.WORDNET_PATH = PUNGAN_ROOT_PATH + '/WSD/data/nltk_data/corpora/wordnet' 44 | self.WORDNET171_PATH = PUNGAN_ROOT_PATH + '/WSD/data/nltk_data/corpora/wordnet1.7.1' 45 | self.WORDNET17_PATH = PUNGAN_ROOT_PATH + '/WSD/data/nltk_data/corpora/wordnet1.7' 46 | 47 | self.LS_OLD_WN_PATH = {self.LS_DATASET[0]: self.WORDNET17_PATH, 48 | self.LS_DATASET[1]: self.WORDNET171_PATH} 49 | self.LS_WN_MAP_PATH = {self.LS_DATASET[0]: PUNGAN_ROOT_PATH + '/WSD/data/nltk_data/mappings-upc-2007/mapping-17-30/wn17-30.', # for senseval2 50 | self.LS_DATASET[1]: PUNGAN_ROOT_PATH + '/WSD/data/nltk_data/mappings-upc-2007/mapping-171-30/wn171-30.'} # for senseval3 51 | 52 | self.GLOVE_DIR = PUNGAN_ROOT_PATH + '/WSD/data/' 53 | 54 | if not os.path.exists(PUNGAN_ROOT_PATH + '/WSD/tmp'): 55 | os.makedirs(PUNGAN_ROOT_PATH + '/WSD/tmp') 56 | self.BACK_OFF_RESULT_PATH = PUNGAN_ROOT_PATH + '/WSD/tmp/back_off_results-{}.txt' 57 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | step=0 3 | n=158 # number of sample files in /Pun_Generation/data/samples 4 | lr=0.01 # learning rate 5 | while(( $step<800 )) # n*5 for 5 epochs 6 | do 7 | sample_num=$[$step%$n] 8 | echo $step 9 | echo '[step 1] backward step one' 10 | python -u ./Pun_Generation/code/nmt.py \ 11 | --infer_batch_size=8 \ 12 | --vocab_prefix=./Pun_Generation/data/1backward/vocab \ 13 | --src=in \ 14 | --tgt=out \ 15 | --out_dir=./Pun_Generation/code/backward_model_path \ 16 | --dev_prefix=./Pun_Generation/data/1backward/dev \ 17 | --test_prefix=./Pun_Generation/data/1backward/test \ 18 | --sample_prefix=./Pun_Generation/data/samples/sample_"$sample_num" \ 19 | --sampling_temperature=1.0 \ 20 | --beam_width=0 \ 21 | --sample_size=32 \ 22 | --first_step=1.0 \ 23 | --learning_rate="$lr" > ./Pun_Generation/code/output1.txt 24 | 25 | echo '[step 2] forward rl' 26 | python -u ./Pun_Generation_Forward/code/nmt.py \ 27 | --infer_batch_size=64 \ 28 | --vocab_prefix=./Pun_Generation_Forward/data/2forward/vocab \ 29 | --src=in \ 30 | --tgt=out \ 31 | --out_dir=./Pun_Generation_Forward/code/forward_model_path \ 32 | --train_prefix=./Pun_Generation_Forward/data/2forward/forward_index \ 33 | --dev_prefix=./Pun_Generation_Forward/data/2forward/dev \ 34 | --test_prefix=./Pun_Generation_Forward/data/2forward/test \ 35 | --sample_prefix=./Pun_Generation/data/1backward/backward_step1.out \ 36 | --reward_prefix=./Pun_Generation_Forward/data/2forward/wsd_train_reward.in \ 37 | --sampling_temperature=-1.0 \ 38 | --beam_width=0 \ 39 | --sample_size=1 \ 40 | --learning_rate="$lr" > ./Pun_Generation_Forward/code/output1_sp.txt 41 | 42 | echo '[step 3] backward step two' 43 | python -u ./Pun_Generation/code/nmt.py \ 44 | --infer_batch_size=64 \ 45 | --vocab_prefix=./Pun_Generation/data/1backward/vocab \ 46 | --src=in \ 47 | --tgt=out \ 48 | --out_dir=./Pun_Generation/code/backward_model_path \ 49 | --train_prefix=./Pun_Generation/data/1backward/backward_step2 \ 50 | --dev_prefix=./Pun_Generation/data/1backward/dev \ 51 | --test_prefix=./Pun_Generation/data/1backward/test \ 52 | --sample_prefix=./Pun_Generation/data/sample_2548 \ 53 | --reward_prefix=./Pun_Generation_Forward/data/2forward/wsd_train_reward.in \ 54 | --beam_width=10 \ 55 | --sample_size=1 \ 56 | --learning_rate="$lr" > ./Pun_Generation/code/output1.txt 57 | 58 | python -u ./Pun_Generation/code/nmt.py \ 59 | --infer_batch_size=64 \ 60 | --vocab_prefix=./Pun_Generation/data/1backward/vocab \ 61 | --src=in \ 62 | --tgt=out \ 63 | --out_dir=./Pun_Generation/code/backward_model_path \ 64 | --train_prefix=./Pun_Generation/data/1backward/train \ 65 | --dev_prefix=./Pun_Generation/data/1backward/dev \ 66 | --test_prefix=./Pun_Generation/data/1backward/test \ 67 | --inference_input_file=./Pun_Generation/data/sample_2548 \ 68 | --inference_output_file=./Pun_Generation/code/backward_model_path/first_part_file \ 69 | --beam_width=10 \ 70 | --num_translations_per_input=1 > ./Pun_Generation/code/output_infer.txt 71 | 72 | python ./Pun_Generation/code/dealt.py 2548 73 | 74 | python -u ./Pun_Generation_Forward/code/nmt.py \ 75 | --infer_batch_size=64 \ 76 | --vocab_prefix=./Pun_Generation_Forward/data/2forward/vocab \ 77 | --src=in \ 78 | --tgt=out \ 79 | --out_dir=./Pun_Generation_Forward/code/forward_model_path \ 80 | --train_prefix=./Pun_Generation_Forward/data/2forward/train \ 81 | --dev_prefix=./Pun_Generation_Forward/data/2forward/dev \ 82 | --test_prefix=./Pun_Generation_Forward/data/2forward/test \ 83 | --inference_input_file=./Pun_Generation/code/backward_model_path/dealt_first_part_file \ 84 | --inference_output_file=./Pun_Generation_Forward/code/forward_model_path/second_part_file \ 85 | --beam_width=10 \ 86 | --num_translations_per_input=1 > ./Pun_Generation_Forward/code/output_infer.txt 87 | 88 | python ./Pun_Generation/code/concatenate.py 89 | 90 | let "step++" 91 | done 92 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility to handle vocabularies.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import lookup_ops 27 | 28 | from utils import misc_utils as utils 29 | 30 | 31 | UNK = "" 32 | SOS = "" 33 | EOS = "" 34 | UNK_ID = 0 35 | 36 | 37 | def load_vocab(vocab_file): 38 | vocab = [] 39 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: 40 | vocab_size = 0 41 | for word in f: 42 | vocab_size += 1 43 | vocab.append(word.strip()) 44 | return vocab, vocab_size 45 | 46 | 47 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, 48 | eos=None, unk=None): 49 | """Check if vocab_file doesn't exist, create from corpus_file.""" 50 | if tf.gfile.Exists(vocab_file): 51 | utils.print_out("# Vocab file %s exists" % vocab_file) 52 | vocab, vocab_size = load_vocab(vocab_file) 53 | if check_special_token: 54 | # Verify if the vocab starts with unk, sos, eos 55 | # If not, prepend those tokens & generate a new vocab file 56 | if not unk: unk = UNK 57 | if not sos: sos = SOS 58 | if not eos: eos = EOS 59 | assert len(vocab) >= 3 60 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: 61 | utils.print_out("The first 3 vocab words [%s, %s, %s]" 62 | " are not [%s, %s, %s]" % 63 | (vocab[0], vocab[1], vocab[2], unk, sos, eos)) 64 | vocab = [unk, sos, eos] + vocab 65 | vocab_size += 3 66 | new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) 67 | with codecs.getwriter("utf-8")( 68 | tf.gfile.GFile(new_vocab_file, "wb")) as f: 69 | for word in vocab: 70 | f.write("%s\n" % word) 71 | vocab_file = new_vocab_file 72 | else: 73 | raise ValueError("vocab_file '%s' does not exist." % vocab_file) 74 | 75 | vocab_size = len(vocab) 76 | return vocab_size, vocab_file 77 | 78 | 79 | def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): 80 | """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" 81 | src_vocab_table = lookup_ops.index_table_from_file( 82 | src_vocab_file, default_value=UNK_ID) 83 | if share_vocab: 84 | tgt_vocab_table = src_vocab_table 85 | else: 86 | tgt_vocab_table = lookup_ops.index_table_from_file( 87 | tgt_vocab_file, default_value=UNK_ID) 88 | return src_vocab_table, tgt_vocab_table 89 | 90 | 91 | def load_embed_txt(embed_file): 92 | """Load embed_file into a python dictionary. 93 | 94 | Note: the embed_file should be a Glove formated txt file. Assuming 95 | embed_size=5, for example: 96 | 97 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 98 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 99 | and 0.20327 0.47348 0.050877 0.002103 0.060547 100 | 101 | Args: 102 | embed_file: file path to the embedding file. 103 | Returns: 104 | a dictionary that maps word to vector, and the size of embedding dimensions. 105 | """ 106 | emb_dict = dict() 107 | emb_size = None 108 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f: 109 | for line in f: 110 | tokens = line.strip().split(" ") 111 | word = tokens[0] 112 | vec = list(map(float, tokens[1:])) 113 | emb_dict[word] = vec 114 | if emb_size: 115 | assert emb_size == len(vec), "All embedding size should be same." 116 | else: 117 | emb_size = len(vec) 118 | return emb_dict, emb_size 119 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility to handle vocabularies.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import lookup_ops 27 | 28 | from utils import misc_utils as utils 29 | 30 | 31 | UNK = "" 32 | SOS = "" 33 | EOS = "" 34 | UNK_ID = 0 35 | 36 | 37 | def load_vocab(vocab_file): 38 | vocab = [] 39 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: 40 | vocab_size = 0 41 | for word in f: 42 | vocab_size += 1 43 | vocab.append(word.strip()) 44 | return vocab, vocab_size 45 | 46 | 47 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, 48 | eos=None, unk=None): 49 | """Check if vocab_file doesn't exist, create from corpus_file.""" 50 | if tf.gfile.Exists(vocab_file): 51 | utils.print_out("# Vocab file %s exists" % vocab_file) 52 | vocab, vocab_size = load_vocab(vocab_file) 53 | if check_special_token: 54 | # Verify if the vocab starts with unk, sos, eos 55 | # If not, prepend those tokens & generate a new vocab file 56 | if not unk: unk = UNK 57 | if not sos: sos = SOS 58 | if not eos: eos = EOS 59 | assert len(vocab) >= 3 60 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: 61 | utils.print_out("The first 3 vocab words [%s, %s, %s]" 62 | " are not [%s, %s, %s]" % 63 | (vocab[0], vocab[1], vocab[2], unk, sos, eos)) 64 | vocab = [unk, sos, eos] + vocab 65 | vocab_size += 3 66 | new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) 67 | with codecs.getwriter("utf-8")( 68 | tf.gfile.GFile(new_vocab_file, "wb")) as f: 69 | for word in vocab: 70 | f.write("%s\n" % word) 71 | vocab_file = new_vocab_file 72 | else: 73 | raise ValueError("vocab_file '%s' does not exist." % vocab_file) 74 | 75 | vocab_size = len(vocab) 76 | return vocab_size, vocab_file 77 | 78 | 79 | def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab): 80 | """Creates vocab tables for src_vocab_file and tgt_vocab_file.""" 81 | src_vocab_table = lookup_ops.index_table_from_file( 82 | src_vocab_file, default_value=UNK_ID) 83 | if share_vocab: 84 | tgt_vocab_table = src_vocab_table 85 | else: 86 | tgt_vocab_table = lookup_ops.index_table_from_file( 87 | tgt_vocab_file, default_value=UNK_ID) 88 | return src_vocab_table, tgt_vocab_table 89 | 90 | 91 | def load_embed_txt(embed_file): 92 | """Load embed_file into a python dictionary. 93 | 94 | Note: the embed_file should be a Glove formated txt file. Assuming 95 | embed_size=5, for example: 96 | 97 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 98 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 99 | and 0.20327 0.47348 0.050877 0.002103 0.060547 100 | 101 | Args: 102 | embed_file: file path to the embedding file. 103 | Returns: 104 | a dictionary that maps word to vector, and the size of embedding dimensions. 105 | """ 106 | emb_dict = dict() 107 | emb_size = None 108 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f: 109 | for line in f: 110 | tokens = line.strip().split(" ") 111 | word = tokens[0] 112 | vec = list(map(float, tokens[1:])) 113 | emb_dict[word] = vec 114 | if emb_size: 115 | assert emb_size == len(vec), "All embedding size should be same." 116 | else: 117 | emb_size = len(vec) 118 | return emb_dict, emb_size 119 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/scripts/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/scripts/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/common_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Common utility functions for tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | from ..utils import standard_hparams_utils 28 | 29 | 30 | def create_test_hparams(unit_type="lstm", 31 | encoder_type="uni", 32 | num_layers=4, 33 | attention="", 34 | attention_architecture=None, 35 | use_residual=False, 36 | inference_indices=None, 37 | num_translations_per_input=1, 38 | beam_width=0, 39 | init_op="uniform"): 40 | """Create training and inference test hparams.""" 41 | num_residual_layers = 0 42 | if use_residual: 43 | # TODO(rzhao): Put num_residual_layers computation logic into 44 | # `model_utils.py`, so we can also test it here. 45 | num_residual_layers = 2 46 | 47 | standard_hparams = standard_hparams_utils.create_standard_hparams() 48 | 49 | # Networks 50 | standard_hparams.num_units = 5 51 | standard_hparams.num_encoder_layers = num_layers 52 | standard_hparams.num_decoder_layers = num_layers 53 | standard_hparams.dropout = 0.5 54 | standard_hparams.unit_type = unit_type 55 | standard_hparams.encoder_type = encoder_type 56 | standard_hparams.residual = use_residual 57 | standard_hparams.num_residual_layers = num_residual_layers 58 | 59 | # Attention mechanisms 60 | standard_hparams.attention = attention 61 | standard_hparams.attention_architecture = attention_architecture 62 | 63 | # Train 64 | standard_hparams.init_op = init_op 65 | standard_hparams.num_train_steps = 1 66 | standard_hparams.decay_scheme = "" 67 | 68 | # Infer 69 | standard_hparams.tgt_max_len_infer = 100 70 | standard_hparams.beam_width = beam_width 71 | standard_hparams.num_translations_per_input = num_translations_per_input 72 | 73 | # Misc 74 | standard_hparams.forget_bias = 0.0 75 | standard_hparams.random_seed = 3 76 | 77 | # Vocab 78 | standard_hparams.src_vocab_size = 5 79 | standard_hparams.tgt_vocab_size = 5 80 | standard_hparams.eos = "eos" 81 | standard_hparams.sos = "sos" 82 | standard_hparams.src_vocab_file = "" 83 | standard_hparams.tgt_vocab_file = "" 84 | standard_hparams.src_embed_file = "" 85 | standard_hparams.tgt_embed_file = "" 86 | 87 | # For inference.py test 88 | standard_hparams.subword_option = "bpe" 89 | standard_hparams.src = "src" 90 | standard_hparams.tgt = "tgt" 91 | standard_hparams.src_max_len = 400 92 | standard_hparams.tgt_eos_id = 0 93 | standard_hparams.inference_indices = inference_indices 94 | return standard_hparams 95 | 96 | 97 | def create_test_iterator(hparams, mode): 98 | """Create test iterator.""" 99 | src_vocab_table = lookup_ops.index_table_from_tensor( 100 | tf.constant([hparams.eos, "a", "b", "c", "d"])) 101 | tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) 102 | tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) 103 | if mode == tf.contrib.learn.ModeKeys.INFER: 104 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( 105 | tgt_vocab_mapping) 106 | 107 | src_dataset = tf.data.Dataset.from_tensor_slices( 108 | tf.constant(["a a b b c", "a b b"])) 109 | 110 | if mode != tf.contrib.learn.ModeKeys.INFER: 111 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 112 | tf.constant(["a b c b c", "a b c b"])) 113 | return ( 114 | iterator_utils.get_iterator( 115 | src_dataset=src_dataset, 116 | tgt_dataset=tgt_dataset, 117 | src_vocab_table=src_vocab_table, 118 | tgt_vocab_table=tgt_vocab_table, 119 | batch_size=hparams.batch_size, 120 | sos=hparams.sos, 121 | eos=hparams.eos, 122 | random_seed=hparams.random_seed, 123 | num_buckets=hparams.num_buckets), 124 | src_vocab_table, 125 | tgt_vocab_table) 126 | else: 127 | return ( 128 | iterator_utils.get_infer_iterator( 129 | src_dataset=src_dataset, 130 | src_vocab_table=src_vocab_table, 131 | eos=hparams.eos, 132 | batch_size=hparams.batch_size), 133 | src_vocab_table, 134 | tgt_vocab_table, 135 | reverse_tgt_vocab_table) 136 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/common_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Common utility functions for tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | from ..utils import standard_hparams_utils 28 | 29 | 30 | def create_test_hparams(unit_type="lstm", 31 | encoder_type="uni", 32 | num_layers=4, 33 | attention="", 34 | attention_architecture=None, 35 | use_residual=False, 36 | inference_indices=None, 37 | num_translations_per_input=1, 38 | beam_width=0, 39 | init_op="uniform"): 40 | """Create training and inference test hparams.""" 41 | num_residual_layers = 0 42 | if use_residual: 43 | # TODO(rzhao): Put num_residual_layers computation logic into 44 | # `model_utils.py`, so we can also test it here. 45 | num_residual_layers = 2 46 | 47 | standard_hparams = standard_hparams_utils.create_standard_hparams() 48 | 49 | # Networks 50 | standard_hparams.num_units = 5 51 | standard_hparams.num_encoder_layers = num_layers 52 | standard_hparams.num_decoder_layers = num_layers 53 | standard_hparams.dropout = 0.5 54 | standard_hparams.unit_type = unit_type 55 | standard_hparams.encoder_type = encoder_type 56 | standard_hparams.residual = use_residual 57 | standard_hparams.num_residual_layers = num_residual_layers 58 | 59 | # Attention mechanisms 60 | standard_hparams.attention = attention 61 | standard_hparams.attention_architecture = attention_architecture 62 | 63 | # Train 64 | standard_hparams.init_op = init_op 65 | standard_hparams.num_train_steps = 1 66 | standard_hparams.decay_scheme = "" 67 | 68 | # Infer 69 | standard_hparams.tgt_max_len_infer = 100 70 | standard_hparams.beam_width = beam_width 71 | standard_hparams.num_translations_per_input = num_translations_per_input 72 | 73 | # Misc 74 | standard_hparams.forget_bias = 0.0 75 | standard_hparams.random_seed = 3 76 | 77 | # Vocab 78 | standard_hparams.src_vocab_size = 5 79 | standard_hparams.tgt_vocab_size = 5 80 | standard_hparams.eos = "eos" 81 | standard_hparams.sos = "sos" 82 | standard_hparams.src_vocab_file = "" 83 | standard_hparams.tgt_vocab_file = "" 84 | standard_hparams.src_embed_file = "" 85 | standard_hparams.tgt_embed_file = "" 86 | 87 | # For inference.py test 88 | standard_hparams.subword_option = "bpe" 89 | standard_hparams.src = "src" 90 | standard_hparams.tgt = "tgt" 91 | standard_hparams.src_max_len = 400 92 | standard_hparams.tgt_eos_id = 0 93 | standard_hparams.inference_indices = inference_indices 94 | return standard_hparams 95 | 96 | 97 | def create_test_iterator(hparams, mode): 98 | """Create test iterator.""" 99 | src_vocab_table = lookup_ops.index_table_from_tensor( 100 | tf.constant([hparams.eos, "a", "b", "c", "d"])) 101 | tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) 102 | tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) 103 | if mode == tf.contrib.learn.ModeKeys.INFER: 104 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( 105 | tgt_vocab_mapping) 106 | 107 | src_dataset = tf.data.Dataset.from_tensor_slices( 108 | tf.constant(["a a b b c", "a b b"])) 109 | 110 | if mode != tf.contrib.learn.ModeKeys.INFER: 111 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 112 | tf.constant(["a b c b c", "a b c b"])) 113 | return ( 114 | iterator_utils.get_iterator( 115 | src_dataset=src_dataset, 116 | tgt_dataset=tgt_dataset, 117 | src_vocab_table=src_vocab_table, 118 | tgt_vocab_table=tgt_vocab_table, 119 | batch_size=hparams.batch_size, 120 | sos=hparams.sos, 121 | eos=hparams.eos, 122 | random_seed=hparams.random_seed, 123 | num_buckets=hparams.num_buckets), 124 | src_vocab_table, 125 | tgt_vocab_table) 126 | else: 127 | return ( 128 | iterator_utils.get_infer_iterator( 129 | src_dataset=src_dataset, 130 | src_vocab_table=src_vocab_table, 131 | eos=hparams.eos, 132 | batch_size=hparams.batch_size), 133 | src_vocab_table, 134 | tgt_vocab_table, 135 | reverse_tgt_vocab_table) 136 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/nmt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions specifically for NMT.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from utils import evaluation_utils 25 | from utils import misc_utils as utils 26 | 27 | __all__ = ["decode_and_evaluate", "get_translation"] 28 | 29 | 30 | def decode_and_evaluate(name, 31 | model, 32 | sess, 33 | trans_file, 34 | ref_file, 35 | metrics, 36 | subword_option, 37 | beam_width, 38 | tgt_eos, 39 | num_translations_per_input=1, 40 | decode=True): 41 | """Decode a test set and compute a score according to the evaluation task.""" 42 | # Decode 43 | if decode: 44 | utils.print_out(" decoding to output %s." % trans_file) 45 | 46 | start_time = time.time() 47 | num_sentences = 0 48 | with codecs.getwriter("utf-8")( 49 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f: 50 | trans_f.write("") # Write empty string to ensure file is created. 51 | 52 | #num_translations_per_input = max(min(10, beam_width), 1) 53 | print ("num_translations_per_input", num_translations_per_input) 54 | while True: 55 | try: 56 | infer_logits, nmt_outputs, _ = model.decode(sess) 57 | #print (infer_logits) 58 | if beam_width == 0: 59 | nmt_outputs = np.expand_dims(nmt_outputs, 0) 60 | 61 | batch_size = nmt_outputs.shape[1] 62 | num_sentences += batch_size 63 | #print ("infer_logits",infer_logits) 64 | #print ("nmt_outputs",nmt_outputs) 65 | for sent_id in range(batch_size): 66 | for beam_id in range(num_translations_per_input): 67 | #print ("nmt_outputs[beam_id]",nmt_outputs[beam_id]) 68 | #print ("infer_logits[beam_id]",infer_logits[beam_id]) 69 | translation = get_translation( 70 | nmt_outputs[beam_id], 71 | infer_logits[beam_id], 72 | sent_id, 73 | tgt_eos=tgt_eos, 74 | subword_option=subword_option) 75 | trans_f.write((translation + b"\n").decode("utf-8")) 76 | except tf.errors.OutOfRangeError: 77 | utils.print_time( 78 | " done, num sentences %d, num translations per input %d" % 79 | (num_sentences, num_translations_per_input), start_time) 80 | break 81 | 82 | # Evaluation 83 | evaluation_scores = {} 84 | if ref_file and tf.gfile.Exists(trans_file): 85 | for metric in metrics: 86 | score = evaluation_utils.evaluate( 87 | ref_file, 88 | trans_file, 89 | metric, 90 | subword_option=subword_option) 91 | evaluation_scores[metric] = score 92 | utils.print_out(" %s %s: %.1f" % (metric, name, score)) 93 | 94 | return evaluation_scores 95 | 96 | 97 | def get_translation(nmt_outputs,infer_logits, sent_id, tgt_eos, subword_option): 98 | """Given batch decoding outputs, select a sentence and turn to text.""" 99 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 100 | # Select a sentence 101 | output = nmt_outputs[sent_id, :].tolist() 102 | scores = infer_logits[sent_id] 103 | #fw=open('sample_res/scores_logits_{}'.format(sent_id),'w+') 104 | #for i in scores: 105 | #fw.write('\n'.join([' '.join(str(a)for a in e)for e in scores])) 106 | #fw.close() 107 | #print ("output",output) 108 | #print ("scores",scores) 109 | # If there is an eos symbol in outputs, cut them at that point. 110 | if tgt_eos and tgt_eos in output: 111 | output = output[:output.index(tgt_eos)] 112 | 113 | if subword_option == "bpe": # BPE 114 | #print ("subword_option ==bpe") 115 | translation = utils.format_bpe_text(output) 116 | elif subword_option == "spm": # SPM 117 | #print ("subword_option ==spm") 118 | translation = utils.format_spm_text(output) 119 | else: 120 | #print ("scores in format_text!") 121 | translation = utils.format_text(output,scores) 122 | 123 | return translation 124 | 125 | def get_translation_train(nmt_outputs, sent_id, tgt_eos, subword_option): 126 | """Given batch decoding outputs, select a sentence and turn to text.""" 127 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 128 | output = nmt_outputs[sent_id, :].tolist() 129 | if tgt_eos and tgt_eos in output: 130 | output = output[:output.index(tgt_eos)] 131 | 132 | return ' '.join(output) 133 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/nmt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions specifically for NMT.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from utils import evaluation_utils 25 | from utils import misc_utils as utils 26 | 27 | __all__ = ["decode_and_evaluate", "get_translation"] 28 | 29 | 30 | def decode_and_evaluate(name, 31 | model, 32 | sess, 33 | trans_file, 34 | ref_file, 35 | metrics, 36 | subword_option, 37 | beam_width, 38 | tgt_eos, 39 | num_translations_per_input=1, 40 | decode=True): 41 | """Decode a test set and compute a score according to the evaluation task.""" 42 | # Decode 43 | if decode: 44 | utils.print_out(" decoding to output %s." % trans_file) 45 | 46 | start_time = time.time() 47 | num_sentences = 0 48 | with codecs.getwriter("utf-8")( 49 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f: 50 | trans_f.write("") # Write empty string to ensure file is created. 51 | 52 | #num_translations_per_input = max(min(10, beam_width), 1) 53 | print ("num_translations_per_input",num_translations_per_input) 54 | while True: 55 | try: 56 | infer_logits, nmt_outputs, _ = model.decode(sess) 57 | #print (infer_logits) 58 | if beam_width == 0: 59 | nmt_outputs = np.expand_dims(nmt_outputs, 0) 60 | 61 | batch_size = nmt_outputs.shape[1] 62 | num_sentences += batch_size 63 | #print ("infer_logits",infer_logits) 64 | #print ("nmt_outputs",nmt_outputs) 65 | for sent_id in range(batch_size): 66 | for beam_id in range(num_translations_per_input): 67 | #print ("nmt_outputs[beam_id]",nmt_outputs[beam_id]) 68 | #print ("infer_logits[beam_id]",infer_logits[beam_id]) 69 | translation = get_translation( 70 | nmt_outputs[beam_id], 71 | infer_logits[beam_id], 72 | sent_id, 73 | tgt_eos=tgt_eos, 74 | subword_option=subword_option) 75 | trans_f.write((translation + b"\n").decode("utf-8")) 76 | except tf.errors.OutOfRangeError: 77 | utils.print_time( 78 | " done, num sentences %d, num translations per input %d" % 79 | (num_sentences, num_translations_per_input), start_time) 80 | break 81 | 82 | # Evaluation 83 | evaluation_scores = {} 84 | if ref_file and tf.gfile.Exists(trans_file): 85 | for metric in metrics: 86 | score = evaluation_utils.evaluate( 87 | ref_file, 88 | trans_file, 89 | metric, 90 | subword_option=subword_option) 91 | evaluation_scores[metric] = score 92 | utils.print_out(" %s %s: %.1f" % (metric, name, score)) 93 | 94 | return evaluation_scores 95 | 96 | 97 | def get_translation(nmt_outputs,infer_logits, sent_id, tgt_eos, subword_option): 98 | """Given batch decoding outputs, select a sentence and turn to text.""" 99 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 100 | # Select a sentence 101 | output = nmt_outputs[sent_id, :].tolist() 102 | scores = infer_logits[sent_id] 103 | #fw=open('sample_res/scores_logits_{}'.format(sent_id),'w+') 104 | #for i in scores: 105 | #fw.write('\n'.join([' '.join(str(a)for a in e)for e in scores])) 106 | #fw.close() 107 | #print ("output",output) 108 | #print ("scores",scores) 109 | # If there is an eos symbol in outputs, cut them at that point. 110 | if tgt_eos and tgt_eos in output: 111 | output = output[:output.index(tgt_eos)] 112 | 113 | if subword_option == "bpe": # BPE 114 | #print ("subword_option ==bpe") 115 | translation = utils.format_bpe_text(output) 116 | elif subword_option == "spm": # SPM 117 | #print ("subword_option ==spm") 118 | translation = utils.format_spm_text(output) 119 | else: 120 | #print ("scores in format_text!") 121 | translation = utils.format_text(output,scores) 122 | 123 | return translation 124 | 125 | def get_translation_train(nmt_outputs, sent_id, tgt_eos, subword_option): 126 | """Given batch decoding outputs, select a sentence and turn to text.""" 127 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 128 | output = nmt_outputs[sent_id, :].tolist() 129 | if tgt_eos and tgt_eos in output: 130 | output = output[:output.index(tgt_eos)] 131 | 132 | return ' '.join(output) 133 | 134 | -------------------------------------------------------------------------------- /WSD/BiLSTM/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding= utf-8 -*- 2 | """ 3 | @version= python2.7 4 | @author= luofuli 5 | @time= 2018/3/30 11=03 6 | """ 7 | import sys 8 | import pickle 9 | import random 10 | 11 | 12 | class BiLSTM_Config(object): 13 | """配置参数""" 14 | 15 | def __init__(self): 16 | self.name = 'CAN' 17 | 18 | # config for data 19 | self.g_len = 40 # max gloss words 20 | self.c_len_f = 10 # forward context words 21 | self.c_len_b = 10 # backward context words 22 | self.expand_type = 0 # 0=无 1=上位hyper,2=下位hypo 3=上+下 4:层次 23 | self.gloss_is_empty = True # ** For comparision(GAS_baseline) 24 | self.back_off_type = 'MFS' # MFS 25 | 26 | # config for word embeddings 27 | self.vocab_name = '6B' # 6B(100200300) 42B(300) 840B(300) 28 | self.vocab_size = 100000 # 后面会更具具体情况修改,方便索引的 29 | self.embedding_size = 300 30 | self.use_pre_trained_emb = False # ** Debug model 31 | self.train_word_emb = False 32 | self.has_pos = False 33 | self.pos_dim = 50 34 | 35 | # config for model 36 | self.batch_size = 32 # 32 is too big for all_words 37 | self.hidden_size = 128 38 | self.forget_bias = 0.0 39 | self.keep_prob = 0.5 40 | 41 | # config for lstm cell 42 | self.dropout = True # 总开关 43 | self.state_dropout = True 44 | self.rnn_type = 'LSTM' # GRU 45 | 46 | # config for train 47 | self.n_epochs = 30 48 | self.lr_start = 0.001 # original lr = 0.001 49 | self.decay_lr = False 50 | self.optimizer_type = 'Adam' # Adagrad 51 | self.momentum = 0.1 52 | self.clip_gradients = True # 是否裁剪梯度 !! 比较影响时间 53 | self.max_grad_norm = 10 # 裁剪梯度的最大梯度 54 | self.warm_start = True 55 | 56 | # config for print logs 57 | self.print_batch = True 58 | self.evaluate_gap = 100 # ** 每多少轮验证一次val的结果是否有提升,并打印输出 59 | self.store_log_gap = 10 # ** 60 | self.save_best_model = True 61 | self.store_run_time = False # ** 记录每个节点的运行时间(找出最慢的节点) 62 | self.run_time_epoch = 1 # ** 每隔多少epoch记录一次时间 63 | self.show_true_result = True # 是否展示真正的结果(两种计算方法,1.打分函数 2.模型正确率+back_off_result正确率 64 | 65 | # Validation info 66 | self.sota_score = 0.706 67 | self.validate = True 68 | self.min_no_improvement = 5000 # 连续多少个epoch没提高就停止,而不是多少步step 69 | 70 | # with open('../tmp/pos_dic.pkl', 'rb') as f: 71 | # self.pos_to_id = pickle.load(f) 72 | # self.pos = self.pos_to_id.keys() 73 | 74 | self.changed_config = {} 75 | 76 | def store_change(self, param, name): 77 | self.changed_config[name] = param 78 | return param 79 | 80 | def random_config(self): 81 | 82 | self.hidden_size = self.store_change(random.choice([128, 256, 512]), 'hidden_size') 83 | self.lr_start = self.store_change(random.choice([0.01, 0.05, 0.001, 0.0001]), 'lr_start') 84 | self.rnn_type = self.store_change(random.choice(['LSTM', 'GRU']), 'rnn_type') 85 | self.optimizer_type = self.store_change(random.choice(['Adadelta', 'Adagrad', 'Adam']), 'optimizer_type') 86 | # self.use_pre_trained_emb = self.store_change(random.choice([False, True]), 'use_pre_trained_emb') 87 | 88 | def grid_search(self): 89 | for self.hidden_size in [256, 128, 512]: 90 | self.store_change(self.hidden_size, 'hidden_size') 91 | for self.lr_start in [0.001, 0.05]: # 1.0差得惨不忍睹 92 | self.store_change(self.lr_start, 'lr_start') 93 | for self.affinity_method in ['general', 'dot_sum']: # 目前看来两者差不多+1 94 | self.store_change(self.affinity_method, 'affinity_method') 95 | for self.update_context_encoding in [True, False]: # 目前来看True更有用,但是5.2发现false更好 96 | self.store_change(self.update_context_encoding, 'update_context_encoding') 97 | yield self 98 | 99 | def get_grid_search_i(self, run_i): 100 | total_i = 0 101 | for conf in self.grid_search(): 102 | total_i += 1 103 | 104 | for i, conf_ in enumerate(self.grid_search()): 105 | if i == (run_i % total_i): 106 | return conf_ 107 | 108 | def degug_model(self): # Debug模型的参数,跑通一轮就结束 109 | self.name = 'CAN_debug' 110 | self.use_pre_trained_emb = False 111 | self.evaluate_gap = 1 # 每多少轮验证一次val的结果是否有提升,并打印输出 112 | self.store_log_gap = 1 # 113 | self.store_run_time = False # 记录每个节点的运行时间(找出最慢的节点) 114 | self.run_time_epoch = 1 # 每隔多少epoch记录一次时间 115 | self.min_no_improvement = 1 # 连续多少个epoch没提高就停止,而不是多少步step 116 | 117 | def detailed_show(self): 118 | self.name = 'CAN_detail' 119 | self.print_batch = True 120 | self.evaluate_gap = 1 # ** 每多少轮验证一次val的结果是否有提升,并打印输出 121 | self.store_log_gap = 1 # ** 122 | self.save_best_model = True 123 | self.store_run_time = True # ** 记录每个节点的运行时间(找出最慢的节点) 124 | self.run_time_epoch = 1 # ** 每隔多少epoch记录一次时间 125 | self.show_true_result = True # 是否展示真正的结果(两种计算方法,1.打分函数 2.模型正确率+back_off_result正确率 126 | 127 | 128 | if __name__ == "__main__": 129 | config = BiLSTM_Config() 130 | config.random_config() 131 | print config.hidden_size 132 | print config.changed_config 133 | 134 | print vars(config) 135 | 136 | i = 0 137 | for conf in config.grid_search(): 138 | i += 1 139 | print i 140 | -------------------------------------------------------------------------------- /WSD/utils/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of kernel-methods-related loss operations.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | from tensorflow.python.ops import array_ops 24 | from tensorflow.python.ops import check_ops 25 | from tensorflow.python.ops import math_ops 26 | from tensorflow.python.ops import nn_ops 27 | from tensorflow.python.ops.losses import losses 28 | 29 | 30 | def sparse_multiclass_hinge_loss( 31 | labels, 32 | logits, 33 | weights=1.0, 34 | scope=None, 35 | loss_collection=ops.GraphKeys.LOSSES, 36 | reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): 37 | """Adds Ops for computing the multiclass hinge loss. 38 | 39 | The implementation is based on the following paper: 40 | On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines 41 | by Crammer and Singer. 42 | link: http://jmlr.csail.mit.edu/papers/volume2/crammer01a/crammer01a.pdf 43 | 44 | This is a generalization of standard (binary) hinge loss. For a given instance 45 | with correct label c*, the loss is given by: 46 | loss = max_{c != c*} logits_c - logits_{c*} + 1. 47 | or equivalently 48 | loss = max_c { logits_c - logits_{c*} + I_{c != c*} } 49 | where I_{c != c*} = 1 if c != c* and 0 otherwise. 50 | 51 | Args: 52 | labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to 53 | the ground truth. Each entry must be an index in `[0, num_classes)`. 54 | logits: `Tensor` of shape [batch_size, num_classes] corresponding to the 55 | unscaled logits. Its dtype should be either `float32` or `float64`. 56 | weights: Optional (python) scalar or `Tensor`. If a non-scalar `Tensor`, its 57 | rank should be either 1 ([batch_size]) or 2 ([batch_size, 1]). 58 | scope: The scope for the operations performed in computing the loss. 59 | loss_collection: collection to which the loss will be added. 60 | reduction: Type of reduction to apply to loss. 61 | 62 | Returns: 63 | Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same 64 | shape as `labels`; otherwise, it is a scalar. 65 | 66 | Raises: 67 | ValueError: If `logits`, `labels` or `weights` have invalid or inconsistent 68 | shapes. 69 | ValueError: If `labels` tensor has invalid dtype. 70 | """ 71 | 72 | with ops.name_scope(scope, 'sparse_multiclass_hinge_loss', (logits, 73 | labels)) as scope: 74 | 75 | # Check logits Tensor has valid rank. 76 | logits_shape = logits.get_shape() 77 | logits_rank = logits_shape.ndims 78 | if logits_rank != 2: 79 | raise ValueError( 80 | 'logits should have rank 2 ([batch_size, num_classes]). Given rank is' 81 | ' {}'.format(logits_rank)) 82 | batch_size = logits_shape[0].value 83 | num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1] 84 | print(batch_size, num_classes) 85 | 86 | logits = math_ops.to_float(logits) 87 | 88 | # Check labels have valid type. 89 | if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64: 90 | raise ValueError( 91 | 'Invalid dtype for labels: {}. Acceptable dtypes: int32 and int64'. 92 | format(labels.dtype)) 93 | 94 | # Check labels and weights have valid ranks and are consistent. 95 | labels_rank = labels.get_shape().ndims 96 | if labels_rank not in [1, 2]: 97 | raise ValueError( 98 | 'labels should have rank 1 ([batch_size]) or 2 ([batch_size, 1]). ' 99 | 'Given rank is {}'.format(labels_rank)) 100 | # with ops.control_dependencies([ 101 | # check_ops.assert_less(labels, math_ops.cast(num_classes, labels.dtype)) 102 | # ]): 103 | labels = array_ops.reshape(labels, shape=[-1]) 104 | 105 | weights = ops.convert_to_tensor(weights) 106 | weights_rank = weights.get_shape().ndims 107 | if weights_rank not in [0, 1, 2]: 108 | raise ValueError( 109 | 'non-scalar weights should have rank 1 ([batch_size]) or 2 ' 110 | '([batch_size, 1]). Given rank is {}'.format(labels_rank)) 111 | 112 | if weights_rank > 0: 113 | weights = array_ops.reshape(weights, shape=[-1]) 114 | # Check weights and labels have the same number of elements. 115 | weights.get_shape().assert_is_compatible_with(labels.get_shape()) 116 | 117 | # Compute the logits tensor corresponding to the correct class per instance. 118 | example_indices = array_ops.reshape( 119 | math_ops.range(batch_size), shape=[batch_size, 1]) 120 | indices = array_ops.concat( 121 | [ 122 | example_indices, 123 | array_ops.reshape( 124 | math_ops.cast(labels, example_indices.dtype), 125 | shape=[batch_size, 1]) 126 | ], 127 | axis=1) 128 | label_logits = array_ops.reshape( 129 | array_ops.gather_nd(params=logits, indices=indices), 130 | shape=[batch_size, 1]) 131 | 132 | one_cold_labels = array_ops.one_hot( 133 | indices=labels, depth=num_classes, on_value=0.0, off_value=1.0) 134 | margin = logits - label_logits + one_cold_labels 135 | margin = nn_ops.relu(margin) 136 | loss = math_ops.reduce_max(margin, axis=1) 137 | return losses.compute_weighted_loss( 138 | loss, weights, scope, loss_collection, reduction=reduction) 139 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generally useful utility functions.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import collections 21 | import json 22 | import math 23 | import os 24 | import sys 25 | import time 26 | 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | 31 | def check_tensorflow_version(): 32 | min_tf_version = "1.4.0-dev20171024" 33 | if tf.__version__ < min_tf_version: 34 | raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version) 35 | 36 | 37 | def safe_exp(value): 38 | """Exponentiation with catching of overflow error.""" 39 | try: 40 | ans = math.exp(value) 41 | except OverflowError: 42 | ans = float("inf") 43 | return ans 44 | 45 | 46 | def print_time(s, start_time): 47 | """Take a start time, print elapsed duration, and return a new time.""" 48 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime())) 49 | sys.stdout.flush() 50 | return time.time() 51 | 52 | 53 | def print_out(s, f=None, new_line=True): 54 | """Similar to print but with support to flush and output to a file.""" 55 | if isinstance(s, bytes): 56 | s = s.decode("utf-8") 57 | 58 | if f: 59 | f.write(s.encode("utf-8")) 60 | if new_line: 61 | f.write(b"\n") 62 | 63 | # stdout 64 | out_s = s.encode("utf-8") 65 | if not isinstance(out_s, str): 66 | out_s = out_s.decode("utf-8") 67 | print(out_s, end="", file=sys.stdout) 68 | 69 | if new_line: 70 | sys.stdout.write("\n") 71 | sys.stdout.flush() 72 | 73 | 74 | def print_hparams(hparams, skip_patterns=None, header=None): 75 | """Print hparams, can skip keys based on pattern.""" 76 | if header: print_out("%s" % header) 77 | values = hparams.values() 78 | for key in sorted(values.keys()): 79 | if not skip_patterns or all( 80 | [skip_pattern not in key for skip_pattern in skip_patterns]): 81 | print_out(" %s=%s" % (key, str(values[key]))) 82 | 83 | 84 | def load_hparams(model_dir): 85 | """Load hparams from an existing model directory.""" 86 | hparams_file = os.path.join(model_dir, "hparams") 87 | if tf.gfile.Exists(hparams_file): 88 | print_out("# Loading hparams from %s" % hparams_file) 89 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f: 90 | try: 91 | hparams_values = json.load(f) 92 | hparams = tf.contrib.training.HParams(**hparams_values) 93 | except ValueError: 94 | print_out(" can't load hparams file") 95 | return None 96 | return hparams 97 | else: 98 | return None 99 | 100 | 101 | def maybe_parse_standard_hparams(hparams, hparams_path): 102 | """Override hparams values with existing standard hparams config.""" 103 | if not hparams_path: 104 | return hparams 105 | 106 | if tf.gfile.Exists(hparams_path): 107 | print_out("# Loading standard hparams from %s" % hparams_path) 108 | with tf.gfile.GFile(hparams_path, "r") as f: 109 | hparams.parse_json(f.read()) 110 | 111 | return hparams 112 | 113 | 114 | def save_hparams(out_dir, hparams): 115 | """Save hparams.""" 116 | hparams_file = os.path.join(out_dir, "hparams") 117 | print_out(" saving hparams to %s" % hparams_file) 118 | with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: 119 | f.write(hparams.to_json()) 120 | 121 | 122 | def debug_tensor(s, msg=None, summarize=10): 123 | """Print the shape and value of a tensor at test time. Return a new tensor.""" 124 | if not msg: 125 | msg = s.name 126 | return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize) 127 | 128 | 129 | def add_summary(summary_writer, global_step, tag, value): 130 | """Add a new summary to the current summary_writer. 131 | Useful to log things that are not part of the training graph, e.g., tag=BLEU. 132 | """ 133 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 134 | summary_writer.add_summary(summary, global_step) 135 | 136 | 137 | def get_config_proto(log_device_placement=False, allow_soft_placement=True, 138 | num_intra_threads=0, num_inter_threads=0): 139 | # GPU options: 140 | # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html 141 | config_proto = tf.ConfigProto( 142 | log_device_placement=log_device_placement, 143 | allow_soft_placement=allow_soft_placement) 144 | config_proto.gpu_options.allow_growth = True 145 | 146 | # CPU threads options 147 | if num_intra_threads: 148 | config_proto.intra_op_parallelism_threads = num_intra_threads 149 | if num_inter_threads: 150 | config_proto.inter_op_parallelism_threads = num_inter_threads 151 | 152 | return config_proto 153 | 154 | 155 | def format_text(words,scores): 156 | """Convert a sequence words into sentence.""" 157 | if (not hasattr(words, "__len__") and # for numpy array 158 | not isinstance(words, collections.Iterable)): 159 | words = [words]+[scores] 160 | print ("in if !!!!") 161 | return b" ".join(words)+' '+str(abs(scores)) 162 | 163 | 164 | def format_bpe_text(symbols, delimiter=b"@@"): 165 | """Convert a sequence of bpe words into sentence.""" 166 | words = [] 167 | word = b"" 168 | if isinstance(symbols, str): 169 | symbols = symbols.encode() 170 | delimiter_len = len(delimiter) 171 | for symbol in symbols: 172 | if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: 173 | word += symbol[:-delimiter_len] 174 | else: # end of a word 175 | word += symbol 176 | words.append(word) 177 | word = b"" 178 | return b" ".join(words) 179 | 180 | 181 | def format_spm_text(symbols): 182 | """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" 183 | return u"".join(format_text(symbols).decode("utf-8").split()).replace( 184 | u"\u2581", u" ").strip().encode("utf-8") 185 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generally useful utility functions.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import collections 21 | import json 22 | import math 23 | import os 24 | import sys 25 | import time 26 | 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | 31 | def check_tensorflow_version(): 32 | min_tf_version = "1.4.0-dev20171024" 33 | if tf.__version__ < min_tf_version: 34 | raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version) 35 | 36 | 37 | def safe_exp(value): 38 | """Exponentiation with catching of overflow error.""" 39 | try: 40 | ans = math.exp(value) 41 | except OverflowError: 42 | ans = float("inf") 43 | return ans 44 | 45 | 46 | def print_time(s, start_time): 47 | """Take a start time, print elapsed duration, and return a new time.""" 48 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime())) 49 | sys.stdout.flush() 50 | return time.time() 51 | 52 | 53 | def print_out(s, f=None, new_line=True): 54 | """Similar to print but with support to flush and output to a file.""" 55 | if isinstance(s, bytes): 56 | s = s.decode("utf-8") 57 | 58 | if f: 59 | f.write(s.encode("utf-8")) 60 | if new_line: 61 | f.write(b"\n") 62 | 63 | # stdout 64 | out_s = s.encode("utf-8") 65 | if not isinstance(out_s, str): 66 | out_s = out_s.decode("utf-8") 67 | print(out_s, end="", file=sys.stdout) 68 | 69 | if new_line: 70 | sys.stdout.write("\n") 71 | sys.stdout.flush() 72 | 73 | 74 | def print_hparams(hparams, skip_patterns=None, header=None): 75 | """Print hparams, can skip keys based on pattern.""" 76 | if header: print_out("%s" % header) 77 | values = hparams.values() 78 | for key in sorted(values.keys()): 79 | if not skip_patterns or all( 80 | [skip_pattern not in key for skip_pattern in skip_patterns]): 81 | print_out(" %s=%s" % (key, str(values[key]))) 82 | 83 | 84 | def load_hparams(model_dir): 85 | """Load hparams from an existing model directory.""" 86 | hparams_file = os.path.join(model_dir, "hparams") 87 | if tf.gfile.Exists(hparams_file): 88 | print_out("# Loading hparams from %s" % hparams_file) 89 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f: 90 | try: 91 | hparams_values = json.load(f) 92 | hparams = tf.contrib.training.HParams(**hparams_values) 93 | except ValueError: 94 | print_out(" can't load hparams file") 95 | return None 96 | return hparams 97 | else: 98 | return None 99 | 100 | 101 | def maybe_parse_standard_hparams(hparams, hparams_path): 102 | """Override hparams values with existing standard hparams config.""" 103 | if not hparams_path: 104 | return hparams 105 | 106 | if tf.gfile.Exists(hparams_path): 107 | print_out("# Loading standard hparams from %s" % hparams_path) 108 | with tf.gfile.GFile(hparams_path, "r") as f: 109 | hparams.parse_json(f.read()) 110 | 111 | return hparams 112 | 113 | 114 | def save_hparams(out_dir, hparams): 115 | """Save hparams.""" 116 | hparams_file = os.path.join(out_dir, "hparams") 117 | print_out(" saving hparams to %s" % hparams_file) 118 | with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: 119 | f.write(hparams.to_json()) 120 | 121 | 122 | def debug_tensor(s, msg=None, summarize=10): 123 | """Print the shape and value of a tensor at test time. Return a new tensor.""" 124 | if not msg: 125 | msg = s.name 126 | return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize) 127 | 128 | 129 | def add_summary(summary_writer, global_step, tag, value): 130 | """Add a new summary to the current summary_writer. 131 | Useful to log things that are not part of the training graph, e.g., tag=BLEU. 132 | """ 133 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 134 | summary_writer.add_summary(summary, global_step) 135 | 136 | 137 | def get_config_proto(log_device_placement=False, allow_soft_placement=True, 138 | num_intra_threads=0, num_inter_threads=0): 139 | # GPU options: 140 | # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html 141 | config_proto = tf.ConfigProto( 142 | log_device_placement=log_device_placement, 143 | allow_soft_placement=allow_soft_placement) 144 | config_proto.gpu_options.allow_growth = True 145 | 146 | # CPU threads options 147 | if num_intra_threads: 148 | config_proto.intra_op_parallelism_threads = num_intra_threads 149 | if num_inter_threads: 150 | config_proto.inter_op_parallelism_threads = num_inter_threads 151 | 152 | return config_proto 153 | 154 | 155 | def format_text(words,scores): 156 | """Convert a sequence words into sentence.""" 157 | if (not hasattr(words, "__len__") and # for numpy array 158 | not isinstance(words, collections.Iterable)): 159 | words = [words]+[scores] 160 | print ("in if !!!!") 161 | return b" ".join(words)+' '+str(abs(scores)) 162 | 163 | 164 | def format_bpe_text(symbols, delimiter=b"@@"): 165 | """Convert a sequence of bpe words into sentence.""" 166 | words = [] 167 | word = b"" 168 | if isinstance(symbols, str): 169 | symbols = symbols.encode() 170 | delimiter_len = len(delimiter) 171 | for symbol in symbols: 172 | if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: 173 | word += symbol[:-delimiter_len] 174 | else: # end of a word 175 | word += symbol 176 | words.append(word) 177 | word = b"" 178 | return b" ".join(words) 179 | 180 | 181 | def format_spm_text(symbols): 182 | """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" 183 | return u"".join(format_text(symbols).decode("utf-8").split()).replace( 184 | u"\u2581", u" ").strip().encode("utf-8") 185 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility for evaluating various tasks, e.g., translation & summarization.""" 17 | import codecs 18 | import os 19 | import re 20 | import subprocess 21 | 22 | import tensorflow as tf 23 | 24 | from utils.scripts import bleu 25 | from utils.scripts import rouge 26 | 27 | 28 | __all__ = ["evaluate"] 29 | 30 | 31 | def evaluate(ref_file, trans_file, metric, subword_option=None): 32 | """Pick a metric and evaluate depending on task.""" 33 | # BLEU scores for translation task 34 | if metric.lower() == "bleu": 35 | evaluation_score = _bleu(ref_file, trans_file, 36 | subword_option=subword_option) 37 | # ROUGE scores for summarization tasks 38 | elif metric.lower() == "rouge": 39 | evaluation_score = _rouge(ref_file, trans_file, 40 | subword_option=subword_option) 41 | elif metric.lower() == "accuracy": 42 | evaluation_score = _accuracy(ref_file, trans_file) 43 | elif metric.lower() == "word_accuracy": 44 | evaluation_score = _word_accuracy(ref_file, trans_file) 45 | else: 46 | raise ValueError("Unknown metric %s" % metric) 47 | 48 | return evaluation_score 49 | 50 | 51 | def _clean(sentence, subword_option): 52 | """Clean and handle BPE or SPM outputs.""" 53 | sentence = sentence.strip() 54 | 55 | # BPE 56 | if subword_option == "bpe": 57 | sentence = re.sub("@@ ", "", sentence) 58 | 59 | # SPM 60 | elif subword_option == "spm": 61 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() 62 | 63 | return sentence 64 | 65 | 66 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py 67 | def _bleu(ref_file, trans_file, subword_option=None): 68 | """Compute BLEU scores and handling BPE.""" 69 | max_order = 4 70 | smooth = False 71 | 72 | ref_files = [ref_file] 73 | reference_text = [] 74 | for reference_filename in ref_files: 75 | with codecs.getreader("utf-8")( 76 | tf.gfile.GFile(reference_filename, "rb")) as fh: 77 | reference_text.append(fh.readlines()) 78 | 79 | per_segment_references = [] 80 | for references in zip(*reference_text): 81 | reference_list = [] 82 | for reference in references: 83 | reference = _clean(reference, subword_option) 84 | reference_list.append(reference.split(" ")) 85 | per_segment_references.append(reference_list) 86 | 87 | translations = [] 88 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 89 | for line in fh: 90 | line = _clean(line, subword_option=None) 91 | translations.append(line.split(" ")) 92 | 93 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 94 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 95 | per_segment_references, translations, max_order, smooth) 96 | return 100 * bleu_score 97 | 98 | 99 | def _rouge(ref_file, summarization_file, subword_option=None): 100 | """Compute ROUGE scores and handling BPE.""" 101 | 102 | references = [] 103 | with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh: 104 | for line in fh: 105 | references.append(_clean(line, subword_option)) 106 | 107 | hypotheses = [] 108 | with codecs.getreader("utf-8")( 109 | tf.gfile.GFile(summarization_file, "rb")) as fh: 110 | for line in fh: 111 | hypotheses.append(_clean(line, subword_option=None)) 112 | 113 | rouge_score_map = rouge.rouge(hypotheses, references) 114 | return 100 * rouge_score_map["rouge_l/f_score"] 115 | 116 | 117 | def _accuracy(label_file, pred_file): 118 | """Compute accuracy, each line contains a label.""" 119 | 120 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 121 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 122 | count = 0.0 123 | match = 0.0 124 | for label in label_fh: 125 | label = label.strip() 126 | pred = pred_fh.readline().strip() 127 | if label == pred: 128 | match += 1 129 | count += 1 130 | return 100 * match / count 131 | 132 | 133 | def _word_accuracy(label_file, pred_file): 134 | """Compute accuracy on per word basis.""" 135 | 136 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "r")) as label_fh: 137 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "r")) as pred_fh: 138 | total_acc, total_count = 0., 0. 139 | for sentence in label_fh: 140 | labels = sentence.strip().split(" ") 141 | preds = pred_fh.readline().strip().split(" ") 142 | match = 0.0 143 | for pos in range(min(len(labels), len(preds))): 144 | label = labels[pos] 145 | pred = preds[pos] 146 | if label == pred: 147 | match += 1 148 | total_acc += 100 * match / max(len(labels), len(preds)) 149 | total_count += 1 150 | return total_acc / total_count 151 | 152 | 153 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): 154 | """Compute BLEU scores using Moses multi-bleu.perl script.""" 155 | 156 | # TODO(thangluong): perform rewrite using python 157 | # BPE 158 | if subword_option == "bpe": 159 | debpe_tgt_test = tgt_test + ".debpe" 160 | if not os.path.exists(debpe_tgt_test): 161 | # TODO(thangluong): not use shell=True, can be a security hazard 162 | subprocess.call("cp %s %s" % (tgt_test, debpe_tgt_test), shell=True) 163 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test), 164 | shell=True) 165 | tgt_test = debpe_tgt_test 166 | elif subword_option == "spm": 167 | despm_tgt_test = tgt_test + ".despm" 168 | if not os.path.exists(despm_tgt_test): 169 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) 170 | subprocess.call("sed s/ //g %s" % (despm_tgt_test)) 171 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) 172 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) 173 | tgt_test = despm_tgt_test 174 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) 175 | 176 | # subprocess 177 | # TODO(thangluong): not use shell=True, can be a security hazard 178 | bleu_output = subprocess.check_output(cmd, shell=True) 179 | 180 | # extract BLEU score 181 | m = re.search("BLEU = (.+?),", bleu_output) 182 | bleu_score = float(m.group(1)) 183 | 184 | return bleu_score 185 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility for evaluating various tasks, e.g., translation & summarization.""" 17 | import codecs 18 | import os 19 | import re 20 | import subprocess 21 | 22 | import tensorflow as tf 23 | 24 | from utils.scripts import bleu 25 | from utils.scripts import rouge 26 | 27 | 28 | __all__ = ["evaluate"] 29 | 30 | 31 | def evaluate(ref_file, trans_file, metric, subword_option=None): 32 | """Pick a metric and evaluate depending on task.""" 33 | # BLEU scores for translation task 34 | if metric.lower() == "bleu": 35 | evaluation_score = _bleu(ref_file, trans_file, 36 | subword_option=subword_option) 37 | # ROUGE scores for summarization tasks 38 | elif metric.lower() == "rouge": 39 | evaluation_score = _rouge(ref_file, trans_file, 40 | subword_option=subword_option) 41 | elif metric.lower() == "accuracy": 42 | evaluation_score = _accuracy(ref_file, trans_file) 43 | elif metric.lower() == "word_accuracy": 44 | evaluation_score = _word_accuracy(ref_file, trans_file) 45 | else: 46 | raise ValueError("Unknown metric %s" % metric) 47 | 48 | return evaluation_score 49 | 50 | 51 | def _clean(sentence, subword_option): 52 | """Clean and handle BPE or SPM outputs.""" 53 | sentence = sentence.strip() 54 | 55 | # BPE 56 | if subword_option == "bpe": 57 | sentence = re.sub("@@ ", "", sentence) 58 | 59 | # SPM 60 | elif subword_option == "spm": 61 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() 62 | 63 | return sentence 64 | 65 | 66 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py 67 | def _bleu(ref_file, trans_file, subword_option=None): 68 | """Compute BLEU scores and handling BPE.""" 69 | max_order = 4 70 | smooth = False 71 | 72 | ref_files = [ref_file] 73 | reference_text = [] 74 | for reference_filename in ref_files: 75 | with codecs.getreader("utf-8")( 76 | tf.gfile.GFile(reference_filename, "rb")) as fh: 77 | reference_text.append(fh.readlines()) 78 | 79 | per_segment_references = [] 80 | for references in zip(*reference_text): 81 | reference_list = [] 82 | for reference in references: 83 | reference = _clean(reference, subword_option) 84 | reference_list.append(reference.split(" ")) 85 | per_segment_references.append(reference_list) 86 | 87 | translations = [] 88 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 89 | for line in fh: 90 | line = _clean(line, subword_option=None) 91 | translations.append(line.split(" ")) 92 | 93 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 94 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 95 | per_segment_references, translations, max_order, smooth) 96 | return 100 * bleu_score 97 | 98 | 99 | def _rouge(ref_file, summarization_file, subword_option=None): 100 | """Compute ROUGE scores and handling BPE.""" 101 | 102 | references = [] 103 | with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh: 104 | for line in fh: 105 | references.append(_clean(line, subword_option)) 106 | 107 | hypotheses = [] 108 | with codecs.getreader("utf-8")( 109 | tf.gfile.GFile(summarization_file, "rb")) as fh: 110 | for line in fh: 111 | hypotheses.append(_clean(line, subword_option=None)) 112 | 113 | rouge_score_map = rouge.rouge(hypotheses, references) 114 | return 100 * rouge_score_map["rouge_l/f_score"] 115 | 116 | 117 | def _accuracy(label_file, pred_file): 118 | """Compute accuracy, each line contains a label.""" 119 | 120 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 121 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 122 | count = 0.0 123 | match = 0.0 124 | for label in label_fh: 125 | label = label.strip() 126 | pred = pred_fh.readline().strip() 127 | if label == pred: 128 | match += 1 129 | count += 1 130 | return 100 * match / count 131 | 132 | 133 | def _word_accuracy(label_file, pred_file): 134 | """Compute accuracy on per word basis.""" 135 | 136 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "r")) as label_fh: 137 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "r")) as pred_fh: 138 | total_acc, total_count = 0., 0. 139 | for sentence in label_fh: 140 | labels = sentence.strip().split(" ") 141 | preds = pred_fh.readline().strip().split(" ") 142 | match = 0.0 143 | for pos in range(min(len(labels), len(preds))): 144 | label = labels[pos] 145 | pred = preds[pos] 146 | if label == pred: 147 | match += 1 148 | total_acc += 100 * match / max(len(labels), len(preds)) 149 | total_count += 1 150 | return total_acc / total_count 151 | 152 | 153 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): 154 | """Compute BLEU scores using Moses multi-bleu.perl script.""" 155 | 156 | # TODO(thangluong): perform rewrite using python 157 | # BPE 158 | if subword_option == "bpe": 159 | debpe_tgt_test = tgt_test + ".debpe" 160 | if not os.path.exists(debpe_tgt_test): 161 | # TODO(thangluong): not use shell=True, can be a security hazard 162 | subprocess.call("cp %s %s" % (tgt_test, debpe_tgt_test), shell=True) 163 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test), 164 | shell=True) 165 | tgt_test = debpe_tgt_test 166 | elif subword_option == "spm": 167 | despm_tgt_test = tgt_test + ".despm" 168 | if not os.path.exists(despm_tgt_test): 169 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) 170 | subprocess.call("sed s/ //g %s" % (despm_tgt_test)) 171 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) 172 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) 173 | tgt_test = despm_tgt_test 174 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) 175 | 176 | # subprocess 177 | # TODO(thangluong): not use shell=True, can be a security hazard 178 | bleu_output = subprocess.check_output(cmd, shell=True) 179 | 180 | # extract BLEU score 181 | m = re.search("BLEU = (.+?),", bleu_output) 182 | bleu_score = float(m.group(1)) 183 | 184 | return bleu_score 185 | -------------------------------------------------------------------------------- /WSD/utils/store_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @version: python2.7 4 | @author: luofuli 5 | @time: 2018/4/1 19:56 6 | """ 7 | 8 | import os 9 | import sys 10 | # print sys.path # 一般是从该文件下的所在的目录中查询,所以找不到MemNN和utils、postprocessing这些目录 11 | # sys.path.insert(0, "..") # 保证下面的import 其他文件夹的类成功 12 | 13 | import path 14 | import score 15 | #from postprocessing import ensemble 16 | 17 | _path = path.WSD_path() 18 | 19 | 20 | # 保存结果文件以及得分文件 21 | def save_result_and_score(tag, train_dataset, val_dataset, back_off_result, best_result, prob_best_result, 22 | best, config=None, score_path='../tmp/params_score.txt', print_logs=True): 23 | # 用best_i而不是最近的存储的i 24 | if train_dataset in _path.LS_DATASET: 25 | gold_key_path = _path.LS_TEST_KEY_PATH.format(val_dataset) 26 | else: 27 | gold_key_path = _path.ALL_WORDS_TEST_KEY_PATH.format(val_dataset) 28 | 29 | if not os.path.exists('../tmp/results'): 30 | os.makedirs('../tmp/results') 31 | 32 | best_result_path = '../tmp/results/{0}-best-{1}-{2}'.format(tag, str(best['i']), val_dataset) 33 | print('best_result_path: ' + best_result_path) 34 | score.write_result(best_result, back_off_result, best_result_path) 35 | f1s = None 36 | if val_dataset == _path.ALL_WORDS_TEST_DATASET[0]: 37 | f1s, pos_ps = score.score_all(best_result_path, gold_key_path, print_logs=print_logs, logs_level=1) 38 | else: 39 | _, _, f1 = score.score_one(best_result_path, gold_key_path) 40 | 41 | prob_best_path = '../tmp/results/{0}-prob-{1}'.format(tag, val_dataset) 42 | # ensemble_result_path = '../tmp/results/{0}-ensemble-{1}'.format(tag, val_dataset) 43 | # f1b, b_i = ensemble.score_for_prob(prob_best_result, back_off_result, 44 | # gold_key_path, print_logs=False) # luofuli add @2018/2/1 45 | # print('Best prob F1:%s\t(step: %d)' % (f1b, best['i'] + b_i + 1)) 46 | # ensemble_result = ensemble.vote(prob_best_result, back_off_result, prob_best_path, ensemble_result_path) 47 | # f1s_ = None 48 | # if val_dataset == _path.ALL_WORDS_TEST_DATASET[0]: 49 | # f1s_, pos_ps_ = score.score_all(ensemble_result_path, gold_key_path, print_logs=print_logs, logs_level=1) 50 | # else: 51 | # _, _, f1_ = score.score_one(ensemble_result_path, gold_key_path) 52 | # 53 | # print('Writing param score in path:%s, tag: %s' % (score_path, tag)) 54 | # try: 55 | # old = open(score_path).read() 56 | # except Exception: 57 | # old = '' 58 | # 59 | # if f1s and f1s_: 60 | # with open(score_path, 'w') as f: 61 | # f.write( 62 | # '%s%s\tbest\t%.1f\t%.1f' % (old, tag, best['acc_train'] * 100, best['acc_val'] * 100)) 63 | # for f1 in f1s: 64 | # f.write('\t%.1f' % (f1 * 100)) 65 | # for p in pos_ps: 66 | # f.write('\t%.1f' % (p * 100)) 67 | # score.store_params(score_path, val_dataset, config) 68 | # old = open(score_path).read() 69 | # with open(score_path, 'w') as f: 70 | # f.write('%s%s\tensemble\t%.1f\t%.1f' % ( 71 | # old, tag, best['acc_train'] * 100, best['acc_val'] * 100)) 72 | # for f1_ in f1s_: 73 | # f.write('\t%.1f' % (f1_ * 100)) 74 | # for p_ in pos_ps_: 75 | # f.write('\t%.1f' % (p_ * 100)) 76 | # score.store_params(score_path, val_dataset, config) 77 | # else: 78 | # with open(score_path, 'w') as f: 79 | # f.write('%s%s\t%.1f\t%.1f' % (old, tag, best['acc_train'] * 100, best['acc_val'] * 100)) 80 | # f.write('\t%.1f\t%.1f' % (f1 * 100, f1_ * 100)) 81 | # score.store_params(score_path, val_dataset, config) 82 | 83 | 84 | def check_correct(correct_path, result_path, gold_key_path, id_index=0, unsocre_dataset='semeval2007'): 85 | id_to_key = {} 86 | for line in open(gold_key_path): 87 | line = line.split() # defaut \s=[ \f\n\r\t\v] 88 | id = line[id_index] 89 | key = line[id_index + 1] 90 | id_to_key[id] = key 91 | 92 | true_ = {} 93 | with open(result_path) as f2: 94 | for line in f2.readlines(): 95 | id = line.split()[id_index] 96 | sense = line.split()[id_index + 1] 97 | if sense == id_to_key.get(id): 98 | tag = 1 99 | else: 100 | tag = 0 101 | true_[id] = tag 102 | 103 | import numpy as np 104 | 105 | with open(correct_path) as f1: 106 | correct = [] 107 | correct_has_se7 = [] 108 | error = [] 109 | error_id = [] 110 | lines1 = f1.readlines() 111 | for i, line1 in enumerate(lines1): 112 | id = line1.split()[id_index] 113 | tag = int(line1.split()[id_index+1]) 114 | if tag != true_[id]: 115 | error.append(tag) 116 | error_id.append(id) 117 | # print ('error line: %s' % line1) 118 | else: 119 | correct_has_se7.append(tag) 120 | if unsocre_dataset not in id: 121 | correct.append(tag) 122 | print('Using back-off: %s' % (len(id_to_key) - len(lines1))) 123 | print('Total score: %s' % (np.mean(correct))) 124 | print('Total score(has se7): %s' % (np.mean(correct_has_se7))) 125 | print('Total score(has se7 + error): %s' % np.mean(error + correct_has_se7)) 126 | print('error number: %s' % len(error)) 127 | 128 | return error_id 129 | 130 | 131 | def write_result(results, back_off_result, path, print_logs=True): 132 | if print_logs: 133 | print('Writing to file:%s' % path) 134 | new_results = results + back_off_result 135 | new_results = sorted(new_results, key=lambda a: a[0]) 136 | with open(path, 'w') as file: 137 | for instance_id, predicted_sense in new_results: 138 | file.write('%s %s\n' % (instance_id, predicted_sense)) 139 | 140 | 141 | if __name__ == "__main__": 142 | e1 = check_correct(correct_path='../tmp/results/correct-1.txt', 143 | result_path='../tmp/results/0CAN-best-1-ALL-1', 144 | gold_key_path=_path.ALL_WORDS_TEST_KEY_PATH.format('ALL')) 145 | 146 | e2 = check_correct(correct_path='../tmp/results/correct-2.txt', 147 | result_path='../tmp/results/0CAN-best-1-ALL-2', 148 | gold_key_path=_path.ALL_WORDS_TEST_KEY_PATH.format('ALL')) 149 | 150 | e3 = check_correct(correct_path='../tmp/results/correct.txt', 151 | result_path='../tmp/results/0CAN-best-1-ALL', 152 | gold_key_path=_path.ALL_WORDS_TEST_KEY_PATH.format('ALL')) 153 | 154 | print((set(e1) & set(e2))) 155 | print((set(e1) & set(e3))) 156 | print((set(e2) & set(e3))) 157 | 158 | print(u'最后找出来原因了,为什么模型run的val_acc比真实测出来的值低,原因主要是:' 159 | u' MFS 的结果没有包含在其中,FS: 2081/2370 = 87.8%的正确率,后面在train.py里增加了一个函数' 160 | u'此外,出来的值有部分是测试集的答案不在训练集中的,也就是test_instance_sensekey_not_in_train,总共有312个,占比4.3%。' 161 | u'这部分在测试的时候被统一归为了0类,所以导致部分预测到0类,实际上是假性提高了模型的正确率(这部分与多歧义互相+-吧)' 162 | u'最后的结论就是,代码没有问题,哪儿都没有问题,唯一需要考虑的就是,要不要预测sense不在训练集这种情况?因为不预测肯定会提高F1值' 163 | u'For 原本正确率上升和召回率不变,所以F1上升' 164 | ) 165 | 166 | -------------------------------------------------------------------------------- /Pun_Generation/code/attention_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention-based sequence-to-sequence model with dynamic RNN support.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | import model 23 | import model_helper 24 | 25 | __all__ = ["AttentionModel"] 26 | 27 | 28 | class AttentionModel(model.Model): 29 | """Sequence-to-sequence dynamic model with attention. 30 | 31 | This class implements a multi-layer recurrent neural network as encoder, 32 | and an attention-based decoder. This is the same as the model described in 33 | (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. 34 | This class also allows to use GRU cells in addition to LSTM cells with 35 | support for dropout. 36 | """ 37 | 38 | def __init__(self, 39 | hparams, 40 | mode, 41 | iterator, 42 | source_vocab_table, 43 | target_vocab_table, 44 | reverse_target_vocab_table=None, 45 | scope=None, 46 | extra_args=None): 47 | # Set attention_mechanism_fn 48 | if extra_args and extra_args.attention_mechanism_fn: 49 | self.attention_mechanism_fn = extra_args.attention_mechanism_fn 50 | else: 51 | self.attention_mechanism_fn = create_attention_mechanism 52 | 53 | super(AttentionModel, self).__init__( 54 | hparams=hparams, 55 | mode=mode, 56 | iterator=iterator, 57 | source_vocab_table=source_vocab_table, 58 | target_vocab_table=target_vocab_table, 59 | reverse_target_vocab_table=reverse_target_vocab_table, 60 | scope=scope, 61 | extra_args=extra_args) 62 | 63 | if self.mode == tf.contrib.learn.ModeKeys.INFER: 64 | self.infer_summary = self._get_infer_summary(hparams) 65 | 66 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 67 | source_sequence_length): 68 | """Build a RNN cell with attention mechanism that can be used by decoder.""" 69 | attention_option = hparams.attention 70 | attention_architecture = hparams.attention_architecture 71 | 72 | if attention_architecture != "standard": 73 | raise ValueError( 74 | "Unknown attention architecture %s" % attention_architecture) 75 | 76 | num_units = hparams.num_units 77 | num_layers = self.num_decoder_layers 78 | num_residual_layers = self.num_decoder_residual_layers 79 | beam_width = hparams.beam_width 80 | 81 | dtype = tf.float32 82 | 83 | # Ensure memory is batch-major 84 | if self.time_major: 85 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 86 | else: 87 | memory = encoder_outputs 88 | 89 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 90 | memory = tf.contrib.seq2seq.tile_batch( 91 | memory, multiplier=beam_width) 92 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 93 | source_sequence_length, multiplier=beam_width) 94 | encoder_state = tf.contrib.seq2seq.tile_batch( 95 | encoder_state, multiplier=beam_width) 96 | batch_size = self.batch_size * beam_width 97 | else: 98 | batch_size = self.batch_size 99 | 100 | attention_mechanism = self.attention_mechanism_fn( 101 | attention_option, num_units, memory, source_sequence_length, self.mode) 102 | 103 | cell = model_helper.create_rnn_cell( 104 | unit_type=hparams.unit_type, 105 | num_units=num_units, 106 | num_layers=num_layers, 107 | num_residual_layers=num_residual_layers, 108 | forget_bias=hparams.forget_bias, 109 | dropout=hparams.dropout, 110 | num_gpus=self.num_gpus, 111 | mode=self.mode, 112 | single_cell_fn=self.single_cell_fn) 113 | 114 | # Only generate alignment in greedy INFER mode. 115 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 116 | beam_width == 0) 117 | cell = tf.contrib.seq2seq.AttentionWrapper( 118 | cell, 119 | attention_mechanism, 120 | attention_layer_size=num_units, 121 | alignment_history=alignment_history, 122 | output_attention=hparams.output_attention, 123 | name="attention") 124 | 125 | # TODO(thangluong): do we need num_layers, num_gpus? 126 | cell = tf.contrib.rnn.DeviceWrapper(cell, 127 | model_helper.get_device_str( 128 | num_layers - 1, self.num_gpus)) 129 | 130 | if hparams.pass_hidden_state: 131 | decoder_initial_state = cell.zero_state(batch_size, dtype).clone( 132 | cell_state=encoder_state) 133 | else: 134 | decoder_initial_state = cell.zero_state(batch_size, dtype) 135 | 136 | return cell, decoder_initial_state 137 | 138 | def _get_infer_summary(self, hparams): 139 | if hparams.beam_width > 0: 140 | return tf.no_op() 141 | return _create_attention_images_summary(self.final_context_state) 142 | 143 | 144 | def create_attention_mechanism(attention_option, num_units, memory, 145 | source_sequence_length, mode): 146 | """Create attention mechanism based on the attention_option.""" 147 | del mode # unused 148 | 149 | # Mechanism 150 | if attention_option == "luong": 151 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 152 | num_units, memory, memory_sequence_length=source_sequence_length) 153 | elif attention_option == "scaled_luong": 154 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 155 | num_units, 156 | memory, 157 | memory_sequence_length=source_sequence_length, 158 | scale=True) 159 | elif attention_option == "bahdanau": 160 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 161 | num_units, memory, memory_sequence_length=source_sequence_length) 162 | elif attention_option == "normed_bahdanau": 163 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 164 | num_units, 165 | memory, 166 | memory_sequence_length=source_sequence_length, 167 | normalize=True) 168 | else: 169 | raise ValueError("Unknown attention option %s" % attention_option) 170 | 171 | return attention_mechanism 172 | 173 | 174 | def _create_attention_images_summary(final_context_state): 175 | """create attention image and attention summary.""" 176 | attention_images = (final_context_state.alignment_history.stack()) 177 | # Reshape to (batch, src_seq_len, tgt_seq_len,1) 178 | attention_images = tf.expand_dims( 179 | tf.transpose(attention_images, [1, 2, 0]), -1) 180 | # Scale to range [0, 255] 181 | attention_images *= 255 182 | attention_summary = tf.summary.image("attention_images", attention_images) 183 | return attention_summary 184 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/attention_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention-based sequence-to-sequence model with dynamic RNN support.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | import model 23 | import model_helper 24 | 25 | __all__ = ["AttentionModel"] 26 | 27 | 28 | class AttentionModel(model.Model): 29 | """Sequence-to-sequence dynamic model with attention. 30 | 31 | This class implements a multi-layer recurrent neural network as encoder, 32 | and an attention-based decoder. This is the same as the model described in 33 | (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. 34 | This class also allows to use GRU cells in addition to LSTM cells with 35 | support for dropout. 36 | """ 37 | 38 | def __init__(self, 39 | hparams, 40 | mode, 41 | iterator, 42 | source_vocab_table, 43 | target_vocab_table, 44 | reverse_target_vocab_table=None, 45 | scope=None, 46 | extra_args=None): 47 | # Set attention_mechanism_fn 48 | if extra_args and extra_args.attention_mechanism_fn: 49 | self.attention_mechanism_fn = extra_args.attention_mechanism_fn 50 | else: 51 | self.attention_mechanism_fn = create_attention_mechanism 52 | 53 | super(AttentionModel, self).__init__( 54 | hparams=hparams, 55 | mode=mode, 56 | iterator=iterator, 57 | source_vocab_table=source_vocab_table, 58 | target_vocab_table=target_vocab_table, 59 | reverse_target_vocab_table=reverse_target_vocab_table, 60 | scope=scope, 61 | extra_args=extra_args) 62 | 63 | if self.mode == tf.contrib.learn.ModeKeys.INFER: 64 | self.infer_summary = self._get_infer_summary(hparams) 65 | 66 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 67 | source_sequence_length): 68 | """Build a RNN cell with attention mechanism that can be used by decoder.""" 69 | attention_option = hparams.attention 70 | attention_architecture = hparams.attention_architecture 71 | 72 | if attention_architecture != "standard": 73 | raise ValueError( 74 | "Unknown attention architecture %s" % attention_architecture) 75 | 76 | num_units = hparams.num_units 77 | num_layers = self.num_decoder_layers 78 | num_residual_layers = self.num_decoder_residual_layers 79 | beam_width = hparams.beam_width 80 | 81 | dtype = tf.float32 82 | 83 | # Ensure memory is batch-major 84 | if self.time_major: 85 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 86 | else: 87 | memory = encoder_outputs 88 | 89 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 90 | memory = tf.contrib.seq2seq.tile_batch( 91 | memory, multiplier=beam_width) 92 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 93 | source_sequence_length, multiplier=beam_width) 94 | encoder_state = tf.contrib.seq2seq.tile_batch( 95 | encoder_state, multiplier=beam_width) 96 | batch_size = self.batch_size * beam_width 97 | else: 98 | batch_size = self.batch_size 99 | 100 | attention_mechanism = self.attention_mechanism_fn( 101 | attention_option, num_units, memory, source_sequence_length, self.mode) 102 | 103 | cell = model_helper.create_rnn_cell( 104 | unit_type=hparams.unit_type, 105 | num_units=num_units, 106 | num_layers=num_layers, 107 | num_residual_layers=num_residual_layers, 108 | forget_bias=hparams.forget_bias, 109 | dropout=hparams.dropout, 110 | num_gpus=self.num_gpus, 111 | mode=self.mode, 112 | single_cell_fn=self.single_cell_fn) 113 | 114 | # Only generate alignment in greedy INFER mode. 115 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 116 | beam_width == 0) 117 | cell = tf.contrib.seq2seq.AttentionWrapper( 118 | cell, 119 | attention_mechanism, 120 | attention_layer_size=num_units, 121 | alignment_history=alignment_history, 122 | output_attention=hparams.output_attention, 123 | name="attention") 124 | 125 | # TODO(thangluong): do we need num_layers, num_gpus? 126 | cell = tf.contrib.rnn.DeviceWrapper(cell, 127 | model_helper.get_device_str( 128 | num_layers - 1, self.num_gpus)) 129 | 130 | if hparams.pass_hidden_state: 131 | decoder_initial_state = cell.zero_state(batch_size, dtype).clone( 132 | cell_state=encoder_state) 133 | else: 134 | decoder_initial_state = cell.zero_state(batch_size, dtype) 135 | 136 | return cell, decoder_initial_state 137 | 138 | def _get_infer_summary(self, hparams): 139 | if hparams.beam_width > 0: 140 | return tf.no_op() 141 | return _create_attention_images_summary(self.final_context_state) 142 | 143 | 144 | def create_attention_mechanism(attention_option, num_units, memory, 145 | source_sequence_length, mode): 146 | """Create attention mechanism based on the attention_option.""" 147 | del mode # unused 148 | 149 | # Mechanism 150 | if attention_option == "luong": 151 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 152 | num_units, memory, memory_sequence_length=source_sequence_length) 153 | elif attention_option == "scaled_luong": 154 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 155 | num_units, 156 | memory, 157 | memory_sequence_length=source_sequence_length, 158 | scale=True) 159 | elif attention_option == "bahdanau": 160 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 161 | num_units, memory, memory_sequence_length=source_sequence_length) 162 | elif attention_option == "normed_bahdanau": 163 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 164 | num_units, 165 | memory, 166 | memory_sequence_length=source_sequence_length, 167 | normalize=True) 168 | else: 169 | raise ValueError("Unknown attention option %s" % attention_option) 170 | 171 | return attention_mechanism 172 | 173 | 174 | def _create_attention_images_summary(final_context_state): 175 | """create attention image and attention summary.""" 176 | attention_images = (final_context_state.alignment_history.stack()) 177 | # Reshape to (batch, src_seq_len, tgt_seq_len,1) 178 | attention_images = tf.expand_dims( 179 | tf.transpose(attention_images, [1, 2, 0]), -1) 180 | # Scale to range [0, 255] 181 | attention_images *= 255 182 | attention_summary = tf.summary.image("attention_images", attention_images) 183 | return attention_summary 184 | -------------------------------------------------------------------------------- /Pun_Generation/code/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """To perform inference on test set given a trained model.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | 22 | import tensorflow as tf 23 | 24 | import attention_model 25 | import gnmt_model 26 | import model as nmt_model 27 | import model_helper 28 | from utils import misc_utils as utils 29 | from utils import nmt_utils 30 | 31 | __all__ = ["load_data", "inference", 32 | "single_worker_inference", "multi_worker_inference"] 33 | 34 | 35 | def _decode_inference_indices(model, sess, output_infer, 36 | output_infer_summary_prefix, 37 | inference_indices, 38 | tgt_eos, 39 | subword_option): 40 | """Decoding only a specific set of sentences.""" 41 | utils.print_out(" decoding to output %s , num sents %d." % 42 | (output_infer, len(inference_indices))) 43 | start_time = time.time() 44 | with codecs.getwriter("utf-8")( 45 | tf.gfile.GFile(output_infer, mode="wb")) as trans_f: 46 | trans_f.write("") # Write empty string to ensure file is created. 47 | for decode_id in inference_indices: 48 | nmt_outputs, infer_summary = model.decode(sess) 49 | 50 | # get text translation 51 | assert nmt_outputs.shape[0] == 1 52 | translation = nmt_utils.get_translation( 53 | nmt_outputs, 54 | sent_id=0, 55 | tgt_eos=tgt_eos, 56 | subword_option=subword_option) 57 | 58 | if infer_summary is not None: # Attention models 59 | image_file = output_infer_summary_prefix + str(decode_id) + ".png" 60 | utils.print_out(" save attention image to %s*" % image_file) 61 | image_summ = tf.Summary() 62 | image_summ.ParseFromString(infer_summary) 63 | with tf.gfile.GFile(image_file, mode="w") as img_f: 64 | img_f.write(image_summ.value[0].image.encoded_image_string) 65 | 66 | trans_f.write("%s\n" % translation) 67 | utils.print_out(translation + b"\n") 68 | utils.print_time(" done", start_time) 69 | 70 | 71 | def load_data(inference_input_file, hparams=None): 72 | """Load inference data.""" 73 | with codecs.getreader("utf-8")( 74 | tf.gfile.GFile(inference_input_file, mode="rb")) as f: 75 | inference_data = f.read().splitlines() 76 | 77 | if hparams and hparams.inference_indices: 78 | inference_data = [inference_data[i] for i in hparams.inference_indices] 79 | 80 | return inference_data 81 | 82 | 83 | def inference(ckpt, 84 | inference_input_file, 85 | inference_output_file, 86 | hparams, 87 | num_workers=1, 88 | jobid=0, 89 | scope=None): 90 | """Perform translation.""" 91 | if hparams.inference_indices: 92 | assert num_workers == 1 93 | 94 | if not hparams.attention: 95 | model_creator = nmt_model.Model 96 | elif hparams.attention_architecture == "standard": 97 | model_creator = attention_model.AttentionModel 98 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: 99 | model_creator = gnmt_model.GNMTModel 100 | else: 101 | raise ValueError("Unknown model architecture") 102 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope) 103 | #emb_matrix = model_helper._create_or_load_embed("embedding_encoder", hparams.src_vocab_file, hparams.src_embed_file, 104 | #hparams.src_vocab_size, hparams.batch_size, tf.float32) 105 | #emb_matrix =infer_model.model.embedding_encoder 106 | #print ("emb_matrix",emb_matrix) 107 | if num_workers == 1: 108 | single_worker_inference( 109 | #emb_matrix, 110 | infer_model, 111 | ckpt, 112 | inference_input_file, 113 | inference_output_file, 114 | hparams, 115 | model_creator) 116 | else: 117 | multi_worker_inference( 118 | infer_model, 119 | ckpt, 120 | inference_input_file, 121 | inference_output_file, 122 | hparams, 123 | num_workers=num_workers, 124 | jobid=jobid) 125 | 126 | 127 | def single_worker_inference(#emb_matrix, 128 | infer_model, 129 | ckpt, 130 | inference_input_file, 131 | inference_output_file, 132 | hparams, 133 | model_creator): 134 | """Inference with a single worker.""" 135 | output_infer = inference_output_file 136 | 137 | # Read data 138 | infer_data = load_data(inference_input_file, hparams) 139 | #saver = tf.train.Saver() 140 | with tf.Session( 141 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 142 | loaded_infer_model = model_helper.load_model( 143 | infer_model.model, ckpt, sess, "infer") 144 | sess.run( 145 | infer_model.iterator.initializer, 146 | feed_dict={ 147 | infer_model.src_placeholder: infer_data, 148 | infer_model.batch_size_placeholder: hparams.infer_batch_size 149 | }) 150 | #sess.run(model_creator._build_decoder.eval()) 151 | # Decode 152 | #saver = tf.train.Saver() 153 | #emb=sess.run(emb_matrix) 154 | #fw=open('/home/yuzw/pun/nmt/inference/embedding_ds','w+') 155 | #fw.write('\n'.join( 156 | # [' '.join([str(u) for u in e]) for e in emb])) 157 | #print("emb=sess.run(emb_matrix)",emb) 158 | #save_path = saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz") 159 | #print("Model saved in path: %s" % save_path) 160 | utils.print_out("# Start decoding single_worker_inference") 161 | if hparams.inference_indices: 162 | _decode_inference_indices( 163 | loaded_infer_model, 164 | sess, 165 | output_infer=output_infer, 166 | output_infer_summary_prefix=output_infer, 167 | inference_indices=hparams.inference_indices, 168 | tgt_eos=hparams.eos, 169 | subword_option=hparams.subword_option) 170 | 171 | else: 172 | 173 | nmt_utils.decode_and_evaluate( 174 | "infer", 175 | loaded_infer_model, 176 | sess, 177 | output_infer, 178 | ref_file=None, 179 | metrics=hparams.metrics, 180 | subword_option=hparams.subword_option, 181 | beam_width=hparams.beam_width, 182 | tgt_eos=hparams.eos, 183 | num_translations_per_input=hparams.num_translations_per_input) 184 | 185 | 186 | def multi_worker_inference(infer_model, 187 | ckpt, 188 | inference_input_file, 189 | inference_output_file, 190 | hparams, 191 | num_workers, 192 | jobid): 193 | """Inference using multiple workers.""" 194 | assert num_workers > 1 195 | 196 | final_output_infer = inference_output_file 197 | output_infer = "%s_%d" % (inference_output_file, jobid) 198 | output_infer_done = "%s_done_%d" % (inference_output_file, jobid) 199 | 200 | # Read data 201 | infer_data = load_data(inference_input_file, hparams) 202 | 203 | # Split data to multiple workers 204 | total_load = len(infer_data) 205 | load_per_worker = int((total_load - 1) / num_workers) + 1 206 | start_position = jobid * load_per_worker 207 | end_position = min(start_position + load_per_worker, total_load) 208 | infer_data = infer_data[start_position:end_position] 209 | #saver = tf.train.Saver() 210 | with tf.Session( 211 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 212 | loaded_infer_model = model_helper.load_model( 213 | infer_model.model, ckpt, sess, "infer") 214 | sess.run(infer_model.iterator.initializer, 215 | { 216 | infer_model.src_placeholder: infer_data, 217 | infer_model.batch_size_placeholder: hparams.infer_batch_size 218 | }) 219 | #print (sess.run(tf.GraphKeys.GLOBAL_VARIABLES)) 220 | #saver = tf.train.Saver() 221 | #sess=tf.Session() 222 | #sess.run(infer_model.model.embedding_encoder) 223 | #saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz") 224 | #print("Model saved in path:") 225 | # Decode 226 | utils.print_out("# Start decoding multi_worker_inference") 227 | nmt_utils.decode_and_evaluate( 228 | "infer", 229 | loaded_infer_model, 230 | sess, 231 | output_infer, 232 | ref_file=None, 233 | metrics=hparams.metrics, 234 | subword_option=hparams.subword_option, 235 | beam_width=hparams.beam_width, 236 | tgt_eos=hparams.eos, 237 | num_translations_per_input=hparams.num_translations_per_input) 238 | 239 | # Change file name to indicate the file writing is completed. 240 | tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) 241 | 242 | # Job 0 is responsible for the clean up. 243 | if jobid != 0: return 244 | 245 | # Now write all translations 246 | with codecs.getwriter("utf-8")( 247 | tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: 248 | for worker_id in range(num_workers): 249 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 250 | while not tf.gfile.Exists(worker_infer_done): 251 | utils.print_out(" waitting job %d to complete." % worker_id) 252 | time.sleep(10) 253 | 254 | with codecs.getreader("utf-8")( 255 | tf.gfile.GFile(worker_infer_done, mode="rb")) as f: 256 | for translation in f: 257 | final_f.write("%s" % translation) 258 | 259 | for worker_id in range(num_workers): 260 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 261 | tf.gfile.Remove(worker_infer_done) 262 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """To perform inference on test set given a trained model.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | 22 | import tensorflow as tf 23 | 24 | import attention_model 25 | import gnmt_model 26 | import model as nmt_model 27 | import model_helper 28 | from utils import misc_utils as utils 29 | from utils import nmt_utils 30 | 31 | __all__ = ["load_data", "inference", 32 | "single_worker_inference", "multi_worker_inference"] 33 | 34 | 35 | def _decode_inference_indices(model, sess, output_infer, 36 | output_infer_summary_prefix, 37 | inference_indices, 38 | tgt_eos, 39 | subword_option): 40 | """Decoding only a specific set of sentences.""" 41 | utils.print_out(" decoding to output %s , num sents %d." % 42 | (output_infer, len(inference_indices))) 43 | start_time = time.time() 44 | with codecs.getwriter("utf-8")( 45 | tf.gfile.GFile(output_infer, mode="wb")) as trans_f: 46 | trans_f.write("") # Write empty string to ensure file is created. 47 | for decode_id in inference_indices: 48 | nmt_outputs, infer_summary = model.decode(sess) 49 | 50 | # get text translation 51 | assert nmt_outputs.shape[0] == 1 52 | translation = nmt_utils.get_translation( 53 | nmt_outputs, 54 | sent_id=0, 55 | tgt_eos=tgt_eos, 56 | subword_option=subword_option) 57 | 58 | if infer_summary is not None: # Attention models 59 | image_file = output_infer_summary_prefix + str(decode_id) + ".png" 60 | utils.print_out(" save attention image to %s*" % image_file) 61 | image_summ = tf.Summary() 62 | image_summ.ParseFromString(infer_summary) 63 | with tf.gfile.GFile(image_file, mode="w") as img_f: 64 | img_f.write(image_summ.value[0].image.encoded_image_string) 65 | 66 | trans_f.write("%s\n" % translation) 67 | utils.print_out(translation + b"\n") 68 | utils.print_time(" done", start_time) 69 | 70 | 71 | def load_data(inference_input_file, hparams=None): 72 | """Load inference data.""" 73 | with codecs.getreader("utf-8")( 74 | tf.gfile.GFile(inference_input_file, mode="rb")) as f: 75 | inference_data = f.read().splitlines() 76 | 77 | if hparams and hparams.inference_indices: 78 | inference_data = [inference_data[i] for i in hparams.inference_indices] 79 | 80 | return inference_data 81 | 82 | 83 | def inference(ckpt, 84 | inference_input_file, 85 | inference_output_file, 86 | hparams, 87 | num_workers=1, 88 | jobid=0, 89 | scope=None): 90 | """Perform translation.""" 91 | if hparams.inference_indices: 92 | assert num_workers == 1 93 | 94 | if not hparams.attention: 95 | model_creator = nmt_model.Model 96 | elif hparams.attention_architecture == "standard": 97 | model_creator = attention_model.AttentionModel 98 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: 99 | model_creator = gnmt_model.GNMTModel 100 | else: 101 | raise ValueError("Unknown model architecture") 102 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope) 103 | #emb_matrix = model_helper._create_or_load_embed("embedding_encoder", hparams.src_vocab_file, hparams.src_embed_file, 104 | #hparams.src_vocab_size, hparams.batch_size, tf.float32) 105 | #emb_matrix =infer_model.model.embedding_encoder 106 | #print ("emb_matrix",emb_matrix) 107 | if num_workers == 1: 108 | single_worker_inference( 109 | #emb_matrix, 110 | infer_model, 111 | ckpt, 112 | inference_input_file, 113 | inference_output_file, 114 | hparams, 115 | model_creator) 116 | else: 117 | multi_worker_inference( 118 | infer_model, 119 | ckpt, 120 | inference_input_file, 121 | inference_output_file, 122 | hparams, 123 | num_workers=num_workers, 124 | jobid=jobid) 125 | 126 | 127 | def single_worker_inference(#emb_matrix, 128 | infer_model, 129 | ckpt, 130 | inference_input_file, 131 | inference_output_file, 132 | hparams, 133 | model_creator): 134 | """Inference with a single worker.""" 135 | output_infer = inference_output_file 136 | 137 | # Read data 138 | infer_data = load_data(inference_input_file, hparams) 139 | #saver = tf.train.Saver() 140 | with tf.Session( 141 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 142 | loaded_infer_model = model_helper.load_model( 143 | infer_model.model, ckpt, sess, "infer") 144 | sess.run( 145 | infer_model.iterator.initializer, 146 | feed_dict={ 147 | infer_model.src_placeholder: infer_data, 148 | infer_model.batch_size_placeholder: hparams.infer_batch_size 149 | }) 150 | #sess.run(model_creator._build_decoder.eval()) 151 | # Decode 152 | #saver = tf.train.Saver() 153 | #emb=sess.run(emb_matrix) 154 | #fw=open('/home/yuzw/pun/nmt/inference/embedding_ds','w+') 155 | #fw.write('\n'.join( 156 | # [' '.join([str(u) for u in e]) for e in emb])) 157 | #print("emb=sess.run(emb_matrix)",emb) 158 | #save_path = saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz") 159 | #print("Model saved in path: %s" % save_path) 160 | utils.print_out("# Start decoding single_worker_inference") 161 | if hparams.inference_indices: 162 | _decode_inference_indices( 163 | loaded_infer_model, 164 | sess, 165 | output_infer=output_infer, 166 | output_infer_summary_prefix=output_infer, 167 | inference_indices=hparams.inference_indices, 168 | tgt_eos=hparams.eos, 169 | subword_option=hparams.subword_option) 170 | 171 | else: 172 | 173 | nmt_utils.decode_and_evaluate( 174 | "infer", 175 | loaded_infer_model, 176 | sess, 177 | output_infer, 178 | ref_file=None, 179 | metrics=hparams.metrics, 180 | subword_option=hparams.subword_option, 181 | beam_width=hparams.beam_width, 182 | tgt_eos=hparams.eos, 183 | num_translations_per_input=hparams.num_translations_per_input) 184 | 185 | 186 | def multi_worker_inference(infer_model, 187 | ckpt, 188 | inference_input_file, 189 | inference_output_file, 190 | hparams, 191 | num_workers, 192 | jobid): 193 | """Inference using multiple workers.""" 194 | assert num_workers > 1 195 | 196 | final_output_infer = inference_output_file 197 | output_infer = "%s_%d" % (inference_output_file, jobid) 198 | output_infer_done = "%s_done_%d" % (inference_output_file, jobid) 199 | 200 | # Read data 201 | infer_data = load_data(inference_input_file, hparams) 202 | 203 | # Split data to multiple workers 204 | total_load = len(infer_data) 205 | load_per_worker = int((total_load - 1) / num_workers) + 1 206 | start_position = jobid * load_per_worker 207 | end_position = min(start_position + load_per_worker, total_load) 208 | infer_data = infer_data[start_position:end_position] 209 | #saver = tf.train.Saver() 210 | with tf.Session( 211 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 212 | loaded_infer_model = model_helper.load_model( 213 | infer_model.model, ckpt, sess, "infer") 214 | sess.run(infer_model.iterator.initializer, 215 | { 216 | infer_model.src_placeholder: infer_data, 217 | infer_model.batch_size_placeholder: hparams.infer_batch_size 218 | }) 219 | #print (sess.run(tf.GraphKeys.GLOBAL_VARIABLES)) 220 | #saver = tf.train.Saver() 221 | #sess=tf.Session() 222 | #sess.run(infer_model.model.embedding_encoder) 223 | #saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz") 224 | #print("Model saved in path:") 225 | # Decode 226 | utils.print_out("# Start decoding multi_worker_inference") 227 | nmt_utils.decode_and_evaluate( 228 | "infer", 229 | loaded_infer_model, 230 | sess, 231 | output_infer, 232 | ref_file=None, 233 | metrics=hparams.metrics, 234 | subword_option=hparams.subword_option, 235 | beam_width=hparams.beam_width, 236 | tgt_eos=hparams.eos, 237 | num_translations_per_input=hparams.num_translations_per_input) 238 | 239 | # Change file name to indicate the file writing is completed. 240 | tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) 241 | 242 | # Job 0 is responsible for the clean up. 243 | if jobid != 0: return 244 | 245 | # Now write all translations 246 | with codecs.getwriter("utf-8")( 247 | tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: 248 | for worker_id in range(num_workers): 249 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 250 | while not tf.gfile.Exists(worker_infer_done): 251 | utils.print_out(" waitting job %d to complete." % worker_id) 252 | time.sleep(10) 253 | 254 | with codecs.getreader("utf-8")( 255 | tf.gfile.GFile(worker_infer_done, mode="rb")) as f: 256 | for translation in f: 257 | final_f.write("%s" % translation) 258 | 259 | for worker_id in range(num_workers): 260 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 261 | tf.gfile.Remove(worker_infer_done) 262 | -------------------------------------------------------------------------------- /Pun_Generation/code/gnmt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GNMT attention sequence-to-sequence model with dynamic RNN support.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | # TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out. 24 | from tensorflow.python.util import nest 25 | 26 | import attention_model 27 | import model_helper 28 | from utils import misc_utils as utils 29 | 30 | __all__ = ["GNMTModel"] 31 | 32 | 33 | class GNMTModel(attention_model.AttentionModel): 34 | """Sequence-to-sequence dynamic model with GNMT attention architecture. 35 | """ 36 | 37 | def __init__(self, 38 | hparams, 39 | mode, 40 | iterator, 41 | source_vocab_table, 42 | target_vocab_table, 43 | reverse_target_vocab_table=None, 44 | scope=None, 45 | extra_args=None): 46 | super(GNMTModel, self).__init__( 47 | hparams=hparams, 48 | mode=mode, 49 | iterator=iterator, 50 | source_vocab_table=source_vocab_table, 51 | target_vocab_table=target_vocab_table, 52 | reverse_target_vocab_table=reverse_target_vocab_table, 53 | scope=scope, 54 | extra_args=extra_args) 55 | 56 | def _build_encoder(self, hparams): 57 | """Build a GNMT encoder.""" 58 | if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": 59 | return super(GNMTModel, self)._build_encoder(hparams) 60 | 61 | if hparams.encoder_type != "gnmt": 62 | raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) 63 | 64 | # Build GNMT encoder. 65 | num_bi_layers = 1 66 | num_uni_layers = self.num_encoder_layers - num_bi_layers 67 | utils.print_out(" num_bi_layers = %d" % num_bi_layers) 68 | utils.print_out(" num_uni_layers = %d" % num_uni_layers) 69 | 70 | iterator = self.iterator 71 | source = iterator.source 72 | if self.time_major: 73 | source = tf.transpose(source) 74 | 75 | with tf.variable_scope("encoder") as scope: 76 | dtype = scope.dtype 77 | 78 | # Look up embedding, emp_inp: [max_time, batch_size, num_units] 79 | # when time_major = True 80 | encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder, 81 | source) 82 | 83 | # Execute _build_bidirectional_rnn from Model class 84 | bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( 85 | inputs=encoder_emb_inp, 86 | sequence_length=iterator.source_sequence_length, 87 | dtype=dtype, 88 | hparams=hparams, 89 | num_bi_layers=num_bi_layers, 90 | num_bi_residual_layers=0, # no residual connection 91 | ) 92 | 93 | uni_cell = model_helper.create_rnn_cell( 94 | unit_type=hparams.unit_type, 95 | num_units=hparams.num_units, 96 | num_layers=num_uni_layers, 97 | num_residual_layers=self.num_encoder_residual_layers, 98 | forget_bias=hparams.forget_bias, 99 | dropout=hparams.dropout, 100 | num_gpus=self.num_gpus, 101 | base_gpu=1, 102 | mode=self.mode, 103 | single_cell_fn=self.single_cell_fn) 104 | 105 | # encoder_outputs: size [max_time, batch_size, num_units] 106 | # when time_major = True 107 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn( 108 | uni_cell, 109 | bi_encoder_outputs, 110 | dtype=dtype, 111 | sequence_length=iterator.source_sequence_length, 112 | time_major=self.time_major) 113 | 114 | # Pass all encoder state except the first bi-directional layer's state to 115 | # decoder. 116 | encoder_state = (bi_encoder_state[1],) + ( 117 | (encoder_state,) if num_uni_layers == 1 else encoder_state) 118 | 119 | return encoder_outputs, encoder_state 120 | 121 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 122 | source_sequence_length): 123 | """Build a RNN cell with GNMT attention architecture.""" 124 | # Standard attention 125 | if hparams.attention_architecture == "standard": 126 | return super(GNMTModel, self)._build_decoder_cell( 127 | hparams, encoder_outputs, encoder_state, source_sequence_length) 128 | 129 | # GNMT attention 130 | attention_option = hparams.attention 131 | attention_architecture = hparams.attention_architecture 132 | num_units = hparams.num_units 133 | beam_width = hparams.beam_width 134 | 135 | dtype = tf.float32 136 | 137 | if self.time_major: 138 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 139 | else: 140 | memory = encoder_outputs 141 | 142 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 143 | memory = tf.contrib.seq2seq.tile_batch( 144 | memory, multiplier=beam_width) 145 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 146 | source_sequence_length, multiplier=beam_width) 147 | encoder_state = tf.contrib.seq2seq.tile_batch( 148 | encoder_state, multiplier=beam_width) 149 | batch_size = self.batch_size * beam_width 150 | else: 151 | batch_size = self.batch_size 152 | 153 | attention_mechanism = self.attention_mechanism_fn( 154 | attention_option, num_units, memory, source_sequence_length, self.mode) 155 | 156 | cell_list = model_helper._cell_list( # pylint: disable=protected-access 157 | unit_type=hparams.unit_type, 158 | num_units=num_units, 159 | num_layers=self.num_decoder_layers, 160 | num_residual_layers=self.num_decoder_residual_layers, 161 | forget_bias=hparams.forget_bias, 162 | dropout=hparams.dropout, 163 | num_gpus=self.num_gpus, 164 | mode=self.mode, 165 | single_cell_fn=self.single_cell_fn, 166 | residual_fn=gnmt_residual_fn 167 | ) 168 | 169 | # Only wrap the bottom layer with the attention mechanism. 170 | attention_cell = cell_list.pop(0) 171 | 172 | # Only generate alignment in greedy INFER mode. 173 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 174 | beam_width == 0) 175 | attention_cell = tf.contrib.seq2seq.AttentionWrapper( 176 | attention_cell, 177 | attention_mechanism, 178 | attention_layer_size=None, # don't use attention layer. 179 | output_attention=False, 180 | alignment_history=alignment_history, 181 | name="attention") 182 | 183 | if attention_architecture == "gnmt": 184 | cell = GNMTAttentionMultiCell( 185 | attention_cell, cell_list) 186 | elif attention_architecture == "gnmt_v2": 187 | cell = GNMTAttentionMultiCell( 188 | attention_cell, cell_list, use_new_attention=True) 189 | else: 190 | raise ValueError( 191 | "Unknown attention_architecture %s" % attention_architecture) 192 | 193 | if hparams.pass_hidden_state: 194 | decoder_initial_state = tuple( 195 | zs.clone(cell_state=es) 196 | if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es 197 | for zs, es in zip( 198 | cell.zero_state(batch_size, dtype), encoder_state)) 199 | else: 200 | decoder_initial_state = cell.zero_state(batch_size, dtype) 201 | 202 | return cell, decoder_initial_state 203 | 204 | def _get_infer_summary(self, hparams): 205 | # Standard attention 206 | if hparams.attention_architecture == "standard": 207 | return super(GNMTModel, self)._get_infer_summary(hparams) 208 | 209 | # GNMT attention 210 | if hparams.beam_width > 0: 211 | return tf.no_op() 212 | return attention_model._create_attention_images_summary( 213 | self.final_context_state[0]) 214 | 215 | 216 | class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): 217 | """A MultiCell with GNMT attention style.""" 218 | 219 | def __init__(self, attention_cell, cells, use_new_attention=False): 220 | """Creates a GNMTAttentionMultiCell. 221 | 222 | Args: 223 | attention_cell: An instance of AttentionWrapper. 224 | cells: A list of RNNCell wrapped with AttentionInputWrapper. 225 | use_new_attention: Whether to use the attention generated from current 226 | step bottom layer's output. Default is False. 227 | """ 228 | cells = [attention_cell] + cells 229 | self.use_new_attention = use_new_attention 230 | super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) 231 | 232 | def __call__(self, inputs, state, scope=None): 233 | """Run the cell with bottom layer's attention copied to all upper layers.""" 234 | if not nest.is_sequence(state): 235 | raise ValueError( 236 | "Expected state to be a tuple of length %d, but received: %s" 237 | % (len(self.state_size), state)) 238 | 239 | with tf.variable_scope(scope or "multi_rnn_cell"): 240 | new_states = [] 241 | 242 | with tf.variable_scope("cell_0_attention"): 243 | attention_cell = self._cells[0] 244 | attention_state = state[0] 245 | cur_inp, new_attention_state = attention_cell(inputs, attention_state) 246 | new_states.append(new_attention_state) 247 | 248 | for i in range(1, len(self._cells)): 249 | with tf.variable_scope("cell_%d" % i): 250 | 251 | cell = self._cells[i] 252 | cur_state = state[i] 253 | 254 | if self.use_new_attention: 255 | cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) 256 | else: 257 | cur_inp = tf.concat([cur_inp, attention_state.attention], -1) 258 | 259 | cur_inp, new_state = cell(cur_inp, cur_state) 260 | new_states.append(new_state) 261 | 262 | return cur_inp, tuple(new_states) 263 | 264 | 265 | def gnmt_residual_fn(inputs, outputs): 266 | """Residual function that handles different inputs and outputs inner dims. 267 | 268 | Args: 269 | inputs: cell inputs, this is actual inputs concatenated with the attention 270 | vector. 271 | outputs: cell outputs 272 | 273 | Returns: 274 | outputs + actual inputs 275 | """ 276 | def split_input(inp, out): 277 | out_dim = out.get_shape().as_list()[-1] 278 | inp_dim = inp.get_shape().as_list()[-1] 279 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) 280 | actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) 281 | def assert_shape_match(inp, out): 282 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 283 | nest.assert_same_structure(actual_inputs, outputs) 284 | nest.map_structure(assert_shape_match, actual_inputs, outputs) 285 | return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 286 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/gnmt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GNMT attention sequence-to-sequence model with dynamic RNN support.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | # TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out. 24 | from tensorflow.python.util import nest 25 | 26 | import attention_model 27 | import model_helper 28 | from utils import misc_utils as utils 29 | 30 | __all__ = ["GNMTModel"] 31 | 32 | 33 | class GNMTModel(attention_model.AttentionModel): 34 | """Sequence-to-sequence dynamic model with GNMT attention architecture. 35 | """ 36 | 37 | def __init__(self, 38 | hparams, 39 | mode, 40 | iterator, 41 | source_vocab_table, 42 | target_vocab_table, 43 | reverse_target_vocab_table=None, 44 | scope=None, 45 | extra_args=None): 46 | super(GNMTModel, self).__init__( 47 | hparams=hparams, 48 | mode=mode, 49 | iterator=iterator, 50 | source_vocab_table=source_vocab_table, 51 | target_vocab_table=target_vocab_table, 52 | reverse_target_vocab_table=reverse_target_vocab_table, 53 | scope=scope, 54 | extra_args=extra_args) 55 | 56 | def _build_encoder(self, hparams): 57 | """Build a GNMT encoder.""" 58 | if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": 59 | return super(GNMTModel, self)._build_encoder(hparams) 60 | 61 | if hparams.encoder_type != "gnmt": 62 | raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) 63 | 64 | # Build GNMT encoder. 65 | num_bi_layers = 1 66 | num_uni_layers = self.num_encoder_layers - num_bi_layers 67 | utils.print_out(" num_bi_layers = %d" % num_bi_layers) 68 | utils.print_out(" num_uni_layers = %d" % num_uni_layers) 69 | 70 | iterator = self.iterator 71 | source = iterator.source 72 | if self.time_major: 73 | source = tf.transpose(source) 74 | 75 | with tf.variable_scope("encoder") as scope: 76 | dtype = scope.dtype 77 | 78 | # Look up embedding, emp_inp: [max_time, batch_size, num_units] 79 | # when time_major = True 80 | encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder, 81 | source) 82 | 83 | # Execute _build_bidirectional_rnn from Model class 84 | bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( 85 | inputs=encoder_emb_inp, 86 | sequence_length=iterator.source_sequence_length, 87 | dtype=dtype, 88 | hparams=hparams, 89 | num_bi_layers=num_bi_layers, 90 | num_bi_residual_layers=0, # no residual connection 91 | ) 92 | 93 | uni_cell = model_helper.create_rnn_cell( 94 | unit_type=hparams.unit_type, 95 | num_units=hparams.num_units, 96 | num_layers=num_uni_layers, 97 | num_residual_layers=self.num_encoder_residual_layers, 98 | forget_bias=hparams.forget_bias, 99 | dropout=hparams.dropout, 100 | num_gpus=self.num_gpus, 101 | base_gpu=1, 102 | mode=self.mode, 103 | single_cell_fn=self.single_cell_fn) 104 | 105 | # encoder_outputs: size [max_time, batch_size, num_units] 106 | # when time_major = True 107 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn( 108 | uni_cell, 109 | bi_encoder_outputs, 110 | dtype=dtype, 111 | sequence_length=iterator.source_sequence_length, 112 | time_major=self.time_major) 113 | 114 | # Pass all encoder state except the first bi-directional layer's state to 115 | # decoder. 116 | encoder_state = (bi_encoder_state[1],) + ( 117 | (encoder_state,) if num_uni_layers == 1 else encoder_state) 118 | 119 | return encoder_outputs, encoder_state 120 | 121 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 122 | source_sequence_length): 123 | """Build a RNN cell with GNMT attention architecture.""" 124 | # Standard attention 125 | if hparams.attention_architecture == "standard": 126 | return super(GNMTModel, self)._build_decoder_cell( 127 | hparams, encoder_outputs, encoder_state, source_sequence_length) 128 | 129 | # GNMT attention 130 | attention_option = hparams.attention 131 | attention_architecture = hparams.attention_architecture 132 | num_units = hparams.num_units 133 | beam_width = hparams.beam_width 134 | 135 | dtype = tf.float32 136 | 137 | if self.time_major: 138 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 139 | else: 140 | memory = encoder_outputs 141 | 142 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 143 | memory = tf.contrib.seq2seq.tile_batch( 144 | memory, multiplier=beam_width) 145 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 146 | source_sequence_length, multiplier=beam_width) 147 | encoder_state = tf.contrib.seq2seq.tile_batch( 148 | encoder_state, multiplier=beam_width) 149 | batch_size = self.batch_size * beam_width 150 | else: 151 | batch_size = self.batch_size 152 | 153 | attention_mechanism = self.attention_mechanism_fn( 154 | attention_option, num_units, memory, source_sequence_length, self.mode) 155 | 156 | cell_list = model_helper._cell_list( # pylint: disable=protected-access 157 | unit_type=hparams.unit_type, 158 | num_units=num_units, 159 | num_layers=self.num_decoder_layers, 160 | num_residual_layers=self.num_decoder_residual_layers, 161 | forget_bias=hparams.forget_bias, 162 | dropout=hparams.dropout, 163 | num_gpus=self.num_gpus, 164 | mode=self.mode, 165 | single_cell_fn=self.single_cell_fn, 166 | residual_fn=gnmt_residual_fn 167 | ) 168 | 169 | # Only wrap the bottom layer with the attention mechanism. 170 | attention_cell = cell_list.pop(0) 171 | 172 | # Only generate alignment in greedy INFER mode. 173 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 174 | beam_width == 0) 175 | attention_cell = tf.contrib.seq2seq.AttentionWrapper( 176 | attention_cell, 177 | attention_mechanism, 178 | attention_layer_size=None, # don't use attention layer. 179 | output_attention=False, 180 | alignment_history=alignment_history, 181 | name="attention") 182 | 183 | if attention_architecture == "gnmt": 184 | cell = GNMTAttentionMultiCell( 185 | attention_cell, cell_list) 186 | elif attention_architecture == "gnmt_v2": 187 | cell = GNMTAttentionMultiCell( 188 | attention_cell, cell_list, use_new_attention=True) 189 | else: 190 | raise ValueError( 191 | "Unknown attention_architecture %s" % attention_architecture) 192 | 193 | if hparams.pass_hidden_state: 194 | decoder_initial_state = tuple( 195 | zs.clone(cell_state=es) 196 | if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es 197 | for zs, es in zip( 198 | cell.zero_state(batch_size, dtype), encoder_state)) 199 | else: 200 | decoder_initial_state = cell.zero_state(batch_size, dtype) 201 | 202 | return cell, decoder_initial_state 203 | 204 | def _get_infer_summary(self, hparams): 205 | # Standard attention 206 | if hparams.attention_architecture == "standard": 207 | return super(GNMTModel, self)._get_infer_summary(hparams) 208 | 209 | # GNMT attention 210 | if hparams.beam_width > 0: 211 | return tf.no_op() 212 | return attention_model._create_attention_images_summary( 213 | self.final_context_state[0]) 214 | 215 | 216 | class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): 217 | """A MultiCell with GNMT attention style.""" 218 | 219 | def __init__(self, attention_cell, cells, use_new_attention=False): 220 | """Creates a GNMTAttentionMultiCell. 221 | 222 | Args: 223 | attention_cell: An instance of AttentionWrapper. 224 | cells: A list of RNNCell wrapped with AttentionInputWrapper. 225 | use_new_attention: Whether to use the attention generated from current 226 | step bottom layer's output. Default is False. 227 | """ 228 | cells = [attention_cell] + cells 229 | self.use_new_attention = use_new_attention 230 | super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) 231 | 232 | def __call__(self, inputs, state, scope=None): 233 | """Run the cell with bottom layer's attention copied to all upper layers.""" 234 | if not nest.is_sequence(state): 235 | raise ValueError( 236 | "Expected state to be a tuple of length %d, but received: %s" 237 | % (len(self.state_size), state)) 238 | 239 | with tf.variable_scope(scope or "multi_rnn_cell"): 240 | new_states = [] 241 | 242 | with tf.variable_scope("cell_0_attention"): 243 | attention_cell = self._cells[0] 244 | attention_state = state[0] 245 | cur_inp, new_attention_state = attention_cell(inputs, attention_state) 246 | new_states.append(new_attention_state) 247 | 248 | for i in range(1, len(self._cells)): 249 | with tf.variable_scope("cell_%d" % i): 250 | 251 | cell = self._cells[i] 252 | cur_state = state[i] 253 | 254 | if self.use_new_attention: 255 | cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) 256 | else: 257 | cur_inp = tf.concat([cur_inp, attention_state.attention], -1) 258 | 259 | cur_inp, new_state = cell(cur_inp, cur_state) 260 | new_states.append(new_state) 261 | 262 | return cur_inp, tuple(new_states) 263 | 264 | 265 | def gnmt_residual_fn(inputs, outputs): 266 | """Residual function that handles different inputs and outputs inner dims. 267 | 268 | Args: 269 | inputs: cell inputs, this is actual inputs concatenated with the attention 270 | vector. 271 | outputs: cell outputs 272 | 273 | Returns: 274 | outputs + actual inputs 275 | """ 276 | def split_input(inp, out): 277 | out_dim = out.get_shape().as_list()[-1] 278 | inp_dim = inp.get_shape().as_list()[-1] 279 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) 280 | actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) 281 | def assert_shape_match(inp, out): 282 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 283 | nest.assert_same_structure(actual_inputs, outputs) 284 | nest.map_structure(assert_shape_match, actual_inputs, outputs) 285 | return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 286 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/scripts/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | 3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 4 | This is a modified and slightly extended verison of 5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import itertools 14 | import numpy as np 15 | 16 | #pylint: disable=C0103 17 | 18 | 19 | def _get_ngrams(n, text): 20 | """Calcualtes n-grams. 21 | 22 | Args: 23 | n: which n-grams to calculate 24 | text: An array of tokens 25 | 26 | Returns: 27 | A set of n-grams 28 | """ 29 | ngram_set = set() 30 | text_length = len(text) 31 | max_index_ngram_start = text_length - n 32 | for i in range(max_index_ngram_start + 1): 33 | ngram_set.add(tuple(text[i:i + n])) 34 | return ngram_set 35 | 36 | 37 | def _split_into_words(sentences): 38 | """Splits multiple sentences into words and flattens the result""" 39 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 40 | 41 | 42 | def _get_word_ngrams(n, sentences): 43 | """Calculates word n-grams for multiple sentences. 44 | """ 45 | assert len(sentences) > 0 46 | assert n > 0 47 | 48 | words = _split_into_words(sentences) 49 | return _get_ngrams(n, words) 50 | 51 | 52 | def _len_lcs(x, y): 53 | """ 54 | Returns the length of the Longest Common Subsequence between sequences x 55 | and y. 56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 57 | 58 | Args: 59 | x: sequence of words 60 | y: sequence of words 61 | 62 | Returns 63 | integer: Length of LCS between x and y 64 | """ 65 | table = _lcs(x, y) 66 | n, m = len(x), len(y) 67 | return table[n, m] 68 | 69 | 70 | def _lcs(x, y): 71 | """ 72 | Computes the length of the longest common subsequence (lcs) between two 73 | strings. The implementation below uses a DP programming algorithm and runs 74 | in O(nm) time where n = len(x) and m = len(y). 75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 76 | 77 | Args: 78 | x: collection of words 79 | y: collection of words 80 | 81 | Returns: 82 | Table of dictionary of coord and len lcs 83 | """ 84 | n, m = len(x), len(y) 85 | table = dict() 86 | for i in range(n + 1): 87 | for j in range(m + 1): 88 | if i == 0 or j == 0: 89 | table[i, j] = 0 90 | elif x[i - 1] == y[j - 1]: 91 | table[i, j] = table[i - 1, j - 1] + 1 92 | else: 93 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 94 | return table 95 | 96 | 97 | def _recon_lcs(x, y): 98 | """ 99 | Returns the Longest Subsequence between x and y. 100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 101 | 102 | Args: 103 | x: sequence of words 104 | y: sequence of words 105 | 106 | Returns: 107 | sequence: LCS of x and y 108 | """ 109 | i, j = len(x), len(y) 110 | table = _lcs(x, y) 111 | 112 | def _recon(i, j): 113 | """private recon calculation""" 114 | if i == 0 or j == 0: 115 | return [] 116 | elif x[i - 1] == y[j - 1]: 117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 118 | elif table[i - 1, j] > table[i, j - 1]: 119 | return _recon(i - 1, j) 120 | else: 121 | return _recon(i, j - 1) 122 | 123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 124 | return recon_tuple 125 | 126 | 127 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 128 | """ 129 | Computes ROUGE-N of two text collections of sentences. 130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 131 | papers/rouge-working-note-v1.3.1.pdf 132 | 133 | Args: 134 | evaluated_sentences: The sentences that have been picked by the summarizer 135 | reference_sentences: The sentences from the referene set 136 | n: Size of ngram. Defaults to 2. 137 | 138 | Returns: 139 | A tuple (f1, precision, recall) for ROUGE-N 140 | 141 | Raises: 142 | ValueError: raises exception if a param has len <= 0 143 | """ 144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 145 | raise ValueError("Collections must contain at least 1 sentence.") 146 | 147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 148 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 149 | reference_count = len(reference_ngrams) 150 | evaluated_count = len(evaluated_ngrams) 151 | 152 | # Gets the overlapping ngrams between evaluated and reference 153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 154 | overlapping_count = len(overlapping_ngrams) 155 | 156 | # Handle edge case. This isn't mathematically correct, but it's good enough 157 | if evaluated_count == 0: 158 | precision = 0.0 159 | else: 160 | precision = overlapping_count / evaluated_count 161 | 162 | if reference_count == 0: 163 | recall = 0.0 164 | else: 165 | recall = overlapping_count / reference_count 166 | 167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 168 | 169 | # return overlapping_count / reference_count 170 | return f1_score, precision, recall 171 | 172 | 173 | def _f_p_r_lcs(llcs, m, n): 174 | """ 175 | Computes the LCS-based F-measure score 176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 177 | rouge-working-note-v1.3.1.pdf 178 | 179 | Args: 180 | llcs: Length of LCS 181 | m: number of words in reference summary 182 | n: number of words in candidate summary 183 | 184 | Returns: 185 | Float. LCS-based F-measure score 186 | """ 187 | r_lcs = llcs / m 188 | p_lcs = llcs / n 189 | beta = p_lcs / (r_lcs + 1e-12) 190 | num = (1 + (beta**2)) * r_lcs * p_lcs 191 | denom = r_lcs + ((beta**2) * p_lcs) 192 | f_lcs = num / (denom + 1e-12) 193 | return f_lcs, p_lcs, r_lcs 194 | 195 | 196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 197 | """ 198 | Computes ROUGE-L (sentence level) of two text collections of sentences. 199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 200 | rouge-working-note-v1.3.1.pdf 201 | 202 | Calculated according to: 203 | R_lcs = LCS(X,Y)/m 204 | P_lcs = LCS(X,Y)/n 205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 206 | 207 | where: 208 | X = reference summary 209 | Y = Candidate summary 210 | m = length of reference summary 211 | n = length of candidate summary 212 | 213 | Args: 214 | evaluated_sentences: The sentences that have been picked by the summarizer 215 | reference_sentences: The sentences from the referene set 216 | 217 | Returns: 218 | A float: F_lcs 219 | 220 | Raises: 221 | ValueError: raises exception if a param has len <= 0 222 | """ 223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 224 | raise ValueError("Collections must contain at least 1 sentence.") 225 | reference_words = _split_into_words(reference_sentences) 226 | evaluated_words = _split_into_words(evaluated_sentences) 227 | m = len(reference_words) 228 | n = len(evaluated_words) 229 | lcs = _len_lcs(evaluated_words, reference_words) 230 | return _f_p_r_lcs(lcs, m, n) 231 | 232 | 233 | def _union_lcs(evaluated_sentences, reference_sentence): 234 | """ 235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 236 | subsequence between reference sentence ri and candidate summary C. For example 237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 241 | LCS_u(r_i, C) = 4/5. 242 | 243 | Args: 244 | evaluated_sentences: The sentences that have been picked by the summarizer 245 | reference_sentence: One of the sentences in the reference summaries 246 | 247 | Returns: 248 | float: LCS_u(r_i, C) 249 | 250 | ValueError: 251 | Raises exception if a param has len <= 0 252 | """ 253 | if len(evaluated_sentences) <= 0: 254 | raise ValueError("Collections must contain at least 1 sentence.") 255 | 256 | lcs_union = set() 257 | reference_words = _split_into_words([reference_sentence]) 258 | combined_lcs_length = 0 259 | for eval_s in evaluated_sentences: 260 | evaluated_words = _split_into_words([eval_s]) 261 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 262 | combined_lcs_length += len(lcs) 263 | lcs_union = lcs_union.union(lcs) 264 | 265 | union_lcs_count = len(lcs_union) 266 | union_lcs_value = union_lcs_count / combined_lcs_length 267 | return union_lcs_value 268 | 269 | 270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 271 | """ 272 | Computes ROUGE-L (summary level) of two text collections of sentences. 273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 274 | rouge-working-note-v1.3.1.pdf 275 | 276 | Calculated according to: 277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 280 | 281 | where: 282 | SUM(i,u) = SUM from i through u 283 | u = number of sentences in reference summary 284 | C = Candidate summary made up of v sentences 285 | m = number of words in reference summary 286 | n = number of words in candidate summary 287 | 288 | Args: 289 | evaluated_sentences: The sentences that have been picked by the summarizer 290 | reference_sentence: One of the sentences in the reference summaries 291 | 292 | Returns: 293 | A float: F_lcs 294 | 295 | Raises: 296 | ValueError: raises exception if a param has len <= 0 297 | """ 298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 299 | raise ValueError("Collections must contain at least 1 sentence.") 300 | 301 | # total number of words in reference sentences 302 | m = len(_split_into_words(reference_sentences)) 303 | 304 | # total number of words in evaluated sentences 305 | n = len(_split_into_words(evaluated_sentences)) 306 | 307 | union_lcs_sum_across_all_references = 0 308 | for ref_s in reference_sentences: 309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 310 | ref_s) 311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 312 | 313 | 314 | def rouge(hypotheses, references): 315 | """Calculates average rouge scores for a list of hypotheses and 316 | references""" 317 | 318 | # Filter out hyps that are of 0 length 319 | # hyps_and_refs = zip(hypotheses, references) 320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 321 | # hypotheses, references = zip(*hyps_and_refs) 322 | 323 | # Calculate ROUGE-1 F1, precision, recall scores 324 | rouge_1 = [ 325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 326 | ] 327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 328 | 329 | # Calculate ROUGE-2 F1, precision, recall scores 330 | rouge_2 = [ 331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 332 | ] 333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 334 | 335 | # Calculate ROUGE-L F1, precision, recall scores 336 | rouge_l = [ 337 | rouge_l_sentence_level([hyp], [ref]) 338 | for hyp, ref in zip(hypotheses, references) 339 | ] 340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 341 | 342 | return { 343 | "rouge_1/f_score": rouge_1_f, 344 | "rouge_1/r_score": rouge_1_r, 345 | "rouge_1/p_score": rouge_1_p, 346 | "rouge_2/f_score": rouge_2_f, 347 | "rouge_2/r_score": rouge_2_r, 348 | "rouge_2/p_score": rouge_2_p, 349 | "rouge_l/f_score": rouge_l_f, 350 | "rouge_l/r_score": rouge_l_r, 351 | "rouge_l/p_score": rouge_l_p, 352 | } 353 | -------------------------------------------------------------------------------- /Pun_Generation_Forward/code/utils/scripts/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | 3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 4 | This is a modified and slightly extended verison of 5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import itertools 14 | import numpy as np 15 | 16 | #pylint: disable=C0103 17 | 18 | 19 | def _get_ngrams(n, text): 20 | """Calcualtes n-grams. 21 | 22 | Args: 23 | n: which n-grams to calculate 24 | text: An array of tokens 25 | 26 | Returns: 27 | A set of n-grams 28 | """ 29 | ngram_set = set() 30 | text_length = len(text) 31 | max_index_ngram_start = text_length - n 32 | for i in range(max_index_ngram_start + 1): 33 | ngram_set.add(tuple(text[i:i + n])) 34 | return ngram_set 35 | 36 | 37 | def _split_into_words(sentences): 38 | """Splits multiple sentences into words and flattens the result""" 39 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 40 | 41 | 42 | def _get_word_ngrams(n, sentences): 43 | """Calculates word n-grams for multiple sentences. 44 | """ 45 | assert len(sentences) > 0 46 | assert n > 0 47 | 48 | words = _split_into_words(sentences) 49 | return _get_ngrams(n, words) 50 | 51 | 52 | def _len_lcs(x, y): 53 | """ 54 | Returns the length of the Longest Common Subsequence between sequences x 55 | and y. 56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 57 | 58 | Args: 59 | x: sequence of words 60 | y: sequence of words 61 | 62 | Returns 63 | integer: Length of LCS between x and y 64 | """ 65 | table = _lcs(x, y) 66 | n, m = len(x), len(y) 67 | return table[n, m] 68 | 69 | 70 | def _lcs(x, y): 71 | """ 72 | Computes the length of the longest common subsequence (lcs) between two 73 | strings. The implementation below uses a DP programming algorithm and runs 74 | in O(nm) time where n = len(x) and m = len(y). 75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 76 | 77 | Args: 78 | x: collection of words 79 | y: collection of words 80 | 81 | Returns: 82 | Table of dictionary of coord and len lcs 83 | """ 84 | n, m = len(x), len(y) 85 | table = dict() 86 | for i in range(n + 1): 87 | for j in range(m + 1): 88 | if i == 0 or j == 0: 89 | table[i, j] = 0 90 | elif x[i - 1] == y[j - 1]: 91 | table[i, j] = table[i - 1, j - 1] + 1 92 | else: 93 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 94 | return table 95 | 96 | 97 | def _recon_lcs(x, y): 98 | """ 99 | Returns the Longest Subsequence between x and y. 100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 101 | 102 | Args: 103 | x: sequence of words 104 | y: sequence of words 105 | 106 | Returns: 107 | sequence: LCS of x and y 108 | """ 109 | i, j = len(x), len(y) 110 | table = _lcs(x, y) 111 | 112 | def _recon(i, j): 113 | """private recon calculation""" 114 | if i == 0 or j == 0: 115 | return [] 116 | elif x[i - 1] == y[j - 1]: 117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 118 | elif table[i - 1, j] > table[i, j - 1]: 119 | return _recon(i - 1, j) 120 | else: 121 | return _recon(i, j - 1) 122 | 123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 124 | return recon_tuple 125 | 126 | 127 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 128 | """ 129 | Computes ROUGE-N of two text collections of sentences. 130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 131 | papers/rouge-working-note-v1.3.1.pdf 132 | 133 | Args: 134 | evaluated_sentences: The sentences that have been picked by the summarizer 135 | reference_sentences: The sentences from the referene set 136 | n: Size of ngram. Defaults to 2. 137 | 138 | Returns: 139 | A tuple (f1, precision, recall) for ROUGE-N 140 | 141 | Raises: 142 | ValueError: raises exception if a param has len <= 0 143 | """ 144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 145 | raise ValueError("Collections must contain at least 1 sentence.") 146 | 147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 148 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 149 | reference_count = len(reference_ngrams) 150 | evaluated_count = len(evaluated_ngrams) 151 | 152 | # Gets the overlapping ngrams between evaluated and reference 153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 154 | overlapping_count = len(overlapping_ngrams) 155 | 156 | # Handle edge case. This isn't mathematically correct, but it's good enough 157 | if evaluated_count == 0: 158 | precision = 0.0 159 | else: 160 | precision = overlapping_count / evaluated_count 161 | 162 | if reference_count == 0: 163 | recall = 0.0 164 | else: 165 | recall = overlapping_count / reference_count 166 | 167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 168 | 169 | # return overlapping_count / reference_count 170 | return f1_score, precision, recall 171 | 172 | 173 | def _f_p_r_lcs(llcs, m, n): 174 | """ 175 | Computes the LCS-based F-measure score 176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 177 | rouge-working-note-v1.3.1.pdf 178 | 179 | Args: 180 | llcs: Length of LCS 181 | m: number of words in reference summary 182 | n: number of words in candidate summary 183 | 184 | Returns: 185 | Float. LCS-based F-measure score 186 | """ 187 | r_lcs = llcs / m 188 | p_lcs = llcs / n 189 | beta = p_lcs / (r_lcs + 1e-12) 190 | num = (1 + (beta**2)) * r_lcs * p_lcs 191 | denom = r_lcs + ((beta**2) * p_lcs) 192 | f_lcs = num / (denom + 1e-12) 193 | return f_lcs, p_lcs, r_lcs 194 | 195 | 196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 197 | """ 198 | Computes ROUGE-L (sentence level) of two text collections of sentences. 199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 200 | rouge-working-note-v1.3.1.pdf 201 | 202 | Calculated according to: 203 | R_lcs = LCS(X,Y)/m 204 | P_lcs = LCS(X,Y)/n 205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 206 | 207 | where: 208 | X = reference summary 209 | Y = Candidate summary 210 | m = length of reference summary 211 | n = length of candidate summary 212 | 213 | Args: 214 | evaluated_sentences: The sentences that have been picked by the summarizer 215 | reference_sentences: The sentences from the referene set 216 | 217 | Returns: 218 | A float: F_lcs 219 | 220 | Raises: 221 | ValueError: raises exception if a param has len <= 0 222 | """ 223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 224 | raise ValueError("Collections must contain at least 1 sentence.") 225 | reference_words = _split_into_words(reference_sentences) 226 | evaluated_words = _split_into_words(evaluated_sentences) 227 | m = len(reference_words) 228 | n = len(evaluated_words) 229 | lcs = _len_lcs(evaluated_words, reference_words) 230 | return _f_p_r_lcs(lcs, m, n) 231 | 232 | 233 | def _union_lcs(evaluated_sentences, reference_sentence): 234 | """ 235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 236 | subsequence between reference sentence ri and candidate summary C. For example 237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 241 | LCS_u(r_i, C) = 4/5. 242 | 243 | Args: 244 | evaluated_sentences: The sentences that have been picked by the summarizer 245 | reference_sentence: One of the sentences in the reference summaries 246 | 247 | Returns: 248 | float: LCS_u(r_i, C) 249 | 250 | ValueError: 251 | Raises exception if a param has len <= 0 252 | """ 253 | if len(evaluated_sentences) <= 0: 254 | raise ValueError("Collections must contain at least 1 sentence.") 255 | 256 | lcs_union = set() 257 | reference_words = _split_into_words([reference_sentence]) 258 | combined_lcs_length = 0 259 | for eval_s in evaluated_sentences: 260 | evaluated_words = _split_into_words([eval_s]) 261 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 262 | combined_lcs_length += len(lcs) 263 | lcs_union = lcs_union.union(lcs) 264 | 265 | union_lcs_count = len(lcs_union) 266 | union_lcs_value = union_lcs_count / combined_lcs_length 267 | return union_lcs_value 268 | 269 | 270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 271 | """ 272 | Computes ROUGE-L (summary level) of two text collections of sentences. 273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 274 | rouge-working-note-v1.3.1.pdf 275 | 276 | Calculated according to: 277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 280 | 281 | where: 282 | SUM(i,u) = SUM from i through u 283 | u = number of sentences in reference summary 284 | C = Candidate summary made up of v sentences 285 | m = number of words in reference summary 286 | n = number of words in candidate summary 287 | 288 | Args: 289 | evaluated_sentences: The sentences that have been picked by the summarizer 290 | reference_sentence: One of the sentences in the reference summaries 291 | 292 | Returns: 293 | A float: F_lcs 294 | 295 | Raises: 296 | ValueError: raises exception if a param has len <= 0 297 | """ 298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 299 | raise ValueError("Collections must contain at least 1 sentence.") 300 | 301 | # total number of words in reference sentences 302 | m = len(_split_into_words(reference_sentences)) 303 | 304 | # total number of words in evaluated sentences 305 | n = len(_split_into_words(evaluated_sentences)) 306 | 307 | union_lcs_sum_across_all_references = 0 308 | for ref_s in reference_sentences: 309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 310 | ref_s) 311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 312 | 313 | 314 | def rouge(hypotheses, references): 315 | """Calculates average rouge scores for a list of hypotheses and 316 | references""" 317 | 318 | # Filter out hyps that are of 0 length 319 | # hyps_and_refs = zip(hypotheses, references) 320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 321 | # hypotheses, references = zip(*hyps_and_refs) 322 | 323 | # Calculate ROUGE-1 F1, precision, recall scores 324 | rouge_1 = [ 325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 326 | ] 327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 328 | 329 | # Calculate ROUGE-2 F1, precision, recall scores 330 | rouge_2 = [ 331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 332 | ] 333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 334 | 335 | # Calculate ROUGE-L F1, precision, recall scores 336 | rouge_l = [ 337 | rouge_l_sentence_level([hyp], [ref]) 338 | for hyp, ref in zip(hypotheses, references) 339 | ] 340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 341 | 342 | return { 343 | "rouge_1/f_score": rouge_1_f, 344 | "rouge_1/r_score": rouge_1_r, 345 | "rouge_1/p_score": rouge_1_p, 346 | "rouge_2/f_score": rouge_2_f, 347 | "rouge_2/r_score": rouge_2_r, 348 | "rouge_2/p_score": rouge_2_p, 349 | "rouge_l/f_score": rouge_l_f, 350 | "rouge_l/r_score": rouge_l_r, 351 | "rouge_l/p_score": rouge_l_p, 352 | } 353 | -------------------------------------------------------------------------------- /Pun_Generation/code/utils/iterator_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for iterator_utils.py""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.ops import lookup_ops 25 | 26 | from ..utils import iterator_utils 27 | 28 | 29 | class IteratorUtilsTest(tf.test.TestCase): 30 | 31 | def testGetIterator(self): 32 | tf.set_random_seed(1) 33 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 34 | tf.constant(["a", "b", "c", "eos", "sos"])) 35 | src_dataset = tf.data.Dataset.from_tensor_slices( 36 | tf.constant(["f e a g", "c c a", "d", "c a"])) 37 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 38 | tf.constant(["c c", "a b", "", "b c"])) 39 | hparams = tf.contrib.training.HParams( 40 | random_seed=3, 41 | num_buckets=5, 42 | eos="eos", 43 | sos="sos") 44 | batch_size = 2 45 | src_max_len = 3 46 | iterator = iterator_utils.get_iterator( 47 | src_dataset=src_dataset, 48 | tgt_dataset=tgt_dataset, 49 | src_vocab_table=src_vocab_table, 50 | tgt_vocab_table=tgt_vocab_table, 51 | batch_size=batch_size, 52 | sos=hparams.sos, 53 | eos=hparams.eos, 54 | random_seed=hparams.random_seed, 55 | num_buckets=hparams.num_buckets, 56 | src_max_len=src_max_len, 57 | reshuffle_each_iteration=False) 58 | table_initializer = tf.tables_initializer() 59 | source = iterator.source 60 | target_input = iterator.target_input 61 | target_output = iterator.target_output 62 | src_seq_len = iterator.source_sequence_length 63 | tgt_seq_len = iterator.target_sequence_length 64 | self.assertEqual([None, None], source.shape.as_list()) 65 | self.assertEqual([None, None], target_input.shape.as_list()) 66 | self.assertEqual([None, None], target_output.shape.as_list()) 67 | self.assertEqual([None], src_seq_len.shape.as_list()) 68 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 69 | with self.test_session() as sess: 70 | sess.run(table_initializer) 71 | sess.run(iterator.initializer) 72 | 73 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 74 | sess.run((source, src_seq_len, target_input, target_output, 75 | tgt_seq_len))) 76 | self.assertAllEqual( 77 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a 78 | [2, 0, 3]], # c a eos -- eos is padding 79 | source_v) 80 | self.assertAllEqual([3, 2], src_len_v) 81 | self.assertAllEqual( 82 | [[4, 2, 2], # sos c c 83 | [4, 1, 2]], # sos b c 84 | target_input_v) 85 | self.assertAllEqual( 86 | [[2, 2, 3], # c c eos 87 | [1, 2, 3]], # b c eos 88 | target_output_v) 89 | self.assertAllEqual([3, 3], tgt_len_v) 90 | 91 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 92 | sess.run((source, src_seq_len, target_input, target_output, 93 | tgt_seq_len))) 94 | self.assertAllEqual( 95 | [[2, 2, 0]], # c c a 96 | source_v) 97 | self.assertAllEqual([3], src_len_v) 98 | self.assertAllEqual( 99 | [[4, 0, 1]], # sos a b 100 | target_input_v) 101 | self.assertAllEqual( 102 | [[0, 1, 3]], # a b eos 103 | target_output_v) 104 | self.assertAllEqual([3], tgt_len_v) 105 | 106 | with self.assertRaisesOpError("End of sequence"): 107 | sess.run(source) 108 | 109 | def testGetIteratorWithShard(self): 110 | tf.set_random_seed(1) 111 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 112 | tf.constant(["a", "b", "c", "eos", "sos"])) 113 | src_dataset = tf.data.Dataset.from_tensor_slices( 114 | tf.constant(["c c a", "f e a g", "d", "c a"])) 115 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 116 | tf.constant(["a b", "c c", "", "b c"])) 117 | hparams = tf.contrib.training.HParams( 118 | random_seed=3, 119 | num_buckets=5, 120 | eos="eos", 121 | sos="sos") 122 | batch_size = 2 123 | src_max_len = 3 124 | iterator = iterator_utils.get_iterator( 125 | src_dataset=src_dataset, 126 | tgt_dataset=tgt_dataset, 127 | src_vocab_table=src_vocab_table, 128 | tgt_vocab_table=tgt_vocab_table, 129 | batch_size=batch_size, 130 | sos=hparams.sos, 131 | eos=hparams.eos, 132 | random_seed=hparams.random_seed, 133 | num_buckets=hparams.num_buckets, 134 | src_max_len=src_max_len, 135 | num_shards=2, 136 | shard_index=1, 137 | reshuffle_each_iteration=False) 138 | table_initializer = tf.tables_initializer() 139 | source = iterator.source 140 | target_input = iterator.target_input 141 | target_output = iterator.target_output 142 | src_seq_len = iterator.source_sequence_length 143 | tgt_seq_len = iterator.target_sequence_length 144 | self.assertEqual([None, None], source.shape.as_list()) 145 | self.assertEqual([None, None], target_input.shape.as_list()) 146 | self.assertEqual([None, None], target_output.shape.as_list()) 147 | self.assertEqual([None], src_seq_len.shape.as_list()) 148 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 149 | with self.test_session() as sess: 150 | sess.run(table_initializer) 151 | sess.run(iterator.initializer) 152 | 153 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 154 | sess.run((source, src_seq_len, target_input, target_output, 155 | tgt_seq_len))) 156 | self.assertAllEqual( 157 | [[-1, -1, 0], # "f" == unknown, "e" == unknown, a 158 | [2, 0, 3]], # c a eos -- eos is padding 159 | source_v) 160 | self.assertAllEqual([3, 2], src_len_v) 161 | self.assertAllEqual( 162 | [[4, 2, 2], # sos c c 163 | [4, 1, 2]], # sos b c 164 | target_input_v) 165 | self.assertAllEqual( 166 | [[2, 2, 3], # c c eos 167 | [1, 2, 3]], # b c eos 168 | target_output_v) 169 | self.assertAllEqual([3, 3], tgt_len_v) 170 | 171 | with self.assertRaisesOpError("End of sequence"): 172 | sess.run(source) 173 | 174 | def testGetIteratorWithSkipCount(self): 175 | tf.set_random_seed(1) 176 | tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( 177 | tf.constant(["a", "b", "c", "eos", "sos"])) 178 | src_dataset = tf.data.Dataset.from_tensor_slices( 179 | tf.constant(["c a", "c c a", "d", "f e a g"])) 180 | tgt_dataset = tf.data.Dataset.from_tensor_slices( 181 | tf.constant(["b c", "a b", "", "c c"])) 182 | hparams = tf.contrib.training.HParams( 183 | random_seed=3, 184 | num_buckets=5, 185 | eos="eos", 186 | sos="sos") 187 | batch_size = 2 188 | src_max_len = 3 189 | skip_count = tf.placeholder(shape=(), dtype=tf.int64) 190 | iterator = iterator_utils.get_iterator( 191 | src_dataset=src_dataset, 192 | tgt_dataset=tgt_dataset, 193 | src_vocab_table=src_vocab_table, 194 | tgt_vocab_table=tgt_vocab_table, 195 | batch_size=batch_size, 196 | sos=hparams.sos, 197 | eos=hparams.eos, 198 | random_seed=hparams.random_seed, 199 | num_buckets=hparams.num_buckets, 200 | src_max_len=src_max_len, 201 | skip_count=skip_count, 202 | reshuffle_each_iteration=False) 203 | table_initializer = tf.tables_initializer() 204 | source = iterator.source 205 | target_input = iterator.target_input 206 | target_output = iterator.target_output 207 | src_seq_len = iterator.source_sequence_length 208 | tgt_seq_len = iterator.target_sequence_length 209 | self.assertEqual([None, None], source.shape.as_list()) 210 | self.assertEqual([None, None], target_input.shape.as_list()) 211 | self.assertEqual([None, None], target_output.shape.as_list()) 212 | self.assertEqual([None], src_seq_len.shape.as_list()) 213 | self.assertEqual([None], tgt_seq_len.shape.as_list()) 214 | with self.test_session() as sess: 215 | sess.run(table_initializer) 216 | sess.run(iterator.initializer, feed_dict={skip_count: 3}) 217 | 218 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 219 | sess.run((source, src_seq_len, target_input, target_output, 220 | tgt_seq_len))) 221 | self.assertAllEqual( 222 | [[-1, -1, 0]], # "f" == unknown, "e" == unknown, a 223 | source_v) 224 | self.assertAllEqual([3], src_len_v) 225 | self.assertAllEqual( 226 | [[4, 2, 2]], # sos c c 227 | target_input_v) 228 | self.assertAllEqual( 229 | [[2, 2, 3]], # c c eos 230 | target_output_v) 231 | self.assertAllEqual([3], tgt_len_v) 232 | 233 | with self.assertRaisesOpError("End of sequence"): 234 | sess.run(source) 235 | 236 | # Re-init iterator with skip_count=0. 237 | sess.run(iterator.initializer, feed_dict={skip_count: 0}) 238 | 239 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 240 | sess.run((source, src_seq_len, target_input, target_output, 241 | tgt_seq_len))) 242 | self.assertAllEqual( 243 | [[2, 0, 3], # c a eos -- eos is padding 244 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 245 | source_v) 246 | self.assertAllEqual([2, 3], src_len_v) 247 | self.assertAllEqual( 248 | [[4, 1, 2], # sos b c 249 | [4, 2, 2]], # sos c c 250 | target_input_v) 251 | self.assertAllEqual( 252 | [[1, 2, 3], # b c eos 253 | [2, 2, 3]], # c c eos 254 | target_output_v) 255 | self.assertAllEqual([3, 3], tgt_len_v) 256 | 257 | (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( 258 | sess.run((source, src_seq_len, target_input, target_output, 259 | tgt_seq_len))) 260 | self.assertAllEqual( 261 | [[2, 2, 0]], # c c a 262 | source_v) 263 | self.assertAllEqual([3], src_len_v) 264 | self.assertAllEqual( 265 | [[4, 0, 1]], # sos a b 266 | target_input_v) 267 | self.assertAllEqual( 268 | [[0, 1, 3]], # a b eos 269 | target_output_v) 270 | self.assertAllEqual([3], tgt_len_v) 271 | 272 | with self.assertRaisesOpError("End of sequence"): 273 | sess.run(source) 274 | 275 | 276 | def testGetInferIterator(self): 277 | src_vocab_table = lookup_ops.index_table_from_tensor( 278 | tf.constant(["a", "b", "c", "eos", "sos"])) 279 | src_dataset = tf.data.Dataset.from_tensor_slices( 280 | tf.constant(["c c a", "c a", "d", "f e a g"])) 281 | hparams = tf.contrib.training.HParams( 282 | random_seed=3, 283 | eos="eos", 284 | sos="sos") 285 | batch_size = 2 286 | src_max_len = 3 287 | iterator = iterator_utils.get_infer_iterator( 288 | src_dataset=src_dataset, 289 | src_vocab_table=src_vocab_table, 290 | batch_size=batch_size, 291 | eos=hparams.eos, 292 | src_max_len=src_max_len) 293 | table_initializer = tf.tables_initializer() 294 | source = iterator.source 295 | seq_len = iterator.source_sequence_length 296 | self.assertEqual([None, None], source.shape.as_list()) 297 | self.assertEqual([None], seq_len.shape.as_list()) 298 | with self.test_session() as sess: 299 | sess.run(table_initializer) 300 | sess.run(iterator.initializer) 301 | 302 | (source_v, seq_len_v) = sess.run((source, seq_len)) 303 | self.assertAllEqual( 304 | [[2, 2, 0], # c c a 305 | [2, 0, 3]], # c a eos 306 | source_v) 307 | self.assertAllEqual([3, 2], seq_len_v) 308 | 309 | (source_v, seq_len_v) = sess.run((source, seq_len)) 310 | self.assertAllEqual( 311 | [[-1, 3, 3], # "d" == unknown, eos eos 312 | [-1, -1, 0]], # "f" == unknown, "e" == unknown, a 313 | source_v) 314 | self.assertAllEqual([1, 3], seq_len_v) 315 | 316 | with self.assertRaisesOpError("End of sequence"): 317 | sess.run((source, seq_len)) 318 | 319 | 320 | if __name__ == "__main__": 321 | tf.test.main() 322 | --------------------------------------------------------------------------------