├── requirements.txt
├── data
└── extract_umbc.py
├── evaluate_word2mat.py
├── word2mat.py
├── README.md
├── mutils.py
├── wrap_evaluation.py
├── cbow.py
└── train_cbow.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | networkx==1.11
2 | sklearn
3 | numpy
4 | scipy
5 | torch
6 | torchtext
7 | bayesian-optimization
8 | hyperopt
9 | nltk
10 | pandas
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/extract_umbc.py:
--------------------------------------------------------------------------------
1 | """
2 | This script turns the UMBC News webcorpus into a file where
3 | each line is a sentence the corpus.
4 | In the resulting file, tokenization has already been performed and can be recreated by splitting by the whitespace character.
5 |
6 | The script assumes as parameters:
7 | arg[1] : Path to folder with .txt files
8 | The extracted UMBC webcorpus consists of a bunch of .txt files with a paragraph in each line (with blank
9 | lines in between.
10 |
11 | arg[2] : Output file path
12 | Path to the location where the output file should be stored (i.e., the file containing one sentence per line).
13 | """
14 |
15 | import sys, os
16 | from nltk.tokenize import sent_tokenize, word_tokenize
17 |
18 |
19 | # parse args
20 | input_path = sys.argv[1]
21 | output_file = sys.argv[2]
22 |
23 | with open(output_file, 'w') as output_f:
24 |
25 |
26 | file_names = os.listdir(input_path)
27 | file_names = [f for f in file_names if f.endswith(".txt")]
28 |
29 | for f in file_names:
30 | input_file_path = os.path.join(input_path, f)
31 |
32 | # iterate over the plain text data from the UMBC corpus
33 | with open(input_file_path, 'r', encoding = "utf-8") as input_file:
34 | for line in input_file:
35 |
36 | # skip empty line
37 | if not line.strip():
38 | continue
39 | else:
40 | line = line.strip()
41 |
42 | # get all sentences in paragraph
43 | sentences = sent_tokenize(line)
44 |
45 | # each line contains a sentence
46 | for s in sentences:
47 | words = word_tokenize(s)
48 | sample = " ".join(words)
49 | print(sample, file = output_f)
50 |
--------------------------------------------------------------------------------
/evaluate_word2mat.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is for assessing the performance of models (or their subparts) AFTER training.
3 | """
4 |
5 | from __future__ import absolute_import, division, unicode_literals
6 |
7 | import sys, os, time
8 | import torch
9 | import logging
10 | import pickle
11 | import numpy as np
12 | import pandas as pd
13 | import argparse
14 |
15 | from wrap_evaluation import run_and_evaluate
16 |
17 | from data import get_index_batch
18 |
19 | from torch.autograd import Variable
20 |
21 | # Set PATHs
22 | PATH_SENTEVAL = '/data22/fmai/data/SentEval/SentEval/'
23 | PATH_TO_DATA = '/data22/fmai/data/SentEval/SentEval/data'
24 |
25 | assert os.path.exists(PATH_SENTEVAL) and os.path.exists(PATH_TO_DATA), "Set path to SentEval + data correctly!"
26 |
27 | # import senteval
28 | sys.path.insert(0, PATH_SENTEVAL)
29 | import senteval
30 |
31 | # Set up logger
32 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
33 |
34 | total_time_encoding = 0.
35 | total_samples_encoded = 0
36 |
37 | if __name__ == "__main__":
38 |
39 | def prepare(params_senteval, samples):
40 |
41 | params = params_senteval["cmd_params"]
42 |
43 | # Load vocabulary
44 | vocabulary = pickle.load(open(params.word_vocab, "rb" ))[0]
45 |
46 | params_senteval['vocabulary'] = vocabulary
47 | params_senteval['inverse_vocab'] = {vocabulary[w] : w for w in vocabulary}
48 | params_senteval['encoders'] = [torch.load(p) for p in params.encoders]
49 |
50 | def _batcher_helper(encoder, vocabulary, batch):
51 | sent, _ = get_index_batch(batch, vocabulary)
52 | sent_cuda = Variable(sent.cuda())
53 | sent_cuda = sent_cuda.t()
54 | encoder.eval() # Deactivate drop-out and such
55 | embeddings = encoder.forward(sent_cuda).data.cpu().numpy()
56 |
57 | return embeddings
58 |
59 | def get_params_parser():
60 | parser = argparse.ArgumentParser(description='Evaluates the performance of a given encoder. Use --included_features to\
61 | evaluate subparts of the model.')
62 |
63 | # paths
64 | parser.add_argument('--word_vocab', type=str, default=None, help= \
65 | "Specify path where to load precomputed word.", required = True)
66 | parser.add_argument('--encoders', type=str, nargs='+', default=None, help= \
67 | "Specify path to load encoder models from.", required = True)
68 | parser.add_argument('--aggregation', type=str, default="concat", help= \
69 | "Specify operation to use for aggregating embeddings from multiple encoders.", required = False,
70 | choices = ['concat', 'add'])
71 | parser.add_argument('--gpu_device', type=int, default=0, help= \
72 | "You need to specify the id of the gpu that you used for training to avoid errors.", required = False)
73 | parser.add_argument('--add_start_end_token', action = "store_true", default=False, help= \
74 | "If activated, the start and end tokens are added to every sample. Used e.g. for NLI trained encoder.", required = False)
75 | parser.add_argument('--included_features', type = int, nargs = '+', default=None, help= \
76 | "If specified, expects two integers a and b, which denote the range (a is inclusive, b is exclusive) of indices to use from\
77 | the embedding. E.g., if '--included_features 0 300' is specified, the embedding that is evaluated consists only of the first\
78 | 300 dimensions of the actual embedding: embeddings[a,b].", required = False)
79 | return parser
80 |
81 | def batcher(params_senteval, batch):
82 |
83 | start_time = time.time()
84 | params = params_senteval["cmd_params"]
85 | if params.add_start_end_token:
86 | batch = [[''] + s + [''] for s in batch]
87 |
88 | embeddings_list = [_batcher_helper(enc, params_senteval['vocabulary'], batch) for enc in params_senteval['encoders']]
89 |
90 | if params.aggregation == "add":
91 | embeddings = sum(embeddings_list)
92 |
93 | elif params.aggregation == "concat":
94 | embeddings = np.hstack(embeddings_list)
95 |
96 | global total_time_encoding
97 | global total_samples_encoded
98 | total_time_encoding += time.time() - start_time
99 | total_samples_encoded += len(batch)
100 |
101 | if params.included_features:
102 | a = params.included_features[0]
103 | b = params.included_features[1]
104 | embeddings = embeddings[:, a:b]
105 |
106 | return embeddings
107 |
108 | def _load_encoder_and_eval(params):
109 | encoder_for_wordemb_eval = torch.load(params.encoders[0])
110 | return (encoder_for_wordemb_eval, [])
111 | run_and_evaluate(_load_encoder_and_eval, get_params_parser, batcher, prepare)
112 |
113 | print("Encoding speed: {}/s".format(total_samples_encoded / total_time_encoding))
114 |
115 |
--------------------------------------------------------------------------------
/word2mat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time, sys, random
3 |
4 | import torch
5 | from torch.autograd import Variable
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | import math
9 | import time
10 |
11 | from torch import FloatTensor as FT
12 | from torch import ByteTensor as BT
13 |
14 | TINY = 1e-11
15 |
16 | class Word2MatEncoder(nn.Module):
17 |
18 | def __init__(self, n_words, word_emb_dim = 784, padding_idx = 0, w2m_type = "cbow", initialization_strategy = "identity"):
19 | """
20 | TODO: Method description for w2m encoder.
21 | """
22 | super(Word2MatEncoder, self).__init__()
23 | self.word_emb_dim = word_emb_dim
24 | self.n_words = n_words
25 | self.w2m_type = w2m_type
26 | self.initialization_strategy = initialization_strategy
27 |
28 | # check that the word embedding size is a square
29 | assert word_emb_dim == int(math.sqrt(word_emb_dim)) ** 2
30 |
31 | # set up word embedding table
32 | self.lookup_table = nn.Embedding(self.n_words + 1,
33 | self.word_emb_dim,
34 | padding_idx=padding_idx,
35 | sparse = False)
36 |
37 | # type of aggregation to use to combine two word matrices
38 | self.w2m_type = w2m_type
39 | if self.w2m_type not in ["cbow", "cmow"]:
40 | raise NotImplementedError("Operator " + self.operator + " is not yet implemented.")
41 |
42 | # set initial weights of word embeddings depending on the initialization strategy
43 | ## set weights of padding symbol such that it is the neutral element with respect to the operation
44 | if self.w2m_type == "cmow":
45 | neutral_element = np.reshape(np.eye(int(np.sqrt(self.word_emb_dim)), dtype=np.float32), (1, -1))
46 | neutral_element = torch.from_numpy(neutral_element)
47 | elif self.w2m_type == "cbow":
48 | neutral_element = np.reshape(torch.from_numpy(np.zeros((self.word_emb_dim), dtype=np.float32)), (1, -1))
49 |
50 | ## set weights of rest others depending on the initialization strategy
51 | if self.w2m_type == "cmow":
52 | if self.initialization_strategy == "identity":
53 | init_weights = self._init_random_identity()
54 |
55 | elif self.initialization_strategy == "normalized":
56 | ### normalized initialization by (Glorot and Bengio, 2010)
57 | init_weights = torch.from_numpy(np.random.uniform(size = (self.n_words,
58 | self.word_emb_dim),
59 | low = -np.sqrt(6 / (2*self.word_emb_dim)),
60 | high = +np.sqrt(6 / (2*self.word_emb_dim))
61 | ).astype(np.float32)
62 | )
63 | elif self.initialization_strategy == "normal":
64 | ### normalized with N(0,0.1), which failed in study by Yessenalina
65 | init_weights = torch.from_numpy(np.random.normal(size = (self.n_words,
66 | self.word_emb_dim),
67 | loc = 0.0,
68 | scale = 0.1
69 | ).astype(np.float32)
70 | )
71 | else:
72 | raise NotImplementedError("Unknown initialization strategy " + self.initialization_strategy)
73 |
74 | elif self.w2m_type == "cbow":
75 | init_weights = self._init_normal()
76 |
77 |
78 | ## concatenate and set weights in the lookup table
79 | weights = torch.cat([neutral_element,
80 | init_weights],
81 | dim=0)
82 | self.lookup_table.weight = nn.Parameter(weights)
83 |
84 | def forward(self, sent):
85 |
86 | sent = self.lookup_table(sent)
87 | seq_length = sent.size()[1]
88 | matrix_dim = self._matrix_dim()
89 |
90 | # reshape vectors to matrices
91 | word_matrices = sent.view(-1, seq_length, matrix_dim, matrix_dim)
92 |
93 | # aggregate matrices
94 | if self.w2m_type == "cmow":
95 | cur_emb = self._continual_multiplication(word_matrices)
96 | elif self.w2m_type == "cbow":
97 | cur_emb = torch.sum(word_matrices, 1)
98 |
99 | # flatten final matrix
100 | emb = self._flatten_matrix(cur_emb)
101 |
102 | return emb
103 |
104 | def _continual_multiplication(self, word_matrices):
105 | cur_emb = word_matrices[:, 0, :]
106 | for i in range(1, word_matrices.size()[1]):
107 | cur_emb = torch.bmm(cur_emb, word_matrices[:, i, :])
108 | return cur_emb
109 |
110 | def _flatten_matrix(self, m):
111 | return m.view(-1, self.word_emb_dim)
112 |
113 | def _unflatten_matrix(self, m):
114 | return m.view(-1, self._matrix_dim(), self._matrix_dim())
115 |
116 | def _matrix_dim(self):
117 | return int(np.sqrt(self.word_emb_dim))
118 |
119 | def _init_random_identity(self):
120 | """Random normal initialization around 0., but add 1. at the diagonal"""
121 | init_weights = np.random.normal(size = (self.n_words, self.word_emb_dim),
122 | loc = 0.,
123 | scale = 0.1
124 | ).astype(np.float32)
125 | for i in range(self.n_words):
126 | init_weights[i, :] += np.reshape(np.eye(int(np.sqrt(self.word_emb_dim)), dtype=np.float32), (-1,))
127 | init_weights = torch.from_numpy(init_weights)
128 | return init_weights
129 |
130 | def _init_normal(self):
131 | ### normal initialization around 0.
132 | init_weights = torch.from_numpy(np.random.normal(size = (self.n_words,
133 | self.word_emb_dim),
134 | loc = 0.0,
135 | scale = 0.1
136 | ).astype(np.float32)
137 | )
138 | return init_weights
139 |
140 |
141 | class HybridEncoder(nn.Module):
142 | def __init__(self, cbow_encoder, cmow_encoder):
143 | super(HybridEncoder, self).__init__()
144 | self.cbow_encoder = cbow_encoder
145 | self.cmow_encoder = cmow_encoder
146 |
147 | def forward(self, sent_tuple):
148 | return torch.cat([self.cbow_encoder(sent_tuple), self.cmow_encoder(sent_tuple)], dim = 1)
149 |
150 |
151 | def get_cmow_encoder(n_words, padding_idx = 0, word_emb_dim = 784, initialization_strategy = "identity"):
152 | encoder = Word2MatEncoder(n_words, word_emb_dim = word_emb_dim,
153 | padding_idx = padding_idx, w2m_type = "cmow",
154 | initialization_strategy = initialization_strategy)
155 | return encoder
156 |
157 | def get_cbow_encoder(n_words, padding_idx = 0, word_emb_dim = 784):
158 | encoder = Word2MatEncoder(n_words, word_emb_dim = word_emb_dim,
159 | padding_idx = padding_idx, w2m_type = "cbow")
160 | return encoder
161 |
162 | def get_cbow_cmow_hybrid_encoder(n_words, padding_idx = 0, word_emb_dim = 400, initialization_strategy = "identity"):
163 | cbow_encoder = get_cbow_encoder(n_words, padding_idx = padding_idx, word_emb_dim = word_emb_dim)
164 | cmow_encoder = get_cmow_encoder(n_words, padding_idx = word_emb_dim,
165 | word_emb_dim = word_emb_dim,
166 | initialization_strategy = initialization_strategy)
167 |
168 | encoder = HybridEncoder(cbow_encoder, cmow_encoder)
169 | return encoder
170 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # word2mat
2 |
3 | *Word2Mat* is a framework that learns *sentence embeddings* in a CBOW-word2vec style, but where the words and sentences are represented as matrices.
4 | Details of this method and results can be found in our [ICLR paper](https://openreview.net/forum?id=H1MgjoR9tQ).
5 |
6 | ## Dependencies
7 |
8 | - Python3
9 | - PyTorch >= 0.4 with CUDA support
10 | - NLTK >= 3
11 |
12 | ## Setup python3 environment
13 |
14 | Please install the python3 dependencies in your environment:
15 |
16 | ```
17 | virtualenv -p python3 venv && source venv/bin/activate
18 | pip install -r requirements.txt
19 | python3 -c "import nltk; nltk.download('punkt')"
20 | ```
21 |
22 | ## Download training data
23 |
24 | In order to reproduce the results from our paper, which were trained on the UMBC corpus,
25 | download the [UMBC corpus](https://ebiquity.umbc.edu/resource/html/id/351), extract the tar.gz file,
26 | and run the extract_umbc.py script in the following way:
27 |
28 | ```
29 | python extract_umbc.py umbc_corpus/webbase_all
30 | ```
31 |
32 | This stores the sentences from the UMBC corpus in a format that is usable by our code: Each line in the resulting file contains a single sentence,
33 | whose (already pre-processed) tokens are separated by a whitespace character.
34 |
35 | ## Running the experiments
36 |
37 | Note: After further experiments, we observed that terminating training based on the validation loss produces unreliable results because of
38 | relatively high variance in the validation loss. Hence, we recommend using training loss as stopping criterion, which is more stable.
39 |
40 | The results below are trained with this stopping criterion, and therefore slightly differ from the results reported in the ICLR paper.
41 | However, the conclusions remain the same: CMOW is much better than CBOW at capturing linguistic properties except WordContent.
42 | Therefore, CBOW is superior in almost all downstream tasks except TREC.
43 | The Hybrid model retains the capabilities of both models and therefore is extremely close to the better model among CBOW and CMOW, or better
44 | on all tasks.
45 |
46 | Probing tasks: All scores denote accuracy.
47 |
48 | | Model | Depth| BigramShift| SubjNumber| Tense| CoordinationInversion| Length| ObjNumber| TopConstituents| OddManOut| WordContent|
49 | |:----------|------:|------------:|-----------:|------:|----------------------:|-------:|----------:|----------------:|----------:|------------:|
50 | | CBOW | 32.73| 49.65| 79.65| 79.46| 53.78| 75.69| 79.00| 72.26| 49.64| 89.11|
51 | | CMOW | 34.40| 72.44| 82.08| 80.32| 62.05| 82.93| 79.70| 74.25| 51.33| 65.15|
52 | | Hybrid | 35.38| 71.22| 81.45| 80.83| 59.17| 87.00| 79.37| 72.88| 50.53| 86.97|
53 |
54 | Supervised downstream tasks: For STS-Benchmark and Sick-Relatedness, the results denote Spearman correlation coefficient. For all others the score denotes accuracy.
55 |
56 | | Model | SNLI| SUBJ| CR| MR| MPQA| TREC| SICKEntailment| SST2| SST5| MRPC| STSBenchmark| SICKRelatedness|
57 | |:----------|------:|------:|------:|------:|------:|-----:|---------------:|------:|------:|------:|-------------:|----------------:|
58 | | CBOW | 67.76| 90.45| 79.76| 74.32| 87.23| 84.4| 79.58| 78.14| 41.72| 72.17| 0.619| 0.721|
59 | | CMOW | 64.77| 87.11| 74.60| 71.42| 87.55| 88.0| 76.90| 76.77| 40.18| 70.61| 0.576| 0.705|
60 | | Hybrid | 67.59| 90.26| 79.60| 74.10| 87.38| 89.2| 78.69| 77.87| 41.58| 71.94| 0.613| 0.718|
61 |
62 | Unsupervised downstream tasks: The score denotes Spearman correlation coefficient.
63 |
64 | | Model | STS12| STS13| STS14| STS15| STS16|
65 | |:----------|------:|------:|------:|------:|------:|
66 | | CBOW | 0.458| 0.497| 0.556| 0.637| 0.630|
67 | | CMOW | 0.432| 0.334| 0.403| 0.471| 0.529|
68 | | Hybrid | 0.472| 0.476| 0.530| 0.621| 0.613|
69 |
70 | ### Train CBOW, CMOW, and CBOW-CMOW hybrid model
71 |
72 | To train a 784-dimensional CBOW model, run the following:
73 |
74 | ```
75 | python train_cbow.py --w2m_type cbow --batch_size=1024 --outputdir= --optimizer adam,lr=0.0003 --max_words=30000 --n_epochs=1000 --n_negs=20 --validation_frequency=1000 --mode=random --num_samples_per_item=30 --patience 10 --downstream_eval full --outputmodelname mode w2m_type word_emb_dim --validation_fraction=0.0001 --context_size=5 --word_emb_dim 784 --temp_path --dataset_path= --num_workers 2 --output_file --num_docs 134442680 --stop_criterion train_loss
76 | ```
77 |
78 | For CMOW:
79 |
80 | ```
81 | python train_cbow.py --w2m_type cmow --batch_size=1024 --outputdir= --optimizer adam,lr=0.0003 --max_words=30000 --n_epochs=1000 --n_negs=20 --validation_frequency=1000 --mode=random --num_samples_per_item=30 --patience 10 --downstream_eval full --outputmodelname mode w2m_type word_emb_dim --validation_fraction=0.0001 --context_size=5 --word_emb_dim 784 --temp_path --dataset_path= --num_workers 2 --output_file --num_docs 134442680 --stop_criterion train_loss --initialization identity
82 | ```
83 |
84 | And the CBOW-CMOW Hybrid:
85 |
86 | ```
87 | python train_cbow.py --w2m_type hybrid --batch_size=1024 --outputdir= --optimizer adam,lr=0.0003 --max_words=30000 --n_epochs=1000 --n_negs=20 --validation_frequency=1000 --mode=random --num_samples_per_item=30 --patience 10 --downstream_eval full --outputmodelname mode w2m_type word_emb_dim --validation_fraction=0.0001 --context_size=5 --word_emb_dim 400 --temp_path --dataset_path= --num_workers 2 --output_file --num_docs 134442680 --stop_criterion train_loss --initialization identity
88 | ```
89 |
90 | ### Evaluate components of hybrid model
91 |
92 | In the paper, we have shown that the jointly training of the individual CBOW/CMOW components emphasizes their individual strengths.
93 | To assess the performance of the CBOW component, restrict the final embedding representation to include only the first
94 | half of the representations from the HybridEncoder (--included_features 0 400 in a 800-dimensional Hybrid encoder), or restrict it to the second half (--included features 400 800) to evaluate the CMOW component.
95 | E.g, for evaluating the CMOW component, run:
96 |
97 | ```
98 | python evaluate_word2mat.py --encoders --word_vocab --included_features 400 800 --outputdir --outputmodelname hybrid_constituent --downstream_eval full
99 | ```
100 |
101 | Here, 'encoder' and 'word_vocab' is saved in 'outputdir' after training the models. By
102 |
103 | ## Files
104 |
105 | - `train_cbow.py` Main training executable. Type python train_cbow.py --help to get overview of training parameters.
106 | - `cbow.py` Contains the data preparation code as well as the neural architecture for CBOW except the encoder.
107 | - `word2mat.py` The code for word2mat encoder.
108 | - `wrap_evaluation.py` Wrapper script for SentEval to automatically evaluate encoder after training.
109 | - `evaluate_word2mat.py` Script for evaluating sub-components of hybrid encoder with SentEval.
110 | - `mutils.py` Helpers for saving the results, hyperparameter optimization and stuff.
111 |
112 | ## Reference
113 |
114 | Please cite our ICLR paper [[1]](https://openreview.net/forum?id=H1MgjoR9tQ) to reference our work or code.
115 |
116 | ### CBOW Is Not All You Need: Combining CBOW with the Compositional Matrix Space Model (ICLR 2019)
117 |
118 | [1] Mai, F., Galke, L & Scherp, A., [*CBOW Is Not All You Need: Combining CBOW with the Compositional Matrix Space Model*](https://openreview.net/forum?id=H1MgjoR9tQ)
119 |
120 | ```
121 | @inproceedings{mai2018cbow,
122 | title={{CBOW} Is Not All You Need: Combining {CBOW} with the Compositional Matrix Space Model},
123 | author={Florian Mai and Lukas Galke and Ansgar Scherp},
124 | booktitle={International Conference on Learning Representations},
125 | year={2019},
126 | url={https://openreview.net/forum?id=H1MgjoR9tQ},
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/mutils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some auxiliary functions for saving results, doing hyperparameter optimization, parsing optimizer parameters.
3 | """
4 |
5 | import re
6 | import inspect
7 | from torch import optim
8 | from itertools import product
9 | from bayes_opt import BayesianOptimization
10 | from hyperopt import fmin, rand, hp
11 | from sklearn.gaussian_process.kernels import Matern
12 | import numpy as np
13 | import os, csv
14 |
15 | def _update_options(options, **parameters):
16 | for param_name, param_value in parameters.items():
17 | print("In automatic optimization trying parameter:", param_name, "with value", param_value)
18 |
19 | try:
20 | setattr(options, param_name, param_value)
21 | except AttributeError:
22 | print("Can't find parameter ", param_name, "so we'll not use it.")
23 | continue
24 |
25 | return options
26 |
27 |
28 | def _make_space(options):
29 |
30 | space = {}
31 | inits = {}
32 | with open(options.optimization_spaces) as optimization_file:
33 | for line in optimization_file:
34 |
35 | # escape comments
36 | if line.startswith("#"):
37 | continue
38 |
39 | line = line.strip()
40 | info = line.split(",")
41 | param_name = info[0]
42 |
43 | if options.optimization == "random":
44 | left_bound, right_bound = float(info[1]), float(info[2])
45 |
46 | param_type = info[3]
47 |
48 | try:
49 | param_type = getattr(hp, param_type)
50 | except AttributeError:
51 | print("hyperopt has no attribute", param_type)
52 | continue
53 |
54 | space[param_name] = param_type(param_name, left_bound, right_bound)
55 | elif options.optimization == "bayesian":
56 | left_bound, right_bound = float(info[1]), float(info[2])
57 |
58 | init_values = list(map(float, info[3:]))
59 | num_init_vals = len(init_values)
60 | inits[param_name] = init_values
61 | space[param_name] = (left_bound, right_bound)
62 |
63 | elif options.optimization == "grid":
64 |
65 | param_type = info[1]
66 | def get_cast_func(some_string_type):
67 | cast_func = None
68 | if some_string_type == "int":
69 | cast_func = int
70 | elif some_string_type == "float":
71 | cast_func = float
72 | elif some_string_type == "string":
73 | cast_func = str
74 | elif some_string_type == "bool":
75 | cast_func = bool
76 | return cast_func
77 |
78 | cast_func = get_cast_func(param_type)
79 | if cast_func is None:
80 | if param_type.startswith("list"):
81 | # determine type in list
82 | list_type = get_cast_func(param_type.split("-")[1])
83 |
84 | # assume they are seperated by semicolon
85 | def extract_items(list_string):
86 | return [list_type(x) for x in list_string.split(";")]
87 |
88 | cast_func = extract_items
89 |
90 |
91 | # all possible values
92 | space[param_name] = list(map(cast_func, info[2:]))
93 |
94 | if options.optimization == "bayesian":
95 | return space, inits, num_init_vals
96 | else:
97 | return space
98 |
99 | def _all_option_combinations(space):
100 |
101 | names = [name for name, _ in space.items()]
102 | values = [values for _, values in space.items()]
103 |
104 | val_combinations = product(*values)
105 |
106 | combinations = []
107 | for combi in val_combinations:
108 | new_param_dict = {}
109 | for i, val in enumerate(combi):
110 | new_param_dict[names[i]] = val
111 |
112 | combinations.append(new_param_dict)
113 |
114 | return combinations
115 |
116 | def run_hyperparameter_optimization(options, run_exp):
117 | """
118 | This function performs hyperparameter optimization using bayesian optimization, random search, or gridsearch.
119 |
120 | It takes an argparse object holding the parameters for configuring an experiments, and a function
121 | 'run_exp' that takes the argparse object, runs an experiments with the respective configuration, and
122 | returns a score from that configuration.
123 | It then uses the hyperparameter optimization method to adjust the parameters and run the new configuration.
124 |
125 | Parameters:
126 | ================
127 |
128 | argparse :
129 | The argparse object holding the parameters. In particular, it must contain the following two parameters.
130 | 'optimization' : str, Specifies the optimization method. Either 'bayesian', 'random', or 'grid'.
131 | 'optimization_spaces' : str, Specifies the path to a file that denotes the parameters to do search over and
132 | their possible values (in case of grid search) or possible spaces. See file 'default_optimization_space' for
133 | details.
134 |
135 | run_exp : function
136 | A function that takes the argparse object as input and returns a float that is interpreted as the
137 | score of the configuration (higher is better).
138 |
139 | """
140 |
141 | if options.optimization:
142 |
143 | def optimized_experiment(**parameters):
144 |
145 | current_options = _update_options(options, **parameters)
146 | result = run_exp(current_options)
147 |
148 | # return the f1 score of the previous experiment
149 | return result
150 |
151 | if options.optimization == "bayesian":
152 |
153 | gp_params = {"alpha": 1e-5, "kernel" : Matern(nu = 5 / 2)}
154 | space, init_vals, num_init_vals = _make_space(options)
155 | bayesian_optimizer = BayesianOptimization(optimized_experiment, space)
156 | bayesian_optimizer.explore(init_vals)
157 | bayesian_optimizer.maximize(n_iter=options.optimization_iterations - num_init_vals,
158 | acq = 'ei',
159 | **gp_params)
160 |
161 | elif options.optimization == "random":
162 |
163 | fmin(lambda parameters : optimized_experiment(**parameters),
164 | _make_space(options),
165 | algo=rand.suggest,
166 | max_evals=options.optimization_iterations,
167 | rstate = np.random.RandomState(1337))
168 |
169 | elif options.optimization == "grid":
170 | # perform grid-search by running every possible parameter combination
171 | combinations = _all_option_combinations(_make_space(options))
172 | for combi in combinations:
173 | optimized_experiment(**combi)
174 |
175 | else:
176 | raise Exception("No hyperparameter method specified!")
177 |
178 | def write_to_csv(score, opt):
179 | """
180 | Writes the scores and configuration to csv file.
181 | """
182 | f = open(opt.output_file, 'a')
183 | if os.stat(opt.output_file).st_size == 0:
184 | for i, (key, _) in enumerate(opt.__dict__.items()):
185 | f.write(key + ";")
186 | for i, (key, _) in enumerate(score.items()):
187 | if i < len(score.items()) - 1:
188 | f.write(key + ";")
189 | else:
190 | f.write(key)
191 | f.write('\n')
192 | f.flush()
193 | f.close()
194 |
195 | f = open(opt.output_file, 'r')
196 | reader = csv.reader(f, delimiter=";")
197 | column_names = next(reader)
198 | f.close();
199 |
200 | f = open(opt.output_file, 'a')
201 | for i, key in enumerate(column_names):
202 | if i < len(column_names) - 1:
203 | if key in opt.__dict__:
204 | f.write(str(opt.__dict__[key]) + ";")
205 | else:
206 | f.write(str(score[key]) + ";")
207 | else:
208 | if key in opt.__dict__:
209 | f.write(str(opt.__dict__[key]))
210 | else:
211 | f.write(str(score[key]))
212 | f.write('\n')
213 | f.flush()
214 | f.close()
215 |
216 | def get_optimizer(s):
217 | """
218 | Parse optimizer parameters.
219 | Input should be of the form:
220 | - "sgd,lr=0.01"
221 | - "adagrad,lr=0.1,lr_decay=0.05"
222 | """
223 | if "," in s:
224 | method = s[:s.find(',')]
225 | optim_params = {}
226 | for x in s[s.find(',') + 1:].split(','):
227 | split = x.split('=')
228 | assert len(split) == 2
229 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
230 | optim_params[split[0]] = float(split[1])
231 | else:
232 | method = s
233 | optim_params = {}
234 |
235 | if method == 'adadelta':
236 | optim_fn = optim.Adadelta
237 | elif method == 'adagrad':
238 | optim_fn = optim.Adagrad
239 | elif method == 'adam':
240 | optim_fn = optim.Adam
241 | elif method == 'sparseadam':
242 | optim_fn = optim.SparseAdam
243 | elif method == 'adamax':
244 | optim_fn = optim.Adamax
245 | elif method == 'asgd':
246 | optim_fn = optim.ASGD
247 | elif method == 'rmsprop':
248 | optim_fn = optim.RMSprop
249 | elif method == 'rprop':
250 | optim_fn = optim.Rprop
251 | elif method == 'sgd':
252 | optim_fn = optim.SGD
253 | assert 'lr' in optim_params
254 | else:
255 | raise Exception('Unknown optimization method: "%s"' % method)
256 |
257 | # check that we give good parameters to the optimizer
258 | expected_args = inspect.getargspec(optim_fn.__init__)[0]
259 | assert expected_args[:2] == ['self', 'params']
260 | if not all(k in expected_args[2:] for k in optim_params.keys()):
261 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
262 | str(expected_args[2:]), str(optim_params.keys())))
263 |
264 | return optim_fn, optim_params
265 |
266 |
267 | class dotdict(dict):
268 | """ dot.notation access to dictionary attributes """
269 | __getattr__ = dict.get
270 | __setattr__ = dict.__setitem__
271 | __delattr__ = dict.__delitem__
272 |
--------------------------------------------------------------------------------
/wrap_evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is intended to provide an interface for standardized evaluation of encoders. For implementation, you need to provide the interface with a bunch of functions that you then need to pass to the run_and_evaluate function in the following
3 | fashion:
4 |
5 | >>> run_and_evaluate(run_experiment, get_params_parser, batcher, prepare).
6 |
7 | For detailed descriptions of what these methods have to provide, please refer to the functions documentation.
8 | """
9 |
10 | import torch, os, sys, logging, pickle
11 | import numpy as np
12 | import pandas as pd
13 |
14 | from torch.autograd import Variable
15 | from mutils import run_hyperparameter_optimization, write_to_csv
16 | from warnings import warn
17 |
18 | # Set PATHs to SentEval
19 | PATH_SENTEVAL = '/data22/fmai/data/SentEval/SentEval/'
20 | PATH_TO_DATA = '/data22/fmai/data/SentEval/SentEval/data'
21 | assert os.path.exists(PATH_SENTEVAL) and os.path.exists(PATH_TO_DATA), "Set path to SentEval + data correctly!"
22 |
23 | # import senteval
24 | sys.path.insert(0, PATH_SENTEVAL)
25 | import senteval
26 |
27 | def _get_score_for_name(downstream_results, name):
28 | if name in ["CR", "CoordinationInversion", "SST2", "Length", "OddManOut", "Tense", "SUBJ", "MRPC", "ObjNumber", "SubjNumber", "Depth", "WordContent", "SST5", "SNLI", "MPQA", "BigramShift", "MR", "TREC", "TopConstituents", "SICKEntailment"]:
29 | return downstream_results[name]["acc"]
30 | elif name in ["STS12", "STS13", "STS14", "STS15", "STS16"]:
31 | all_results = downstream_results[name]["all"]
32 | spearman = all_results["spearman"]["mean"]
33 | pearson = all_results["pearson"]["mean"]
34 | return (spearman, pearson)
35 | elif name in ["STSBenchmark", "SICKRelatedness"]:
36 | spearman = downstream_results[name]["spearman"]
37 | pearson = downstream_results[name]["pearson"]
38 | return (spearman, pearson)
39 | else:
40 | warn("Can not extract score from downstream task " + name)
41 | return 0
42 |
43 | def _run_experiment_and_save(run_experiment, params, batcher, prepare):
44 |
45 | encoder, losses = run_experiment(params)
46 |
47 | # Save encoder
48 | outputmodelname = construct_model_name(params.outputmodelname, params)
49 | torch.save(encoder, os.path.join(params.outputdir, outputmodelname + '.encoder'))
50 |
51 | # write training and validation loss to csv file
52 | with open(os.path.join(params.outputdir, outputmodelname + "_losses.csv"), 'w') as loss_csv:
53 | loss_csv.write("train_loss,val_loss\n")
54 | for train_loss, val_loss in losses:
55 | loss_csv.write(",".join([str(train_loss), str(val_loss)]) + "\n")
56 |
57 | scores = {}
58 | # Compute scores on downstream tasks
59 | if params.downstream_eval:
60 | downstream_scores = _evaluate_downstream_and_probing_tasks(encoder, params, batcher, prepare)
61 |
62 | # from each downstream task, only select scores we care about
63 | to_be_saved_scores = {}
64 | for score_name in downstream_scores:
65 | to_be_saved_scores[score_name] = _get_score_for_name(downstream_scores, score_name)
66 | scores.update(to_be_saved_scores)
67 |
68 | # Compute word embedding score
69 | if params.word_embedding_eval:
70 | output_path = _save_embeddings_to_word2vec(encoder, outputmodelname, params)
71 |
72 | # Save results to csv
73 | if params.output_file:
74 | write_to_csv(scores, params)
75 |
76 | return scores
77 |
78 | def _load_embeddings_from_word2vec(fname):
79 | w = load_embedding(fname, format="word2vec", normalize=True, lower=True, clean_words=False)
80 | return w
81 |
82 | def _save_embeddings_to_word2vec(encoder, outputmodelname, params):
83 |
84 | # Load lookup table as numpy
85 | embeddings = encoder.lookup_table
86 | embeddings = embeddings.weight.data.cpu().numpy()
87 |
88 | # Load (inverse) vocabulary to match ids to words
89 | path_to_vocabulary = os.path.join(params.outputdir, outputmodelname + '.vocab')
90 | vocabulary = pickle.load(open(path_to_vocabulary, "rb" ))[0]
91 | inverse_vocab = {vocabulary[w] : w for w in vocabulary}
92 |
93 | # Open file and write values in word2vec format
94 | output_path = os.path.join(params.outputdir, outputmodelname + '.emb')
95 | f = open(output_path, 'w')
96 | print(embeddings.shape[0] - 1, embeddings.shape[1], file = f)
97 | for i in range(1, embeddings.shape[0]): # skip the padding token
98 | cur_word = inverse_vocab[i]
99 | f.write(" ".join([cur_word] + [str(embeddings[i, j]) for j in range(embeddings.shape[1])]) + "\n")
100 |
101 | f.close()
102 |
103 | return output_path
104 |
105 | def _evaluate_downstream_and_probing_tasks(encoder, params, batcher, prepare):
106 | # define senteval params
107 | eval_type = sys.argv[1] if len(sys.argv) > 1 else ""
108 | if params.downstream_eval == "full":
109 |
110 | ## for comparable evaluation (as in literature)
111 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
112 | params_senteval['classifier'] = {'nhid': params.nhid, 'optim': 'adam', 'batch_size': 64,
113 | 'tenacity': 5, 'epoch_size': 4}
114 |
115 | elif params.downstream_eval == "test":
116 | ## for testing purpose
117 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
118 | params_senteval['classifier'] = {'nhid': params.nhid, 'optim': 'rmsprop', 'batch_size': 128,
119 | 'tenacity': 3, 'epoch_size': 2}
120 | else:
121 | sys.exit("Need to specify if you want to run it in full mode ('full', takes a long time," \
122 | "but is comparable to literature) or test mode ('test', fast, but not as accurate and comparable).")
123 |
124 | # Pass encoder and command line parameters
125 | params_senteval['word2mat'] = encoder
126 | params_senteval['cmd_params'] = params
127 |
128 | # evaluate
129 | se = senteval.engine.SE(params_senteval, batcher, prepare)
130 | results = se.eval(params.downstream_tasks)
131 |
132 | return results
133 |
134 | def construct_model_name(names, params):
135 | """Constructs model name from all params in "name", unless name only holds a single argument."""
136 | if len(names) == 1:
137 | return names[0]
138 | else:
139 | # construct model name from configuration
140 | name = ""
141 | params_dict = vars(params)
142 |
143 | for key in names:
144 |
145 | name += str(key) + ":" + str(params_dict[key]) + "-"
146 |
147 | return name
148 |
149 |
150 | def _add_common_arguments(parser):
151 |
152 | group = parser.add_argument_group("common", "Arguments needed for evaluation.")
153 | # paths
154 | parser.add_argument("--outputdir", type=str, default='savedir/', help="Output directory", required = True)
155 | parser.add_argument("--outputmodelname", type=str, nargs = "+", default=["mymodel"], help="If one argument is passed, the model is saved at the respective location. If multiple arguments are passed, these are interpreted of names of parameters from which the modelname is automatically constructed in a fashion.", required = True)
156 | parser.add_argument('--output_file', type=str, default=None, help= \
157 | "Specify the file name to save the result in. Default: [None]")
158 |
159 | # hyperparameter opt
160 | parser.add_argument("--optimization", type=str, default=None)
161 | parser.add_argument("--optimization_spaces", type=str, default="default_optimization_spaces")
162 | parser.add_argument("--optimization_iterations", type=int, default=10)
163 |
164 | # reproducibility
165 | parser.add_argument("--seed", type=int, default=1234, help="seed")
166 |
167 | # evaluation
168 | parser.add_argument("--downstream_eval", type=str, help="Whether to perform 'full'" \
169 | "downstream evaluation (slow), 'test' downstream evaluation (fast).",
170 | choices = ["test", "full"], required = True)
171 | parser.add_argument("--downstream_tasks", type=str, nargs = "+",
172 | default=['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI',
173 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
174 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16',
175 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense',
176 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion'],
177 | help="Downstream tasks to evaluate on.",
178 | choices = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI',
179 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
180 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16',
181 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense',
182 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion'])
183 | parser.add_argument("--word_embedding_eval", action="store_true", default=False, help="If specified," \
184 | "evaluate the word embeddings and store the results.")
185 | parser.add_argument("--nhid", type=int, default=0, help="Specify the number of hidden units used at test time to train classifiers. If 0 is specified, no hidden layer is employed.")
186 |
187 |
188 | return parser
189 |
190 | def run_and_evaluate(run_experiment, get_params_parser, batcher, prepare):
191 | """
192 | Trains and evaluates one or several configurations, given some functionality to train the encoder
193 | and generate embeddings at test time. In particular, 4 functions have to be provided:
194 |
195 | - get_params_parser: () -> argparse.Argumentparser
196 | Creates and returns an argument parser with arguments that will be used by the
197 | run_experiment parser. Note that it must not contain any of the arguments that are
198 | are later added by the run_and_evaluate function, and are shared by all implementations. Please refer to the _add_common_arguments function for an overview of common arguments.
199 | - outputdir is not optional and the dir has to exist
200 |
201 | - run_experiment: params -> encoder, losses
202 | The function that trains the encoder. As input, it gets the params dictionary that contains
203 | the arguments parsed from the command line. As output, it has to return the encoder object that will
204 | at test time be passed to the 'batcher' function.
205 | The encoder is required to have an attribute 'lookup_table' which is a torch.nn.Embedding.
206 | 'losses' is a list of successive
207 | (train_loss, val_loss) pairs that will be written to a csv file. If you do not want to track
208 | the loss, simply return an empty list.
209 | - params need to include an 'outputmodelname', else the params need to satisfy construct_model_name
210 |
211 | - batcher: (params, batch) -> nparray of shape (batch_size, embedding_size)
212 | A function as defined by SentEval (see https://github.com/facebookresearch/SentEval).
213 | 'params' is a dictionary that contains the encoder as "params['word2mat']", as well as
214 | further parameters specified by the implementer and added via the 'prepare' function.
215 | - the params namespace from run_experiments is available as params['cmd_params']
216 |
217 | - prepare: (senteval_params, samples) -> ()
218 | A function as defined by SentEval (see https://github.com/facebookresearch/SentEval).
219 | In particular, this function can be used to add further arguments to the senteval_params
220 | dictionary for use in batcher.
221 | 'senteval_params' is a dictionary that in turn contains the 'params' objective parsed from
222 | the command line as the argument 'cmd_params'.
223 | """
224 | parser = get_params_parser()
225 | parser = _add_common_arguments(parser)
226 | params = parser.parse_args()
227 |
228 | if params.optimization:
229 | def hyperparameter_optimization_func(opts):
230 | results = _run_experiment_and_save(run_experiment, opts, batcher, prepare)
231 | return results["CR"]
232 | run_hyperparameter_optimization(params, hyperparameter_optimization_func)
233 | else:
234 | _run_experiment_and_save(run_experiment, params, batcher, prepare)
235 |
236 |
--------------------------------------------------------------------------------
/cbow.py:
--------------------------------------------------------------------------------
1 | """
2 | This file handles everything related to the CBOW task, i.e., it creates the training examples (CBOWDataset), and provides the neural architecture (except encoder) and loss computation (see CBOWNet).
3 | """
4 |
5 | import os, pickle, math
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 | from torch.utils.data.sampler import Sampler
10 | from collections import Counter
11 | import nltk.data
12 | from nltk.tokenize import word_tokenize
13 | from random import shuffle
14 | import random
15 |
16 | import torch.nn as nn
17 | from torch import FloatTensor as FT
18 | from torch import ByteTensor as BT
19 | from torch.autograd import Variable
20 |
21 | sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
22 |
23 | def recursive_file_list(path):
24 | """
25 | Recursively aggregates all files at the given path, i.e., files in subfolders are also
26 | included.
27 | """
28 | return [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn]
29 |
30 | def tokenize(sent, sophisticated = False):
31 | """
32 | Tokenizes the sentence. If 'sophisticated' is set to False, the
33 | tokenization is a simple split by the blank character. Otherwise the
34 | TreebankWordTokenizer provided by NLTK.
35 | """
36 | return sent.split() if not sophisticated else word_tokenize(sent)
37 |
38 | def sentenize(text):
39 | return sent_tokenizer.tokenize(text)
40 |
41 | def get_wordvec_batch(batch, word_vec):
42 | # sent in batch in decreasing order of lengths (bsize, max_len, word_dim)
43 | lengths = np.array([len(x) for x in batch])
44 | max_len = np.max(lengths)
45 | embed = np.zeros((max_len, len(batch), 300))
46 |
47 | for i in range(len(batch)):
48 | for j in range(len(batch[i])):
49 | embed[j, i, :] = word_vec[batch[i][j]]
50 |
51 | return torch.from_numpy(embed).float(), lengths
52 |
53 | def get_index_batch(batch, word_vec):
54 |
55 | # remove all words that are out of vocabulary
56 | clean_batch = []
57 | for sen in batch:
58 | clean_batch.append([w for w in sen if w in word_vec])
59 | batch = clean_batch
60 |
61 | lengths = np.array([len(x) for x in batch])
62 | max_len = np.max(lengths)
63 | embed = np.zeros((max_len, len(batch)))
64 |
65 | for i in range(len(batch)):
66 | for j in range(len(batch[i])):
67 | embed[j, i] = word_vec[batch[i][j]]
68 |
69 | return torch.from_numpy(embed).long(), lengths
70 |
71 |
72 | def get_word_dict(sentences):
73 | # create vocab of words and also count occurences
74 | word_dict = {}
75 | for sent in sentences:
76 | for word in tokenize(sent):
77 | if word not in word_dict:
78 | word_dict[word] = 1
79 | else:
80 | word_dict[word] += 1
81 |
82 | return word_dict
83 |
84 |
85 | def get_wordembedding(word_dict, we_path):
86 | # create word_vec with glove vectors
87 | word_voc = {}
88 | with open(we_path) as f:
89 |
90 | # discard the information in first row
91 | _, emb_size = f.readline().split()
92 |
93 | i = 1
94 | word_embs = []
95 | for line in f:
96 | line = line.strip('\n').split()
97 | word_end = len(line) - int(emb_size)
98 | word = " ".join(line[:word_end])
99 | if word in word_dict:
100 | word_voc[word] = i
101 | word_embs.append(np.asarray(list(map(float, line[word_end:]))))
102 | i += 1
103 | print('Found {0}(/{1}) words with glove vectors'.format(
104 | len(word_voc), len(word_dict)))
105 |
106 | word_embs = np.vstack(word_embs)
107 | word_count = {w : word_dict[w] for w in word_voc}
108 | return word_voc, word_count, word_embs
109 |
110 | def get_index_vocab(word_dict, max_words):
111 | if max_words is not None:
112 | counter = Counter(word_dict)
113 | most_common_words = counter.most_common(max_words)
114 | reduced_word_dict = {}
115 | for w, cnt in most_common_words:
116 | reduced_word_dict[w] = cnt
117 | word_dict = reduced_word_dict
118 | print('Num words in corpus : {:,}'.format(np.sum([word_dict[w] for w in word_dict])))
119 |
120 | # create word_vec with glove vectors
121 | word_vec = {}
122 | idx = 1 # reserve 0 for padding_idx
123 | for word in word_dict:
124 | word_vec[word] = idx
125 | idx += 1
126 | return word_vec, word_dict
127 |
128 |
129 | def build_vocab(sentences, pretrained_embeddings = None, max_words = None):
130 | word_dict = get_word_dict(sentences)
131 | if pretrained_embeddings:
132 | word_to_index, word_to_count, word_embeddings = get_wordembedding(word_dict, pretrained_embeddings)
133 | else:
134 | word_to_index, word_to_count = get_index_vocab(word_dict, max_words) # padding_idx = 0
135 | print('Vocab size : {0}'.format(len(word_to_index)))
136 |
137 | if pretrained_embeddings:
138 | return word_to_index, word_to_count, word_embeddings
139 | else:
140 | return word_to_index, word_to_count
141 |
142 | class CBOWDataset(Dataset):
143 | """
144 | Considers each line of a file to be a text.
145 | Reads all files found at directory 'path' and corresponding subdirectories.
146 | """
147 | def __init__(self, path, num_texts, context_size, num_samples_per_item, mode, precomputed_word_vocab, max_words, pretrained_embeddings, num_texts_per_chunk, precomputed_chunks_dir, temp_path):
148 |
149 | self.context_size = context_size
150 | self.num_samples_per_item = num_samples_per_item
151 | self.mode = mode
152 | self.num_texts_per_chunk = num_texts_per_chunk
153 |
154 | texts_generator = _generate_texts(path, num_texts)
155 |
156 | # load precomputed word vocabulary and counts
157 | if precomputed_word_vocab:
158 | word_vec = pickle.load(open(os.path.join(precomputed_word_vocab), "rb" ))
159 | else:
160 |
161 | word_vec = build_vocab(texts_generator,
162 | pretrained_embeddings = pretrained_embeddings,
163 | max_words = max_words)
164 |
165 | # create chunks
166 | self.num_texts = num_texts
167 | self.num_chunks = math.ceil(num_texts / (1.0*self.num_texts_per_chunk))
168 | self._temp_path = temp_path
169 | if not os.path.exists(self._temp_path):
170 | os.makedirs(self._temp_path)
171 |
172 | if precomputed_chunks_dir is None:
173 | self._create_chunk_files(_generate_texts(path, num_texts))
174 | self._check_chunk_files()
175 | else:
176 | self._temp_path = precomputed_chunks_dir
177 | print("use precomputed chunk files.")
178 |
179 | self._word_vec_count_tuple = word_vec
180 | self.word_vec, self.word_count = word_vec
181 | self.num_training_samples = self.num_texts
182 |
183 | # compute unigram distribution
184 | ## set frequency of padding token to 0 implicitly
185 | unigram_dist = np.zeros((len(self.word_vec) + 1))
186 | for w in self.word_vec:
187 | unigram_dist[self.word_vec[w]] = self.word_count[w]
188 |
189 | self.unigram_dist = unigram_dist
190 |
191 |
192 | def _count_words_per_text(self):
193 | text_lengths = [0] * len(self.texts)
194 |
195 | for i, text in enumerate(self.texts):
196 |
197 | words = tokenize(text)
198 | words = [self.word_vec[w] for w in words if w in self.word_vec]
199 | text_lengths[i] = len(words)
200 |
201 | return text_lengths
202 |
203 | def _check_chunk_files(self):
204 | """Raises an exception if any of the chunks generated
205 | is empty.
206 | """
207 | for i in range(self.num_chunks):
208 | with open(self._get_chunk_file_name(i), "r") as f:
209 | lines = f.readlines()
210 | if(len(lines) == 0):
211 | raise Exception("Chunk ", i, " is empty\n")
212 |
213 | def _create_chunk_files(self, texts_generator):
214 | cur_chunk_number = 0
215 | cur_chunk_file = open(self._get_chunk_file_name(cur_chunk_number), "w")
216 | cur_idx = 0
217 | last_chunk_size = self.num_texts - (self.num_texts_per_chunk*(self.num_chunks-1))
218 | for text in texts_generator:
219 | print(text, file=cur_chunk_file)
220 | if cur_idx == self.num_texts_per_chunk - 1 or (cur_idx == last_chunk_size-1 and
221 | cur_chunk_number == self.num_chunks-1):
222 | # start next chunk
223 | cur_chunk_file.close()
224 | cur_idx = 0 # index within the chunk
225 | cur_chunk_number += 1
226 | cur_chunk_file = open(self._get_chunk_file_name(cur_chunk_number), "w")
227 | else:
228 | cur_idx += 1
229 | cur_chunk_file.close()
230 |
231 | def _get_chunk_file_name(self, chunk_number):
232 | return os.path.join(self._temp_path, "chunk" + str(chunk_number))
233 |
234 | def __len__(self):
235 | return self.num_texts
236 |
237 | def _load_text(self, idx):
238 | chunk_number = math.floor(idx / (1.0*self.num_texts_per_chunk))
239 | idx_in_chunk = idx % self.num_texts_per_chunk
240 | with open(self._get_chunk_file_name(chunk_number), "r") as f:
241 | for i, line in enumerate(f):
242 | if i == idx_in_chunk:
243 | return line.strip()
244 | raise Exception("Text with idx: ", idx, " in chunk: ", chunk_number,\
245 | " and idx_in_chunk: ", idx_in_chunk, " not found.")
246 |
247 | def _compute_idx_to_text_word_dict(self):
248 | idx_to_text_word_tuple = {}
249 | idx = 0
250 | for i, text in enumerate(self.texts):
251 |
252 | for j in range(self.text_lengths[i]):
253 | idx_to_text_word_tuple.update({idx : (i, j)})
254 | idx += 1
255 |
256 | self.idx_to_text_word_tuple = idx_to_text_word_tuple
257 |
258 | def _create_window_samples(self, words):
259 |
260 | text_len = len(words)
261 | num_samples = min(text_len, self.num_samples_per_item)
262 |
263 | words = [0] * self.context_size + words + [0] * self.context_size
264 |
265 | training_sequences = np.zeros((num_samples, 2 * self.context_size))
266 | missing_words = np.zeros((num_samples))
267 |
268 | # randomly select mid_words to use
269 | mid_words = random.sample(range(text_len), num_samples)
270 | for i, j in enumerate(mid_words):
271 |
272 | middle_word = self.context_size + j
273 |
274 | # choose a word that is removed from the window
275 | if self.mode == 'random':
276 | rand_offset = random.randint(-self.context_size, self.context_size)
277 | missing_word = middle_word + rand_offset
278 | elif self.mode == 'cbow':
279 | missing_word = middle_word
280 | else:
281 | raise NotImplementedError("Unknown training mode " + self.mode)
282 |
283 | # zero is the padding word
284 | training_sequence = [middle_word + context_word for context_word in range(-self.context_size, self.context_size + 1) if middle_word + context_word != missing_word]
285 | training_sequence = [words[w] for w in training_sequence]
286 | training_sequences[i, :] = np.array(training_sequence)
287 | missing_word = words[missing_word]
288 | missing_words[i] = np.array(missing_word)
289 | return training_sequences, missing_words
290 |
291 | def __getitem__(self, idx):
292 |
293 | text = self._load_text(idx)
294 | words = tokenize(text)
295 | words = [self.word_vec[w] for w in words if w in self.word_vec]
296 | text_len = len(words)
297 |
298 | # TODO: is there a better way to handle empty texts?
299 | if text_len == 0:
300 | return None, None
301 |
302 | if self.mode in ['random', 'cbow']:
303 | return self._create_window_samples(words)
304 | else:
305 | raise NotImplementedError("Unknown mode " + str(self.mode))
306 |
307 | ## collate function for cbow
308 | def collate_fn(self, l):
309 | l1, l2 = zip(*l)
310 | l1 = [x for x in l1 if x is not None]
311 | l2 = [x for x in l2 if x is not None]
312 | l1 = np.vstack(l1)
313 | l2 = np.concatenate(l2)
314 | return torch.from_numpy(l1).long(), torch.from_numpy(l2).long()
315 |
316 | def _load_texts(path, num_docs):
317 | texts = []
318 | filename_list = recursive_file_list(path)
319 |
320 | for filename in filename_list:
321 | with open(os.path.realpath(filename), 'r') as f:
322 |
323 | # change encoding to utf8 to be consistent with other datasets
324 | #cur_text.decode("ISO-8859-1").encode("utf-8")
325 | for line in f:
326 | line = line.strip()
327 | texts.append(line)
328 |
329 | if num_docs is not None and len(texts) > num_docs:
330 | break
331 |
332 | return texts
333 |
334 | def _generate_texts(path, num_docs):
335 | filename_list = recursive_file_list(path)
336 |
337 | for filename in filename_list:
338 | with open(os.path.realpath(filename), "r") as f:
339 |
340 | # change encoding to utf8 to be consistent with other datasets
341 | # cur_text.decode("ISO-8859-1").encode("utf-8")
342 | for i, line in enumerate(f):
343 | line = line.strip()
344 | if num_docs is not None and i > num_docs - 1:
345 | break
346 | yield line
347 |
348 | class CBOWNet(nn.Module):
349 | def __init__(self, encoder, output_embedding_size, output_vocab_size, weights = None, n_negs = 20, padding_idx = 0):
350 | super(CBOWNet, self).__init__()
351 |
352 | self.encoder = encoder
353 | self.n_negs = n_negs
354 | self.weights = weights
355 | self.output_vocab_size = output_vocab_size
356 | self.output_embedding_size = output_embedding_size
357 |
358 | self.outputembeddings = nn.Embedding(output_vocab_size + 1, output_embedding_size, padding_idx=0)
359 |
360 | if self.weights is not None:
361 | wf = np.power(self.weights, 0.75)
362 | wf = wf / wf.sum()
363 | self.weights = FT(wf)
364 |
365 | def forward(self, input_s, missing_word):
366 |
367 | embedding = self.encoder(input_s)
368 | batch_size = embedding.size()[0]
369 | emb_size = embedding.size()[1]
370 |
371 | # draw negative samples
372 | if self.weights is not None:
373 | nwords = torch.multinomial(self.weights, batch_size * self.n_negs, replacement=True).view(batch_size, -1)
374 | else:
375 | nwords = FT(batch_size, self.n_negs).uniform_(0, self.vocab_size).long()
376 | nwords = Variable(torch.LongTensor(nwords), requires_grad=False).cuda()
377 |
378 | # lookup the embeddings of output words
379 | missing_word_vector = self.outputembeddings(missing_word)
380 |
381 | nvectors = self.outputembeddings(nwords).neg()
382 |
383 | # compute loss for correct word
384 | oloss = torch.bmm(missing_word_vector.view(batch_size, 1, emb_size), embedding.view(batch_size, emb_size, 1))
385 | oloss = oloss.squeeze().sigmoid()
386 |
387 | ## add epsilon to prediction to avoid numerical instabilities
388 | oloss = self._add_epsilon(oloss)
389 | oloss = oloss.log()
390 |
391 | # compute loss for negative samples
392 | nloss = torch.bmm(nvectors, embedding.view(batch_size, -1, 1)).squeeze().sigmoid()
393 |
394 | ## add epsilon to prediction to avoid numerical instabilities
395 | nloss = self._add_epsilon(nloss)
396 | nloss = nloss.log()
397 | nloss = nloss.mean(1)
398 |
399 | # combine losses
400 | return -(oloss + nloss)
401 |
402 | def _add_epsilon(self, pred):
403 | return pred + 0.00001
404 |
405 | def encode(self, s1):
406 | emb = self.encoder(s1)
407 | return emb
408 |
409 |
--------------------------------------------------------------------------------
/train_cbow.py:
--------------------------------------------------------------------------------
1 | """
2 | Main script for training a Word2Mat model.
3 | """
4 |
5 | import os
6 | import sys
7 | import time
8 | import argparse
9 | import pickle
10 | import random
11 |
12 | import numpy as np
13 | from random import shuffle
14 |
15 | import torch
16 | from torch.autograd import Variable
17 | import torch.nn as nn
18 |
19 | # CBOW data
20 | from torch.utils.data import DataLoader
21 | from cbow import CBOWNet, build_vocab, tokenize, CBOWDataset, recursive_file_list, get_index_batch
22 |
23 | # Encoder
24 | from mutils import get_optimizer, run_hyperparameter_optimization, write_to_csv
25 | from word2mat import get_cbow_cmow_hybrid_encoder, get_cbow_encoder, get_cmow_encoder
26 | from torch.utils.data.sampler import SubsetRandomSampler
27 |
28 | from wrap_evaluation import run_and_evaluate, construct_model_name
29 |
30 | import time
31 |
32 | def run_experiment(params):
33 |
34 | # print parameters passed, and all parameters
35 | print('\ntogrep : {0}\n'.format(sys.argv[1:]))
36 | print(params)
37 |
38 | """
39 | SEED
40 | """
41 | np.random.seed(params.seed)
42 | torch.manual_seed(params.seed)
43 | torch.cuda.manual_seed(params.seed)
44 |
45 | """
46 | DATA
47 | """
48 | dataset_path = params.dataset_path
49 |
50 | # build training and test corpus
51 | filename_list = recursive_file_list(dataset_path)
52 | print('Use the following files for training: ', filename_list)
53 | corpus = CBOWDataset(dataset_path, params.num_docs, params.context_size,
54 | params.num_samples_per_item, params.mode,
55 | params.precomputed_word_vocab, params.max_words,
56 | None, 1000, params.precomputed_chunks_dir, params.temp_path)
57 | corpus_len = len(corpus)
58 |
59 | ## split train and test
60 | inds = list(range(corpus_len))
61 | shuffle(inds)
62 |
63 | num_val_samples = int(corpus_len * params.validation_fraction)
64 | train_indices = inds[:-num_val_samples] if num_val_samples > 0 else inds
65 | test_indices = inds[-num_val_samples:] if num_val_samples > 0 else []
66 |
67 | cbow_train_loader = DataLoader(corpus, sampler = SubsetRandomSampler(train_indices), batch_size=params.batch_size, shuffle=False, num_workers = params.num_workers, pin_memory = True, collate_fn = corpus.collate_fn)
68 | cbow_test_loader = DataLoader(corpus, sampler = SubsetRandomSampler(test_indices), batch_size=params.batch_size, shuffle=False, num_workers = params.num_workers, pin_memory = True, collate_fn = corpus.collate_fn)
69 |
70 | ## extract some variables needed for training
71 | num_training_samples = corpus.num_training_samples
72 | word_vec = corpus.word_vec
73 | unigram_dist = corpus.unigram_dist
74 | word_vec_copy = corpus._word_vec_count_tuple
75 |
76 | print("Number of sentences used for training:", str(num_training_samples))
77 |
78 | """
79 | MODEL
80 | """
81 |
82 | # build path where to store the encoder
83 | outputmodelname = construct_model_name(params.outputmodelname, params)
84 |
85 | # build encoder
86 | n_words = len(word_vec)
87 | if params.w2m_type == "cmow":
88 | encoder = get_cmow_encoder(n_words, padding_idx = 0,
89 | word_emb_dim = params.word_emb_dim,
90 | initialization_strategy = params.initialization)
91 | output_embedding_size = params.word_emb_dim
92 | elif params.w2m_type == "cbow":
93 | encoder = get_cbow_encoder(n_words, padding_idx = 0, word_emb_dim = params.word_emb_dim)
94 | output_embedding_size = params.word_emb_dim
95 | elif params.w2m_type == "hybrid":
96 | encoder = get_cbow_cmow_hybrid_encoder(n_words, padding_idx = 0,
97 | word_emb_dim = params.word_emb_dim,
98 | initialization_strategy = params.initialization)
99 | output_embedding_size = 2 * params.word_emb_dim
100 |
101 | # build cbow model
102 | cbow_net = CBOWNet(encoder, output_embedding_size, n_words,
103 | weights = unigram_dist, n_negs = params.n_negs, padding_idx = 0)
104 | if torch.cuda.device_count() > 1:
105 | print("Using", torch.cuda.device_count(), "GPUs for training!")
106 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
107 | cbow_net = nn.DataParallel(cbow_net)
108 | use_multiple_gpus = True
109 | else:
110 | use_multiple_gpus = False
111 |
112 | # optimizer
113 | print([x.size() for x in cbow_net.parameters()])
114 | optim_fn, optim_params = get_optimizer(params.optimizer)
115 | optimizer = optim_fn(cbow_net.parameters(), **optim_params)
116 |
117 | # cuda by default
118 | cbow_net.cuda()
119 |
120 | """
121 | TRAIN
122 | """
123 | val_acc_best = -1e10
124 | adam_stop = False
125 | stop_training = False
126 | lr = optim_params['lr'] if 'sgd' in params.optimizer else None
127 |
128 | # compute learning rate schedule
129 | if params.linear_decay:
130 | lr_shrinkage = (lr - params.minlr) / ((float(num_training_samples) / params.batch_size) * params.n_epochs)
131 |
132 |
133 | def forward_pass(X_batch, tgt_batch, params, check_size = False):
134 |
135 | X_batch = Variable(X_batch).cuda()
136 | tgt_batch = Variable(torch.LongTensor(tgt_batch)).cuda()
137 | k = X_batch.size(0) # actual batch size
138 |
139 | loss = cbow_net(X_batch, tgt_batch).mean()
140 | return loss, k
141 |
142 |
143 | def validate(data_loader):
144 | cbow_net.eval()
145 |
146 | with torch.no_grad():
147 | all_costs = []
148 | for X_batch, tgt_batch in data_loader:
149 | loss, k = forward_pass(X_batch, tgt_batch, params)
150 | all_costs.append(loss.item())
151 |
152 | cbow_net.train()
153 | return np.mean(all_costs)
154 |
155 | def trainepoch(epoch):
156 | print('\nTRAINING : Epoch ' + str(epoch))
157 | cbow_net.train()
158 | all_costs = []
159 | logs = []
160 | words_count = 0
161 |
162 | last_time = time.time()
163 | correct = 0.
164 |
165 | if not params.linear_decay:
166 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * params.decay if epoch>1\
167 | and 'sgd' in params.optimizer else optimizer.param_groups[0]['lr']
168 | print('Learning rate : {0}'.format(optimizer.param_groups[0]['lr']))
169 |
170 | processed_training_samples = 0
171 | start_time = time.time()
172 | total_time = 0
173 | total_batch_generation_time = 0
174 | total_forward_time = 0
175 | total_backward_time = 0
176 | total_step_time = 0
177 | last_processed_training_samples = 0
178 |
179 | nonlocal processed_batches, stop_training, no_improvement, min_val_loss, losses, min_loss_criterion
180 | for i, (X_batch, tgt_batch) in enumerate(cbow_train_loader):
181 |
182 | batch_generation_time = (time.time() - start_time) * 1000000
183 |
184 | # forward pass
185 | forward_start = time.time()
186 | loss, k = forward_pass(X_batch, tgt_batch, params)
187 | all_costs.append(loss.item())
188 | forward_total = (time.time() - forward_start) * 1000000
189 |
190 | # backward
191 | backward_start = time.time()
192 | optimizer.zero_grad()
193 | loss.backward()
194 |
195 | backward_total = (time.time() - backward_start) * 1000000
196 |
197 | # linear learning rate decay
198 | if params.linear_decay:
199 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] - lr_shrinkage if \
200 | 'sgd' in params.optimizer else optimizer.param_groups[0]['lr']
201 |
202 | # optimizer step
203 | step_time = time.time()
204 | optimizer.step()
205 | total_step_time += (time.time() - step_time) * 1000000
206 |
207 | # log progress
208 | processed_training_samples += params.batch_size
209 | percentage_done = float(processed_training_samples) / num_training_samples
210 | processed_batches += 1
211 | if processed_batches == params.validation_frequency:
212 |
213 | # compute validation loss and train loss
214 | val_loss = round(validate(cbow_test_loader), 5) if num_val_samples > 0 else float('inf')
215 | train_loss = round(np.mean(all_costs), 5)
216 |
217 | # print current loss and processing speed
218 | logs.append('Epoch {3} - {4:.4} ; lr {2:.4} ; train-loss {0} ; val-loss {5} ; sentence/s {1}'.format(train_loss, int((processed_training_samples - last_processed_training_samples) / (time.time() - last_time)), optimizer.param_groups[0]['lr'], epoch, percentage_done, val_loss))
219 | if params.VERBOSE:
220 | print('\n\n\n')
221 | print(logs[-1])
222 | last_time = time.time()
223 | words_count = 0
224 | all_costs = []
225 | last_processed_training_samples = processed_training_samples
226 |
227 | if params.VERBOSE:
228 | print("100 Batches took {} microseconds".format(total_time))
229 | print("get_batch: {} \nforward: {} \nbackward: {} \nstep: {}".format(total_batch_generation_time / total_time, total_forward_time / total_time, total_backward_time / total_time, total_step_time / total_time))
230 | total_time = 0
231 | total_batch_generation_time = 0
232 | total_forward_time = 0
233 | total_backward_time = 0
234 | total_step_time = 0
235 | processed_batches = 0
236 |
237 | # save losses for logging later
238 | losses.append((train_loss, val_loss))
239 |
240 | # early stopping?
241 | if val_loss < min_val_loss:
242 | min_val_loss = val_loss
243 |
244 | # save best model
245 | torch.save(cbow_net, os.path.join(params.outputdir, outputmodelname + '.cbow_net'))
246 |
247 | if params.stop_criterion is not None:
248 | stop_crit_loss = eval(params.stop_criterion)
249 | if stop_crit_loss < min_loss_criterion:
250 | no_improvement = 0
251 | min_loss_criterion = stop_crit_loss
252 | else:
253 | no_improvement += 1
254 | if no_improvement > params.patience:
255 | stop_training = True
256 | print("No improvement in loss criterion", str(params.stop_criterion),
257 | "for", str(no_improvement), "steps. Terminate training.")
258 | break
259 |
260 | now = time.time()
261 | batch_time_micro = (now - start_time) * 1000000
262 |
263 | total_time = total_time + batch_time_micro
264 | total_batch_generation_time += batch_generation_time
265 | total_forward_time += forward_total
266 | total_backward_time += backward_total
267 |
268 | start_time = now
269 |
270 |
271 | """
272 | Train model on CBOW objective
273 | """
274 | epoch = 1
275 |
276 | processed_batches = 0
277 | min_val_loss = float('inf')
278 | min_loss_criterion = float('inf')
279 | no_improvement = 0
280 | losses = []
281 | while not stop_training and epoch <= params.n_epochs:
282 | trainepoch(epoch)
283 | epoch += 1
284 |
285 | # load the best model
286 | if min_val_loss < float('inf'):
287 | cbow_net = torch.load(os.path.join(params.outputdir, outputmodelname + '.cbow_net'))
288 | print("Loading model with best validation loss.")
289 | else:
290 | # we use the current model;
291 | print("No model with better validation loss has been saved.")
292 |
293 | # save word vocabulary and counts
294 | pickle.dump(word_vec_copy, open( os.path.join(params.outputdir, outputmodelname + '.vocab'), "wb" ))
295 |
296 | if use_multiple_gpus:
297 | cbow_net = cbow_net.module
298 | return cbow_net.encoder, losses
299 |
300 | def get_params_parser():
301 |
302 | parser = argparse.ArgumentParser(description='Training a word2mat model.')
303 |
304 | # paths
305 | parser.add_argument('--precomputed_word_vocab', type=str, default=None, help= \
306 | "Specify path where to load precomputed word.")
307 | parser.add_argument('--precomputed_chunks_dir', type=str, default=None, help= \
308 | "Specify path from where to load the chunkified input text.")
309 | parser.add_argument('--temp_path', type=str, required=True, help= \
310 | "Specify path where to save the chunkified input text.")
311 |
312 | # training parameters
313 | parser.add_argument("--n_epochs", type=int, default=20)
314 | parser.add_argument("--batch_size", type=int, default=64)
315 | parser.add_argument("--optimizer", type=str, default="sgd,lr=0.1", help="adam or sgd,lr=0.1")
316 | parser.add_argument("--decay", type=float, default=0.99, help="lr decay")
317 | parser.add_argument("--linear_decay", action="store_true", help="If set, the learning rate is shrunk linearly after each batch as to approach minlr.")
318 | parser.add_argument("--minlr", type=float, default=1e-5, help="minimum lr")
319 | parser.add_argument("--validation_frequency", type=int, default=500, help="How many batches to process before evaluating on the validation set (500).")
320 | parser.add_argument("--validation_fraction", type=float, default=0.0001, help="What fraction of the corpus to use for validation.\
321 | Set to 0 to not use validation set based saving of intermediate models.")
322 | parser.add_argument("--stop_criterion", type=str, default=None, help="Which loss to use as stopping criterion.", choices = ['val_loss', 'train_loss'])
323 | parser.add_argument("--patience", type=int, default=3, help="How many validation steps to make before terminating training.")
324 | parser.add_argument("--VERBOSE", action="store_true", default=False, help="Whether to print additional info on speed of processing.")
325 | parser.add_argument("--num_workers", type=int, default=10, help="How many worker threads to use for creating the samples from the dataset.")
326 |
327 | # Word2Mat specific
328 | parser.add_argument("--w2m_type", type=str, default='cmow', choices=['cmow', 'cbow', 'hybrid'], help="Choose the encoder to use.")
329 | parser.add_argument("--word_emb_dim", type=int, default=100, help="Dimensionality of word embeddings.")
330 | parser.add_argument("--initialization", type=str, default='identity', help="Initialization strategy to use.", choices = ['one', 'identity', 'normalized', 'normal'])
331 |
332 | # dataset and vocab
333 | parser.add_argument("--dataset_path", type=str, required=True, help="Path to a directory containing all files to use for training. " \
334 | "One sentence per line in a file is assumed.")
335 | parser.add_argument("--max_words", type=int, default=None, help="Only produce embeddings for the most common tokens.")
336 | parser.add_argument("--num_docs", type=int, default=None, help="How many documents to consider from the source directory.")
337 |
338 | # CBOW specific
339 | parser.add_argument("--context_size", type=int, default=5, help="Context window size for CBOW.")
340 | parser.add_argument("--num_samples_per_item", type=int, default=1, help="Specify number of samples to generate from each sentence (the higher, the faster training).")
341 | parser.add_argument("--mode", type=str, help="Determines the mode of the prediction task, i.e., which word is to be removed from a given window of words. Options are 'cbow' (remove middle word) and 'random' (a random word from the window is removed).", default='random', choices = ['cbow', 'random'])
342 | parser.add_argument("--n_negs", type=int, default=5, help="How many negative samples to use for training (the larger the dataset, the fewer are required (5).")
343 |
344 | return parser
345 |
346 | def prepare(params_senteval, samples):
347 |
348 | params = params_senteval["cmd_params"]
349 | outputmodelname = construct_model_name(params.outputmodelname, params)
350 |
351 | # Load vocabulary
352 | vocabulary = pickle.load(open(os.path.join(params.outputdir, outputmodelname + '.vocab'), "rb" ))[0]
353 |
354 | params_senteval['vocabulary'] = vocabulary
355 | params_senteval['inverse_vocab'] = {vocabulary[w] : w for w in vocabulary}
356 |
357 | def _batcher_helper(params, batch):
358 | sent, _ = get_index_batch(batch, params.vocabulary)
359 | sent_cuda = Variable(sent.cuda())
360 | sent_cuda = sent_cuda.t()
361 | params.word2mat.eval() # Deactivate drop-out and such
362 | embeddings = params.word2mat.forward(sent_cuda).data.cpu().numpy()
363 |
364 | return embeddings
365 |
366 | def batcher_cbow(params_senteval, batch):
367 |
368 | params = params_senteval["cmd_params"]
369 | embeddings = _batcher_helper(params_senteval, batch)
370 | return embeddings
371 |
372 | if __name__ == "__main__":
373 |
374 | run_and_evaluate(run_experiment, get_params_parser, batcher_cbow, prepare)
375 |
--------------------------------------------------------------------------------