├── requirements.txt ├── data └── extract_umbc.py ├── evaluate_word2mat.py ├── word2mat.py ├── README.md ├── mutils.py ├── wrap_evaluation.py ├── cbow.py └── train_cbow.py /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx==1.11 2 | sklearn 3 | numpy 4 | scipy 5 | torch 6 | torchtext 7 | bayesian-optimization 8 | hyperopt 9 | nltk 10 | pandas 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /data/extract_umbc.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script turns the UMBC News webcorpus into a file where 3 | each line is a sentence the corpus. 4 | In the resulting file, tokenization has already been performed and can be recreated by splitting by the whitespace character. 5 | 6 | The script assumes as parameters: 7 | arg[1] : Path to folder with .txt files 8 | The extracted UMBC webcorpus consists of a bunch of .txt files with a paragraph in each line (with blank 9 | lines in between. 10 | 11 | arg[2] : Output file path 12 | Path to the location where the output file should be stored (i.e., the file containing one sentence per line). 13 | """ 14 | 15 | import sys, os 16 | from nltk.tokenize import sent_tokenize, word_tokenize 17 | 18 | 19 | # parse args 20 | input_path = sys.argv[1] 21 | output_file = sys.argv[2] 22 | 23 | with open(output_file, 'w') as output_f: 24 | 25 | 26 | file_names = os.listdir(input_path) 27 | file_names = [f for f in file_names if f.endswith(".txt")] 28 | 29 | for f in file_names: 30 | input_file_path = os.path.join(input_path, f) 31 | 32 | # iterate over the plain text data from the UMBC corpus 33 | with open(input_file_path, 'r', encoding = "utf-8") as input_file: 34 | for line in input_file: 35 | 36 | # skip empty line 37 | if not line.strip(): 38 | continue 39 | else: 40 | line = line.strip() 41 | 42 | # get all sentences in paragraph 43 | sentences = sent_tokenize(line) 44 | 45 | # each line contains a sentence 46 | for s in sentences: 47 | words = word_tokenize(s) 48 | sample = " ".join(words) 49 | print(sample, file = output_f) 50 | -------------------------------------------------------------------------------- /evaluate_word2mat.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is for assessing the performance of models (or their subparts) AFTER training. 3 | """ 4 | 5 | from __future__ import absolute_import, division, unicode_literals 6 | 7 | import sys, os, time 8 | import torch 9 | import logging 10 | import pickle 11 | import numpy as np 12 | import pandas as pd 13 | import argparse 14 | 15 | from wrap_evaluation import run_and_evaluate 16 | 17 | from data import get_index_batch 18 | 19 | from torch.autograd import Variable 20 | 21 | # Set PATHs 22 | PATH_SENTEVAL = '/data22/fmai/data/SentEval/SentEval/' 23 | PATH_TO_DATA = '/data22/fmai/data/SentEval/SentEval/data' 24 | 25 | assert os.path.exists(PATH_SENTEVAL) and os.path.exists(PATH_TO_DATA), "Set path to SentEval + data correctly!" 26 | 27 | # import senteval 28 | sys.path.insert(0, PATH_SENTEVAL) 29 | import senteval 30 | 31 | # Set up logger 32 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 33 | 34 | total_time_encoding = 0. 35 | total_samples_encoded = 0 36 | 37 | if __name__ == "__main__": 38 | 39 | def prepare(params_senteval, samples): 40 | 41 | params = params_senteval["cmd_params"] 42 | 43 | # Load vocabulary 44 | vocabulary = pickle.load(open(params.word_vocab, "rb" ))[0] 45 | 46 | params_senteval['vocabulary'] = vocabulary 47 | params_senteval['inverse_vocab'] = {vocabulary[w] : w for w in vocabulary} 48 | params_senteval['encoders'] = [torch.load(p) for p in params.encoders] 49 | 50 | def _batcher_helper(encoder, vocabulary, batch): 51 | sent, _ = get_index_batch(batch, vocabulary) 52 | sent_cuda = Variable(sent.cuda()) 53 | sent_cuda = sent_cuda.t() 54 | encoder.eval() # Deactivate drop-out and such 55 | embeddings = encoder.forward(sent_cuda).data.cpu().numpy() 56 | 57 | return embeddings 58 | 59 | def get_params_parser(): 60 | parser = argparse.ArgumentParser(description='Evaluates the performance of a given encoder. Use --included_features to\ 61 | evaluate subparts of the model.') 62 | 63 | # paths 64 | parser.add_argument('--word_vocab', type=str, default=None, help= \ 65 | "Specify path where to load precomputed word.", required = True) 66 | parser.add_argument('--encoders', type=str, nargs='+', default=None, help= \ 67 | "Specify path to load encoder models from.", required = True) 68 | parser.add_argument('--aggregation', type=str, default="concat", help= \ 69 | "Specify operation to use for aggregating embeddings from multiple encoders.", required = False, 70 | choices = ['concat', 'add']) 71 | parser.add_argument('--gpu_device', type=int, default=0, help= \ 72 | "You need to specify the id of the gpu that you used for training to avoid errors.", required = False) 73 | parser.add_argument('--add_start_end_token', action = "store_true", default=False, help= \ 74 | "If activated, the start and end tokens are added to every sample. Used e.g. for NLI trained encoder.", required = False) 75 | parser.add_argument('--included_features', type = int, nargs = '+', default=None, help= \ 76 | "If specified, expects two integers a and b, which denote the range (a is inclusive, b is exclusive) of indices to use from\ 77 | the embedding. E.g., if '--included_features 0 300' is specified, the embedding that is evaluated consists only of the first\ 78 | 300 dimensions of the actual embedding: embeddings[a,b].", required = False) 79 | return parser 80 | 81 | def batcher(params_senteval, batch): 82 | 83 | start_time = time.time() 84 | params = params_senteval["cmd_params"] 85 | if params.add_start_end_token: 86 | batch = [[''] + s + [''] for s in batch] 87 | 88 | embeddings_list = [_batcher_helper(enc, params_senteval['vocabulary'], batch) for enc in params_senteval['encoders']] 89 | 90 | if params.aggregation == "add": 91 | embeddings = sum(embeddings_list) 92 | 93 | elif params.aggregation == "concat": 94 | embeddings = np.hstack(embeddings_list) 95 | 96 | global total_time_encoding 97 | global total_samples_encoded 98 | total_time_encoding += time.time() - start_time 99 | total_samples_encoded += len(batch) 100 | 101 | if params.included_features: 102 | a = params.included_features[0] 103 | b = params.included_features[1] 104 | embeddings = embeddings[:, a:b] 105 | 106 | return embeddings 107 | 108 | def _load_encoder_and_eval(params): 109 | encoder_for_wordemb_eval = torch.load(params.encoders[0]) 110 | return (encoder_for_wordemb_eval, []) 111 | run_and_evaluate(_load_encoder_and_eval, get_params_parser, batcher, prepare) 112 | 113 | print("Encoding speed: {}/s".format(total_samples_encoded / total_time_encoding)) 114 | 115 | -------------------------------------------------------------------------------- /word2mat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time, sys, random 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | import math 9 | import time 10 | 11 | from torch import FloatTensor as FT 12 | from torch import ByteTensor as BT 13 | 14 | TINY = 1e-11 15 | 16 | class Word2MatEncoder(nn.Module): 17 | 18 | def __init__(self, n_words, word_emb_dim = 784, padding_idx = 0, w2m_type = "cbow", initialization_strategy = "identity"): 19 | """ 20 | TODO: Method description for w2m encoder. 21 | """ 22 | super(Word2MatEncoder, self).__init__() 23 | self.word_emb_dim = word_emb_dim 24 | self.n_words = n_words 25 | self.w2m_type = w2m_type 26 | self.initialization_strategy = initialization_strategy 27 | 28 | # check that the word embedding size is a square 29 | assert word_emb_dim == int(math.sqrt(word_emb_dim)) ** 2 30 | 31 | # set up word embedding table 32 | self.lookup_table = nn.Embedding(self.n_words + 1, 33 | self.word_emb_dim, 34 | padding_idx=padding_idx, 35 | sparse = False) 36 | 37 | # type of aggregation to use to combine two word matrices 38 | self.w2m_type = w2m_type 39 | if self.w2m_type not in ["cbow", "cmow"]: 40 | raise NotImplementedError("Operator " + self.operator + " is not yet implemented.") 41 | 42 | # set initial weights of word embeddings depending on the initialization strategy 43 | ## set weights of padding symbol such that it is the neutral element with respect to the operation 44 | if self.w2m_type == "cmow": 45 | neutral_element = np.reshape(np.eye(int(np.sqrt(self.word_emb_dim)), dtype=np.float32), (1, -1)) 46 | neutral_element = torch.from_numpy(neutral_element) 47 | elif self.w2m_type == "cbow": 48 | neutral_element = np.reshape(torch.from_numpy(np.zeros((self.word_emb_dim), dtype=np.float32)), (1, -1)) 49 | 50 | ## set weights of rest others depending on the initialization strategy 51 | if self.w2m_type == "cmow": 52 | if self.initialization_strategy == "identity": 53 | init_weights = self._init_random_identity() 54 | 55 | elif self.initialization_strategy == "normalized": 56 | ### normalized initialization by (Glorot and Bengio, 2010) 57 | init_weights = torch.from_numpy(np.random.uniform(size = (self.n_words, 58 | self.word_emb_dim), 59 | low = -np.sqrt(6 / (2*self.word_emb_dim)), 60 | high = +np.sqrt(6 / (2*self.word_emb_dim)) 61 | ).astype(np.float32) 62 | ) 63 | elif self.initialization_strategy == "normal": 64 | ### normalized with N(0,0.1), which failed in study by Yessenalina 65 | init_weights = torch.from_numpy(np.random.normal(size = (self.n_words, 66 | self.word_emb_dim), 67 | loc = 0.0, 68 | scale = 0.1 69 | ).astype(np.float32) 70 | ) 71 | else: 72 | raise NotImplementedError("Unknown initialization strategy " + self.initialization_strategy) 73 | 74 | elif self.w2m_type == "cbow": 75 | init_weights = self._init_normal() 76 | 77 | 78 | ## concatenate and set weights in the lookup table 79 | weights = torch.cat([neutral_element, 80 | init_weights], 81 | dim=0) 82 | self.lookup_table.weight = nn.Parameter(weights) 83 | 84 | def forward(self, sent): 85 | 86 | sent = self.lookup_table(sent) 87 | seq_length = sent.size()[1] 88 | matrix_dim = self._matrix_dim() 89 | 90 | # reshape vectors to matrices 91 | word_matrices = sent.view(-1, seq_length, matrix_dim, matrix_dim) 92 | 93 | # aggregate matrices 94 | if self.w2m_type == "cmow": 95 | cur_emb = self._continual_multiplication(word_matrices) 96 | elif self.w2m_type == "cbow": 97 | cur_emb = torch.sum(word_matrices, 1) 98 | 99 | # flatten final matrix 100 | emb = self._flatten_matrix(cur_emb) 101 | 102 | return emb 103 | 104 | def _continual_multiplication(self, word_matrices): 105 | cur_emb = word_matrices[:, 0, :] 106 | for i in range(1, word_matrices.size()[1]): 107 | cur_emb = torch.bmm(cur_emb, word_matrices[:, i, :]) 108 | return cur_emb 109 | 110 | def _flatten_matrix(self, m): 111 | return m.view(-1, self.word_emb_dim) 112 | 113 | def _unflatten_matrix(self, m): 114 | return m.view(-1, self._matrix_dim(), self._matrix_dim()) 115 | 116 | def _matrix_dim(self): 117 | return int(np.sqrt(self.word_emb_dim)) 118 | 119 | def _init_random_identity(self): 120 | """Random normal initialization around 0., but add 1. at the diagonal""" 121 | init_weights = np.random.normal(size = (self.n_words, self.word_emb_dim), 122 | loc = 0., 123 | scale = 0.1 124 | ).astype(np.float32) 125 | for i in range(self.n_words): 126 | init_weights[i, :] += np.reshape(np.eye(int(np.sqrt(self.word_emb_dim)), dtype=np.float32), (-1,)) 127 | init_weights = torch.from_numpy(init_weights) 128 | return init_weights 129 | 130 | def _init_normal(self): 131 | ### normal initialization around 0. 132 | init_weights = torch.from_numpy(np.random.normal(size = (self.n_words, 133 | self.word_emb_dim), 134 | loc = 0.0, 135 | scale = 0.1 136 | ).astype(np.float32) 137 | ) 138 | return init_weights 139 | 140 | 141 | class HybridEncoder(nn.Module): 142 | def __init__(self, cbow_encoder, cmow_encoder): 143 | super(HybridEncoder, self).__init__() 144 | self.cbow_encoder = cbow_encoder 145 | self.cmow_encoder = cmow_encoder 146 | 147 | def forward(self, sent_tuple): 148 | return torch.cat([self.cbow_encoder(sent_tuple), self.cmow_encoder(sent_tuple)], dim = 1) 149 | 150 | 151 | def get_cmow_encoder(n_words, padding_idx = 0, word_emb_dim = 784, initialization_strategy = "identity"): 152 | encoder = Word2MatEncoder(n_words, word_emb_dim = word_emb_dim, 153 | padding_idx = padding_idx, w2m_type = "cmow", 154 | initialization_strategy = initialization_strategy) 155 | return encoder 156 | 157 | def get_cbow_encoder(n_words, padding_idx = 0, word_emb_dim = 784): 158 | encoder = Word2MatEncoder(n_words, word_emb_dim = word_emb_dim, 159 | padding_idx = padding_idx, w2m_type = "cbow") 160 | return encoder 161 | 162 | def get_cbow_cmow_hybrid_encoder(n_words, padding_idx = 0, word_emb_dim = 400, initialization_strategy = "identity"): 163 | cbow_encoder = get_cbow_encoder(n_words, padding_idx = padding_idx, word_emb_dim = word_emb_dim) 164 | cmow_encoder = get_cmow_encoder(n_words, padding_idx = word_emb_dim, 165 | word_emb_dim = word_emb_dim, 166 | initialization_strategy = initialization_strategy) 167 | 168 | encoder = HybridEncoder(cbow_encoder, cmow_encoder) 169 | return encoder 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # word2mat 2 | 3 | *Word2Mat* is a framework that learns *sentence embeddings* in a CBOW-word2vec style, but where the words and sentences are represented as matrices. 4 | Details of this method and results can be found in our [ICLR paper](https://openreview.net/forum?id=H1MgjoR9tQ). 5 | 6 | ## Dependencies 7 | 8 | - Python3 9 | - PyTorch >= 0.4 with CUDA support 10 | - NLTK >= 3 11 | 12 | ## Setup python3 environment 13 | 14 | Please install the python3 dependencies in your environment: 15 | 16 | ``` 17 | virtualenv -p python3 venv && source venv/bin/activate 18 | pip install -r requirements.txt 19 | python3 -c "import nltk; nltk.download('punkt')" 20 | ``` 21 | 22 | ## Download training data 23 | 24 | In order to reproduce the results from our paper, which were trained on the UMBC corpus, 25 | download the [UMBC corpus](https://ebiquity.umbc.edu/resource/html/id/351), extract the tar.gz file, 26 | and run the extract_umbc.py script in the following way: 27 | 28 | ``` 29 | python extract_umbc.py umbc_corpus/webbase_all 30 | ``` 31 | 32 | This stores the sentences from the UMBC corpus in a format that is usable by our code: Each line in the resulting file contains a single sentence, 33 | whose (already pre-processed) tokens are separated by a whitespace character. 34 | 35 | ## Running the experiments 36 | 37 | Note: After further experiments, we observed that terminating training based on the validation loss produces unreliable results because of 38 | relatively high variance in the validation loss. Hence, we recommend using training loss as stopping criterion, which is more stable. 39 | 40 | The results below are trained with this stopping criterion, and therefore slightly differ from the results reported in the ICLR paper. 41 | However, the conclusions remain the same: CMOW is much better than CBOW at capturing linguistic properties except WordContent. 42 | Therefore, CBOW is superior in almost all downstream tasks except TREC. 43 | The Hybrid model retains the capabilities of both models and therefore is extremely close to the better model among CBOW and CMOW, or better 44 | on all tasks. 45 | 46 | Probing tasks: All scores denote accuracy. 47 | 48 | | Model | Depth| BigramShift| SubjNumber| Tense| CoordinationInversion| Length| ObjNumber| TopConstituents| OddManOut| WordContent| 49 | |:----------|------:|------------:|-----------:|------:|----------------------:|-------:|----------:|----------------:|----------:|------------:| 50 | | CBOW | 32.73| 49.65| 79.65| 79.46| 53.78| 75.69| 79.00| 72.26| 49.64| 89.11| 51 | | CMOW | 34.40| 72.44| 82.08| 80.32| 62.05| 82.93| 79.70| 74.25| 51.33| 65.15| 52 | | Hybrid | 35.38| 71.22| 81.45| 80.83| 59.17| 87.00| 79.37| 72.88| 50.53| 86.97| 53 | 54 | Supervised downstream tasks: For STS-Benchmark and Sick-Relatedness, the results denote Spearman correlation coefficient. For all others the score denotes accuracy. 55 | 56 | | Model | SNLI| SUBJ| CR| MR| MPQA| TREC| SICKEntailment| SST2| SST5| MRPC| STSBenchmark| SICKRelatedness| 57 | |:----------|------:|------:|------:|------:|------:|-----:|---------------:|------:|------:|------:|-------------:|----------------:| 58 | | CBOW | 67.76| 90.45| 79.76| 74.32| 87.23| 84.4| 79.58| 78.14| 41.72| 72.17| 0.619| 0.721| 59 | | CMOW | 64.77| 87.11| 74.60| 71.42| 87.55| 88.0| 76.90| 76.77| 40.18| 70.61| 0.576| 0.705| 60 | | Hybrid | 67.59| 90.26| 79.60| 74.10| 87.38| 89.2| 78.69| 77.87| 41.58| 71.94| 0.613| 0.718| 61 | 62 | Unsupervised downstream tasks: The score denotes Spearman correlation coefficient. 63 | 64 | | Model | STS12| STS13| STS14| STS15| STS16| 65 | |:----------|------:|------:|------:|------:|------:| 66 | | CBOW | 0.458| 0.497| 0.556| 0.637| 0.630| 67 | | CMOW | 0.432| 0.334| 0.403| 0.471| 0.529| 68 | | Hybrid | 0.472| 0.476| 0.530| 0.621| 0.613| 69 | 70 | ### Train CBOW, CMOW, and CBOW-CMOW hybrid model 71 | 72 | To train a 784-dimensional CBOW model, run the following: 73 | 74 | ``` 75 | python train_cbow.py --w2m_type cbow --batch_size=1024 --outputdir= --optimizer adam,lr=0.0003 --max_words=30000 --n_epochs=1000 --n_negs=20 --validation_frequency=1000 --mode=random --num_samples_per_item=30 --patience 10 --downstream_eval full --outputmodelname mode w2m_type word_emb_dim --validation_fraction=0.0001 --context_size=5 --word_emb_dim 784 --temp_path --dataset_path= --num_workers 2 --output_file --num_docs 134442680 --stop_criterion train_loss 76 | ``` 77 | 78 | For CMOW: 79 | 80 | ``` 81 | python train_cbow.py --w2m_type cmow --batch_size=1024 --outputdir= --optimizer adam,lr=0.0003 --max_words=30000 --n_epochs=1000 --n_negs=20 --validation_frequency=1000 --mode=random --num_samples_per_item=30 --patience 10 --downstream_eval full --outputmodelname mode w2m_type word_emb_dim --validation_fraction=0.0001 --context_size=5 --word_emb_dim 784 --temp_path --dataset_path= --num_workers 2 --output_file --num_docs 134442680 --stop_criterion train_loss --initialization identity 82 | ``` 83 | 84 | And the CBOW-CMOW Hybrid: 85 | 86 | ``` 87 | python train_cbow.py --w2m_type hybrid --batch_size=1024 --outputdir= --optimizer adam,lr=0.0003 --max_words=30000 --n_epochs=1000 --n_negs=20 --validation_frequency=1000 --mode=random --num_samples_per_item=30 --patience 10 --downstream_eval full --outputmodelname mode w2m_type word_emb_dim --validation_fraction=0.0001 --context_size=5 --word_emb_dim 400 --temp_path --dataset_path= --num_workers 2 --output_file --num_docs 134442680 --stop_criterion train_loss --initialization identity 88 | ``` 89 | 90 | ### Evaluate components of hybrid model 91 | 92 | In the paper, we have shown that the jointly training of the individual CBOW/CMOW components emphasizes their individual strengths. 93 | To assess the performance of the CBOW component, restrict the final embedding representation to include only the first 94 | half of the representations from the HybridEncoder (--included_features 0 400 in a 800-dimensional Hybrid encoder), or restrict it to the second half (--included features 400 800) to evaluate the CMOW component. 95 | E.g, for evaluating the CMOW component, run: 96 | 97 | ``` 98 | python evaluate_word2mat.py --encoders --word_vocab --included_features 400 800 --outputdir --outputmodelname hybrid_constituent --downstream_eval full 99 | ``` 100 | 101 | Here, 'encoder' and 'word_vocab' is saved in 'outputdir' after training the models. By 102 | 103 | ## Files 104 | 105 | - `train_cbow.py` Main training executable. Type python train_cbow.py --help to get overview of training parameters. 106 | - `cbow.py` Contains the data preparation code as well as the neural architecture for CBOW except the encoder. 107 | - `word2mat.py` The code for word2mat encoder. 108 | - `wrap_evaluation.py` Wrapper script for SentEval to automatically evaluate encoder after training. 109 | - `evaluate_word2mat.py` Script for evaluating sub-components of hybrid encoder with SentEval. 110 | - `mutils.py` Helpers for saving the results, hyperparameter optimization and stuff. 111 | 112 | ## Reference 113 | 114 | Please cite our ICLR paper [[1]](https://openreview.net/forum?id=H1MgjoR9tQ) to reference our work or code. 115 | 116 | ### CBOW Is Not All You Need: Combining CBOW with the Compositional Matrix Space Model (ICLR 2019) 117 | 118 | [1] Mai, F., Galke, L & Scherp, A., [*CBOW Is Not All You Need: Combining CBOW with the Compositional Matrix Space Model*](https://openreview.net/forum?id=H1MgjoR9tQ) 119 | 120 | ``` 121 | @inproceedings{mai2018cbow, 122 | title={{CBOW} Is Not All You Need: Combining {CBOW} with the Compositional Matrix Space Model}, 123 | author={Florian Mai and Lukas Galke and Ansgar Scherp}, 124 | booktitle={International Conference on Learning Representations}, 125 | year={2019}, 126 | url={https://openreview.net/forum?id=H1MgjoR9tQ}, 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /mutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some auxiliary functions for saving results, doing hyperparameter optimization, parsing optimizer parameters. 3 | """ 4 | 5 | import re 6 | import inspect 7 | from torch import optim 8 | from itertools import product 9 | from bayes_opt import BayesianOptimization 10 | from hyperopt import fmin, rand, hp 11 | from sklearn.gaussian_process.kernels import Matern 12 | import numpy as np 13 | import os, csv 14 | 15 | def _update_options(options, **parameters): 16 | for param_name, param_value in parameters.items(): 17 | print("In automatic optimization trying parameter:", param_name, "with value", param_value) 18 | 19 | try: 20 | setattr(options, param_name, param_value) 21 | except AttributeError: 22 | print("Can't find parameter ", param_name, "so we'll not use it.") 23 | continue 24 | 25 | return options 26 | 27 | 28 | def _make_space(options): 29 | 30 | space = {} 31 | inits = {} 32 | with open(options.optimization_spaces) as optimization_file: 33 | for line in optimization_file: 34 | 35 | # escape comments 36 | if line.startswith("#"): 37 | continue 38 | 39 | line = line.strip() 40 | info = line.split(",") 41 | param_name = info[0] 42 | 43 | if options.optimization == "random": 44 | left_bound, right_bound = float(info[1]), float(info[2]) 45 | 46 | param_type = info[3] 47 | 48 | try: 49 | param_type = getattr(hp, param_type) 50 | except AttributeError: 51 | print("hyperopt has no attribute", param_type) 52 | continue 53 | 54 | space[param_name] = param_type(param_name, left_bound, right_bound) 55 | elif options.optimization == "bayesian": 56 | left_bound, right_bound = float(info[1]), float(info[2]) 57 | 58 | init_values = list(map(float, info[3:])) 59 | num_init_vals = len(init_values) 60 | inits[param_name] = init_values 61 | space[param_name] = (left_bound, right_bound) 62 | 63 | elif options.optimization == "grid": 64 | 65 | param_type = info[1] 66 | def get_cast_func(some_string_type): 67 | cast_func = None 68 | if some_string_type == "int": 69 | cast_func = int 70 | elif some_string_type == "float": 71 | cast_func = float 72 | elif some_string_type == "string": 73 | cast_func = str 74 | elif some_string_type == "bool": 75 | cast_func = bool 76 | return cast_func 77 | 78 | cast_func = get_cast_func(param_type) 79 | if cast_func is None: 80 | if param_type.startswith("list"): 81 | # determine type in list 82 | list_type = get_cast_func(param_type.split("-")[1]) 83 | 84 | # assume they are seperated by semicolon 85 | def extract_items(list_string): 86 | return [list_type(x) for x in list_string.split(";")] 87 | 88 | cast_func = extract_items 89 | 90 | 91 | # all possible values 92 | space[param_name] = list(map(cast_func, info[2:])) 93 | 94 | if options.optimization == "bayesian": 95 | return space, inits, num_init_vals 96 | else: 97 | return space 98 | 99 | def _all_option_combinations(space): 100 | 101 | names = [name for name, _ in space.items()] 102 | values = [values for _, values in space.items()] 103 | 104 | val_combinations = product(*values) 105 | 106 | combinations = [] 107 | for combi in val_combinations: 108 | new_param_dict = {} 109 | for i, val in enumerate(combi): 110 | new_param_dict[names[i]] = val 111 | 112 | combinations.append(new_param_dict) 113 | 114 | return combinations 115 | 116 | def run_hyperparameter_optimization(options, run_exp): 117 | """ 118 | This function performs hyperparameter optimization using bayesian optimization, random search, or gridsearch. 119 | 120 | It takes an argparse object holding the parameters for configuring an experiments, and a function 121 | 'run_exp' that takes the argparse object, runs an experiments with the respective configuration, and 122 | returns a score from that configuration. 123 | It then uses the hyperparameter optimization method to adjust the parameters and run the new configuration. 124 | 125 | Parameters: 126 | ================ 127 | 128 | argparse : 129 | The argparse object holding the parameters. In particular, it must contain the following two parameters. 130 | 'optimization' : str, Specifies the optimization method. Either 'bayesian', 'random', or 'grid'. 131 | 'optimization_spaces' : str, Specifies the path to a file that denotes the parameters to do search over and 132 | their possible values (in case of grid search) or possible spaces. See file 'default_optimization_space' for 133 | details. 134 | 135 | run_exp : function 136 | A function that takes the argparse object as input and returns a float that is interpreted as the 137 | score of the configuration (higher is better). 138 | 139 | """ 140 | 141 | if options.optimization: 142 | 143 | def optimized_experiment(**parameters): 144 | 145 | current_options = _update_options(options, **parameters) 146 | result = run_exp(current_options) 147 | 148 | # return the f1 score of the previous experiment 149 | return result 150 | 151 | if options.optimization == "bayesian": 152 | 153 | gp_params = {"alpha": 1e-5, "kernel" : Matern(nu = 5 / 2)} 154 | space, init_vals, num_init_vals = _make_space(options) 155 | bayesian_optimizer = BayesianOptimization(optimized_experiment, space) 156 | bayesian_optimizer.explore(init_vals) 157 | bayesian_optimizer.maximize(n_iter=options.optimization_iterations - num_init_vals, 158 | acq = 'ei', 159 | **gp_params) 160 | 161 | elif options.optimization == "random": 162 | 163 | fmin(lambda parameters : optimized_experiment(**parameters), 164 | _make_space(options), 165 | algo=rand.suggest, 166 | max_evals=options.optimization_iterations, 167 | rstate = np.random.RandomState(1337)) 168 | 169 | elif options.optimization == "grid": 170 | # perform grid-search by running every possible parameter combination 171 | combinations = _all_option_combinations(_make_space(options)) 172 | for combi in combinations: 173 | optimized_experiment(**combi) 174 | 175 | else: 176 | raise Exception("No hyperparameter method specified!") 177 | 178 | def write_to_csv(score, opt): 179 | """ 180 | Writes the scores and configuration to csv file. 181 | """ 182 | f = open(opt.output_file, 'a') 183 | if os.stat(opt.output_file).st_size == 0: 184 | for i, (key, _) in enumerate(opt.__dict__.items()): 185 | f.write(key + ";") 186 | for i, (key, _) in enumerate(score.items()): 187 | if i < len(score.items()) - 1: 188 | f.write(key + ";") 189 | else: 190 | f.write(key) 191 | f.write('\n') 192 | f.flush() 193 | f.close() 194 | 195 | f = open(opt.output_file, 'r') 196 | reader = csv.reader(f, delimiter=";") 197 | column_names = next(reader) 198 | f.close(); 199 | 200 | f = open(opt.output_file, 'a') 201 | for i, key in enumerate(column_names): 202 | if i < len(column_names) - 1: 203 | if key in opt.__dict__: 204 | f.write(str(opt.__dict__[key]) + ";") 205 | else: 206 | f.write(str(score[key]) + ";") 207 | else: 208 | if key in opt.__dict__: 209 | f.write(str(opt.__dict__[key])) 210 | else: 211 | f.write(str(score[key])) 212 | f.write('\n') 213 | f.flush() 214 | f.close() 215 | 216 | def get_optimizer(s): 217 | """ 218 | Parse optimizer parameters. 219 | Input should be of the form: 220 | - "sgd,lr=0.01" 221 | - "adagrad,lr=0.1,lr_decay=0.05" 222 | """ 223 | if "," in s: 224 | method = s[:s.find(',')] 225 | optim_params = {} 226 | for x in s[s.find(',') + 1:].split(','): 227 | split = x.split('=') 228 | assert len(split) == 2 229 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 230 | optim_params[split[0]] = float(split[1]) 231 | else: 232 | method = s 233 | optim_params = {} 234 | 235 | if method == 'adadelta': 236 | optim_fn = optim.Adadelta 237 | elif method == 'adagrad': 238 | optim_fn = optim.Adagrad 239 | elif method == 'adam': 240 | optim_fn = optim.Adam 241 | elif method == 'sparseadam': 242 | optim_fn = optim.SparseAdam 243 | elif method == 'adamax': 244 | optim_fn = optim.Adamax 245 | elif method == 'asgd': 246 | optim_fn = optim.ASGD 247 | elif method == 'rmsprop': 248 | optim_fn = optim.RMSprop 249 | elif method == 'rprop': 250 | optim_fn = optim.Rprop 251 | elif method == 'sgd': 252 | optim_fn = optim.SGD 253 | assert 'lr' in optim_params 254 | else: 255 | raise Exception('Unknown optimization method: "%s"' % method) 256 | 257 | # check that we give good parameters to the optimizer 258 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 259 | assert expected_args[:2] == ['self', 'params'] 260 | if not all(k in expected_args[2:] for k in optim_params.keys()): 261 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 262 | str(expected_args[2:]), str(optim_params.keys()))) 263 | 264 | return optim_fn, optim_params 265 | 266 | 267 | class dotdict(dict): 268 | """ dot.notation access to dictionary attributes """ 269 | __getattr__ = dict.get 270 | __setattr__ = dict.__setitem__ 271 | __delattr__ = dict.__delitem__ 272 | -------------------------------------------------------------------------------- /wrap_evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is intended to provide an interface for standardized evaluation of encoders. For implementation, you need to provide the interface with a bunch of functions that you then need to pass to the run_and_evaluate function in the following 3 | fashion: 4 | 5 | >>> run_and_evaluate(run_experiment, get_params_parser, batcher, prepare). 6 | 7 | For detailed descriptions of what these methods have to provide, please refer to the functions documentation. 8 | """ 9 | 10 | import torch, os, sys, logging, pickle 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from torch.autograd import Variable 15 | from mutils import run_hyperparameter_optimization, write_to_csv 16 | from warnings import warn 17 | 18 | # Set PATHs to SentEval 19 | PATH_SENTEVAL = '/data22/fmai/data/SentEval/SentEval/' 20 | PATH_TO_DATA = '/data22/fmai/data/SentEval/SentEval/data' 21 | assert os.path.exists(PATH_SENTEVAL) and os.path.exists(PATH_TO_DATA), "Set path to SentEval + data correctly!" 22 | 23 | # import senteval 24 | sys.path.insert(0, PATH_SENTEVAL) 25 | import senteval 26 | 27 | def _get_score_for_name(downstream_results, name): 28 | if name in ["CR", "CoordinationInversion", "SST2", "Length", "OddManOut", "Tense", "SUBJ", "MRPC", "ObjNumber", "SubjNumber", "Depth", "WordContent", "SST5", "SNLI", "MPQA", "BigramShift", "MR", "TREC", "TopConstituents", "SICKEntailment"]: 29 | return downstream_results[name]["acc"] 30 | elif name in ["STS12", "STS13", "STS14", "STS15", "STS16"]: 31 | all_results = downstream_results[name]["all"] 32 | spearman = all_results["spearman"]["mean"] 33 | pearson = all_results["pearson"]["mean"] 34 | return (spearman, pearson) 35 | elif name in ["STSBenchmark", "SICKRelatedness"]: 36 | spearman = downstream_results[name]["spearman"] 37 | pearson = downstream_results[name]["pearson"] 38 | return (spearman, pearson) 39 | else: 40 | warn("Can not extract score from downstream task " + name) 41 | return 0 42 | 43 | def _run_experiment_and_save(run_experiment, params, batcher, prepare): 44 | 45 | encoder, losses = run_experiment(params) 46 | 47 | # Save encoder 48 | outputmodelname = construct_model_name(params.outputmodelname, params) 49 | torch.save(encoder, os.path.join(params.outputdir, outputmodelname + '.encoder')) 50 | 51 | # write training and validation loss to csv file 52 | with open(os.path.join(params.outputdir, outputmodelname + "_losses.csv"), 'w') as loss_csv: 53 | loss_csv.write("train_loss,val_loss\n") 54 | for train_loss, val_loss in losses: 55 | loss_csv.write(",".join([str(train_loss), str(val_loss)]) + "\n") 56 | 57 | scores = {} 58 | # Compute scores on downstream tasks 59 | if params.downstream_eval: 60 | downstream_scores = _evaluate_downstream_and_probing_tasks(encoder, params, batcher, prepare) 61 | 62 | # from each downstream task, only select scores we care about 63 | to_be_saved_scores = {} 64 | for score_name in downstream_scores: 65 | to_be_saved_scores[score_name] = _get_score_for_name(downstream_scores, score_name) 66 | scores.update(to_be_saved_scores) 67 | 68 | # Compute word embedding score 69 | if params.word_embedding_eval: 70 | output_path = _save_embeddings_to_word2vec(encoder, outputmodelname, params) 71 | 72 | # Save results to csv 73 | if params.output_file: 74 | write_to_csv(scores, params) 75 | 76 | return scores 77 | 78 | def _load_embeddings_from_word2vec(fname): 79 | w = load_embedding(fname, format="word2vec", normalize=True, lower=True, clean_words=False) 80 | return w 81 | 82 | def _save_embeddings_to_word2vec(encoder, outputmodelname, params): 83 | 84 | # Load lookup table as numpy 85 | embeddings = encoder.lookup_table 86 | embeddings = embeddings.weight.data.cpu().numpy() 87 | 88 | # Load (inverse) vocabulary to match ids to words 89 | path_to_vocabulary = os.path.join(params.outputdir, outputmodelname + '.vocab') 90 | vocabulary = pickle.load(open(path_to_vocabulary, "rb" ))[0] 91 | inverse_vocab = {vocabulary[w] : w for w in vocabulary} 92 | 93 | # Open file and write values in word2vec format 94 | output_path = os.path.join(params.outputdir, outputmodelname + '.emb') 95 | f = open(output_path, 'w') 96 | print(embeddings.shape[0] - 1, embeddings.shape[1], file = f) 97 | for i in range(1, embeddings.shape[0]): # skip the padding token 98 | cur_word = inverse_vocab[i] 99 | f.write(" ".join([cur_word] + [str(embeddings[i, j]) for j in range(embeddings.shape[1])]) + "\n") 100 | 101 | f.close() 102 | 103 | return output_path 104 | 105 | def _evaluate_downstream_and_probing_tasks(encoder, params, batcher, prepare): 106 | # define senteval params 107 | eval_type = sys.argv[1] if len(sys.argv) > 1 else "" 108 | if params.downstream_eval == "full": 109 | 110 | ## for comparable evaluation (as in literature) 111 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 112 | params_senteval['classifier'] = {'nhid': params.nhid, 'optim': 'adam', 'batch_size': 64, 113 | 'tenacity': 5, 'epoch_size': 4} 114 | 115 | elif params.downstream_eval == "test": 116 | ## for testing purpose 117 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 118 | params_senteval['classifier'] = {'nhid': params.nhid, 'optim': 'rmsprop', 'batch_size': 128, 119 | 'tenacity': 3, 'epoch_size': 2} 120 | else: 121 | sys.exit("Need to specify if you want to run it in full mode ('full', takes a long time," \ 122 | "but is comparable to literature) or test mode ('test', fast, but not as accurate and comparable).") 123 | 124 | # Pass encoder and command line parameters 125 | params_senteval['word2mat'] = encoder 126 | params_senteval['cmd_params'] = params 127 | 128 | # evaluate 129 | se = senteval.engine.SE(params_senteval, batcher, prepare) 130 | results = se.eval(params.downstream_tasks) 131 | 132 | return results 133 | 134 | def construct_model_name(names, params): 135 | """Constructs model name from all params in "name", unless name only holds a single argument.""" 136 | if len(names) == 1: 137 | return names[0] 138 | else: 139 | # construct model name from configuration 140 | name = "" 141 | params_dict = vars(params) 142 | 143 | for key in names: 144 | 145 | name += str(key) + ":" + str(params_dict[key]) + "-" 146 | 147 | return name 148 | 149 | 150 | def _add_common_arguments(parser): 151 | 152 | group = parser.add_argument_group("common", "Arguments needed for evaluation.") 153 | # paths 154 | parser.add_argument("--outputdir", type=str, default='savedir/', help="Output directory", required = True) 155 | parser.add_argument("--outputmodelname", type=str, nargs = "+", default=["mymodel"], help="If one argument is passed, the model is saved at the respective location. If multiple arguments are passed, these are interpreted of names of parameters from which the modelname is automatically constructed in a fashion.", required = True) 156 | parser.add_argument('--output_file', type=str, default=None, help= \ 157 | "Specify the file name to save the result in. Default: [None]") 158 | 159 | # hyperparameter opt 160 | parser.add_argument("--optimization", type=str, default=None) 161 | parser.add_argument("--optimization_spaces", type=str, default="default_optimization_spaces") 162 | parser.add_argument("--optimization_iterations", type=int, default=10) 163 | 164 | # reproducibility 165 | parser.add_argument("--seed", type=int, default=1234, help="seed") 166 | 167 | # evaluation 168 | parser.add_argument("--downstream_eval", type=str, help="Whether to perform 'full'" \ 169 | "downstream evaluation (slow), 'test' downstream evaluation (fast).", 170 | choices = ["test", "full"], required = True) 171 | parser.add_argument("--downstream_tasks", type=str, nargs = "+", 172 | default=['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI', 173 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 174 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 175 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense', 176 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion'], 177 | help="Downstream tasks to evaluate on.", 178 | choices = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI', 179 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 180 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 181 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense', 182 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion']) 183 | parser.add_argument("--word_embedding_eval", action="store_true", default=False, help="If specified," \ 184 | "evaluate the word embeddings and store the results.") 185 | parser.add_argument("--nhid", type=int, default=0, help="Specify the number of hidden units used at test time to train classifiers. If 0 is specified, no hidden layer is employed.") 186 | 187 | 188 | return parser 189 | 190 | def run_and_evaluate(run_experiment, get_params_parser, batcher, prepare): 191 | """ 192 | Trains and evaluates one or several configurations, given some functionality to train the encoder 193 | and generate embeddings at test time. In particular, 4 functions have to be provided: 194 | 195 | - get_params_parser: () -> argparse.Argumentparser 196 | Creates and returns an argument parser with arguments that will be used by the 197 | run_experiment parser. Note that it must not contain any of the arguments that are 198 | are later added by the run_and_evaluate function, and are shared by all implementations. Please refer to the _add_common_arguments function for an overview of common arguments. 199 | - outputdir is not optional and the dir has to exist 200 | 201 | - run_experiment: params -> encoder, losses 202 | The function that trains the encoder. As input, it gets the params dictionary that contains 203 | the arguments parsed from the command line. As output, it has to return the encoder object that will 204 | at test time be passed to the 'batcher' function. 205 | The encoder is required to have an attribute 'lookup_table' which is a torch.nn.Embedding. 206 | 'losses' is a list of successive 207 | (train_loss, val_loss) pairs that will be written to a csv file. If you do not want to track 208 | the loss, simply return an empty list. 209 | - params need to include an 'outputmodelname', else the params need to satisfy construct_model_name 210 | 211 | - batcher: (params, batch) -> nparray of shape (batch_size, embedding_size) 212 | A function as defined by SentEval (see https://github.com/facebookresearch/SentEval). 213 | 'params' is a dictionary that contains the encoder as "params['word2mat']", as well as 214 | further parameters specified by the implementer and added via the 'prepare' function. 215 | - the params namespace from run_experiments is available as params['cmd_params'] 216 | 217 | - prepare: (senteval_params, samples) -> () 218 | A function as defined by SentEval (see https://github.com/facebookresearch/SentEval). 219 | In particular, this function can be used to add further arguments to the senteval_params 220 | dictionary for use in batcher. 221 | 'senteval_params' is a dictionary that in turn contains the 'params' objective parsed from 222 | the command line as the argument 'cmd_params'. 223 | """ 224 | parser = get_params_parser() 225 | parser = _add_common_arguments(parser) 226 | params = parser.parse_args() 227 | 228 | if params.optimization: 229 | def hyperparameter_optimization_func(opts): 230 | results = _run_experiment_and_save(run_experiment, opts, batcher, prepare) 231 | return results["CR"] 232 | run_hyperparameter_optimization(params, hyperparameter_optimization_func) 233 | else: 234 | _run_experiment_and_save(run_experiment, params, batcher, prepare) 235 | 236 | -------------------------------------------------------------------------------- /cbow.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file handles everything related to the CBOW task, i.e., it creates the training examples (CBOWDataset), and provides the neural architecture (except encoder) and loss computation (see CBOWNet). 3 | """ 4 | 5 | import os, pickle, math 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torch.utils.data.sampler import Sampler 10 | from collections import Counter 11 | import nltk.data 12 | from nltk.tokenize import word_tokenize 13 | from random import shuffle 14 | import random 15 | 16 | import torch.nn as nn 17 | from torch import FloatTensor as FT 18 | from torch import ByteTensor as BT 19 | from torch.autograd import Variable 20 | 21 | sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') 22 | 23 | def recursive_file_list(path): 24 | """ 25 | Recursively aggregates all files at the given path, i.e., files in subfolders are also 26 | included. 27 | """ 28 | return [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] 29 | 30 | def tokenize(sent, sophisticated = False): 31 | """ 32 | Tokenizes the sentence. If 'sophisticated' is set to False, the 33 | tokenization is a simple split by the blank character. Otherwise the 34 | TreebankWordTokenizer provided by NLTK. 35 | """ 36 | return sent.split() if not sophisticated else word_tokenize(sent) 37 | 38 | def sentenize(text): 39 | return sent_tokenizer.tokenize(text) 40 | 41 | def get_wordvec_batch(batch, word_vec): 42 | # sent in batch in decreasing order of lengths (bsize, max_len, word_dim) 43 | lengths = np.array([len(x) for x in batch]) 44 | max_len = np.max(lengths) 45 | embed = np.zeros((max_len, len(batch), 300)) 46 | 47 | for i in range(len(batch)): 48 | for j in range(len(batch[i])): 49 | embed[j, i, :] = word_vec[batch[i][j]] 50 | 51 | return torch.from_numpy(embed).float(), lengths 52 | 53 | def get_index_batch(batch, word_vec): 54 | 55 | # remove all words that are out of vocabulary 56 | clean_batch = [] 57 | for sen in batch: 58 | clean_batch.append([w for w in sen if w in word_vec]) 59 | batch = clean_batch 60 | 61 | lengths = np.array([len(x) for x in batch]) 62 | max_len = np.max(lengths) 63 | embed = np.zeros((max_len, len(batch))) 64 | 65 | for i in range(len(batch)): 66 | for j in range(len(batch[i])): 67 | embed[j, i] = word_vec[batch[i][j]] 68 | 69 | return torch.from_numpy(embed).long(), lengths 70 | 71 | 72 | def get_word_dict(sentences): 73 | # create vocab of words and also count occurences 74 | word_dict = {} 75 | for sent in sentences: 76 | for word in tokenize(sent): 77 | if word not in word_dict: 78 | word_dict[word] = 1 79 | else: 80 | word_dict[word] += 1 81 | 82 | return word_dict 83 | 84 | 85 | def get_wordembedding(word_dict, we_path): 86 | # create word_vec with glove vectors 87 | word_voc = {} 88 | with open(we_path) as f: 89 | 90 | # discard the information in first row 91 | _, emb_size = f.readline().split() 92 | 93 | i = 1 94 | word_embs = [] 95 | for line in f: 96 | line = line.strip('\n').split() 97 | word_end = len(line) - int(emb_size) 98 | word = " ".join(line[:word_end]) 99 | if word in word_dict: 100 | word_voc[word] = i 101 | word_embs.append(np.asarray(list(map(float, line[word_end:])))) 102 | i += 1 103 | print('Found {0}(/{1}) words with glove vectors'.format( 104 | len(word_voc), len(word_dict))) 105 | 106 | word_embs = np.vstack(word_embs) 107 | word_count = {w : word_dict[w] for w in word_voc} 108 | return word_voc, word_count, word_embs 109 | 110 | def get_index_vocab(word_dict, max_words): 111 | if max_words is not None: 112 | counter = Counter(word_dict) 113 | most_common_words = counter.most_common(max_words) 114 | reduced_word_dict = {} 115 | for w, cnt in most_common_words: 116 | reduced_word_dict[w] = cnt 117 | word_dict = reduced_word_dict 118 | print('Num words in corpus : {:,}'.format(np.sum([word_dict[w] for w in word_dict]))) 119 | 120 | # create word_vec with glove vectors 121 | word_vec = {} 122 | idx = 1 # reserve 0 for padding_idx 123 | for word in word_dict: 124 | word_vec[word] = idx 125 | idx += 1 126 | return word_vec, word_dict 127 | 128 | 129 | def build_vocab(sentences, pretrained_embeddings = None, max_words = None): 130 | word_dict = get_word_dict(sentences) 131 | if pretrained_embeddings: 132 | word_to_index, word_to_count, word_embeddings = get_wordembedding(word_dict, pretrained_embeddings) 133 | else: 134 | word_to_index, word_to_count = get_index_vocab(word_dict, max_words) # padding_idx = 0 135 | print('Vocab size : {0}'.format(len(word_to_index))) 136 | 137 | if pretrained_embeddings: 138 | return word_to_index, word_to_count, word_embeddings 139 | else: 140 | return word_to_index, word_to_count 141 | 142 | class CBOWDataset(Dataset): 143 | """ 144 | Considers each line of a file to be a text. 145 | Reads all files found at directory 'path' and corresponding subdirectories. 146 | """ 147 | def __init__(self, path, num_texts, context_size, num_samples_per_item, mode, precomputed_word_vocab, max_words, pretrained_embeddings, num_texts_per_chunk, precomputed_chunks_dir, temp_path): 148 | 149 | self.context_size = context_size 150 | self.num_samples_per_item = num_samples_per_item 151 | self.mode = mode 152 | self.num_texts_per_chunk = num_texts_per_chunk 153 | 154 | texts_generator = _generate_texts(path, num_texts) 155 | 156 | # load precomputed word vocabulary and counts 157 | if precomputed_word_vocab: 158 | word_vec = pickle.load(open(os.path.join(precomputed_word_vocab), "rb" )) 159 | else: 160 | 161 | word_vec = build_vocab(texts_generator, 162 | pretrained_embeddings = pretrained_embeddings, 163 | max_words = max_words) 164 | 165 | # create chunks 166 | self.num_texts = num_texts 167 | self.num_chunks = math.ceil(num_texts / (1.0*self.num_texts_per_chunk)) 168 | self._temp_path = temp_path 169 | if not os.path.exists(self._temp_path): 170 | os.makedirs(self._temp_path) 171 | 172 | if precomputed_chunks_dir is None: 173 | self._create_chunk_files(_generate_texts(path, num_texts)) 174 | self._check_chunk_files() 175 | else: 176 | self._temp_path = precomputed_chunks_dir 177 | print("use precomputed chunk files.") 178 | 179 | self._word_vec_count_tuple = word_vec 180 | self.word_vec, self.word_count = word_vec 181 | self.num_training_samples = self.num_texts 182 | 183 | # compute unigram distribution 184 | ## set frequency of padding token to 0 implicitly 185 | unigram_dist = np.zeros((len(self.word_vec) + 1)) 186 | for w in self.word_vec: 187 | unigram_dist[self.word_vec[w]] = self.word_count[w] 188 | 189 | self.unigram_dist = unigram_dist 190 | 191 | 192 | def _count_words_per_text(self): 193 | text_lengths = [0] * len(self.texts) 194 | 195 | for i, text in enumerate(self.texts): 196 | 197 | words = tokenize(text) 198 | words = [self.word_vec[w] for w in words if w in self.word_vec] 199 | text_lengths[i] = len(words) 200 | 201 | return text_lengths 202 | 203 | def _check_chunk_files(self): 204 | """Raises an exception if any of the chunks generated 205 | is empty. 206 | """ 207 | for i in range(self.num_chunks): 208 | with open(self._get_chunk_file_name(i), "r") as f: 209 | lines = f.readlines() 210 | if(len(lines) == 0): 211 | raise Exception("Chunk ", i, " is empty\n") 212 | 213 | def _create_chunk_files(self, texts_generator): 214 | cur_chunk_number = 0 215 | cur_chunk_file = open(self._get_chunk_file_name(cur_chunk_number), "w") 216 | cur_idx = 0 217 | last_chunk_size = self.num_texts - (self.num_texts_per_chunk*(self.num_chunks-1)) 218 | for text in texts_generator: 219 | print(text, file=cur_chunk_file) 220 | if cur_idx == self.num_texts_per_chunk - 1 or (cur_idx == last_chunk_size-1 and 221 | cur_chunk_number == self.num_chunks-1): 222 | # start next chunk 223 | cur_chunk_file.close() 224 | cur_idx = 0 # index within the chunk 225 | cur_chunk_number += 1 226 | cur_chunk_file = open(self._get_chunk_file_name(cur_chunk_number), "w") 227 | else: 228 | cur_idx += 1 229 | cur_chunk_file.close() 230 | 231 | def _get_chunk_file_name(self, chunk_number): 232 | return os.path.join(self._temp_path, "chunk" + str(chunk_number)) 233 | 234 | def __len__(self): 235 | return self.num_texts 236 | 237 | def _load_text(self, idx): 238 | chunk_number = math.floor(idx / (1.0*self.num_texts_per_chunk)) 239 | idx_in_chunk = idx % self.num_texts_per_chunk 240 | with open(self._get_chunk_file_name(chunk_number), "r") as f: 241 | for i, line in enumerate(f): 242 | if i == idx_in_chunk: 243 | return line.strip() 244 | raise Exception("Text with idx: ", idx, " in chunk: ", chunk_number,\ 245 | " and idx_in_chunk: ", idx_in_chunk, " not found.") 246 | 247 | def _compute_idx_to_text_word_dict(self): 248 | idx_to_text_word_tuple = {} 249 | idx = 0 250 | for i, text in enumerate(self.texts): 251 | 252 | for j in range(self.text_lengths[i]): 253 | idx_to_text_word_tuple.update({idx : (i, j)}) 254 | idx += 1 255 | 256 | self.idx_to_text_word_tuple = idx_to_text_word_tuple 257 | 258 | def _create_window_samples(self, words): 259 | 260 | text_len = len(words) 261 | num_samples = min(text_len, self.num_samples_per_item) 262 | 263 | words = [0] * self.context_size + words + [0] * self.context_size 264 | 265 | training_sequences = np.zeros((num_samples, 2 * self.context_size)) 266 | missing_words = np.zeros((num_samples)) 267 | 268 | # randomly select mid_words to use 269 | mid_words = random.sample(range(text_len), num_samples) 270 | for i, j in enumerate(mid_words): 271 | 272 | middle_word = self.context_size + j 273 | 274 | # choose a word that is removed from the window 275 | if self.mode == 'random': 276 | rand_offset = random.randint(-self.context_size, self.context_size) 277 | missing_word = middle_word + rand_offset 278 | elif self.mode == 'cbow': 279 | missing_word = middle_word 280 | else: 281 | raise NotImplementedError("Unknown training mode " + self.mode) 282 | 283 | # zero is the padding word 284 | training_sequence = [middle_word + context_word for context_word in range(-self.context_size, self.context_size + 1) if middle_word + context_word != missing_word] 285 | training_sequence = [words[w] for w in training_sequence] 286 | training_sequences[i, :] = np.array(training_sequence) 287 | missing_word = words[missing_word] 288 | missing_words[i] = np.array(missing_word) 289 | return training_sequences, missing_words 290 | 291 | def __getitem__(self, idx): 292 | 293 | text = self._load_text(idx) 294 | words = tokenize(text) 295 | words = [self.word_vec[w] for w in words if w in self.word_vec] 296 | text_len = len(words) 297 | 298 | # TODO: is there a better way to handle empty texts? 299 | if text_len == 0: 300 | return None, None 301 | 302 | if self.mode in ['random', 'cbow']: 303 | return self._create_window_samples(words) 304 | else: 305 | raise NotImplementedError("Unknown mode " + str(self.mode)) 306 | 307 | ## collate function for cbow 308 | def collate_fn(self, l): 309 | l1, l2 = zip(*l) 310 | l1 = [x for x in l1 if x is not None] 311 | l2 = [x for x in l2 if x is not None] 312 | l1 = np.vstack(l1) 313 | l2 = np.concatenate(l2) 314 | return torch.from_numpy(l1).long(), torch.from_numpy(l2).long() 315 | 316 | def _load_texts(path, num_docs): 317 | texts = [] 318 | filename_list = recursive_file_list(path) 319 | 320 | for filename in filename_list: 321 | with open(os.path.realpath(filename), 'r') as f: 322 | 323 | # change encoding to utf8 to be consistent with other datasets 324 | #cur_text.decode("ISO-8859-1").encode("utf-8") 325 | for line in f: 326 | line = line.strip() 327 | texts.append(line) 328 | 329 | if num_docs is not None and len(texts) > num_docs: 330 | break 331 | 332 | return texts 333 | 334 | def _generate_texts(path, num_docs): 335 | filename_list = recursive_file_list(path) 336 | 337 | for filename in filename_list: 338 | with open(os.path.realpath(filename), "r") as f: 339 | 340 | # change encoding to utf8 to be consistent with other datasets 341 | # cur_text.decode("ISO-8859-1").encode("utf-8") 342 | for i, line in enumerate(f): 343 | line = line.strip() 344 | if num_docs is not None and i > num_docs - 1: 345 | break 346 | yield line 347 | 348 | class CBOWNet(nn.Module): 349 | def __init__(self, encoder, output_embedding_size, output_vocab_size, weights = None, n_negs = 20, padding_idx = 0): 350 | super(CBOWNet, self).__init__() 351 | 352 | self.encoder = encoder 353 | self.n_negs = n_negs 354 | self.weights = weights 355 | self.output_vocab_size = output_vocab_size 356 | self.output_embedding_size = output_embedding_size 357 | 358 | self.outputembeddings = nn.Embedding(output_vocab_size + 1, output_embedding_size, padding_idx=0) 359 | 360 | if self.weights is not None: 361 | wf = np.power(self.weights, 0.75) 362 | wf = wf / wf.sum() 363 | self.weights = FT(wf) 364 | 365 | def forward(self, input_s, missing_word): 366 | 367 | embedding = self.encoder(input_s) 368 | batch_size = embedding.size()[0] 369 | emb_size = embedding.size()[1] 370 | 371 | # draw negative samples 372 | if self.weights is not None: 373 | nwords = torch.multinomial(self.weights, batch_size * self.n_negs, replacement=True).view(batch_size, -1) 374 | else: 375 | nwords = FT(batch_size, self.n_negs).uniform_(0, self.vocab_size).long() 376 | nwords = Variable(torch.LongTensor(nwords), requires_grad=False).cuda() 377 | 378 | # lookup the embeddings of output words 379 | missing_word_vector = self.outputembeddings(missing_word) 380 | 381 | nvectors = self.outputembeddings(nwords).neg() 382 | 383 | # compute loss for correct word 384 | oloss = torch.bmm(missing_word_vector.view(batch_size, 1, emb_size), embedding.view(batch_size, emb_size, 1)) 385 | oloss = oloss.squeeze().sigmoid() 386 | 387 | ## add epsilon to prediction to avoid numerical instabilities 388 | oloss = self._add_epsilon(oloss) 389 | oloss = oloss.log() 390 | 391 | # compute loss for negative samples 392 | nloss = torch.bmm(nvectors, embedding.view(batch_size, -1, 1)).squeeze().sigmoid() 393 | 394 | ## add epsilon to prediction to avoid numerical instabilities 395 | nloss = self._add_epsilon(nloss) 396 | nloss = nloss.log() 397 | nloss = nloss.mean(1) 398 | 399 | # combine losses 400 | return -(oloss + nloss) 401 | 402 | def _add_epsilon(self, pred): 403 | return pred + 0.00001 404 | 405 | def encode(self, s1): 406 | emb = self.encoder(s1) 407 | return emb 408 | 409 | -------------------------------------------------------------------------------- /train_cbow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main script for training a Word2Mat model. 3 | """ 4 | 5 | import os 6 | import sys 7 | import time 8 | import argparse 9 | import pickle 10 | import random 11 | 12 | import numpy as np 13 | from random import shuffle 14 | 15 | import torch 16 | from torch.autograd import Variable 17 | import torch.nn as nn 18 | 19 | # CBOW data 20 | from torch.utils.data import DataLoader 21 | from cbow import CBOWNet, build_vocab, tokenize, CBOWDataset, recursive_file_list, get_index_batch 22 | 23 | # Encoder 24 | from mutils import get_optimizer, run_hyperparameter_optimization, write_to_csv 25 | from word2mat import get_cbow_cmow_hybrid_encoder, get_cbow_encoder, get_cmow_encoder 26 | from torch.utils.data.sampler import SubsetRandomSampler 27 | 28 | from wrap_evaluation import run_and_evaluate, construct_model_name 29 | 30 | import time 31 | 32 | def run_experiment(params): 33 | 34 | # print parameters passed, and all parameters 35 | print('\ntogrep : {0}\n'.format(sys.argv[1:])) 36 | print(params) 37 | 38 | """ 39 | SEED 40 | """ 41 | np.random.seed(params.seed) 42 | torch.manual_seed(params.seed) 43 | torch.cuda.manual_seed(params.seed) 44 | 45 | """ 46 | DATA 47 | """ 48 | dataset_path = params.dataset_path 49 | 50 | # build training and test corpus 51 | filename_list = recursive_file_list(dataset_path) 52 | print('Use the following files for training: ', filename_list) 53 | corpus = CBOWDataset(dataset_path, params.num_docs, params.context_size, 54 | params.num_samples_per_item, params.mode, 55 | params.precomputed_word_vocab, params.max_words, 56 | None, 1000, params.precomputed_chunks_dir, params.temp_path) 57 | corpus_len = len(corpus) 58 | 59 | ## split train and test 60 | inds = list(range(corpus_len)) 61 | shuffle(inds) 62 | 63 | num_val_samples = int(corpus_len * params.validation_fraction) 64 | train_indices = inds[:-num_val_samples] if num_val_samples > 0 else inds 65 | test_indices = inds[-num_val_samples:] if num_val_samples > 0 else [] 66 | 67 | cbow_train_loader = DataLoader(corpus, sampler = SubsetRandomSampler(train_indices), batch_size=params.batch_size, shuffle=False, num_workers = params.num_workers, pin_memory = True, collate_fn = corpus.collate_fn) 68 | cbow_test_loader = DataLoader(corpus, sampler = SubsetRandomSampler(test_indices), batch_size=params.batch_size, shuffle=False, num_workers = params.num_workers, pin_memory = True, collate_fn = corpus.collate_fn) 69 | 70 | ## extract some variables needed for training 71 | num_training_samples = corpus.num_training_samples 72 | word_vec = corpus.word_vec 73 | unigram_dist = corpus.unigram_dist 74 | word_vec_copy = corpus._word_vec_count_tuple 75 | 76 | print("Number of sentences used for training:", str(num_training_samples)) 77 | 78 | """ 79 | MODEL 80 | """ 81 | 82 | # build path where to store the encoder 83 | outputmodelname = construct_model_name(params.outputmodelname, params) 84 | 85 | # build encoder 86 | n_words = len(word_vec) 87 | if params.w2m_type == "cmow": 88 | encoder = get_cmow_encoder(n_words, padding_idx = 0, 89 | word_emb_dim = params.word_emb_dim, 90 | initialization_strategy = params.initialization) 91 | output_embedding_size = params.word_emb_dim 92 | elif params.w2m_type == "cbow": 93 | encoder = get_cbow_encoder(n_words, padding_idx = 0, word_emb_dim = params.word_emb_dim) 94 | output_embedding_size = params.word_emb_dim 95 | elif params.w2m_type == "hybrid": 96 | encoder = get_cbow_cmow_hybrid_encoder(n_words, padding_idx = 0, 97 | word_emb_dim = params.word_emb_dim, 98 | initialization_strategy = params.initialization) 99 | output_embedding_size = 2 * params.word_emb_dim 100 | 101 | # build cbow model 102 | cbow_net = CBOWNet(encoder, output_embedding_size, n_words, 103 | weights = unigram_dist, n_negs = params.n_negs, padding_idx = 0) 104 | if torch.cuda.device_count() > 1: 105 | print("Using", torch.cuda.device_count(), "GPUs for training!") 106 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 107 | cbow_net = nn.DataParallel(cbow_net) 108 | use_multiple_gpus = True 109 | else: 110 | use_multiple_gpus = False 111 | 112 | # optimizer 113 | print([x.size() for x in cbow_net.parameters()]) 114 | optim_fn, optim_params = get_optimizer(params.optimizer) 115 | optimizer = optim_fn(cbow_net.parameters(), **optim_params) 116 | 117 | # cuda by default 118 | cbow_net.cuda() 119 | 120 | """ 121 | TRAIN 122 | """ 123 | val_acc_best = -1e10 124 | adam_stop = False 125 | stop_training = False 126 | lr = optim_params['lr'] if 'sgd' in params.optimizer else None 127 | 128 | # compute learning rate schedule 129 | if params.linear_decay: 130 | lr_shrinkage = (lr - params.minlr) / ((float(num_training_samples) / params.batch_size) * params.n_epochs) 131 | 132 | 133 | def forward_pass(X_batch, tgt_batch, params, check_size = False): 134 | 135 | X_batch = Variable(X_batch).cuda() 136 | tgt_batch = Variable(torch.LongTensor(tgt_batch)).cuda() 137 | k = X_batch.size(0) # actual batch size 138 | 139 | loss = cbow_net(X_batch, tgt_batch).mean() 140 | return loss, k 141 | 142 | 143 | def validate(data_loader): 144 | cbow_net.eval() 145 | 146 | with torch.no_grad(): 147 | all_costs = [] 148 | for X_batch, tgt_batch in data_loader: 149 | loss, k = forward_pass(X_batch, tgt_batch, params) 150 | all_costs.append(loss.item()) 151 | 152 | cbow_net.train() 153 | return np.mean(all_costs) 154 | 155 | def trainepoch(epoch): 156 | print('\nTRAINING : Epoch ' + str(epoch)) 157 | cbow_net.train() 158 | all_costs = [] 159 | logs = [] 160 | words_count = 0 161 | 162 | last_time = time.time() 163 | correct = 0. 164 | 165 | if not params.linear_decay: 166 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * params.decay if epoch>1\ 167 | and 'sgd' in params.optimizer else optimizer.param_groups[0]['lr'] 168 | print('Learning rate : {0}'.format(optimizer.param_groups[0]['lr'])) 169 | 170 | processed_training_samples = 0 171 | start_time = time.time() 172 | total_time = 0 173 | total_batch_generation_time = 0 174 | total_forward_time = 0 175 | total_backward_time = 0 176 | total_step_time = 0 177 | last_processed_training_samples = 0 178 | 179 | nonlocal processed_batches, stop_training, no_improvement, min_val_loss, losses, min_loss_criterion 180 | for i, (X_batch, tgt_batch) in enumerate(cbow_train_loader): 181 | 182 | batch_generation_time = (time.time() - start_time) * 1000000 183 | 184 | # forward pass 185 | forward_start = time.time() 186 | loss, k = forward_pass(X_batch, tgt_batch, params) 187 | all_costs.append(loss.item()) 188 | forward_total = (time.time() - forward_start) * 1000000 189 | 190 | # backward 191 | backward_start = time.time() 192 | optimizer.zero_grad() 193 | loss.backward() 194 | 195 | backward_total = (time.time() - backward_start) * 1000000 196 | 197 | # linear learning rate decay 198 | if params.linear_decay: 199 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] - lr_shrinkage if \ 200 | 'sgd' in params.optimizer else optimizer.param_groups[0]['lr'] 201 | 202 | # optimizer step 203 | step_time = time.time() 204 | optimizer.step() 205 | total_step_time += (time.time() - step_time) * 1000000 206 | 207 | # log progress 208 | processed_training_samples += params.batch_size 209 | percentage_done = float(processed_training_samples) / num_training_samples 210 | processed_batches += 1 211 | if processed_batches == params.validation_frequency: 212 | 213 | # compute validation loss and train loss 214 | val_loss = round(validate(cbow_test_loader), 5) if num_val_samples > 0 else float('inf') 215 | train_loss = round(np.mean(all_costs), 5) 216 | 217 | # print current loss and processing speed 218 | logs.append('Epoch {3} - {4:.4} ; lr {2:.4} ; train-loss {0} ; val-loss {5} ; sentence/s {1}'.format(train_loss, int((processed_training_samples - last_processed_training_samples) / (time.time() - last_time)), optimizer.param_groups[0]['lr'], epoch, percentage_done, val_loss)) 219 | if params.VERBOSE: 220 | print('\n\n\n') 221 | print(logs[-1]) 222 | last_time = time.time() 223 | words_count = 0 224 | all_costs = [] 225 | last_processed_training_samples = processed_training_samples 226 | 227 | if params.VERBOSE: 228 | print("100 Batches took {} microseconds".format(total_time)) 229 | print("get_batch: {} \nforward: {} \nbackward: {} \nstep: {}".format(total_batch_generation_time / total_time, total_forward_time / total_time, total_backward_time / total_time, total_step_time / total_time)) 230 | total_time = 0 231 | total_batch_generation_time = 0 232 | total_forward_time = 0 233 | total_backward_time = 0 234 | total_step_time = 0 235 | processed_batches = 0 236 | 237 | # save losses for logging later 238 | losses.append((train_loss, val_loss)) 239 | 240 | # early stopping? 241 | if val_loss < min_val_loss: 242 | min_val_loss = val_loss 243 | 244 | # save best model 245 | torch.save(cbow_net, os.path.join(params.outputdir, outputmodelname + '.cbow_net')) 246 | 247 | if params.stop_criterion is not None: 248 | stop_crit_loss = eval(params.stop_criterion) 249 | if stop_crit_loss < min_loss_criterion: 250 | no_improvement = 0 251 | min_loss_criterion = stop_crit_loss 252 | else: 253 | no_improvement += 1 254 | if no_improvement > params.patience: 255 | stop_training = True 256 | print("No improvement in loss criterion", str(params.stop_criterion), 257 | "for", str(no_improvement), "steps. Terminate training.") 258 | break 259 | 260 | now = time.time() 261 | batch_time_micro = (now - start_time) * 1000000 262 | 263 | total_time = total_time + batch_time_micro 264 | total_batch_generation_time += batch_generation_time 265 | total_forward_time += forward_total 266 | total_backward_time += backward_total 267 | 268 | start_time = now 269 | 270 | 271 | """ 272 | Train model on CBOW objective 273 | """ 274 | epoch = 1 275 | 276 | processed_batches = 0 277 | min_val_loss = float('inf') 278 | min_loss_criterion = float('inf') 279 | no_improvement = 0 280 | losses = [] 281 | while not stop_training and epoch <= params.n_epochs: 282 | trainepoch(epoch) 283 | epoch += 1 284 | 285 | # load the best model 286 | if min_val_loss < float('inf'): 287 | cbow_net = torch.load(os.path.join(params.outputdir, outputmodelname + '.cbow_net')) 288 | print("Loading model with best validation loss.") 289 | else: 290 | # we use the current model; 291 | print("No model with better validation loss has been saved.") 292 | 293 | # save word vocabulary and counts 294 | pickle.dump(word_vec_copy, open( os.path.join(params.outputdir, outputmodelname + '.vocab'), "wb" )) 295 | 296 | if use_multiple_gpus: 297 | cbow_net = cbow_net.module 298 | return cbow_net.encoder, losses 299 | 300 | def get_params_parser(): 301 | 302 | parser = argparse.ArgumentParser(description='Training a word2mat model.') 303 | 304 | # paths 305 | parser.add_argument('--precomputed_word_vocab', type=str, default=None, help= \ 306 | "Specify path where to load precomputed word.") 307 | parser.add_argument('--precomputed_chunks_dir', type=str, default=None, help= \ 308 | "Specify path from where to load the chunkified input text.") 309 | parser.add_argument('--temp_path', type=str, required=True, help= \ 310 | "Specify path where to save the chunkified input text.") 311 | 312 | # training parameters 313 | parser.add_argument("--n_epochs", type=int, default=20) 314 | parser.add_argument("--batch_size", type=int, default=64) 315 | parser.add_argument("--optimizer", type=str, default="sgd,lr=0.1", help="adam or sgd,lr=0.1") 316 | parser.add_argument("--decay", type=float, default=0.99, help="lr decay") 317 | parser.add_argument("--linear_decay", action="store_true", help="If set, the learning rate is shrunk linearly after each batch as to approach minlr.") 318 | parser.add_argument("--minlr", type=float, default=1e-5, help="minimum lr") 319 | parser.add_argument("--validation_frequency", type=int, default=500, help="How many batches to process before evaluating on the validation set (500).") 320 | parser.add_argument("--validation_fraction", type=float, default=0.0001, help="What fraction of the corpus to use for validation.\ 321 | Set to 0 to not use validation set based saving of intermediate models.") 322 | parser.add_argument("--stop_criterion", type=str, default=None, help="Which loss to use as stopping criterion.", choices = ['val_loss', 'train_loss']) 323 | parser.add_argument("--patience", type=int, default=3, help="How many validation steps to make before terminating training.") 324 | parser.add_argument("--VERBOSE", action="store_true", default=False, help="Whether to print additional info on speed of processing.") 325 | parser.add_argument("--num_workers", type=int, default=10, help="How many worker threads to use for creating the samples from the dataset.") 326 | 327 | # Word2Mat specific 328 | parser.add_argument("--w2m_type", type=str, default='cmow', choices=['cmow', 'cbow', 'hybrid'], help="Choose the encoder to use.") 329 | parser.add_argument("--word_emb_dim", type=int, default=100, help="Dimensionality of word embeddings.") 330 | parser.add_argument("--initialization", type=str, default='identity', help="Initialization strategy to use.", choices = ['one', 'identity', 'normalized', 'normal']) 331 | 332 | # dataset and vocab 333 | parser.add_argument("--dataset_path", type=str, required=True, help="Path to a directory containing all files to use for training. " \ 334 | "One sentence per line in a file is assumed.") 335 | parser.add_argument("--max_words", type=int, default=None, help="Only produce embeddings for the most common tokens.") 336 | parser.add_argument("--num_docs", type=int, default=None, help="How many documents to consider from the source directory.") 337 | 338 | # CBOW specific 339 | parser.add_argument("--context_size", type=int, default=5, help="Context window size for CBOW.") 340 | parser.add_argument("--num_samples_per_item", type=int, default=1, help="Specify number of samples to generate from each sentence (the higher, the faster training).") 341 | parser.add_argument("--mode", type=str, help="Determines the mode of the prediction task, i.e., which word is to be removed from a given window of words. Options are 'cbow' (remove middle word) and 'random' (a random word from the window is removed).", default='random', choices = ['cbow', 'random']) 342 | parser.add_argument("--n_negs", type=int, default=5, help="How many negative samples to use for training (the larger the dataset, the fewer are required (5).") 343 | 344 | return parser 345 | 346 | def prepare(params_senteval, samples): 347 | 348 | params = params_senteval["cmd_params"] 349 | outputmodelname = construct_model_name(params.outputmodelname, params) 350 | 351 | # Load vocabulary 352 | vocabulary = pickle.load(open(os.path.join(params.outputdir, outputmodelname + '.vocab'), "rb" ))[0] 353 | 354 | params_senteval['vocabulary'] = vocabulary 355 | params_senteval['inverse_vocab'] = {vocabulary[w] : w for w in vocabulary} 356 | 357 | def _batcher_helper(params, batch): 358 | sent, _ = get_index_batch(batch, params.vocabulary) 359 | sent_cuda = Variable(sent.cuda()) 360 | sent_cuda = sent_cuda.t() 361 | params.word2mat.eval() # Deactivate drop-out and such 362 | embeddings = params.word2mat.forward(sent_cuda).data.cpu().numpy() 363 | 364 | return embeddings 365 | 366 | def batcher_cbow(params_senteval, batch): 367 | 368 | params = params_senteval["cmd_params"] 369 | embeddings = _batcher_helper(params_senteval, batch) 370 | return embeddings 371 | 372 | if __name__ == "__main__": 373 | 374 | run_and_evaluate(run_experiment, get_params_parser, batcher_cbow, prepare) 375 | --------------------------------------------------------------------------------