├── LICENSE ├── README.md ├── similarity_estimator ├── networks.py ├── options.py ├── sick_extender.py ├── sim_util.py ├── testing.py └── training.py └── utils ├── data_server.py ├── init_and_storage.py └── parameter_initialization.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 demelin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Siamese Sentence Similarity Classifier for pyTorch 2 | 3 | ## Overview 4 | This repository contains a re-implementation of Mueller's et al., ["Siamese Recurrent Architectures for Learning Sentence Similarity."](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/viewFile/12195/12023) (AAAI, 2016). For the technical details, please refer to the publication. 5 | 6 | ## Training 7 | To train the classifier, execute `similarity_estimator/training.py` after modifying the hard-coded values (such as the training corpus filename) to your own specifications. 8 | 9 | ## Evaluation 10 | To evaluate the performance of a trained model, run the `similarity_estimator/testing.py` script. Again, adjust user-specific values as needed within the script itself. 11 | 12 | ## Note 13 | This re-implementation was completed with personal use in mind and is, as such, not actively maintained. You are, however, very welcome to extend or adjust it according to your own needs, should you find it useful. Happy coding :) . 14 | -------------------------------------------------------------------------------- /similarity_estimator/networks.py: -------------------------------------------------------------------------------- 1 | """ An implementation of the siamese RNN for sentence similarity classification outlined in Mueller et al., 2 | "Siamese Recurrent Architectures for Learning Sentence Similarity." """ 3 | 4 | import os 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | from torch.nn.utils import clip_grad_norm 11 | 12 | # Inference with SVR 13 | import pickle 14 | 15 | from utils.parameter_initialization import xavier_normal 16 | 17 | 18 | class LSTMEncoder(nn.Module): 19 | """ Implements the network type integrated within the Siamese RNN architecture. """ 20 | def __init__(self, vocab_size, opt, is_train=False): 21 | super(LSTMEncoder, self).__init__() 22 | self.vocab_size = vocab_size 23 | self.opt = opt 24 | self.name = 'sim_encoder' 25 | 26 | # Layers 27 | self.embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.opt.embedding_dims, 28 | padding_idx=0, max_norm=None, scale_grad_by_freq=False, sparse=False) 29 | self.lstm_rnn = nn.LSTM(input_size=self.opt.embedding_dims, hidden_size=self.opt.hidden_dims, num_layers=1) 30 | 31 | def initialize_hidden_plus_cell(self, batch_size): 32 | """ Re-initializes the hidden state, cell state, and the forget gate bias of the network. """ 33 | zero_hidden = Variable(torch.randn(1, batch_size, self.opt.hidden_dims)) 34 | zero_cell = Variable(torch.randn(1, batch_size, self.opt.hidden_dims)) 35 | return zero_hidden, zero_cell 36 | 37 | def forward(self, batch_size, input_data, hidden, cell): 38 | """ Performs a forward pass through the network. """ 39 | output = self.embedding_table(input_data).view(1, batch_size, -1) 40 | for _ in range(self.opt.num_layers): 41 | output, (hidden, cell) = self.lstm_rnn(output, (hidden, cell)) 42 | return output, hidden, cell 43 | 44 | 45 | class SiameseClassifier(nn.Module): 46 | """ Sentence similarity estimator implementing a siamese arcitecture. Uses pretrained word2vec embeddings. 47 | Different to the paper, the weights are untied, to avoid exploding/ vanishing gradients. """ 48 | def __init__(self, vocab_size, opt, pretrained_embeddings=None, is_train=False): 49 | super(SiameseClassifier, self).__init__() 50 | self.opt = opt 51 | # Initialize constituent network 52 | self.encoder_a = self.encoder_b = LSTMEncoder(vocab_size, self.opt, is_train) 53 | # Initialize pre-trained embeddings, if given 54 | if pretrained_embeddings is not None: 55 | self.encoder_a.embedding_table.weight.data.copy_(pretrained_embeddings) 56 | # Initialize network parameters 57 | self.initialize_parameters() 58 | # Declare loss function 59 | self.loss_function = nn.MSELoss() 60 | # Initialize network optimizers 61 | self.optimizer_a = optim.Adam(self.encoder_a.parameters(), lr=self.opt.learning_rate, 62 | betas=(self.opt.beta_1, 0.999)) 63 | self.optimizer_b = optim.Adam(self.encoder_a.parameters(), lr=self.opt.learning_rate, 64 | betas=(self.opt.beta_1, 0.999)) 65 | 66 | def forward(self): 67 | """ Performs a single forward pass through the siamese architecture. """ 68 | # Checkpoint the encoder state 69 | state_dict = self.encoder_a.state_dict() 70 | 71 | # Obtain the input length (each batch consists of padded sentences) 72 | input_length = self.batch_a.size(0) 73 | 74 | # Obtain sentence encodings from each encoder 75 | hidden_a, cell_a = self.encoder_a.initialize_hidden_plus_cell(self.batch_size) 76 | for t_i in range(input_length): 77 | output_a, hidden_a, cell_a = self.encoder_a(self.batch_size, self.batch_a[t_i, :], hidden_a, cell_a) 78 | 79 | # Restore checkpoint to establish weight-sharing 80 | self.encoder_b.load_state_dict(state_dict) 81 | hidden_b, cell_b = self.encoder_b.initialize_hidden_plus_cell(self.batch_size) 82 | for t_j in range(input_length): 83 | output_b, hidden_b, cell_b = self.encoder_b(self.batch_size, self.batch_b[t_j, :], hidden_b, cell_b) 84 | 85 | # Format sentence encodings as 2D tensors 86 | self.encoding_a = hidden_a.squeeze() 87 | self.encoding_b = hidden_b.squeeze() 88 | 89 | # Obtain similarity score predictions by calculating the Manhattan distance between sentence encodings 90 | if self.batch_size == 1: 91 | self.prediction = torch.exp(-torch.norm((self.encoding_a - self.encoding_b), 1)) 92 | else: 93 | self.prediction = torch.exp(-torch.norm((self.encoding_a - self.encoding_b), 1, 1)) 94 | 95 | def get_loss(self): 96 | """ Calculates the MSE loss between the network predictions and the ground truth. """ 97 | # Loss is the L1 norm of the difference between the obtained sentence encodings 98 | self.loss = self.loss_function(self.prediction, self.labels) 99 | 100 | def load_pretrained_parameters(self): 101 | """ Loads the parameters learned during the pre-training on the SemEval data. """ 102 | pretrained_state_dict_path = os.path.join(self.opt.pretraining_dir, self.opt.pretrained_state_dict) 103 | self.encoder_a.load_state_dict(torch.load(pretrained_state_dict_path)) 104 | print('Pretrained parameters have been successfully loaded into the encoder networks.') 105 | 106 | def initialize_parameters(self): 107 | """ Initializes network parameters. """ 108 | state_dict = self.encoder_a.state_dict() 109 | for key in state_dict.keys(): 110 | if '.weight' in key: 111 | state_dict[key] = xavier_normal(state_dict[key]) 112 | if '.bias' in key: 113 | bias_length = state_dict[key].size()[0] 114 | start, end = bias_length // 4, bias_length // 2 115 | state_dict[key][start:end].fill_(2.5) 116 | self.encoder_a.load_state_dict(state_dict) 117 | 118 | def train_step(self, train_batch_a, train_batch_b, train_labels): 119 | """ Optimizes the parameters of the active networks, i.e. performs a single training step. """ 120 | # Get batches 121 | self.batch_a = train_batch_a 122 | self.batch_b = train_batch_b 123 | self.labels = train_labels 124 | 125 | # Get batch_size for current batch 126 | self.batch_size = self.batch_a.size(1) 127 | 128 | # Get gradients 129 | self.forward() 130 | self.encoder_a.zero_grad() # encoder_a == encoder_b 131 | self.get_loss() 132 | self.loss.backward() 133 | 134 | # Clip gradients 135 | clip_grad_norm(self.encoder_a.parameters(), self.opt.clip_value) 136 | 137 | # Optimize 138 | self.optimizer_a.step() 139 | 140 | def test_step(self, test_batch_a, test_batch_b, test_labels): 141 | """ Performs a single test step. """ 142 | # Get batches 143 | self.batch_a = test_batch_a 144 | self.batch_b = test_batch_b 145 | self.labels = test_labels 146 | 147 | # Get batch_size for current batch 148 | self.batch_size = self.batch_a.size(1) 149 | 150 | svr_path = os.path.join(self.opt.save_dir, 'sim_svr.pkl') 151 | if os.path.exists(svr_path): 152 | # Correct predictions via trained SVR 153 | with open(svr_path, 'rb') as f: 154 | sim_svr = pickle.load(f) 155 | self.forward() 156 | self.prediction = Variable(torch.FloatTensor(sim_svr.predict(self.prediction.view(-1, 1).data.numpy()))) 157 | 158 | else: 159 | self.forward() 160 | 161 | self.get_loss() 162 | -------------------------------------------------------------------------------- /similarity_estimator/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class TestingOptions(object): 5 | """ Default options for the siamese similarity estimator network. Use for quick evaluation on home machine. """ 6 | 7 | def __init__(self): 8 | # Data 9 | self.max_sent_len = None 10 | self.pad = True 11 | self.freq_bound = 3 12 | self.shuffle = True 13 | self.sent_select = 'truncate' 14 | self.lower = False 15 | self.num_buckets = 3 16 | 17 | # Network 18 | self.embedding_dims = 300 19 | self.hidden_dims = 50 20 | self.num_layers = 1 21 | self.train_batch_size = 16 22 | self.test_batch_size = 1 23 | self.clip_value = 0.25 24 | self.learning_rate = 0.0001 25 | self.beta_1 = 0.5 26 | 27 | self.pre_training = True 28 | self.num_epochs = 100 29 | 30 | self.start_early_stopping = 2 31 | self.patience = 10 32 | self.start_annealing = 4 33 | self.annealing_factor = 0.75 34 | 35 | # Training 36 | self.report_freq = 1 37 | self.save_freq = 4 38 | self.home_dir = os.path.join(os.path.dirname(__file__), '..') 39 | self.data_dir = os.path.join(self.home_dir, 'data') 40 | self.save_dir = os.path.join(self.home_dir, 'similarity_estimator/models') 41 | self.pretraining_dir = os.path.join(self.save_dir, 'pretraining') 42 | 43 | # Testing 44 | self.num_test_samples = 10 45 | 46 | 47 | class ClusterOptions(object): 48 | """ Default options for the siamese similarity estimator network. Use for deployment on cluster. """ 49 | 50 | def __init__(self): 51 | # Data 52 | self.max_sent_len = None 53 | self.pad = True 54 | self.freq_bound = 3 55 | self.shuffle = True 56 | self.sent_select = 'truncate' 57 | self.lower = False 58 | self.num_buckets = 8 59 | 60 | # Network 61 | self.embedding_dims = 300 62 | self.hidden_dims = 50 63 | self.num_layers = 1 64 | self.train_batch_size = 16 65 | self.test_batch_size = 1 66 | self.clip_value = 0.25 # Following pyTorch LM example's default value 67 | self.learning_rate = 0.0001 68 | self.beta_1 = 0.5 69 | 70 | # Training 71 | self.pre_training = True 72 | self.num_epochs = 1000 73 | 74 | # Mostly arbitrary values from here on; for a more informed approach, consult Early Stopping paper 75 | self.start_early_stopping = self.num_epochs // 20 76 | self.patience = self.num_epochs // 50 77 | self.start_annealing = self.num_epochs // 100 78 | self.annealing_factor = 0.75 79 | 80 | self.report_freq = 100 81 | self.save_freq = 20 82 | self.home_dir = os.path.join(os.path.dirname(__file__), '..') 83 | self.data_dir = os.path.join(self.home_dir, 'data') 84 | self.save_dir = os.path.join(self.home_dir, 'similarity_estimator/models') 85 | self.pretraining_dir = os.path.join(self.save_dir, 'pretraining') 86 | 87 | # Testing 88 | self.num_test_samples = 10 89 | -------------------------------------------------------------------------------- /similarity_estimator/sick_extender.py: -------------------------------------------------------------------------------- 1 | """ Generates synthetic data from corpora consisting of individual sentences, such as the SICK corpus by replacing 2 | random words in each sentence with one of their synonyms found in WordNet. Implemented extension strategy owes to: 3 | [1] Mueller et al., "Siamese Recurrent Architectures for Learning Sentence Similarity." 4 | [2] Zhang et al., "Character-level convolutional networks for text classification." 5 | The extensions are, as expected, reasonably noisy. 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import nltk 13 | from nltk import word_tokenize 14 | from nltk.corpus import wordnet 15 | 16 | from pywsd.lesk import simple_lesk, cosine_lesk, adapted_lesk 17 | import kenlm 18 | 19 | 20 | class SickExtender(object): 21 | """ Extends the SICK sentence similarity corpus with synthetic data generated by substituting synonyms for 22 | random content words. Synonyms are obtained via WordNet's synset.lemmas() lookup following the sense disambiguation 23 | of the word to be replaced which, in turn, relies on the specified Lesk algorithm - simple, cosine, or adapted. 24 | Refer to the pywsd documentation for further information. """ 25 | def __init__(self, sick_path, target_directory, lm_path=None, wsd_algorithm='cosine', sampling_parameter=0.5, 26 | min_substitutions=2, num_candidates=5, concatenate_corpora=True): 27 | self.sick_path = sick_path 28 | self.target_directory = target_directory 29 | self.lm_path = lm_path 30 | self.wsd_algorithm = wsd_algorithm 31 | self.sampling_parameter = sampling_parameter 32 | self.min_substitutions = min_substitutions 33 | self.num_candidates = num_candidates 34 | self.concatenate_corpora = concatenate_corpora 35 | self.filtered_path = os.path.join(self.target_directory, 'filtered_sick.txt') 36 | self.noscore_path = os.path.join(self.target_directory, 'noscore_sick.txt') 37 | # Filter the original SICK corpus to match the expected format, and create file for LM training 38 | if not os.path.exists(self.filtered_path) or not os.path.exists(self.noscore_path): 39 | self.filter_sick() 40 | if self.lm_path is None: 41 | raise ValueError('No language model provided! Use the noscore_sick corpus to train an .klm LM, first.') 42 | else: 43 | self.language_model = kenlm.LanguageModel(self.lm_path) 44 | 45 | def create_extension(self): 46 | """ Replaces some words within each line of the given file with their WordNet synonyms. Replacement 47 | limited to noun, verb, adj, and adv, as those are the POS tags utilized by WordNet.""" 48 | # Track the proportion of the corpus already processed 49 | counter = 0 50 | # Create path to the SICK extension corpus 51 | if self.concatenate_corpora: 52 | target_path = os.path.join(self.target_directory, 'extended_sick.txt') 53 | else: 54 | target_path = os.path.join(self.target_directory, 'sick_extension.txt') 55 | # Generate paraphrases via thesaurus-based replacement 56 | print('Commencing with the creation of the synthetic SICK examples.') 57 | with open(self.filtered_path, 'r') as rf: 58 | with open(target_path, 'w') as wf: 59 | for line in rf: 60 | # Get tokens and POS tags, i.e. sentences == [sent1, sent2] 61 | sentences, sim_score = self.line_prep(line) 62 | new_line = list() 63 | for sentence in sentences: 64 | # Store tokens for subsequent reconstruction 65 | tokens = sentence[1] 66 | # Get the most likely synset for each token 67 | disambiguation = self.disambiguate_synset(sentence) 68 | # Replace random words with random synonyms 69 | candidate_list = self.replace_with_synonyms(disambiguation) 70 | if candidate_list is None: 71 | continue 72 | paraphrase = self.pick_candidate(tokens, candidate_list) 73 | new_line.append(paraphrase) 74 | # If nothing could be replaced in either sentence, skip the sentence pair 75 | if len(new_line) < 2: 76 | continue 77 | # Add header 78 | # wf.write('sentence_A\tsentence_B\trelatedness_score') 79 | if self.concatenate_corpora: 80 | wf.write(line) 81 | wf.write(new_line[0] + '\t' + new_line[1] + '\t' + sim_score) 82 | else: 83 | wf.write(new_line[0] + '\t' + new_line[1] + '\t' + sim_score) 84 | 85 | # Basic bookkeeping 86 | counter += 1 87 | if counter % 50 == 0 and counter != 0: 88 | print('Current progress: Line %d.' % counter) 89 | 90 | # For quick testing 91 | # if counter % 50 == 0 and counter != 0: 92 | # break 93 | 94 | print('The extension sentences for the SICK corpus has been successfully generated.\n' 95 | 'It can be found under %s.\n' 96 | 'Total amount of new sentence pairs: %d.' % (target_path, counter)) 97 | 98 | def filter_sick(self): 99 | """ Processes the original S.I.C.K. corpus into a format where each line contains the two compared sentences 100 | followed by their relatedness score. """ 101 | # Filter the SICK dataset for sentences and relatedness score only 102 | df_origin = pd.read_table(self.sick_path) 103 | df_classify = df_origin.loc[:, ['sentence_A', 'sentence_B', 'relatedness_score']] 104 | # Scale relatedness score to to lie ∈ [0, 1] for training of the classifier 105 | df_classify['relatedness_score'] = df_classify['relatedness_score'].apply( 106 | lambda x: "{:.4f}".format(float(x)/5.0)) 107 | 108 | df_noscore = df_origin.loc[:, ['sentence_A', 'sentence_B']] 109 | df_noscore = df_noscore.stack() 110 | 111 | # Write the filtered set to a .csv file 112 | df_classify.to_csv(self.filtered_path, sep='\t', index=False, header=False) 113 | print('Filtered corpus saved to %s.' % self.filtered_path) 114 | 115 | # Write a score-free set to a .csv file to be used in the training of the KN language model 116 | df_noscore.to_csv(self.noscore_path, index=False, header=False) 117 | print('Filtered corpus saved to %s.' % self.noscore_path) 118 | 119 | def line_prep(self, line): 120 | """ Tokenizes and POS-tags a line from the SICK corpus to be compatible with WordNet synset lookup. """ 121 | # Split line into sentences + score 122 | s1, s2, sim_score = line.split('\t') 123 | # Tokenize 124 | s1_tokens = word_tokenize(s1) 125 | s2_tokens = word_tokenize(s2) 126 | # Assign part of speech tags 127 | s1_penn_pos = nltk.pos_tag(s1_tokens) 128 | s2_penn_pos = nltk.pos_tag(s2_tokens) 129 | # Convert to WordNet POS tags and store word position in sentence for replacement 130 | # Each tuple contains (word, WordNet_POS_tag, position) 131 | s1_wn_pos = list() 132 | s2_wn_pos = list() 133 | for idx, item in enumerate(s1_penn_pos): 134 | if self.get_wordnet_pos(item[1]) != 'OTHER': 135 | s1_wn_pos.append((item[0], self.get_wordnet_pos(item[1]), s1_penn_pos.index(item))) 136 | for idx, item in enumerate(s2_penn_pos): 137 | if self.get_wordnet_pos(item[1]) != 'OTHER': 138 | s2_wn_pos.append((item[0], self.get_wordnet_pos(item[1]), s2_penn_pos.index(item))) 139 | 140 | # Each tuple contains (word, WordNet_POS_tag, position); Source sentence provided for use in disambiguation 141 | return [(s1_wn_pos, s1_tokens), (s2_wn_pos, s2_tokens)], sim_score 142 | 143 | def disambiguate_synset(self, sentence_plus_lemmas): 144 | """ Picks the most likely synset for a lemma provided the context sentence and target word. Utilizes 145 | the 'Cosine Lesk' algorithm provided by pywds. """ 146 | # Select the disambiguation algorithm 147 | if self.wsd_algorithm == 'simple': 148 | wsd_function = simple_lesk 149 | elif self.wsd_algorithm == 'cosine': 150 | wsd_function = cosine_lesk 151 | elif self.wsd_algorithm == 'adapted': 152 | wsd_function = adapted_lesk 153 | else: 154 | raise ValueError('Please specify the word sense disambiguation algorithm:\n ' 155 | '\'simple\' for Simple Lesk\n' 156 | '\'cosine\' for Cosine Lesk\n' 157 | '\'adapted\' for Adapted/Extended Lesk') 158 | 159 | lemmas, context = sentence_plus_lemmas 160 | context = ' '.join(context) 161 | disambiguated = list() 162 | for lemma in lemmas: 163 | try: 164 | selection = wsd_function(context, lemma[0], pos=lemma[1]) 165 | # For simple_lesk disambiguation algorithm, in case no synsets can be found 166 | except IndexError: 167 | selection = None 168 | disambiguated.append((lemma[0], selection, lemma[2])) 169 | return disambiguated 170 | 171 | def replace_with_synonyms(self, disambiguated_lemmas): 172 | """ Calculates the distance between a lemma and each of its synonyms and orders them in a list by increasing 173 | distance. Uses the """ 174 | all_synonyms = list() 175 | # Obtain WordNet synonyms for each lemma in the sentence list 176 | for idx, lemma in enumerate(disambiguated_lemmas): 177 | if lemma[1] is not None: 178 | if len(lemma[1].lemma_names()) > 1: 179 | synonyms_per_word = ([' '.join(s.split('_')) for s in lemma[1].lemma_names()], idx) 180 | all_synonyms.append(synonyms_per_word) 181 | 182 | # If the sentence cannot be modified, skip it 183 | if len(all_synonyms) == 0: 184 | return None 185 | 186 | # Model a geometric distribution with parameter p, following Zhang, Zhao, and LeCun (2015) 187 | lower_bound = max(min(self.min_substitutions, len(all_synonyms)), 1) 188 | distribution = {i: self.sampling_parameter ** i for i in range(lower_bound, len(all_synonyms) + 1)} 189 | sampling_array = list() 190 | position = 0 191 | for key in distribution.keys(): 192 | occurrences = int(np.round(distribution[key] * 1000)) 193 | while occurrences != 0: 194 | sampling_array.append(key) 195 | position += 1 196 | occurrences -= 1 197 | 198 | # Sample n substitutions 199 | outputs = list() 200 | no_subs = [(l[0], l[2]) for l in disambiguated_lemmas] 201 | for _ in range(self.num_candidates): 202 | syn_list = all_synonyms[:] 203 | candidate = no_subs[:] 204 | # Randomly pick the amount of word to replace with synonyms 205 | pick = np.random.randint(0, len(sampling_array)) 206 | to_replace = sampling_array[pick] 207 | # Perform replacement 208 | for __ in range(to_replace): 209 | # Randomly pick the word to be replaced 210 | j = np.random.randint(0, len(syn_list)) 211 | # Randomly pick the synonym to replace the word with 212 | k = np.random.randint(0, len(syn_list[j][0])) 213 | candidate[syn_list[j][1]] = (syn_list[j][0][k], disambiguated_lemmas[syn_list[j][1]][2]) 214 | # Remove the sampled synonym set 215 | del(syn_list[j]) 216 | outputs.append(candidate) 217 | return outputs 218 | 219 | def pick_candidate(self, tokens, candidate_list): 220 | """ Picks the most probable paraprase candidate according to the provided language model. """ 221 | best_paraphrase = None 222 | best_nll = 0 223 | 224 | # Reconstruct and rate paraphrases 225 | for candidate in candidate_list: 226 | for replacement in candidate: 227 | tokens[replacement[1]] = replacement[0] 228 | paraphrase = ' '.join(tokens) 229 | score = self.language_model.score(paraphrase) 230 | # Keep the most probable one 231 | if abs(score) > best_nll: 232 | best_nll = score 233 | best_paraphrase = paraphrase 234 | 235 | return best_paraphrase 236 | 237 | @staticmethod 238 | def get_wordnet_pos(treebank_tag): 239 | """ Converts a Penn Tree-Bank part of speech tag into a corresponding WordNet-friendly tag. 240 | Borrowed from: http://stackoverflow.com/questions/15586721/wordnet-lemmatization-and-pos-tagging-in-python. """ 241 | if treebank_tag.startswith('J') or treebank_tag.startswith('A'): 242 | return wordnet.ADJ 243 | elif treebank_tag.startswith('V'): 244 | return wordnet.VERB 245 | elif treebank_tag.startswith('N'): 246 | return wordnet.NOUN 247 | elif treebank_tag.startswith('R'): 248 | return wordnet.ADV 249 | else: 250 | return 'OTHER' 251 | -------------------------------------------------------------------------------- /similarity_estimator/sim_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | class Indexer(object): 6 | """ Translates words to their respective indices and vice versa. """ 7 | 8 | def __init__(self, name): 9 | self.name = name 10 | self.word_to_index = dict() 11 | self.word_to_count = dict() 12 | # Specify start-and-end-of-sentence tokens 13 | self.index_to_word = {0: '', 1: ''} 14 | self.n_words = 2 15 | 16 | self.target_len = None 17 | 18 | def add_sentence(self, sentence): 19 | """ Adds sentence contents to index dict. """ 20 | for word in sentence.split(): 21 | self.add_word(word) 22 | 23 | def add_word(self, word): 24 | """ Adds words to index dict. """ 25 | if word not in self.word_to_index: 26 | self.word_to_index[word] = self.n_words 27 | self.index_to_word[self.n_words] = word 28 | self.word_to_count[word] = 1 29 | self.n_words += 1 30 | else: 31 | self.word_to_count[word] += 1 32 | 33 | def set_target_len(self, value): 34 | self.target_len = value 35 | 36 | 37 | def perform_bucketing(opt, labeled_pair_list): 38 | """ Groups the provided sentence pairs into the specified number of buckets of similar size based on the length of 39 | their longest member. """ 40 | # Obtain sentence lengths 41 | sentence_pair_lens = [(len(pair[0].split()), len(pair[1].split())) for pair in labeled_pair_list[0]] 42 | 43 | # Calculate bucket size 44 | buckets = [[0, 0] for _ in range(opt.num_buckets)] 45 | avg_bucket = len(labeled_pair_list[0]) // opt.num_buckets 46 | max_lens = [max(pair[0], pair[1]) for pair in sentence_pair_lens] 47 | len_counts = [(sent_len, max_lens.count(sent_len)) for sent_len in set(max_lens)] 48 | len_counts.sort(key=lambda x: x[0]) 49 | 50 | bucket_pointer = 0 51 | len_pointer = 0 52 | 53 | while bucket_pointer < opt.num_buckets and len_pointer < len(len_counts): 54 | target_bucket = buckets[bucket_pointer] 55 | # Set lower limit on the bucket's lengths 56 | target_bucket[0] = len_counts[len_pointer][0] 57 | bucket_load = 0 58 | while True: 59 | try: 60 | len_count_pair = len_counts[len_pointer] 61 | deficit = avg_bucket - bucket_load 62 | surplus = (bucket_load + len_count_pair[1]) - avg_bucket 63 | if deficit >= surplus or bucket_pointer == opt.num_buckets - 1: 64 | bucket_load += len_count_pair[1] 65 | # Update upper limit on the bucket's lengths 66 | target_bucket[1] = len_count_pair[0] 67 | len_pointer += 1 68 | else: 69 | bucket_pointer += 1 70 | break 71 | except IndexError: 72 | break 73 | 74 | # Populate buckets 75 | bucketed = [([], []) for _ in range(opt.num_buckets)] 76 | for k in range(len(labeled_pair_list[0])): 77 | pair_len = max(sentence_pair_lens[k][0], sentence_pair_lens[k][1]) 78 | for l in range(len(buckets)): 79 | if buckets[l][0] <= pair_len <= buckets[l][1]: 80 | bucketed[l][0].append(labeled_pair_list[0][k]) 81 | bucketed[l][1].append(labeled_pair_list[1][k]) 82 | return buckets, bucketed 83 | 84 | 85 | def load_similarity_data(opt, corpus_location, corpus_name): 86 | """ Converts the extended SICK/ combined STM corpus into a list of tuples of the form (sent_a, sent_b, sim_class), 87 | used to train the content similarity estimator used within the tGAN model. """ 88 | # Read in the corpus 89 | df_sim = pd.read_table(corpus_location, header=None, names=['sentence_A', 'sentence_B', 'relatedness_score'], 90 | skip_blank_lines=True) 91 | 92 | # Generate corpus list of sentences and labels, and the collections of sentences used for the word to index mapping 93 | sim_data = [[], []] 94 | sim_sents = list() 95 | # Track sentence lengths for max and mean length calculations 96 | sent_lens = list() 97 | for i in range(len(df_sim['relatedness_score'])): 98 | sent_a = df_sim.iloc[i, 0].strip() 99 | sent_b = df_sim.iloc[i, 1].strip() 100 | label = "{:.4f}".format(float(df_sim.iloc[i, 2])) 101 | 102 | # Assemble a list of tuples containing the compared sentences, while tracking the maximum observed length 103 | sim_data[0].append((sent_a, sent_b)) 104 | sim_data[1].append(label) 105 | sim_sents += [sent_a, sent_b] 106 | sent_lens += [len(sent_a.split()), len(sent_b.split())] 107 | 108 | # Filter corpus according to specified sentence length parameters 109 | filtered = [[], []] 110 | filtered_sents = list() 111 | filtered_lens = list() 112 | 113 | # Sent filtering method to truncation by default (in case of anomalous input) 114 | if opt.sent_select == 'drop' or opt.sent_select == 'truncate' or opt.sent_select is None: 115 | sent_select = opt.sent_select 116 | else: 117 | sent_select = 'truncate' 118 | 119 | # Set filtering size to mean_len + (max_len - mean_len) // 2 by default 120 | observed_max_len = max(sent_lens) 121 | if opt.max_sent_len: 122 | target_len = opt.max_sent_len 123 | elif opt.sent_select is None: 124 | target_len = observed_max_len 125 | else: 126 | observed_mean_len = int(np.round(np.mean(sent_lens))) 127 | target_len = observed_mean_len + (observed_max_len - observed_mean_len) // 2 128 | 129 | for i in range(len(sim_data[0])): 130 | pair = sim_data[0][i] 131 | if len(pair[0].split()) > target_len or len(pair[1].split()) > target_len: 132 | if sent_select == 'drop': 133 | continue 134 | elif sent_select is None: 135 | pass 136 | else: 137 | pair_0 = ' '.join(pair[0].split()[:target_len]) 138 | pair_1 = ' '.join(pair[1].split()[:target_len]) 139 | pair = (pair_0, pair_1) 140 | 141 | filtered[0].append(pair) 142 | filtered[1].append(sim_data[1][i]) 143 | filtered_sents += [pair[0], pair[1]] 144 | filtered_lens.append((len(pair[0]), len(pair[1]))) 145 | 146 | # Generate SICK index dictionary and a list of pre-processed 147 | sim_vocab = Indexer(corpus_name) 148 | sim_vocab.set_target_len(target_len) 149 | 150 | print('Assembling index dictionary ...') 151 | for i in range(len(filtered_sents)): 152 | sim_vocab.add_sentence(filtered_sents[i]) 153 | # Summarize the final data 154 | print('Registered %d unique words for the %s corpus.\n' % (sim_vocab.n_words, sim_vocab.name)) 155 | return sim_vocab, filtered 156 | -------------------------------------------------------------------------------- /similarity_estimator/testing.py: -------------------------------------------------------------------------------- 1 | """ Tests the performance of the trained model by checking its predictive accuracy on n randomly sampled items. """ 2 | 3 | import os 4 | import pickle 5 | 6 | from similarity_estimator.networks import SiameseClassifier 7 | from similarity_estimator.options import TestingOptions, ClusterOptions 8 | from similarity_estimator.sim_util import load_similarity_data 9 | from utils.data_server import DataServer 10 | from utils.init_and_storage import load_network 11 | 12 | # Initialize training parameters 13 | opt = TestingOptions() 14 | # Obtain data 15 | extended_corpus_path = os.path.join(opt.data_dir, 'extended_sick.txt') 16 | _, corpus_data = load_similarity_data(opt, extended_corpus_path, 'sick_corpus') 17 | # Load extended vocab 18 | vocab_path = os.path.join(opt.save_dir, 'extended_vocab.pkl') 19 | with open(vocab_path, 'rb') as f: 20 | vocab = pickle.load(f) 21 | 22 | # Initialize the similarity classifier 23 | classifier = SiameseClassifier(vocab.n_words, opt, is_train=False) 24 | # Load best available configuration (or modify as needed) 25 | load_network(classifier.encoder_a, 'sim_classifier', 'latest', opt.save_dir) 26 | 27 | # Initialize a data loader from randomly shuffled corpus data; inspection limited to individual items, hence bs=1 28 | shuffled_loader = DataServer(corpus_data, vocab, opt, is_train=False, use_buckets=False, volatile=True) 29 | 30 | # Keep track of performance 31 | total_classification_divergence = 0.0 32 | total_classification_loss = 0.0 33 | 34 | # Test loop 35 | for i, data in enumerate(shuffled_loader): 36 | # Upon completion 37 | if i >= opt.num_test_samples: 38 | average_classification_divergence = total_classification_divergence / opt.num_test_samples 39 | average_classification_loss = total_classification_loss / opt.num_test_samples 40 | print('=================================================\n' 41 | '= Testing concluded after examining %d samples. =\n' 42 | '= Average classification divergence is %.4f. =\n' 43 | '= Average classification loss (MSE) is %.4f. =\n' 44 | '=================================================' % 45 | (opt.num_test_samples, average_classification_divergence, average_classification_loss)) 46 | break 47 | 48 | s1_var, s2_var, label_var = data 49 | # Get predictions and update tracking values 50 | classifier.test_step(s1_var, s2_var, label_var) 51 | prediction = classifier.prediction 52 | loss = classifier.loss.data[0] 53 | divergence = abs((prediction - label_var).data[0]) 54 | total_classification_divergence += divergence 55 | total_classification_loss += loss 56 | 57 | sentence_a = ' '.join([vocab.index_to_word[int(idx[0])] if idx[0] != 0 else '' for idx in 58 | s1_var.data.numpy().tolist()]) 59 | sentence_b = ' '.join([vocab.index_to_word[int(idx[0])] if idx[0] != 0 else '' for idx in 60 | s2_var.data.numpy().tolist()]) 61 | 62 | print('Sample: %d\n' 63 | 'Sentence A: %s\n' 64 | 'Sentence B: %s\n' 65 | 'Prediction: %.4f\n' 66 | 'Ground truth: %.4f\n' 67 | 'Divergence: %.4f\n' 68 | 'Loss: %.4f\n' % 69 | (i, sentence_a, sentence_b, prediction.data[0], label_var.data[0][0], divergence, loss)) 70 | -------------------------------------------------------------------------------- /similarity_estimator/training.py: -------------------------------------------------------------------------------- 1 | """ Pre-trains the similarity estimator network the SemEval corpus and fine-tunes it on the SICK corpus. """ 2 | 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | import torch 8 | from utils.data_server import DataServer 9 | from sklearn.model_selection import GridSearchCV 10 | from sklearn.model_selection import train_test_split 11 | from sklearn.svm import SVR 12 | 13 | from similarity_estimator.networks import SiameseClassifier 14 | from similarity_estimator.options import TestingOptions, ClusterOptions 15 | from similarity_estimator.sick_extender import SickExtender 16 | from similarity_estimator.sim_util import load_similarity_data 17 | from utils.init_and_storage import add_pretrained_embeddings, extend_embeddings, update_learning_rate, save_network 18 | from utils.parameter_initialization import xavier_normal 19 | 20 | # Initialize training parameters 21 | opt = TestingOptions() 22 | 23 | if opt.pre_training: 24 | save_dir = opt.pretraining_dir 25 | sts_corpus_path = os.path.join(opt.data_dir, 'se100.txt') 26 | vocab, corpus_data = load_similarity_data(opt, sts_corpus_path, 'SemEval13STS_corpus') 27 | # Initialize an embedding table 28 | init_embeddings = xavier_normal(torch.randn([vocab.n_words, 300])).numpy() 29 | # Add FastText embeddings 30 | fasttext_embeddings = add_pretrained_embeddings( 31 | init_embeddings, vocab, os.path.join(opt.data_dir, 'fasttext_embeds.txt')) 32 | # Initialize the similarity estimator network 33 | classifier = SiameseClassifier(vocab.n_words, opt, is_train=True) 34 | # Initialize parameters 35 | classifier.initialize_parameters() 36 | # Inject the pre-trained embedding table 37 | classifier.encoder_a.embedding_table.weight.data.copy_(fasttext_embeddings) 38 | 39 | else: 40 | save_dir = opt.save_dir 41 | # Extend the corpus with synthetic data 42 | source_corpus_path = os.path.join(opt.data_dir, 'SICK.txt') 43 | language_model_path = os.path.join(opt.data_dir, 'sick_lm.klm') 44 | extended_corpus_path = os.path.join(opt.data_dir, 'extended_sick.txt') 45 | extender = SickExtender(source_corpus_path, opt.data_dir, lm_path=language_model_path) 46 | if not os.path.exists(extended_corpus_path): 47 | extender.create_extension() 48 | # Obtain data 49 | target_vocab, corpus_data = load_similarity_data(opt, extended_corpus_path, 'sick_corpus') 50 | # Load pretrained parameters 51 | pretrained_path = os.path.join(opt.save_dir, 'pretraining/pretrained.pkl') 52 | with open(pretrained_path, 'rb') as f: 53 | pretrained_embeddings, pretrained_vocab = pickle.load(f) 54 | # Extend embeddings 55 | vocab, extended_embeddings = extend_embeddings( 56 | pretrained_embeddings, pretrained_vocab, target_vocab, os.path.join(opt.data_dir, 'fasttext_embeds.txt')) 57 | # Save extended embeddings 58 | vocab_path = os.path.join(opt.save_dir, 'extended_vocab.pkl') 59 | with open(vocab_path, 'wb') as f: 60 | pickle.dump(vocab, f) 61 | # Initialize the similarity estimator network 62 | classifier = SiameseClassifier(vocab.n_words, opt, is_train=True) 63 | # Initialize parameters 64 | classifier.initialize_parameters() 65 | # Inject the pre-trained embedding table 66 | classifier.encoder_a.embedding_table.weight.data.copy_(extended_embeddings) 67 | 68 | # Set up training 69 | learning_rate = opt.learning_rate 70 | 71 | # Initialize global tracking variables 72 | best_validation_accuracy = 0 73 | epochs_without_improvement = 0 74 | final_epoch = 0 75 | 76 | # Split the data for training and validation (70/30) 77 | train_data, valid_data, train_labels, valid_labels = train_test_split(corpus_data[0], corpus_data[1], 78 | test_size=0.3, random_state=0) 79 | 80 | # Training loop 81 | for epoch in range(opt.num_epochs): 82 | # Declare tracking variables 83 | running_loss = list() 84 | total_train_loss = list() 85 | 86 | # Initiate the training data loader 87 | train_loader = DataServer([train_data, train_labels], vocab, opt, is_train=True, use_buckets=True, volatile=False) 88 | 89 | # Training loop 90 | for i, data in enumerate(train_loader): 91 | # Obtain data 92 | s1_var, s2_var, label_var = data 93 | classifier.train_step(s1_var, s2_var, label_var) 94 | train_batch_loss = classifier.loss.data[0] 95 | 96 | running_loss.append(train_batch_loss) 97 | total_train_loss.append(train_batch_loss) 98 | 99 | if i % opt.report_freq == 0 and i != 0: 100 | running_avg_loss = sum(running_loss) / len(running_loss) 101 | print('Epoch: %d | Training Batch: %d | Average loss since batch %d: %.4f' % 102 | (epoch, i, i - opt.report_freq, running_avg_loss)) 103 | running_loss = list() 104 | 105 | # Report epoch statistics 106 | avg_training_accuracy = sum(total_train_loss) / len(total_train_loss) 107 | print('Average training batch loss at epoch %d: %.4f' % (epoch, avg_training_accuracy)) 108 | 109 | # Validate after each epoch; set tracking variables 110 | if epoch >= opt.start_early_stopping: 111 | total_valid_loss = list() 112 | 113 | # Initiate the training data loader 114 | valid_loader = DataServer([valid_data, valid_labels], vocab, opt, is_train=True, use_buckets=False, 115 | volatile=True) 116 | 117 | # Validation loop (i.e. perform inference on the validation set) 118 | for i, data in enumerate(valid_loader): 119 | s1_var, s2_var, label_var = data 120 | # Get predictions and update tracking values 121 | classifier.test_step(s1_var, s2_var, label_var) 122 | valid_batch_loss = classifier.loss.data[0] 123 | total_valid_loss.append(valid_batch_loss) 124 | 125 | # Report fold statistics 126 | avg_valid_accuracy = sum(total_valid_loss) / len(total_valid_loss) 127 | print('Average validation fold accuracy at epoch %d: %.4f' % (epoch, avg_valid_accuracy)) 128 | # Save network parameters if performance has improved 129 | if avg_valid_accuracy <= best_validation_accuracy: 130 | epochs_without_improvement += 1 131 | else: 132 | best_validation_accuracy = avg_valid_accuracy 133 | epochs_without_improvement = 0 134 | save_network(classifier.encoder_a, 'sim_classifier', 'latest', save_dir) 135 | 136 | # Save network parameters at the end of each nth epoch 137 | if epoch % opt.save_freq == 0 and epoch != 0: 138 | print('Saving model networks after completing epoch %d' % epoch) 139 | save_network(classifier.encoder_a, 'sim_classifier', epoch, save_dir) 140 | 141 | # Anneal learning rate: 142 | if epochs_without_improvement == opt.start_annealing: 143 | old_learning_rate = learning_rate 144 | learning_rate *= opt.annealing_factor 145 | update_learning_rate(classifier.optimizer_a, learning_rate) 146 | print('Learning rate has been updated from %.4f to %.4f' % (old_learning_rate, learning_rate)) 147 | 148 | # Terminate training early, if no improvement has been observed for n epochs 149 | if epochs_without_improvement >= opt.patience: 150 | print('Stopping training early after %d epochs, following %d epochs without performance improvement.' % 151 | (epoch, epochs_without_improvement)) 152 | final_epoch = epoch 153 | break 154 | 155 | print('Training procedure concluded after %d epochs total. Best validated epoch: %d.' 156 | % (final_epoch, final_epoch - opt.patience)) 157 | 158 | if opt.pre_training: 159 | # Save pretrained embeddings and the vocab object 160 | pretrained_path = os.path.join(save_dir, 'pretrained.pkl') 161 | pretrained_embeddings = classifier.encoder_a.embedding_table.weight.data 162 | with open(pretrained_path, 'wb') as f: 163 | pickle.dump((pretrained_embeddings, vocab), f) 164 | print('Pre-trained parameters saved to %s' % pretrained_path) 165 | 166 | if not opt.pre_training: 167 | ''' Regression step over the training set to improve the predictive power of the model''' 168 | # Obtain similarity score predictions for each item within the training corpus 169 | labels = list() 170 | predictions = list() 171 | 172 | # Initiate the training data loader 173 | train_loader = DataServer([train_data, train_labels], vocab, opt, is_train=True, volatile=True) 174 | 175 | # Obtaining predictions 176 | for i, data in enumerate(train_loader): 177 | # Obtain data 178 | s1_var, s2_var, label_var = data 179 | labels += [l[0] for l in label_var.data.numpy().tolist()] 180 | classifier.test_step(s1_var, s2_var, label_var) 181 | batch_predict = classifier.prediction.data.squeeze().numpy().tolist() 182 | predictions += batch_predict 183 | 184 | labels = np.array(labels) 185 | predictions = np.array(predictions).reshape(-1, 1) 186 | 187 | # Fit an SVR (following the scikit-learn tutorial) 188 | sim_svr = GridSearchCV(SVR(kernel='rbf', gamma=0.1), cv=5, param_grid={"C": [1e0, 1e1, 1e2, 1e3], 189 | "gamma": np.logspace(-2, 2, 5)}) 190 | 191 | sim_svr.fit(predictions, labels) 192 | print('SVR complexity and bandwidth selected and model fitted successfully.') 193 | 194 | # Save trained SVR model 195 | svr_path = os.path.join(save_dir, 'sim_svr.pkl') 196 | with open(svr_path, 'wb') as f: 197 | pickle.dump(sim_svr, f) 198 | print('Trained SVR model saved to %s' % svr_path) 199 | -------------------------------------------------------------------------------- /utils/data_server.py: -------------------------------------------------------------------------------- 1 | """ Word-based, for now. Switch to sub-word eventually, o a combination of character- and word-based input.""" 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | from similarity_estimator.sim_util import perform_bucketing 10 | 11 | 12 | class DataServer(object): 13 | """ Iterates through a data source, i.e. a list of sentences or list of buckets containing sentences of similar length. 14 | Produces batch-major batches, i.e. of shape=[seq_len, batch_size]. """ 15 | def __init__(self, data, vocab, opt, is_train=False, shuffle=True, use_buckets=True, volatile=False): 16 | self.data = data 17 | self.vocab = vocab 18 | self.opt = opt 19 | self.volatile = volatile 20 | self.use_buckets = use_buckets 21 | self.pair_id = 0 22 | self.buckets = None 23 | # Obtain bucket data 24 | if self.use_buckets: 25 | self.buckets, self.data = perform_bucketing(self.opt, self.data) 26 | self.bucket_id = 0 27 | # Select appropriate batch size 28 | if is_train: 29 | self.batch_size = self.opt.train_batch_size 30 | else: 31 | self.batch_size = self.opt.test_batch_size 32 | # Shuffle data (either batch-wise or as a whole) 33 | if shuffle: 34 | if self.use_buckets: 35 | # Shuffle within buckets 36 | for i in range(len(self.data)): 37 | zipped = list(zip(*self.data[i])) 38 | random.shuffle(zipped) 39 | self.data[i] = list(zip(*zipped)) 40 | # Shuffle buckets, also 41 | bucket_all = list(zip(self.buckets, self.data)) 42 | random.shuffle(bucket_all) 43 | self.buckets, self.data = zip(*bucket_all) 44 | else: 45 | zipped = list(zip(*self.data)) 46 | random.shuffle(zipped) 47 | self.data = list(zip(*zipped)) 48 | 49 | def sent_to_idx(self, sent): 50 | """ Transforms a sequence of strings to the corresponding sequence of indices. """ 51 | idx_list = [self.vocab.word_to_index[word] if self.vocab.word_to_count[word] >= self.opt.freq_bound else 1 for 52 | word in sent.split()] 53 | # Pad to the desired sentence length 54 | if self.opt.pad: 55 | if self.use_buckets: 56 | # Pad to bucket upper length bound 57 | max_len = self.buckets[self.bucket_id][1] 58 | else: 59 | # In case of no bucketing, pad all corpus sentence to a uniform, specified length 60 | max_len = self.vocab.target_len 61 | # Adjust padding for single sentence-pair evalualtion (i.e. no buckets, singleton batches) 62 | if self.batch_size == 1: 63 | max_len = max(len(self.data[0][self.pair_id][0].split()), len(self.data[0][self.pair_id][1].split())) 64 | # Pad items to maximum length 65 | diff = max_len - len(idx_list) 66 | if diff >= 1: 67 | idx_list += [0] * diff 68 | return idx_list 69 | 70 | def __iter__(self): 71 | """ Returns an iterator object. """ 72 | return self 73 | 74 | def __next__(self): 75 | """ Returns the next batch from within the iterator source. """ 76 | try: 77 | if self.use_buckets: 78 | s1_batch, s2_batch, label_batch = self.bucketed_next() 79 | else: 80 | s1_batch, s2_batch, label_batch = self.corpus_next() 81 | except IndexError: 82 | raise StopIteration 83 | 84 | # Covert batches into batch major form 85 | s1_batch = torch.LongTensor(s1_batch).t().contiguous() 86 | s2_batch = torch.LongTensor(s2_batch).t().contiguous() 87 | label_batch = torch.FloatTensor(label_batch).contiguous() 88 | # Convert to variables 89 | s1_var = Variable(s1_batch, volatile=self.volatile) 90 | s2_var = Variable(s2_batch, volatile=self.volatile) 91 | label_var = Variable(label_batch, volatile=self.volatile) 92 | return s1_var, s2_var, label_var 93 | 94 | def bucketed_next(self): 95 | """ Samples the next batch from the current bucket. """ 96 | # Assemble batches 97 | s1_batch = list() 98 | s2_batch = list() 99 | label_batch = list() 100 | 101 | if self.bucket_id < self.opt.num_buckets: 102 | # Fill batches 103 | while len(s1_batch) < self.batch_size: 104 | try: 105 | s1 = self.sent_to_idx(self.data[self.bucket_id][0][self.pair_id][0]) 106 | s2 = self.sent_to_idx(self.data[self.bucket_id][0][self.pair_id][1]) 107 | label = [float(self.data[self.bucket_id][1][self.pair_id])] 108 | s1_batch.append(s1) 109 | s2_batch.append(s2) 110 | label_batch.append(label) 111 | self.pair_id += 1 112 | except IndexError: 113 | # Finish batch prematurely if bucket or corpus has been emptied 114 | self.pair_id = 0 115 | self.bucket_id += 1 116 | break 117 | # Check if bucket is empty, to avoid generation of empty batches 118 | try: 119 | if self.pair_id == len(self.data[self.bucket_id][0]): 120 | self.bucket_id += 1 121 | except IndexError: 122 | pass 123 | else: 124 | raise IndexError 125 | 126 | return s1_batch, s2_batch, label_batch 127 | 128 | def corpus_next(self): 129 | """ Samples the next batch from the un-bucketed corpus. """ 130 | # Assemble batches 131 | s1_batch = list() 132 | s2_batch = list() 133 | label_batch = list() 134 | 135 | # Without bucketing 136 | if self.pair_id < self.get_length(): 137 | while len(s1_batch) < self.batch_size: 138 | try: 139 | s1 = self.sent_to_idx(self.data[0][self.pair_id][0]) 140 | s2 = self.sent_to_idx(self.data[0][self.pair_id][1]) 141 | label = [float(self.data[1][self.pair_id])] 142 | s1_batch.append(s1) 143 | s2_batch.append(s2) 144 | label_batch.append(label) 145 | self.pair_id += 1 146 | except IndexError: 147 | break 148 | else: 149 | raise IndexError 150 | 151 | return s1_batch, s2_batch, label_batch 152 | 153 | def get_length(self): 154 | # Return corpus length in sentence pairs 155 | if self.use_buckets: 156 | return sum([len(bucket[0]) for bucket in self.data]) 157 | else: 158 | return len(self.data[0]) 159 | -------------------------------------------------------------------------------- /utils/init_and_storage.py: -------------------------------------------------------------------------------- 1 | """ Various utility and helper functions used throughout the model. """ 2 | import os 3 | import torch 4 | 5 | from utils.parameter_initialization import xavier_normal 6 | 7 | 8 | def add_pretrained_embeddings(embedding_table, target_vocab, pretrained_vec_file): 9 | """ Fills the existing embedding table with pre-trained embeddings. Run after the initialization of the 10 | embedding table. """ 11 | print('Adding pre-trained embeddings ... ') 12 | 13 | # Read in the pretrained vector file 14 | with open(pretrained_vec_file, 'r') as in_file: 15 | for line in in_file: 16 | entries = line.split() 17 | # Check for blank/ incomplete lines 18 | if len(entries) != 301: 19 | continue 20 | word = entries[0] 21 | vec = [float(n) for n in entries[1:]] 22 | # Inject pretrained vectors 23 | try: 24 | word_row = target_vocab.word_to_index[word] 25 | embedding_table[word_row][:] = vec 26 | # Extend the vocabulary 27 | except KeyError: 28 | continue 29 | 30 | return torch.FloatTensor(embedding_table) 31 | 32 | 33 | def extend_embeddings(source_table, source_vocab, target_vocab, pretrained_vec_file): 34 | """ Extends an existing, trained embedding table with new entries corresponding to new words from some target 35 | corpus. """ 36 | print('Extending embedding table with pre-trained embeddings ... ') 37 | # Consolidate new and old vocabs 38 | added_words = list() 39 | source_vocab_start_words = source_vocab.n_words 40 | 41 | for idx_i in range(4, target_vocab.n_words): 42 | word = target_vocab.index_to_word[idx_i] 43 | try: 44 | source_vocab.word_to_index[word] 45 | except KeyError: 46 | source_vocab.add_word(word) 47 | source_vocab.word_to_count[word] = target_vocab.word_to_count[word] 48 | added_words.append(word) 49 | 50 | # Initialize embedding table extension 51 | added_embeddings = xavier_normal(torch.FloatTensor(len(added_words), 300)).numpy() 52 | 53 | # Update embedding table 54 | with open(pretrained_vec_file, 'r') as in_file: 55 | for line in in_file: 56 | entries = line.split() 57 | # Check for blank/ incomplete lines 58 | if len(entries) != 301: 59 | continue 60 | word = entries[0] 61 | vec = [float(n) for n in entries[1:]] 62 | # Collect new vectors 63 | if word in added_words: 64 | word_row = source_vocab.word_to_index[word] - source_vocab_start_words 65 | added_embeddings[word_row][:] = vec 66 | 67 | # Concatenate source embedding table with the new additions 68 | extended_table = torch.cat([source_table, torch.FloatTensor(added_embeddings)], 0) 69 | return source_vocab, extended_table 70 | 71 | 72 | def add_all_embeddings(embedding_table, vocab_object, pretrained_vec_file): 73 | """ Concatenates all missing pre-trained vectors to an existing embedding table and modifies the lookup object 74 | accordingly. """ 75 | print('Adding pre-trained embeddings ... ') 76 | # Initialize pretrained embedding table 77 | pretrained_table = list() 78 | # Read in the pretrained vector file 79 | with open(pretrained_vec_file, 'r') as in_file: 80 | for line in in_file: 81 | entries = line.split() 82 | # Check for blank/ incomplete lines 83 | if len(entries) != 301: 84 | continue 85 | word = entries[0] 86 | vec = [float(n) for n in entries[1:]] 87 | # Inject pretrained vectors 88 | try: 89 | value = vocab_object.word_to_index[word] 90 | embedding_table[value][:] = vec 91 | # Extend the vocabulary 92 | except KeyError: 93 | vocab_object.add_word(word) 94 | pretrained_table.append(vec) 95 | 96 | # Join embedding tables 97 | pretrained_table = torch.FloatTensor(pretrained_table) 98 | joint_table = torch.cat([embedding_table, pretrained_table], 0) 99 | 100 | return vocab_object, joint_table 101 | 102 | 103 | def initialize_parameters(network): 104 | """ Initializes the parameters of the network's weights following the Xavier initialization scheme. """ 105 | params = network.parameters() 106 | for tensor in params: 107 | if len(tensor.size()) > 1: 108 | tensor = xavier_normal(tensor) 109 | else: 110 | tensor.data.fill_(0.1) 111 | print("Initialized weight parameters of %s with Xavier initialization using the normal distribution." % 112 | network.name) 113 | 114 | 115 | def update_learning_rate(optimizer, new_lr): 116 | """ Decreases the learning rate to promote training gains. """ 117 | for param_group in optimizer.param_groups: 118 | param_group['lr'] = new_lr 119 | 120 | 121 | def save_network(network, network_label, active_epoch, save_directory): 122 | """ Saves the parameters of the specified network under the specified path. """ 123 | file_name = '%s_%s' % (str(active_epoch), network_label) 124 | save_path = os.path.join(save_directory, file_name) 125 | torch.save(network.cpu().state_dict(), save_path) 126 | print('Network %s saved following the completion of epoch %s | Location: %s' % 127 | (network_label, str(active_epoch), save_path)) 128 | 129 | 130 | def load_network(network, network_label, target_epoch, load_directory): 131 | """ Helper function for loading network work. """ 132 | load_filename = '%s_%s' % (str(target_epoch), network_label) 133 | load_path = os.path.join(load_directory, load_filename) 134 | network.load_state_dict(torch.load(load_path)) 135 | print('Network %s, version %s loaded from location %s' % (network_label, target_epoch, load_path)) 136 | -------------------------------------------------------------------------------- /utils/parameter_initialization.py: -------------------------------------------------------------------------------- 1 | """ Borrowed from https://github.com/alykhantejani/nninit/blob/master/nninit.py """ 2 | 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | def _calculate_fan_in_and_fan_out(tensor): 8 | if tensor.ndimension() < 2: 9 | raise ValueError("fan in and fan out can not be computed for tensor of size ", tensor.size()) 10 | 11 | if tensor.ndimension() == 2: # Linear 12 | fan_in = tensor.size(1) 13 | fan_out = tensor.size(0) 14 | else: 15 | num_input_fmaps = tensor.size(1) 16 | num_output_fmaps = tensor.size(0) 17 | receptive_field_size = np.prod(tensor.numpy().shape[2:]) 18 | fan_in = num_input_fmaps * receptive_field_size 19 | fan_out = num_output_fmaps * receptive_field_size 20 | 21 | return fan_in, fan_out 22 | 23 | 24 | def xavier_uniform(tensor, gain=1): 25 | """Fills the input Tensor or Variable with values according to the method described in "Understanding the difficulty of training 26 | deep feedforward neural networks" - Glorot, X. and Bengio, Y., using a uniform distribution. 27 | The resulting tensor will have values sampled from U(-a, a) where a = gain * sqrt(2/(fan_in + fan_out)) 28 | Args: 29 | tensor: a n-dimension torch.Tensor 30 | gain: an optional scaling factor to be applied 31 | Examples: 32 | w = torch.Tensor(3, 5) 33 | xavier_uniform(w, gain=np.sqrt(2.0)) 34 | """ 35 | if isinstance(tensor, Variable): 36 | xavier_uniform(tensor.data, gain=gain) 37 | return tensor 38 | else: 39 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 40 | std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 41 | a = np.sqrt(3.0) * std 42 | return tensor.uniform_(-a, a) 43 | 44 | 45 | def xavier_normal(tensor, gain=1): 46 | """Fills the input Tensor or Variable with values according to the method described in "Understanding the difficulty of training 47 | deep feedforward neural networks" - Glorot, X. and Bengio, Y., using a normal distribution. 48 | The resulting tensor will have values sampled from normal distribution with mean=0 and 49 | std = gain * sqrt(2/(fan_in + fan_out)) 50 | Args: 51 | tensor: a n-dimension torch.Tensor 52 | gain: an optional scaling factor to be applied 53 | Examples: 54 | w = torch.Tensor(3, 5) 55 | xavier_normal(w, gain=np.sqrt(2.0)) 56 | """ 57 | if isinstance(tensor, Variable): 58 | xavier_normal(tensor.data, gain=gain) 59 | return tensor 60 | else: 61 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 62 | std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 63 | return tensor.normal_(0, std) 64 | --------------------------------------------------------------------------------