8 |
9 | We provide a simple Web demo based on [flask](https://github.com/pallets/flask) to show how SimCSE can be directly used for information retrieval. The code is based on [DensePhrases](https://arxiv.org/abs/2012.12624)' [repo](https://github.com/princeton-nlp/DensePhrases) and [demo](http://densephrases.korea.ac.kr) (a lot of thanks to the authors of DensePhrases). To run this flask demo locally, make sure the SimCSE inference interfaces are setup:
10 | ```bash
11 | git clone https://github.com/princeton-nlp/SimCSE
12 | cd SimCSE
13 | python setup.py develop
14 | ```
15 | Then you can use `run_demo_example.sh` to launch the demo. As a default setting, we build the index for 1000 sentences sampled from STS-B dataset. Feel free to build the index of your own corpora. You can also install [faiss](https://github.com/facebookresearch/faiss) to speed up the retrieval process.
16 |
17 | ### Gradio Demo
18 | [AK391](https://github.com/AK391) has provided a [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) of SimCSE to show how the pre-trained models can predict the semantic similarity between two sentences.
19 |
--------------------------------------------------------------------------------
/simcse_to_huggingface.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert SSCL's checkpoints to Huggingface style.
3 | """
4 |
5 | import argparse
6 | import torch
7 | import os
8 | import json
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--path", type=str, help="Path of SSCL checkpoint folder")
14 | args = parser.parse_args()
15 |
16 | print("SSCL checkpoint -> Huggingface checkpoint for {}".format(args.path))
17 |
18 | state_dict = torch.load(os.path.join(args.path, "pytorch_model.bin"), map_location=torch.device("cpu"))
19 | new_state_dict = {}
20 | for key, param in state_dict.items():
21 | # Replace "mlp" to "pooler"
22 | if "mlp" in key:
23 | key = key.replace("mlp", "pooler")
24 |
25 | # Delete "bert" or "roberta" prefix
26 | if "bert." in key:
27 | key = key.replace("bert.", "")
28 | if "roberta." in key:
29 | key = key.replace("roberta.", "")
30 |
31 | new_state_dict[key] = param
32 |
33 | torch.save(new_state_dict, os.path.join(args.path, "pytorch_model.bin"))
34 |
35 | # Change architectures in config.json
36 | config = json.load(open(os.path.join(args.path, "config.json")))
37 | for i in range(len(config["architectures"])):
38 | config["architectures"][i] = config["architectures"][i].replace("ForCL", "Model")
39 | json.dump(config, open(os.path.join(args.path, "config.json"), "w"), indent=2)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/sscl_to_huggingface.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert SSCL's checkpoints to Huggingface style.
3 | """
4 |
5 | import argparse
6 | import torch
7 | import os
8 | import json
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--path", type=str, help="Path of SSCL checkpoint folder")
14 | args = parser.parse_args()
15 |
16 | print("SSCL checkpoint -> Huggingface checkpoint for {}".format(args.path))
17 |
18 | state_dict = torch.load(os.path.join(args.path, "pytorch_model.bin"), map_location=torch.device("cpu"))
19 | new_state_dict = {}
20 | for key, param in state_dict.items():
21 | # Replace "mlp" to "pooler"
22 | if "mlp" in key:
23 | key = key.replace("mlp", "pooler")
24 |
25 | # Delete "bert" or "roberta" prefix
26 | if "bert." in key:
27 | key = key.replace("bert.", "")
28 | if "roberta." in key:
29 | key = key.replace("roberta.", "")
30 |
31 | new_state_dict[key] = param
32 |
33 | torch.save(new_state_dict, os.path.join(args.path, "pytorch_model.bin"))
34 |
35 | # Change architectures in config.json
36 | config = json.load(open(os.path.join(args.path, "config.json")))
37 | for i in range(len(config["architectures"])):
38 | config["architectures"][i] = config["architectures"][i].replace("ForCL", "Model")
39 | json.dump(config, open(os.path.join(args.path, "config.json"), "w"), indent=2)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/SentEval/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | For SentEval software
4 |
5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without modification,
8 | are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name Facebook nor the names of its contributors may be used to
18 | endorse or promote products derived from this software without specific
19 | prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/demo/gradiodemo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy.spatial.distance import cosine
3 | from transformers import AutoModel, AutoTokenizer
4 | import gradio as gr
5 |
6 | # Import our models. The package will take care of downloading the models automatically
7 | tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
8 | model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
9 |
10 | def simcse(text1, text2, text3):
11 | # Tokenize input texts
12 | texts = [
13 | text1,
14 | text2,
15 | text3
16 | ]
17 | inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
18 |
19 | # Get the embeddings
20 | with torch.no_grad():
21 | embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
22 |
23 | # Calculate cosine similarities
24 | # Cosine similarities are in [-1, 1]. Higher means more similar
25 | cosine_sim_0_1 = 1 - cosine(embeddings[0], embeddings[1])
26 | cosine_sim_0_2 = 1 - cosine(embeddings[0], embeddings[2])
27 | return {"cosine similarity":cosine_sim_0_1}, {"cosine similarity":cosine_sim_0_2}
28 |
29 |
30 | inputs = [
31 | gr.inputs.Textbox(lines=5, label="Input Text One"),
32 | gr.inputs.Textbox(lines=5, label="Input Text Two"),
33 | gr.inputs.Textbox(lines=5, label="Input Text Three")
34 | ]
35 |
36 | outputs = [
37 | gr.outputs.Label(type="confidences",label="Cosine similarity between text one and two"),
38 | gr.outputs.Label(type="confidences", label="Cosine similarity between text one and three")
39 | ]
40 |
41 |
42 | title = "SimCSE"
43 | description = "demo for Princeton-NLP SimCSE. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
44 | article = "
38 |
39 | SimCSE is a novel framework for contrastive learning of sentence embeddings. This demo shows how our pre-trained sentence embeddings can be directly applied to sentence retrieval tasks. You can type any natural language sentences and click the search button to see which sentences in the example database are semantically similar to the provided sentence. Here are some details about this demo:
40 |
41 |
42 |
43 |
Retrieved sentences are coming from STS-Benchmark dataset
44 |
Two hyperparameters can be adjusted: (1) Top-K: the maximum number of sentences to be displayed (2) Threshold: the minimum similarity score for a sentence to be retrieved
45 |
We use Faiss to accelerate the sentence retrieval process
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
58 |
59 |
60 |
61 |
62 |
65 |
66 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
Top K: 5
76 |
77 |
78 |
79 |
Threshold: 0.6
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
169 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/SentEval/senteval/probing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | probing tasks
10 | '''
11 |
12 | from __future__ import absolute_import, division, unicode_literals
13 |
14 | import os
15 | import io
16 | import copy
17 | import logging
18 | import numpy as np
19 |
20 | from senteval.tools.validation import SplitClassifier
21 |
22 |
23 | class PROBINGEval(object):
24 | def __init__(self, task, task_path, seed=1111):
25 | self.seed = seed
26 | self.task = task
27 | logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper())
28 | self.task_data = {'train': {'X': [], 'y': []},
29 | 'dev': {'X': [], 'y': []},
30 | 'test': {'X': [], 'y': []}}
31 | self.loadFile(task_path)
32 | logging.info('Loaded %s train - %s dev - %s test for %s' %
33 | (len(self.task_data['train']['y']), len(self.task_data['dev']['y']),
34 | len(self.task_data['test']['y']), self.task))
35 |
36 | def do_prepare(self, params, prepare):
37 | samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \
38 | self.task_data['test']['X']
39 | return prepare(params, samples)
40 |
41 | def loadFile(self, fpath):
42 | self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'}
43 | with io.open(fpath, 'r', encoding='utf-8') as f:
44 | for line in f:
45 | line = line.rstrip().split('\t')
46 | self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split())
47 | self.task_data[self.tok2split[line[0]]]['y'].append(line[1])
48 |
49 | labels = sorted(np.unique(self.task_data['train']['y']))
50 | self.tok2label = dict(zip(labels, range(len(labels))))
51 | self.nclasses = len(self.tok2label)
52 |
53 | for split in self.task_data:
54 | for i, y in enumerate(self.task_data[split]['y']):
55 | self.task_data[split]['y'][i] = self.tok2label[y]
56 |
57 | def run(self, params, batcher):
58 | task_embed = {'train': {}, 'dev': {}, 'test': {}}
59 | bsize = params.batch_size
60 | logging.info('Computing embeddings for train/dev/test')
61 | for key in self.task_data:
62 | # Sort to reduce padding
63 | sorted_data = sorted(zip(self.task_data[key]['X'],
64 | self.task_data[key]['y']),
65 | key=lambda z: (len(z[0]), z[1]))
66 | self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data))
67 |
68 | task_embed[key]['X'] = []
69 | for ii in range(0, len(self.task_data[key]['y']), bsize):
70 | batch = self.task_data[key]['X'][ii:ii + bsize]
71 | embeddings = batcher(params, batch)
72 | task_embed[key]['X'].append(embeddings)
73 | task_embed[key]['X'] = np.vstack(task_embed[key]['X'])
74 | task_embed[key]['y'] = np.array(self.task_data[key]['y'])
75 | logging.info('Computed embeddings')
76 |
77 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
78 | 'usepytorch': params.usepytorch,
79 | 'classifier': params.classifier}
80 |
81 | if self.task == "WordContent" and params.classifier['nhid'] > 0:
82 | config_classifier = copy.deepcopy(config_classifier)
83 | config_classifier['classifier']['nhid'] = 0
84 | print(params.classifier['nhid'])
85 |
86 | clf = SplitClassifier(X={'train': task_embed['train']['X'],
87 | 'valid': task_embed['dev']['X'],
88 | 'test': task_embed['test']['X']},
89 | y={'train': task_embed['train']['y'],
90 | 'valid': task_embed['dev']['y'],
91 | 'test': task_embed['test']['y']},
92 | config=config_classifier)
93 |
94 | devacc, testacc = clf.run()
95 | logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper()))
96 |
97 | return {'devacc': devacc, 'acc': testacc,
98 | 'ndev': len(task_embed['dev']['X']),
99 | 'ntest': len(task_embed['test']['X'])}
100 |
101 | """
102 | Surface Information
103 | """
104 | class LengthEval(PROBINGEval):
105 | def __init__(self, task_path, seed=1111):
106 | task_path = os.path.join(task_path, 'sentence_length.txt')
107 | # labels: bins
108 | PROBINGEval.__init__(self, 'Length', task_path, seed)
109 |
110 | class WordContentEval(PROBINGEval):
111 | def __init__(self, task_path, seed=1111):
112 | task_path = os.path.join(task_path, 'word_content.txt')
113 | # labels: 200 target words
114 | PROBINGEval.__init__(self, 'WordContent', task_path, seed)
115 |
116 | """
117 | Latent Structural Information
118 | """
119 | class DepthEval(PROBINGEval):
120 | def __init__(self, task_path, seed=1111):
121 | task_path = os.path.join(task_path, 'tree_depth.txt')
122 | # labels: bins
123 | PROBINGEval.__init__(self, 'Depth', task_path, seed)
124 |
125 | class TopConstituentsEval(PROBINGEval):
126 | def __init__(self, task_path, seed=1111):
127 | task_path = os.path.join(task_path, 'top_constituents.txt')
128 | # labels: 'PP_NP_VP_.' .. (20 classes)
129 | PROBINGEval.__init__(self, 'TopConstituents', task_path, seed)
130 |
131 | class BigramShiftEval(PROBINGEval):
132 | def __init__(self, task_path, seed=1111):
133 | task_path = os.path.join(task_path, 'bigram_shift.txt')
134 | # labels: 0 or 1
135 | PROBINGEval.__init__(self, 'BigramShift', task_path, seed)
136 |
137 | # TODO: Voice?
138 |
139 | """
140 | Latent Semantic Information
141 | """
142 |
143 | class TenseEval(PROBINGEval):
144 | def __init__(self, task_path, seed=1111):
145 | task_path = os.path.join(task_path, 'past_present.txt')
146 | # labels: 'PRES', 'PAST'
147 | PROBINGEval.__init__(self, 'Tense', task_path, seed)
148 |
149 | class SubjNumberEval(PROBINGEval):
150 | def __init__(self, task_path, seed=1111):
151 | task_path = os.path.join(task_path, 'subj_number.txt')
152 | # labels: 'NN', 'NNS'
153 | PROBINGEval.__init__(self, 'SubjNumber', task_path, seed)
154 |
155 | class ObjNumberEval(PROBINGEval):
156 | def __init__(self, task_path, seed=1111):
157 | task_path = os.path.join(task_path, 'obj_number.txt')
158 | # labels: 'NN', 'NNS'
159 | PROBINGEval.__init__(self, 'ObjNumber', task_path, seed)
160 |
161 | class OddManOutEval(PROBINGEval):
162 | def __init__(self, task_path, seed=1111):
163 | task_path = os.path.join(task_path, 'odd_man_out.txt')
164 | # labels: 'O', 'C'
165 | PROBINGEval.__init__(self, 'OddManOut', task_path, seed)
166 |
167 | class CoordinationInversionEval(PROBINGEval):
168 | def __init__(self, task_path, seed=1111):
169 | task_path = os.path.join(task_path, 'coordination_inversion.txt')
170 | # labels: 'O', 'I'
171 | PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed)
172 |
--------------------------------------------------------------------------------
/SentEval/senteval/tools/classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Pytorch Classifier class in the style of scikit-learn
10 | Classifiers include Logistic Regression and MLP
11 | """
12 |
13 | from __future__ import absolute_import, division, unicode_literals
14 |
15 | import numpy as np
16 | import copy
17 | from senteval import utils
18 |
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 |
23 |
24 | class PyTorchClassifier(object):
25 | def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111,
26 | cudaEfficient=False):
27 | # fix seed
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 |
32 | self.inputdim = inputdim
33 | self.nclasses = nclasses
34 | self.l2reg = l2reg
35 | self.batch_size = batch_size
36 | self.cudaEfficient = cudaEfficient
37 |
38 | def prepare_split(self, X, y, validation_data=None, validation_split=None):
39 | # Preparing validation data
40 | assert validation_split or validation_data
41 | if validation_data is not None:
42 | trainX, trainy = X, y
43 | devX, devy = validation_data
44 | else:
45 | permutation = np.random.permutation(len(X))
46 | trainidx = permutation[int(validation_split * len(X)):]
47 | devidx = permutation[0:int(validation_split * len(X))]
48 | trainX, trainy = X[trainidx], y[trainidx]
49 | devX, devy = X[devidx], y[devidx]
50 |
51 | device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda')
52 |
53 | trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32)
54 | trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64)
55 | devX = torch.from_numpy(devX).to(device, dtype=torch.float32)
56 | devy = torch.from_numpy(devy).to(device, dtype=torch.int64)
57 |
58 | return trainX, trainy, devX, devy
59 |
60 | def fit(self, X, y, validation_data=None, validation_split=None,
61 | early_stop=True):
62 | self.nepoch = 0
63 | bestaccuracy = -1
64 | stop_train = False
65 | early_stop_count = 0
66 |
67 | # Preparing validation data
68 | trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data,
69 | validation_split)
70 |
71 | # Training
72 | while not stop_train and self.nepoch <= self.max_epoch:
73 | self.trainepoch(trainX, trainy, epoch_size=self.epoch_size)
74 | accuracy = self.score(devX, devy)
75 | if accuracy > bestaccuracy:
76 | bestaccuracy = accuracy
77 | bestmodel = copy.deepcopy(self.model)
78 | elif early_stop:
79 | if early_stop_count >= self.tenacity:
80 | stop_train = True
81 | early_stop_count += 1
82 | self.model = bestmodel
83 | return bestaccuracy
84 |
85 | def trainepoch(self, X, y, epoch_size=1):
86 | self.model.train()
87 | for _ in range(self.nepoch, self.nepoch + epoch_size):
88 | permutation = np.random.permutation(len(X))
89 | all_costs = []
90 | for i in range(0, len(X), self.batch_size):
91 | # forward
92 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device)
93 |
94 | Xbatch = X[idx]
95 | ybatch = y[idx]
96 |
97 | if self.cudaEfficient:
98 | Xbatch = Xbatch.cuda()
99 | ybatch = ybatch.cuda()
100 | output = self.model(Xbatch)
101 | # loss
102 | loss = self.loss_fn(output, ybatch)
103 | all_costs.append(loss.data.item())
104 | # backward
105 | self.optimizer.zero_grad()
106 | loss.backward()
107 | # Update parameters
108 | self.optimizer.step()
109 | self.nepoch += epoch_size
110 |
111 | def score(self, devX, devy):
112 | self.model.eval()
113 | correct = 0
114 | if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient:
115 | devX = torch.FloatTensor(devX).cuda()
116 | devy = torch.LongTensor(devy).cuda()
117 | with torch.no_grad():
118 | for i in range(0, len(devX), self.batch_size):
119 | Xbatch = devX[i:i + self.batch_size]
120 | ybatch = devy[i:i + self.batch_size]
121 | if self.cudaEfficient:
122 | Xbatch = Xbatch.cuda()
123 | ybatch = ybatch.cuda()
124 | output = self.model(Xbatch)
125 | pred = output.data.max(1)[1]
126 | correct += pred.long().eq(ybatch.data.long()).sum().item()
127 | accuracy = 1.0 * correct / len(devX)
128 | return accuracy
129 |
130 | def predict(self, devX):
131 | self.model.eval()
132 | if not isinstance(devX, torch.cuda.FloatTensor):
133 | devX = torch.FloatTensor(devX).cuda()
134 | yhat = np.array([])
135 | with torch.no_grad():
136 | for i in range(0, len(devX), self.batch_size):
137 | Xbatch = devX[i:i + self.batch_size]
138 | output = self.model(Xbatch)
139 | yhat = np.append(yhat,
140 | output.data.max(1)[1].cpu().numpy())
141 | yhat = np.vstack(yhat)
142 | return yhat
143 |
144 | def predict_proba(self, devX):
145 | self.model.eval()
146 | probas = []
147 | with torch.no_grad():
148 | for i in range(0, len(devX), self.batch_size):
149 | Xbatch = devX[i:i + self.batch_size]
150 | vals = F.softmax(self.model(Xbatch).data.cpu().numpy())
151 | if not probas:
152 | probas = vals
153 | else:
154 | probas = np.concatenate(probas, vals, axis=0)
155 | return probas
156 |
157 |
158 | """
159 | MLP with Pytorch (nhid=0 --> Logistic Regression)
160 | """
161 |
162 | class MLP(PyTorchClassifier):
163 | def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64,
164 | seed=1111, cudaEfficient=False):
165 | super(self.__class__, self).__init__(inputdim, nclasses, l2reg,
166 | batch_size, seed, cudaEfficient)
167 | """
168 | PARAMETERS:
169 | -nhid: number of hidden units (0: Logistic Regression)
170 | -optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
171 | -tenacity: how many times dev acc does not increase before stopping
172 | -epoch_size: each epoch corresponds to epoch_size pass on the train set
173 | -max_epoch: max number of epoches
174 | -dropout: dropout for MLP
175 | """
176 |
177 | self.nhid = 0 if "nhid" not in params else params["nhid"]
178 | self.optim = "adam" if "optim" not in params else params["optim"]
179 | self.tenacity = 5 if "tenacity" not in params else params["tenacity"]
180 | self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"]
181 | self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"]
182 | self.dropout = 0. if "dropout" not in params else params["dropout"]
183 | self.batch_size = 64 if "batch_size" not in params else params["batch_size"]
184 |
185 | if params["nhid"] == 0:
186 | self.model = nn.Sequential(
187 | nn.Linear(self.inputdim, self.nclasses),
188 | ).cuda()
189 | else:
190 | self.model = nn.Sequential(
191 | nn.Linear(self.inputdim, params["nhid"]),
192 | nn.Dropout(p=self.dropout),
193 | nn.Sigmoid(),
194 | nn.Linear(params["nhid"], self.nclasses),
195 | ).cuda()
196 |
197 | self.loss_fn = nn.CrossEntropyLoss().cuda()
198 | self.loss_fn.size_average = False
199 |
200 | optim_fn, optim_params = utils.get_optimizer(self.optim)
201 | self.optimizer = optim_fn(self.model.parameters(), **optim_params)
202 | self.optimizer.param_groups[0]['weight_decay'] = self.l2reg
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Alleviating Over-smoothing for Unsupervised Sentence Representation
2 | This is the project of our ACL 2023 paper: [Alleviating Over-smoothing for Unsupervised Sentence Representation](https://arxiv.org/pdf/2305.06154).
3 |
4 | Our work mainly based on [SimCSE project](https://github.com/princeton-nlp/SimCSE), thanks to SimCSE!
5 |
6 | ## Quick Links
7 |
8 | - [Overview](#overview)
9 | - [Getting Started](##setup)
10 | - [Train SSCL](#training)
11 | - [Requirements](#requirements)
12 | - [Evaluation](#evaluation)
13 | - [Training](#training)
14 | - [Citation](#citation)
15 |
16 | ## Overview
17 |
18 | we present a new training paradigm based on contrastive learning: Simple contrastive method named Self-Contrastive Learning (SSCL), which can significantly improve the performance of learned sentence representations while alleviating the over-smoothing issue. Simply Said, we utilize hidden representations from intermediate PLMs layers as negative samples which the final sentence representations should be away from. Generally, our SSCL has several advantages: (1) It is fairly straightforward and does not require complex data augmentation techniques; (2) It can be seen as a contrastive framework that focuses on mining negatives effectively, and can be easily extended into different sentence encoders that aim for building positive pairs; (3) It can further be viewed as a plug-and-play framework for enhancing sentence representations.
19 | 
20 |
21 | ## Setup
22 | First, install PyTorch by following the instructions from the [official website](https://pytorch.org/get-started/previous-versions/). To faithfully reproduce our results, please use the correct 1.9.1 version corresponding to your platforms/CUDA versions. Install PyTorch by the following command,
23 |
24 | ```
25 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
26 | ```
27 | Then run the following script to install the remaining dependencies,
28 |
29 | ```bash
30 | pip install -r requirements.txt
31 | ```
32 |
33 | ### Evaluation
34 | Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation.
35 |
36 | Before evaluation, please download the evaluation datasets by running
37 | ```bash
38 | cd SentEval/data/downstream/
39 | bash download_dataset.sh
40 | ```
41 |
42 | Then come back to the root directory, you can evaluate any `transformers`-based pre-trained models using our evaluation code. For example,
43 | ```bash
44 | python evaluation.py \
45 | --model_name_or_path sscl-bert-base-uncased \
46 | --pooler cls \
47 | --task_set sts \
48 | --mode test
49 | ```
50 |
51 | Arguments for the evaluation script are as follows,
52 |
53 | * `--model_name_or_path`: The name or path of a `transformers`-based pre-trained checkpoint. You can directly use the models in the above table.
54 | * `--pooler`: Pooling method. Now we support
55 | * `cls` (default): Use the representation of `[CLS]` token. A linear+activation layer is applied after the representation (it's in the standard BERT implementation).
56 | * `cls_before_pooler`: Use the representation of `[CLS]` token without the extra linear+activation.
57 | * `avg`: Average embeddings of the last layer. If you use checkpoints of SBERT/SRoBERTa ([paper](https://arxiv.org/abs/1908.10084)), you should use this option.
58 | * `avg_top2`: Average embeddings of the last two layers.
59 | * `avg_first_last`: Average embeddings of the first and last layers. If you use vanilla BERT or RoBERTa, this works the best.
60 | * `--mode`: Evaluation mode
61 | * `test` (default): The default test mode. To faithfully reproduce our results, you should use this option.
62 | * `dev`: Report the development set results. Note that in STS tasks, only `STS-B` and `SICK-R` have development sets, so we only report their numbers. It also takes a fast mode for transfer tasks, so the running time is much shorter than the `test` mode (though numbers are slightly lower).
63 | * `fasttest`: It is the same as `test`, but with a fast mode so the running time is much shorter, but the reported numbers may be lower (only for transfer tasks).
64 | * `--task_set`: What set of tasks to evaluate on (if set, it will override `--tasks`)
65 | * `sts` (default): Evaluate on STS tasks, including `STS 12~16`, `STS-B` and `SICK-R`. This is the most commonly-used set of tasks to evaluate the quality of sentence embeddings.
66 | * `transfer`: Evaluate on transfer tasks.
67 | * `full`: Evaluate on both STS and transfer tasks.
68 | * `na`: Manually set tasks by `--tasks`.
69 | * `--tasks`: Specify which dataset(s) to evaluate on. Will be overridden if `--task_set` is not `na`. See the code for a full list of tasks.
70 |
71 | ### Training
72 |
73 | **Data**
74 |
75 | You can run `data/download_wiki.sh` and `data/download_nli.sh` to download the two datasets.
76 |
77 | **Training scripts**
78 | In `run_unsup_example.sh`, we provide a single-GPU (or CPU) example for the unsupervised version. We explain the arguments in following:
79 | * `--train_file`: Training file path. We support "txt" files (one line for one sentence) and "csv" files (2-column: pair data with no hard negative; 3-column: pair data with one corresponding hard negative instance). You can use our provided Wikipedia or NLI data, or you can use your own data with the same format.
80 | * `--model_name_or_path`: Pre-trained checkpoints to start with. For now we support BERT-based models (`bert-base-uncased`, `bert-large-uncased`, etc.).
81 | * `--temp`: Temperature for the contrastive loss.
82 | * `--pooler_type`: Pooling method. It's the same as the `--pooler_type` in the [evaluation part](#evaluation).
83 | * `--mlp_only_train`: You should use this argument when training SSCL models.
84 | * `--hard_negative_weight`: If using hard negatives (i.e., there are 3 columns in the training file), this is the logarithm of the weight. For example, if the weight is 1, then this argument should be set as 0 (default value).
85 | * `--do_mlm`: Whether to use the MLM auxiliary objective. If True:
86 | * `--mlm_weight`: Weight for the MLM objective.
87 | * `--mlm_probability`: Masking rate for the MLM objective.
88 | * `--do_neg`: Whether to use negatives in SSCL.
89 | * `--hard_negative_layers`: How many previous layers to construct negative layers.
90 |
91 | All the other arguments are standard Huggingface's `transformers` training arguments. Some of the often-used arguments are: `--output_dir`, `--learning_rate`, `--per_device_train_batch_size`. In our example scripts, we also set to evaluate the model on the STS-B development set (need to download the dataset following the [evaluation](#evaluation) section) and save the best checkpoint.
92 |
93 | For results in the paper, we use Nvidia A100 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.
94 |
95 |
96 |
97 | **Convert models**
98 |
99 | Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python sscl_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. After that, you can evaluate it by our [evaluation](#evaluation) code or directly use it [out of the box](#use-our-models-out-of-the-box).
100 |
101 | ## Citation
102 |
103 | Please cite our paper if you use SSCL in your work:
104 |
105 | ```bibtex
106 | @inproceedings{chen-etal-2023-alleviating,
107 | title = "Alleviating Over-smoothing for Unsupervised Sentence Representation",
108 | author = "Chen, Nuo and
109 | Shou, Linjun and
110 | Pei, Jian and
111 | Gong, Ming and
112 | Cao, Bowen and
113 | Chang, Jianhui and
114 | Li, Jia and
115 | Jiang, Daxin",
116 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
117 | month = jul,
118 | year = "2023",
119 | address = "Toronto, Canada",
120 | publisher = "Association for Computational Linguistics",
121 | url = "https://aclanthology.org/2023.acl-long.197",
122 | pages = "3552--3566",
123 | abstract = "Currently, learning better unsupervised sentence representations is the pursuit of many natural language processing communities. Lots of approaches based on pre-trained language models (PLMs) and contrastive learning have achieved promising results on this task. Experimentally, we observe that the over-smoothing problem reduces the capacity of these powerful PLMs, leading to sub-optimal sentence representations. In this paper, we present a Simple method named Self-Contrastive Learning (SSCL) to alleviate this issue, which samples negatives from PLMs intermediate layers, improving the quality of the sentence representation. Our proposed method is quite simple and can be easily extended to various state-of-the-art models for performance boosting, which can be seen as a plug-and-play contrastive framework for learning unsupervised sentence representation. Extensive results prove that SSCL brings the superior performance improvements of different strong baselines (e.g., BERT and SimCSE) on Semantic Textual Similarity and Transfer datasets",
124 | }
125 | ```
126 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import io, os
3 | import numpy as np
4 | import logging
5 | import argparse
6 | from prettytable import PrettyTable
7 | import torch
8 | import transformers
9 | from transformers import AutoModel, AutoTokenizer
10 | import xlrd
11 | from openpyxl import Workbook
12 |
13 | # Set up logger
14 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
15 |
16 | # Set PATHs
17 | PATH_TO_SENTEVAL = './SentEval'
18 | PATH_TO_DATA = './SentEval/data'
19 | workbook = Workbook()
20 |
21 | # Import SentEval
22 | sys.path.insert(0, PATH_TO_SENTEVAL)
23 | import senteval
24 |
25 | def print_table(task_names, scores):
26 | tb = PrettyTable()
27 | tb.field_names = task_names
28 | tb.add_row(scores)
29 | print(tb)
30 |
31 | def main():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--model_name_or_path", type=str,
34 | help="Transformers' model name or path")
35 | parser.add_argument("--pooler", type=str,
36 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'],
37 | default='cls',
38 | help="Which pooler to use")
39 | parser.add_argument("--mode", type=str,
40 | choices=['dev', 'test', 'fasttest'],
41 | default='test',
42 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results")
43 | parser.add_argument("--task_set", type=str,
44 | choices=['sts', 'transfer', 'full', 'na'],
45 | default='sts',
46 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'")
47 | parser.add_argument("--tasks", type=str, nargs='+',
48 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
49 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC',
50 | 'SICKRelatedness', 'STSBenchmark'],
51 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden")
52 |
53 | args = parser.parse_args()
54 |
55 | # Load transformers' model checkpoint
56 | model = AutoModel.from_pretrained(args.model_name_or_path)
57 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
58 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
59 | model = model.to(device)
60 |
61 | # Set up the tasks
62 | if args.task_set == 'sts':
63 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
64 | elif args.task_set == 'transfer':
65 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC']
66 | elif args.task_set == 'full':
67 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
68 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC']
69 |
70 | # Set params for SentEval
71 | if args.mode == 'dev' or args.mode == 'fasttest':
72 | # Fast mode
73 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
74 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
75 | 'tenacity': 3, 'epoch_size': 2}
76 | elif args.mode == 'test':
77 | # Full mode
78 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
79 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
80 | 'tenacity': 5, 'epoch_size': 4}
81 | else:
82 | raise NotImplementedError
83 |
84 | # SentEval prepare and batcher
85 | def prepare(params, samples):
86 | return
87 |
88 | def batcher(params, batch, max_length=None):
89 | # Handle rare token encoding issues in the dataset
90 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes):
91 | batch = [[word.decode('utf-8') for word in s] for s in batch]
92 |
93 | sentences = [' '.join(s) for s in batch]
94 |
95 | # Tokenization
96 | if max_length is not None:
97 | batch = tokenizer.batch_encode_plus(
98 | sentences,
99 | return_tensors='pt',
100 | padding=True,
101 | max_length=max_length,
102 | truncation=True
103 | )
104 | else:
105 | batch = tokenizer.batch_encode_plus(
106 | sentences,
107 | return_tensors='pt',
108 | padding=True,
109 | )
110 |
111 | # Move to the correct device
112 | for k in batch:
113 | batch[k] = batch[k].to(device)
114 |
115 | # Get raw embeddings
116 | with torch.no_grad():
117 | outputs = model(**batch, output_hidden_states=True, return_dict=True)
118 | last_hidden = outputs.last_hidden_state
119 | pooler_output = outputs.pooler_output
120 | hidden_states = outputs.hidden_states
121 |
122 | # Apply different poolers
123 | if args.pooler == 'cls':
124 | # There is a linear+activation layer after CLS representation
125 | return pooler_output.cpu()
126 | elif args.pooler == 'cls_before_pooler':
127 | return last_hidden[:, 0].cpu()
128 | elif args.pooler == "avg":
129 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu()
130 | elif args.pooler == "avg_first_last":
131 | first_hidden = hidden_states[0]
132 | last_hidden = hidden_states[-1]
133 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)
134 | return pooled_result.cpu()
135 | elif args.pooler == "avg_top2":
136 | second_last_hidden = hidden_states[-2]
137 | last_hidden = hidden_states[-1]
138 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)
139 | return pooled_result.cpu()
140 | else:
141 | raise NotImplementedError
142 |
143 | results = {}
144 |
145 | for task in args.tasks:
146 | se = senteval.engine.SE(params, batcher, prepare)
147 | result = se.eval(task)
148 | results[task] = result
149 |
150 | # Print evaluation results
151 | if args.mode == 'dev':
152 | print("------ %s ------" % (args.mode))
153 |
154 | task_names = []
155 | scores = []
156 | for task in ['STSBenchmark', 'SICKRelatedness']:
157 | task_names.append(task)
158 | if task in results:
159 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100))
160 | else:
161 | scores.append("0.00")
162 | print_table(task_names, scores)
163 |
164 | task_names = []
165 | scores = []
166 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']:
167 | task_names.append(task)
168 | if task in results:
169 | scores.append("%.2f" % (results[task]['devacc']))
170 | else:
171 | scores.append("0.00")
172 | task_names.append("Avg.")
173 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
174 | print_table(task_names, scores)
175 |
176 | elif args.mode == 'test' or args.mode == 'fasttest':
177 | print("------ %s ------" % (args.mode))
178 | save_file = os.path.join(args.model_name_or_path, 'results.xlsx')
179 | task_names = []
180 | scores = []
181 | sheet_name = ''
182 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']:
183 | task_names.append(task)
184 | if task in results:
185 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
186 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100))
187 | else:
188 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100))
189 | else:
190 | scores.append("0.00")
191 | sheet_name = 'STS'
192 | task_names.append("Avg.")
193 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
194 | # print(task_names)
195 | # print(scores)
196 | worksheet = workbook.create_sheet(sheet_name)
197 | worksheet.append(task_names)
198 | worksheet.append(scores)
199 | print_table(task_names, scores)
200 |
201 | task_names = []
202 | scores = []
203 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']:
204 | task_names.append(task)
205 | if task in results:
206 | scores.append("%.2f" % (results[task]['acc']))
207 | else:
208 | scores.append("0.00")
209 | sheet_name = 'Transfer'
210 | task_names.append("Avg.")
211 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
212 | print_table(task_names, scores)
213 | worksheet = workbook.create_sheet(sheet_name)
214 | worksheet.append(task_names)
215 | worksheet.append(scores)
216 | workbook.save(filename=save_file)
217 |
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/SentEval/senteval/sick.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | SICK Relatedness and Entailment
10 | '''
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import os
14 | import io
15 | import logging
16 | import numpy as np
17 |
18 | from sklearn.metrics import mean_squared_error
19 | from scipy.stats import pearsonr, spearmanr
20 |
21 | from senteval.tools.relatedness import RelatednessPytorch
22 | from senteval.tools.validation import SplitClassifier
23 |
24 | class SICKEval(object):
25 | def __init__(self, task_path, seed=1111):
26 | logging.debug('***** Transfer task : SICK-Relatedness*****\n\n')
27 | self.seed = seed
28 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
29 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
30 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
31 | self.sick_data = {'train': train, 'dev': dev, 'test': test}
32 |
33 | def do_prepare(self, params, prepare):
34 | samples = self.sick_data['train']['X_A'] + \
35 | self.sick_data['train']['X_B'] + \
36 | self.sick_data['dev']['X_A'] + \
37 | self.sick_data['dev']['X_B'] + \
38 | self.sick_data['test']['X_A'] + self.sick_data['test']['X_B']
39 | return prepare(params, samples)
40 |
41 | def loadFile(self, fpath):
42 | skipFirstLine = True
43 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
44 | with io.open(fpath, 'r', encoding='utf-8') as f:
45 | for line in f:
46 | if skipFirstLine:
47 | skipFirstLine = False
48 | else:
49 | text = line.strip().split('\t')
50 | sick_data['X_A'].append(text[1].split())
51 | sick_data['X_B'].append(text[2].split())
52 | sick_data['y'].append(text[3])
53 |
54 | sick_data['y'] = [float(s) for s in sick_data['y']]
55 | return sick_data
56 |
57 | def run(self, params, batcher):
58 | sick_embed = {'train': {}, 'dev': {}, 'test': {}}
59 | bsize = params.batch_size
60 |
61 | for key in self.sick_data:
62 | logging.info('Computing embedding for {0}'.format(key))
63 | # Sort to reduce padding
64 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
65 | self.sick_data[key]['X_B'],
66 | self.sick_data[key]['y']),
67 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
68 |
69 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
70 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
71 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
72 |
73 | for txt_type in ['X_A', 'X_B']:
74 | sick_embed[key][txt_type] = []
75 | for ii in range(0, len(self.sick_data[key]['y']), bsize):
76 | batch = self.sick_data[key][txt_type][ii:ii + bsize]
77 | embeddings = batcher(params, batch)
78 | sick_embed[key][txt_type].append(embeddings)
79 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
80 | sick_embed[key]['y'] = np.array(self.sick_data[key]['y'])
81 | logging.info('Computed {0} embeddings'.format(key))
82 |
83 | # Train
84 | trainA = sick_embed['train']['X_A']
85 | trainB = sick_embed['train']['X_B']
86 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
87 | trainY = self.encode_labels(self.sick_data['train']['y'])
88 |
89 | # Dev
90 | devA = sick_embed['dev']['X_A']
91 | devB = sick_embed['dev']['X_B']
92 | devF = np.c_[np.abs(devA - devB), devA * devB]
93 | devY = self.encode_labels(self.sick_data['dev']['y'])
94 |
95 | # Test
96 | testA = sick_embed['test']['X_A']
97 | testB = sick_embed['test']['X_B']
98 | testF = np.c_[np.abs(testA - testB), testA * testB]
99 | testY = self.encode_labels(self.sick_data['test']['y'])
100 |
101 | config = {'seed': self.seed, 'nclasses': 5}
102 | clf = RelatednessPytorch(train={'X': trainF, 'y': trainY},
103 | valid={'X': devF, 'y': devY},
104 | test={'X': testF, 'y': testY},
105 | devscores=self.sick_data['dev']['y'],
106 | config=config)
107 |
108 | devspr, yhat = clf.run()
109 |
110 | pr = pearsonr(yhat, self.sick_data['test']['y'])[0]
111 | sr = spearmanr(yhat, self.sick_data['test']['y'])[0]
112 | pr = 0 if pr != pr else pr
113 | sr = 0 if sr != sr else sr
114 | se = mean_squared_error(yhat, self.sick_data['test']['y'])
115 | logging.debug('Dev : Spearman {0}'.format(devspr))
116 | logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \
117 | for SICK Relatedness\n'.format(pr, sr, se))
118 |
119 | return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se,
120 | 'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)}
121 |
122 | def encode_labels(self, labels, nclass=5):
123 | """
124 | Label encoding from Tree LSTM paper (Tai, Socher, Manning)
125 | """
126 | Y = np.zeros((len(labels), nclass)).astype('float32')
127 | for j, y in enumerate(labels):
128 | for i in range(nclass):
129 | if i+1 == np.floor(y) + 1:
130 | Y[j, i] = y - np.floor(y)
131 | if i+1 == np.floor(y):
132 | Y[j, i] = np.floor(y) - y + 1
133 | return Y
134 |
135 |
136 | class SICKEntailmentEval(SICKEval):
137 | def __init__(self, task_path, seed=1111):
138 | logging.debug('***** Transfer task : SICK-Entailment*****\n\n')
139 | self.seed = seed
140 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
141 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
142 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
143 | self.sick_data = {'train': train, 'dev': dev, 'test': test}
144 |
145 | def loadFile(self, fpath):
146 | label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
147 | skipFirstLine = True
148 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
149 | with io.open(fpath, 'r', encoding='utf-8') as f:
150 | for line in f:
151 | if skipFirstLine:
152 | skipFirstLine = False
153 | else:
154 | text = line.strip().split('\t')
155 | sick_data['X_A'].append(text[1].split())
156 | sick_data['X_B'].append(text[2].split())
157 | sick_data['y'].append(text[4])
158 | sick_data['y'] = [label2id[s] for s in sick_data['y']]
159 | return sick_data
160 |
161 | def run(self, params, batcher):
162 | sick_embed = {'train': {}, 'dev': {}, 'test': {}}
163 | bsize = params.batch_size
164 |
165 | for key in self.sick_data:
166 | logging.info('Computing embedding for {0}'.format(key))
167 | # Sort to reduce padding
168 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
169 | self.sick_data[key]['X_B'],
170 | self.sick_data[key]['y']),
171 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
172 |
173 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
174 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
175 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
176 |
177 | for txt_type in ['X_A', 'X_B']:
178 | sick_embed[key][txt_type] = []
179 | for ii in range(0, len(self.sick_data[key]['y']), bsize):
180 | batch = self.sick_data[key][txt_type][ii:ii + bsize]
181 | embeddings = batcher(params, batch)
182 | sick_embed[key][txt_type].append(embeddings)
183 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
184 | logging.info('Computed {0} embeddings'.format(key))
185 |
186 | # Train
187 | trainA = sick_embed['train']['X_A']
188 | trainB = sick_embed['train']['X_B']
189 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
190 | trainY = np.array(self.sick_data['train']['y'])
191 |
192 | # Dev
193 | devA = sick_embed['dev']['X_A']
194 | devB = sick_embed['dev']['X_B']
195 | devF = np.c_[np.abs(devA - devB), devA * devB]
196 | devY = np.array(self.sick_data['dev']['y'])
197 |
198 | # Test
199 | testA = sick_embed['test']['X_A']
200 | testB = sick_embed['test']['X_B']
201 | testF = np.c_[np.abs(testA - testB), testA * testB]
202 | testY = np.array(self.sick_data['test']['y'])
203 |
204 | config = {'nclasses': 3, 'seed': self.seed,
205 | 'usepytorch': params.usepytorch,
206 | 'classifier': params.classifier,
207 | 'nhid': params.nhid}
208 | clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF},
209 | y={'train': trainY, 'valid': devY, 'test': testY},
210 | config=config)
211 |
212 | devacc, testacc = clf.run()
213 | logging.debug('\nDev acc : {0} Test acc : {1} for \
214 | SICK entailment\n'.format(devacc, testacc))
215 | return {'devacc': devacc, 'acc': testacc,
216 | 'ndev': len(devA), 'ntest': len(testA)}
217 |
--------------------------------------------------------------------------------
/SentEval/senteval/sts.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and
10 | STS-benchmark (supervised) tasks
11 | '''
12 |
13 | from __future__ import absolute_import, division, unicode_literals
14 |
15 | import os
16 | import io
17 | import numpy as np
18 | import logging
19 |
20 | from scipy.stats import spearmanr, pearsonr
21 |
22 | from senteval.utils import cosine
23 | from senteval.sick import SICKEval
24 |
25 |
26 | class STSEval(object):
27 | def loadFile(self, fpath):
28 | self.data = {}
29 | self.samples = []
30 |
31 | for dataset in self.datasets:
32 | sent1, sent2 = zip(*[l.split("\t") for l in
33 | io.open(fpath + '/STS.input.%s.txt' % dataset,
34 | encoding='utf8').read().splitlines()])
35 | raw_scores = np.array([x for x in
36 | io.open(fpath + '/STS.gs.%s.txt' % dataset,
37 | encoding='utf8')
38 | .read().splitlines()])
39 | not_empty_idx = raw_scores != ''
40 |
41 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]]
42 | sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
43 | sent2 = np.array([s.split() for s in sent2])[not_empty_idx]
44 | # sort data by length to minimize padding in batcher
45 | sorted_data = sorted(zip(sent1, sent2, gs_scores),
46 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
47 | sent1, sent2, gs_scores = map(list, zip(*sorted_data))
48 |
49 | self.data[dataset] = (sent1, sent2, gs_scores)
50 | self.samples += sent1 + sent2
51 |
52 | def do_prepare(self, params, prepare):
53 | if 'similarity' in params:
54 | self.similarity = params.similarity
55 | else: # Default similarity is cosine
56 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2)))
57 | return prepare(params, self.samples)
58 |
59 | def run(self, params, batcher):
60 | results = {}
61 | all_sys_scores = []
62 | all_gs_scores = []
63 | for dataset in self.datasets:
64 | sys_scores = []
65 | input1, input2, gs_scores = self.data[dataset]
66 | for ii in range(0, len(gs_scores), params.batch_size):
67 | batch1 = input1[ii:ii + params.batch_size]
68 | batch2 = input2[ii:ii + params.batch_size]
69 |
70 | # we assume get_batch already throws out the faulty ones
71 | if len(batch1) == len(batch2) and len(batch1) > 0:
72 | enc1 = batcher(params, batch1)
73 | enc2 = batcher(params, batch2)
74 |
75 | for kk in range(enc2.shape[0]):
76 | sys_score = self.similarity(enc1[kk], enc2[kk])
77 | sys_scores.append(sys_score)
78 | all_sys_scores.extend(sys_scores)
79 | all_gs_scores.extend(gs_scores)
80 |
81 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores),
82 | 'spearman': spearmanr(sys_scores, gs_scores),
83 | 'nsamples': len(sys_scores)}
84 | logging.debug('%s : pearson = %.4f, spearman = %.4f' %
85 | (dataset, results[dataset]['pearson'][0],
86 | results[dataset]['spearman'][0]))
87 |
88 | weights = [results[dset]['nsamples'] for dset in results.keys()]
89 | list_prs = np.array([results[dset]['pearson'][0] for
90 | dset in results.keys()])
91 | list_spr = np.array([results[dset]['spearman'][0] for
92 | dset in results.keys()])
93 |
94 | avg_pearson = np.average(list_prs)
95 | avg_spearman = np.average(list_spr)
96 | wavg_pearson = np.average(list_prs, weights=weights)
97 | wavg_spearman = np.average(list_spr, weights=weights)
98 | all_pearson = pearsonr(all_sys_scores, all_gs_scores)
99 | all_spearman = spearmanr(all_sys_scores, all_gs_scores)
100 | results['all'] = {'pearson': {'all': all_pearson[0],
101 | 'mean': avg_pearson,
102 | 'wmean': wavg_pearson},
103 | 'spearman': {'all': all_spearman[0],
104 | 'mean': avg_spearman,
105 | 'wmean': wavg_spearman}}
106 | logging.debug('ALL : Pearson = %.4f, \
107 | Spearman = %.4f' % (all_pearson[0], all_spearman[0]))
108 | logging.debug('ALL (weighted average) : Pearson = %.4f, \
109 | Spearman = %.4f' % (wavg_pearson, wavg_spearman))
110 | logging.debug('ALL (average) : Pearson = %.4f, \
111 | Spearman = %.4f\n' % (avg_pearson, avg_spearman))
112 |
113 | return results
114 |
115 |
116 | class STS12Eval(STSEval):
117 | def __init__(self, taskpath, seed=1111):
118 | logging.debug('***** Transfer task : STS12 *****\n\n')
119 | self.seed = seed
120 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl',
121 | 'surprise.OnWN', 'surprise.SMTnews']
122 | self.loadFile(taskpath)
123 |
124 |
125 | class STS13Eval(STSEval):
126 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue
127 | def __init__(self, taskpath, seed=1111):
128 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n')
129 | self.seed = seed
130 | self.datasets = ['FNWN', 'headlines', 'OnWN']
131 | self.loadFile(taskpath)
132 |
133 |
134 | class STS14Eval(STSEval):
135 | def __init__(self, taskpath, seed=1111):
136 | logging.debug('***** Transfer task : STS14 *****\n\n')
137 | self.seed = seed
138 | self.datasets = ['deft-forum', 'deft-news', 'headlines',
139 | 'images', 'OnWN', 'tweet-news']
140 | self.loadFile(taskpath)
141 |
142 |
143 | class STS15Eval(STSEval):
144 | def __init__(self, taskpath, seed=1111):
145 | logging.debug('***** Transfer task : STS15 *****\n\n')
146 | self.seed = seed
147 | self.datasets = ['answers-forums', 'answers-students',
148 | 'belief', 'headlines', 'images']
149 | self.loadFile(taskpath)
150 |
151 |
152 | class STS16Eval(STSEval):
153 | def __init__(self, taskpath, seed=1111):
154 | logging.debug('***** Transfer task : STS16 *****\n\n')
155 | self.seed = seed
156 | self.datasets = ['answer-answer', 'headlines', 'plagiarism',
157 | 'postediting', 'question-question']
158 | self.loadFile(taskpath)
159 |
160 |
161 | class STSBenchmarkEval(STSEval):
162 | def __init__(self, task_path, seed=1111):
163 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
164 | self.seed = seed
165 | self.samples = []
166 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
167 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
168 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
169 | self.datasets = ['train', 'dev', 'test']
170 | self.data = {'train': train, 'dev': dev, 'test': test}
171 |
172 | def loadFile(self, fpath):
173 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
174 | with io.open(fpath, 'r', encoding='utf-8') as f:
175 | for line in f:
176 | text = line.strip().split('\t')
177 | sick_data['X_A'].append(text[5].split())
178 | sick_data['X_B'].append(text[6].split())
179 | sick_data['y'].append(text[4])
180 |
181 | sick_data['y'] = [float(s) for s in sick_data['y']]
182 | self.samples += sick_data['X_A'] + sick_data["X_B"]
183 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
184 |
185 | class STSBenchmarkFinetune(SICKEval):
186 | def __init__(self, task_path, seed=1111):
187 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
188 | self.seed = seed
189 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
190 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
191 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
192 | self.sick_data = {'train': train, 'dev': dev, 'test': test}
193 |
194 | def loadFile(self, fpath):
195 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
196 | with io.open(fpath, 'r', encoding='utf-8') as f:
197 | for line in f:
198 | text = line.strip().split('\t')
199 | sick_data['X_A'].append(text[5].split())
200 | sick_data['X_B'].append(text[6].split())
201 | sick_data['y'].append(text[4])
202 |
203 | sick_data['y'] = [float(s) for s in sick_data['y']]
204 | return sick_data
205 |
206 | class SICKRelatednessEval(STSEval):
207 | def __init__(self, task_path, seed=1111):
208 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n')
209 | self.seed = seed
210 | self.samples = []
211 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
212 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
213 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
214 | self.datasets = ['train', 'dev', 'test']
215 | self.data = {'train': train, 'dev': dev, 'test': test}
216 |
217 | def loadFile(self, fpath):
218 | skipFirstLine = True
219 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
220 | with io.open(fpath, 'r', encoding='utf-8') as f:
221 | for line in f:
222 | if skipFirstLine:
223 | skipFirstLine = False
224 | else:
225 | text = line.strip().split('\t')
226 | sick_data['X_A'].append(text[1].split())
227 | sick_data['X_B'].append(text[2].split())
228 | sick_data['y'].append(text[3])
229 |
230 | sick_data['y'] = [float(s) for s in sick_data['y']]
231 | self.samples += sick_data['X_A'] + sick_data["X_B"]
232 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
233 |
--------------------------------------------------------------------------------
/SentEval/senteval/.ipynb_checkpoints/sts-checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and
10 | STS-benchmark (supervised) tasks
11 | '''
12 |
13 | from __future__ import absolute_import, division, unicode_literals
14 |
15 | import os
16 | import io
17 | import numpy as np
18 | import logging
19 |
20 | from scipy.stats import spearmanr, pearsonr
21 |
22 | from senteval.utils import cosine
23 | from senteval.sick import SICKEval
24 |
25 |
26 | class STSEval(object):
27 | def loadFile(self, fpath):
28 | self.data = {}
29 | self.samples = []
30 |
31 | for dataset in self.datasets:
32 | sent1, sent2 = zip(*[l.split("\t") for l in
33 | io.open(fpath + '/STS.input.%s.txt' % dataset,
34 | encoding='utf8').read().splitlines()])
35 | raw_scores = np.array([x for x in
36 | io.open(fpath + '/STS.gs.%s.txt' % dataset,
37 | encoding='utf8')
38 | .read().splitlines()])
39 | not_empty_idx = raw_scores != ''
40 |
41 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]]
42 | sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
43 | sent2 = np.array([s.split() for s in sent2])[not_empty_idx]
44 | # sort data by length to minimize padding in batcher
45 | sorted_data = sorted(zip(sent1, sent2, gs_scores),
46 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
47 | sent1, sent2, gs_scores = map(list, zip(*sorted_data))
48 |
49 | self.data[dataset] = (sent1, sent2, gs_scores)
50 | self.samples += sent1 + sent2
51 |
52 | def do_prepare(self, params, prepare):
53 | if 'similarity' in params:
54 | self.similarity = params.similarity
55 | else: # Default similarity is cosine
56 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2)))
57 | return prepare(params, self.samples)
58 |
59 | def run(self, params, batcher):
60 | results = {}
61 | all_sys_scores = []
62 | all_gs_scores = []
63 | for dataset in self.datasets:
64 | sys_scores = []
65 | input1, input2, gs_scores = self.data[dataset]
66 | for ii in range(0, len(gs_scores), params.batch_size):
67 | batch1 = input1[ii:ii + params.batch_size]
68 | batch2 = input2[ii:ii + params.batch_size]
69 |
70 | # we assume get_batch already throws out the faulty ones
71 | if len(batch1) == len(batch2) and len(batch1) > 0:
72 | enc1 = batcher(params, batch1)
73 | enc2 = batcher(params, batch2)
74 |
75 | for kk in range(enc2.shape[0]):
76 | sys_score = self.similarity(enc1[kk], enc2[kk])
77 | sys_scores.append(sys_score)
78 | all_sys_scores.extend(sys_scores)
79 | all_gs_scores.extend(gs_scores)
80 |
81 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores),
82 | 'spearman': spearmanr(sys_scores, gs_scores),
83 | 'nsamples': len(sys_scores)}
84 | logging.debug('%s : pearson = %.4f, spearman = %.4f' %
85 | (dataset, results[dataset]['pearson'][0],
86 | results[dataset]['spearman'][0]))
87 |
88 | weights = [results[dset]['nsamples'] for dset in results.keys()]
89 | list_prs = np.array([results[dset]['pearson'][0] for
90 | dset in results.keys()])
91 | list_spr = np.array([results[dset]['spearman'][0] for
92 | dset in results.keys()])
93 |
94 | avg_pearson = np.average(list_prs)
95 | avg_spearman = np.average(list_spr)
96 | wavg_pearson = np.average(list_prs, weights=weights)
97 | wavg_spearman = np.average(list_spr, weights=weights)
98 | all_pearson = pearsonr(all_sys_scores, all_gs_scores)
99 | all_spearman = spearmanr(all_sys_scores, all_gs_scores)
100 | results['all'] = {'pearson': {'all': all_pearson[0],
101 | 'mean': avg_pearson,
102 | 'wmean': wavg_pearson},
103 | 'spearman': {'all': all_spearman[0],
104 | 'mean': avg_spearman,
105 | 'wmean': wavg_spearman}}
106 | logging.debug('ALL : Pearson = %.4f, \
107 | Spearman = %.4f' % (all_pearson[0], all_spearman[0]))
108 | logging.debug('ALL (weighted average) : Pearson = %.4f, \
109 | Spearman = %.4f' % (wavg_pearson, wavg_spearman))
110 | logging.debug('ALL (average) : Pearson = %.4f, \
111 | Spearman = %.4f\n' % (avg_pearson, avg_spearman))
112 |
113 | return results
114 |
115 |
116 | class STS12Eval(STSEval):
117 | def __init__(self, taskpath, seed=1111):
118 | logging.debug('***** Transfer task : STS12 *****\n\n')
119 | self.seed = seed
120 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl',
121 | 'surprise.OnWN', 'surprise.SMTnews']
122 | self.loadFile(taskpath)
123 |
124 |
125 | class STS13Eval(STSEval):
126 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue
127 | def __init__(self, taskpath, seed=1111):
128 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n')
129 | self.seed = seed
130 | self.datasets = ['FNWN', 'headlines', 'OnWN']
131 | self.loadFile(taskpath)
132 |
133 |
134 | class STS14Eval(STSEval):
135 | def __init__(self, taskpath, seed=1111):
136 | logging.debug('***** Transfer task : STS14 *****\n\n')
137 | self.seed = seed
138 | self.datasets = ['deft-forum', 'deft-news', 'headlines',
139 | 'images', 'OnWN', 'tweet-news']
140 | self.loadFile(taskpath)
141 |
142 |
143 | class STS15Eval(STSEval):
144 | def __init__(self, taskpath, seed=1111):
145 | logging.debug('***** Transfer task : STS15 *****\n\n')
146 | self.seed = seed
147 | self.datasets = ['answers-forums', 'answers-students',
148 | 'belief', 'headlines', 'images']
149 | self.loadFile(taskpath)
150 |
151 |
152 | class STS16Eval(STSEval):
153 | def __init__(self, taskpath, seed=1111):
154 | logging.debug('***** Transfer task : STS16 *****\n\n')
155 | self.seed = seed
156 | self.datasets = ['answer-answer', 'headlines', 'plagiarism',
157 | 'postediting', 'question-question']
158 | self.loadFile(taskpath)
159 |
160 |
161 | class STSBenchmarkEval(STSEval):
162 | def __init__(self, task_path, seed=1111):
163 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
164 | self.seed = seed
165 | self.samples = []
166 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
167 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
168 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
169 | self.datasets = ['train', 'dev', 'test']
170 | self.data = {'train': train, 'dev': dev, 'test': test}
171 |
172 | def loadFile(self, fpath):
173 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
174 | with io.open(fpath, 'r', encoding='utf-8') as f:
175 | for line in f:
176 | text = line.strip().split('\t')
177 | sick_data['X_A'].append(text[5].split())
178 | sick_data['X_B'].append(text[6].split())
179 | sick_data['y'].append(text[4])
180 |
181 | sick_data['y'] = [float(s) for s in sick_data['y']]
182 | self.samples += sick_data['X_A'] + sick_data["X_B"]
183 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
184 |
185 | class STSBenchmarkFinetune(SICKEval):
186 | def __init__(self, task_path, seed=1111):
187 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n')
188 | self.seed = seed
189 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv'))
190 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv'))
191 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv'))
192 | self.sick_data = {'train': train, 'dev': dev, 'test': test}
193 |
194 | def loadFile(self, fpath):
195 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
196 | with io.open(fpath, 'r', encoding='utf-8') as f:
197 | for line in f:
198 | text = line.strip().split('\t')
199 | sick_data['X_A'].append(text[5].split())
200 | sick_data['X_B'].append(text[6].split())
201 | sick_data['y'].append(text[4])
202 |
203 | sick_data['y'] = [float(s) for s in sick_data['y']]
204 | return sick_data
205 |
206 | class SICKRelatednessEval(STSEval):
207 | def __init__(self, task_path, seed=1111):
208 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n')
209 | self.seed = seed
210 | self.samples = []
211 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
212 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
213 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
214 | self.datasets = ['train', 'dev', 'test']
215 | self.data = {'train': train, 'dev': dev, 'test': test}
216 |
217 | def loadFile(self, fpath):
218 | skipFirstLine = True
219 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
220 | with io.open(fpath, 'r', encoding='utf-8') as f:
221 | for line in f:
222 | if skipFirstLine:
223 | skipFirstLine = False
224 | else:
225 | text = line.strip().split('\t')
226 | sick_data['X_A'].append(text[1].split())
227 | sick_data['X_B'].append(text[2].split())
228 | sick_data['y'].append(text[3])
229 |
230 | sick_data['y'] = [float(s) for s in sick_data['y']]
231 | self.samples += sick_data['X_A'] + sick_data["X_B"]
232 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y'])
233 |
--------------------------------------------------------------------------------
/SentEval/examples/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
10 | """
11 |
12 | import numpy as np
13 | import time
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 |
19 | class InferSent(nn.Module):
20 |
21 | def __init__(self, config):
22 | super(InferSent, self).__init__()
23 | self.bsize = config['bsize']
24 | self.word_emb_dim = config['word_emb_dim']
25 | self.enc_lstm_dim = config['enc_lstm_dim']
26 | self.pool_type = config['pool_type']
27 | self.dpout_model = config['dpout_model']
28 | self.version = 1 if 'version' not in config else config['version']
29 |
30 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1,
31 | bidirectional=True, dropout=self.dpout_model)
32 |
33 | assert self.version in [1, 2]
34 | if self.version == 1:
35 | self.bos = ''
36 | self.eos = ''
37 | self.max_pad = True
38 | self.moses_tok = False
39 | elif self.version == 2:
40 | self.bos = '
'
41 | self.eos = '
'
42 | self.max_pad = False
43 | self.moses_tok = True
44 |
45 | def is_cuda(self):
46 | # either all weights are on cpu or they are on gpu
47 | return self.enc_lstm.bias_hh_l0.data.is_cuda
48 |
49 | def forward(self, sent_tuple):
50 | # sent_len: [max_len, ..., min_len] (bsize)
51 | # sent: (seqlen x bsize x worddim)
52 | sent, sent_len = sent_tuple
53 |
54 | # Sort by length (keep idx)
55 | sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
56 | sent_len_sorted = sent_len_sorted.copy()
57 | idx_unsort = np.argsort(idx_sort)
58 |
59 | idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \
60 | else torch.from_numpy(idx_sort)
61 | sent = sent.index_select(1, idx_sort)
62 |
63 | # Handling padding in Recurrent Networks
64 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted)
65 | sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid
66 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]
67 |
68 | # Un-sort by length
69 | idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \
70 | else torch.from_numpy(idx_unsort)
71 | sent_output = sent_output.index_select(1, idx_unsort)
72 |
73 | # Pooling
74 | if self.pool_type == "mean":
75 | sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda()
76 | emb = torch.sum(sent_output, 0).squeeze(0)
77 | emb = emb / sent_len.expand_as(emb)
78 | elif self.pool_type == "max":
79 | if not self.max_pad:
80 | sent_output[sent_output == 0] = -1e9
81 | emb = torch.max(sent_output, 0)[0]
82 | if emb.ndimension() == 3:
83 | emb = emb.squeeze(0)
84 | assert emb.ndimension() == 2
85 |
86 | return emb
87 |
88 | def set_w2v_path(self, w2v_path):
89 | self.w2v_path = w2v_path
90 |
91 | def get_word_dict(self, sentences, tokenize=True):
92 | # create vocab of words
93 | word_dict = {}
94 | sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences]
95 | for sent in sentences:
96 | for word in sent:
97 | if word not in word_dict:
98 | word_dict[word] = ''
99 | word_dict[self.bos] = ''
100 | word_dict[self.eos] = ''
101 | return word_dict
102 |
103 | def get_w2v(self, word_dict):
104 | assert hasattr(self, 'w2v_path'), 'w2v path not set'
105 | # create word_vec with w2v vectors
106 | word_vec = {}
107 | with open(self.w2v_path, encoding='utf-8') as f:
108 | for line in f:
109 | word, vec = line.split(' ', 1)
110 | if word in word_dict:
111 | word_vec[word] = np.fromstring(vec, sep=' ')
112 | print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict)))
113 | return word_vec
114 |
115 | def get_w2v_k(self, K):
116 | assert hasattr(self, 'w2v_path'), 'w2v path not set'
117 | # create word_vec with k first w2v vectors
118 | k = 0
119 | word_vec = {}
120 | with open(self.w2v_path, encoding='utf-8') as f:
121 | for line in f:
122 | word, vec = line.split(' ', 1)
123 | if k <= K:
124 | word_vec[word] = np.fromstring(vec, sep=' ')
125 | k += 1
126 | if k > K:
127 | if word in [self.bos, self.eos]:
128 | word_vec[word] = np.fromstring(vec, sep=' ')
129 |
130 | if k > K and all([w in word_vec for w in [self.bos, self.eos]]):
131 | break
132 | return word_vec
133 |
134 | def build_vocab(self, sentences, tokenize=True):
135 | assert hasattr(self, 'w2v_path'), 'w2v path not set'
136 | word_dict = self.get_word_dict(sentences, tokenize)
137 | self.word_vec = self.get_w2v(word_dict)
138 | print('Vocab size : %s' % (len(self.word_vec)))
139 |
140 | # build w2v vocab with k most frequent words
141 | def build_vocab_k_words(self, K):
142 | assert hasattr(self, 'w2v_path'), 'w2v path not set'
143 | self.word_vec = self.get_w2v_k(K)
144 | print('Vocab size : %s' % (K))
145 |
146 | def update_vocab(self, sentences, tokenize=True):
147 | assert hasattr(self, 'w2v_path'), 'warning : w2v path not set'
148 | assert hasattr(self, 'word_vec'), 'build_vocab before updating it'
149 | word_dict = self.get_word_dict(sentences, tokenize)
150 |
151 | # keep only new words
152 | for word in self.word_vec:
153 | if word in word_dict:
154 | del word_dict[word]
155 |
156 | # udpate vocabulary
157 | if word_dict:
158 | new_word_vec = self.get_w2v(word_dict)
159 | self.word_vec.update(new_word_vec)
160 | else:
161 | new_word_vec = []
162 | print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec)))
163 |
164 | def get_batch(self, batch):
165 | # sent in batch in decreasing order of lengths
166 | # batch: (bsize, max_len, word_dim)
167 | embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim))
168 |
169 | for i in range(len(batch)):
170 | for j in range(len(batch[i])):
171 | embed[j, i, :] = self.word_vec[batch[i][j]]
172 |
173 | return torch.FloatTensor(embed)
174 |
175 | def tokenize(self, s):
176 | from nltk.tokenize import word_tokenize
177 | if self.moses_tok:
178 | s = ' '.join(word_tokenize(s))
179 | s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization
180 | return s.split()
181 | else:
182 | return word_tokenize(s)
183 |
184 | def prepare_samples(self, sentences, bsize, tokenize, verbose):
185 | sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else
186 | [self.bos] + self.tokenize(s) + [self.eos] for s in sentences]
187 | n_w = np.sum([len(x) for x in sentences])
188 |
189 | # filters words without w2v vectors
190 | for i in range(len(sentences)):
191 | s_f = [word for word in sentences[i] if word in self.word_vec]
192 | if not s_f:
193 | import warnings
194 | warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \
195 | Replacing by ""..' % (sentences[i], i))
196 | s_f = [self.eos]
197 | sentences[i] = s_f
198 |
199 | lengths = np.array([len(s) for s in sentences])
200 | n_wk = np.sum(lengths)
201 | if verbose:
202 | print('Nb words kept : %s/%s (%.1f%s)' % (
203 | n_wk, n_w, 100.0 * n_wk / n_w, '%'))
204 |
205 | # sort by decreasing length
206 | lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths)
207 | sentences = np.array(sentences)[idx_sort]
208 |
209 | return sentences, lengths, idx_sort
210 |
211 | def encode(self, sentences, bsize=64, tokenize=True, verbose=False):
212 | tic = time.time()
213 | sentences, lengths, idx_sort = self.prepare_samples(
214 | sentences, bsize, tokenize, verbose)
215 |
216 | embeddings = []
217 | for stidx in range(0, len(sentences), bsize):
218 | batch = self.get_batch(sentences[stidx:stidx + bsize])
219 | if self.is_cuda():
220 | batch = batch.cuda()
221 | with torch.no_grad():
222 | batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy()
223 | embeddings.append(batch)
224 | embeddings = np.vstack(embeddings)
225 |
226 | # unsort
227 | idx_unsort = np.argsort(idx_sort)
228 | embeddings = embeddings[idx_unsort]
229 |
230 | if verbose:
231 | print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % (
232 | len(embeddings)/(time.time()-tic),
233 | 'gpu' if self.is_cuda() else 'cpu', bsize))
234 | return embeddings
235 |
236 | def visualize(self, sent, tokenize=True):
237 |
238 | sent = sent.split() if not tokenize else self.tokenize(sent)
239 | sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]]
240 |
241 | if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos):
242 | import warnings
243 | warnings.warn('No words in "%s" have w2v vectors. Replacing \
244 | by "%s %s"..' % (sent, self.bos, self.eos))
245 | batch = self.get_batch(sent)
246 |
247 | if self.is_cuda():
248 | batch = batch.cuda()
249 | output = self.enc_lstm(batch)[0]
250 | output, idxs = torch.max(output, 0)
251 | # output, idxs = output.squeeze(), idxs.squeeze()
252 | idxs = idxs.data.cpu().numpy()
253 | argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))]
254 |
255 | # visualize model
256 | import matplotlib.pyplot as plt
257 | x = range(len(sent[0]))
258 | y = [100.0 * n / np.sum(argmaxs) for n in argmaxs]
259 | plt.xticks(x, sent[0], rotation=45)
260 | plt.bar(x, y)
261 | plt.ylabel('%')
262 | plt.title('Visualisation of words importance')
263 | plt.show()
264 |
265 | return output, idxs
266 |
--------------------------------------------------------------------------------
/SentEval/senteval/tools/validation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Validation and classification
10 | (train) : inner-kfold classifier
11 | (train, test) : kfold classifier
12 | (train, dev, test) : split classifier
13 |
14 | """
15 | from __future__ import absolute_import, division, unicode_literals
16 |
17 | import logging
18 | import numpy as np
19 | from senteval.tools.classifier import MLP
20 |
21 | import sklearn
22 | assert(sklearn.__version__ >= "0.18.0"), \
23 | "need to update sklearn to version >= 0.18.0"
24 | from sklearn.linear_model import LogisticRegression
25 | from sklearn.model_selection import StratifiedKFold
26 |
27 |
28 | def get_classif_name(classifier_config, usepytorch):
29 | if not usepytorch:
30 | modelname = 'sklearn-LogReg'
31 | else:
32 | nhid = classifier_config['nhid']
33 | optim = 'adam' if 'optim' not in classifier_config else classifier_config['optim']
34 | bs = 64 if 'batch_size' not in classifier_config else classifier_config['batch_size']
35 | modelname = 'pytorch-MLP-nhid%s-%s-bs%s' % (nhid, optim, bs)
36 | return modelname
37 |
38 | # Pytorch version
39 | class InnerKFoldClassifier(object):
40 | """
41 | (train) split classifier : InnerKfold.
42 | """
43 | def __init__(self, X, y, config):
44 | self.X = X
45 | self.y = y
46 | self.featdim = X.shape[1]
47 | self.nclasses = config['nclasses']
48 | self.seed = config['seed']
49 | self.devresults = []
50 | self.testresults = []
51 | self.usepytorch = config['usepytorch']
52 | self.classifier_config = config['classifier']
53 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch)
54 |
55 | self.k = 5 if 'kfold' not in config else config['kfold']
56 |
57 | def run(self):
58 | logging.info('Training {0} with (inner) {1}-fold cross-validation'
59 | .format(self.modelname, self.k))
60 |
61 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \
62 | [2**t for t in range(-2, 4, 1)]
63 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=1111)
64 | innerskf = StratifiedKFold(n_splits=self.k, shuffle=True,
65 | random_state=1111)
66 | count = 0
67 | for train_idx, test_idx in skf.split(self.X, self.y):
68 | count += 1
69 | X_train, X_test = self.X[train_idx], self.X[test_idx]
70 | y_train, y_test = self.y[train_idx], self.y[test_idx]
71 | scores = []
72 | for reg in regs:
73 | regscores = []
74 | for inner_train_idx, inner_test_idx in innerskf.split(X_train, y_train):
75 | X_in_train, X_in_test = X_train[inner_train_idx], X_train[inner_test_idx]
76 | y_in_train, y_in_test = y_train[inner_train_idx], y_train[inner_test_idx]
77 | if self.usepytorch:
78 | clf = MLP(self.classifier_config, inputdim=self.featdim,
79 | nclasses=self.nclasses, l2reg=reg,
80 | seed=self.seed)
81 | clf.fit(X_in_train, y_in_train,
82 | validation_data=(X_in_test, y_in_test))
83 | else:
84 | clf = LogisticRegression(C=reg, random_state=self.seed)
85 | clf.fit(X_in_train, y_in_train)
86 | regscores.append(clf.score(X_in_test, y_in_test))
87 | scores.append(round(100*np.mean(regscores), 2))
88 | optreg = regs[np.argmax(scores)]
89 | logging.info('Best param found at split {0}: l2reg = {1} \
90 | with score {2}'.format(count, optreg, np.max(scores)))
91 | self.devresults.append(np.max(scores))
92 |
93 | if self.usepytorch:
94 | clf = MLP(self.classifier_config, inputdim=self.featdim,
95 | nclasses=self.nclasses, l2reg=optreg,
96 | seed=self.seed)
97 |
98 | clf.fit(X_train, y_train, validation_split=0.05)
99 | else:
100 | clf = LogisticRegression(C=optreg, random_state=self.seed)
101 | clf.fit(X_train, y_train)
102 |
103 | self.testresults.append(round(100*clf.score(X_test, y_test), 2))
104 |
105 | devaccuracy = round(np.mean(self.devresults), 2)
106 | testaccuracy = round(np.mean(self.testresults), 2)
107 | return devaccuracy, testaccuracy
108 |
109 |
110 | class KFoldClassifier(object):
111 | """
112 | (train, test) split classifier : cross-validation on train.
113 | """
114 | def __init__(self, train, test, config):
115 | self.train = train
116 | self.test = test
117 | self.featdim = self.train['X'].shape[1]
118 | self.nclasses = config['nclasses']
119 | self.seed = config['seed']
120 | self.usepytorch = config['usepytorch']
121 | self.classifier_config = config['classifier']
122 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch)
123 |
124 | self.k = 5 if 'kfold' not in config else config['kfold']
125 |
126 | def run(self):
127 | # cross-validation
128 | logging.info('Training {0} with {1}-fold cross-validation'
129 | .format(self.modelname, self.k))
130 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \
131 | [2**t for t in range(-1, 6, 1)]
132 | skf = StratifiedKFold(n_splits=self.k, shuffle=True,
133 | random_state=self.seed)
134 | scores = []
135 |
136 | for reg in regs:
137 | scanscores = []
138 | for train_idx, test_idx in skf.split(self.train['X'],
139 | self.train['y']):
140 | # Split data
141 | X_train, y_train = self.train['X'][train_idx], self.train['y'][train_idx]
142 |
143 | X_test, y_test = self.train['X'][test_idx], self.train['y'][test_idx]
144 |
145 | # Train classifier
146 | if self.usepytorch:
147 | clf = MLP(self.classifier_config, inputdim=self.featdim,
148 | nclasses=self.nclasses, l2reg=reg,
149 | seed=self.seed)
150 | clf.fit(X_train, y_train, validation_data=(X_test, y_test))
151 | else:
152 | clf = LogisticRegression(C=reg, random_state=self.seed)
153 | clf.fit(X_train, y_train)
154 | score = clf.score(X_test, y_test)
155 | scanscores.append(score)
156 | # Append mean score
157 | scores.append(round(100*np.mean(scanscores), 2))
158 |
159 | # evaluation
160 | logging.info([('reg:' + str(regs[idx]), scores[idx])
161 | for idx in range(len(scores))])
162 | optreg = regs[np.argmax(scores)]
163 | devaccuracy = np.max(scores)
164 | logging.info('Cross-validation : best param found is reg = {0} \
165 | with score {1}'.format(optreg, devaccuracy))
166 |
167 | logging.info('Evaluating...')
168 | if self.usepytorch:
169 | clf = MLP(self.classifier_config, inputdim=self.featdim,
170 | nclasses=self.nclasses, l2reg=optreg,
171 | seed=self.seed)
172 | clf.fit(self.train['X'], self.train['y'], validation_split=0.05)
173 | else:
174 | clf = LogisticRegression(C=optreg, random_state=self.seed)
175 | clf.fit(self.train['X'], self.train['y'])
176 | yhat = clf.predict(self.test['X'])
177 |
178 | testaccuracy = clf.score(self.test['X'], self.test['y'])
179 | testaccuracy = round(100*testaccuracy, 2)
180 |
181 | return devaccuracy, testaccuracy, yhat
182 |
183 |
184 | class SplitClassifier(object):
185 | """
186 | (train, valid, test) split classifier.
187 | """
188 | def __init__(self, X, y, config):
189 | self.X = X
190 | self.y = y
191 | self.nclasses = config['nclasses']
192 | self.featdim = self.X['train'].shape[1]
193 | self.seed = config['seed']
194 | self.usepytorch = config['usepytorch']
195 | self.classifier_config = config['classifier']
196 | self.cudaEfficient = False if 'cudaEfficient' not in config else \
197 | config['cudaEfficient']
198 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch)
199 | self.noreg = False if 'noreg' not in config else config['noreg']
200 | self.config = config
201 |
202 | def run(self):
203 | logging.info('Training {0} with standard validation..'
204 | .format(self.modelname))
205 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \
206 | [2**t for t in range(-2, 4, 1)]
207 | if self.noreg:
208 | regs = [1e-9 if self.usepytorch else 1e9]
209 | scores = []
210 | for reg in regs:
211 | if self.usepytorch:
212 | clf = MLP(self.classifier_config, inputdim=self.featdim,
213 | nclasses=self.nclasses, l2reg=reg,
214 | seed=self.seed, cudaEfficient=self.cudaEfficient)
215 |
216 | # TODO: Find a hack for reducing nb epoches in SNLI
217 | clf.fit(self.X['train'], self.y['train'],
218 | validation_data=(self.X['valid'], self.y['valid']))
219 | else:
220 | clf = LogisticRegression(C=reg, random_state=self.seed)
221 | clf.fit(self.X['train'], self.y['train'])
222 | scores.append(round(100*clf.score(self.X['valid'],
223 | self.y['valid']), 2))
224 | logging.info([('reg:'+str(regs[idx]), scores[idx])
225 | for idx in range(len(scores))])
226 | optreg = regs[np.argmax(scores)]
227 | devaccuracy = np.max(scores)
228 | logging.info('Validation : best param found is reg = {0} with score \
229 | {1}'.format(optreg, devaccuracy))
230 | clf = LogisticRegression(C=optreg, random_state=self.seed)
231 | logging.info('Evaluating...')
232 | if self.usepytorch:
233 | clf = MLP(self.classifier_config, inputdim=self.featdim,
234 | nclasses=self.nclasses, l2reg=optreg,
235 | seed=self.seed, cudaEfficient=self.cudaEfficient)
236 |
237 | # TODO: Find a hack for reducing nb epoches in SNLI
238 | clf.fit(self.X['train'], self.y['train'],
239 | validation_data=(self.X['valid'], self.y['valid']))
240 | else:
241 | clf = LogisticRegression(C=optreg, random_state=self.seed)
242 | clf.fit(self.X['train'], self.y['train'])
243 |
244 | testaccuracy = clf.score(self.X['test'], self.y['test'])
245 | testaccuracy = round(100*testaccuracy, 2)
246 | return devaccuracy, testaccuracy
247 |
--------------------------------------------------------------------------------
/sscl/tool.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from tqdm import tqdm
3 | import numpy as np
4 | from numpy import ndarray
5 | import torch
6 | from torch import Tensor, device
7 | import transformers
8 | from transformers import AutoModel, AutoTokenizer
9 | from sklearn.metrics.pairwise import cosine_similarity
10 | from sklearn.preprocessing import normalize
11 | from typing import List, Dict, Tuple, Type, Union
12 |
13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
14 | level=logging.INFO)
15 | logger = logging.getLogger(__name__)
16 |
17 | class SSCL(object):
18 | """
19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE.
20 | """
21 | def __init__(self, model_name_or_path: str,
22 | device: str = None,
23 | num_cells: int = 100,
24 | num_cells_in_search: int = 10,
25 | pooler = None):
26 |
27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
28 | self.model = AutoModel.from_pretrained(model_name_or_path)
29 | if device is None:
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 | self.device = device
32 |
33 | self.index = None
34 | self.is_faiss_index = False
35 | self.num_cells = num_cells
36 | self.num_cells_in_search = num_cells_in_search
37 |
38 | if pooler is not None:
39 | self.pooler = pooler
40 | elif "unsup" in model_name_or_path:
41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.")
42 | self.pooler = "cls_before_pooler"
43 | else:
44 | self.pooler = "cls"
45 |
46 | def encode(self, sentence: Union[str, List[str]],
47 | device: str = None,
48 | return_numpy: bool = False,
49 | normalize_to_unit: bool = True,
50 | keepdim: bool = False,
51 | batch_size: int = 64,
52 | max_length: int = 128) -> Union[ndarray, Tensor]:
53 |
54 | target_device = self.device if device is None else device
55 | self.model = self.model.to(target_device)
56 |
57 | single_sentence = False
58 | if isinstance(sentence, str):
59 | sentence = [sentence]
60 | single_sentence = True
61 |
62 | embedding_list = []
63 | with torch.no_grad():
64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
65 | for batch_id in tqdm(range(total_batch)):
66 | inputs = self.tokenizer(
67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size],
68 | padding=True,
69 | truncation=True,
70 | max_length=max_length,
71 | return_tensors="pt"
72 | )
73 | inputs = {k: v.to(target_device) for k, v in inputs.items()}
74 | outputs = self.model(**inputs, return_dict=True)
75 | if self.pooler == "cls":
76 | embeddings = outputs.pooler_output
77 | elif self.pooler == "cls_before_pooler":
78 | embeddings = outputs.last_hidden_state[:, 0]
79 | else:
80 | raise NotImplementedError
81 | if normalize_to_unit:
82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
83 | embedding_list.append(embeddings.cpu())
84 | embeddings = torch.cat(embedding_list, 0)
85 |
86 | if single_sentence and not keepdim:
87 | embeddings = embeddings[0]
88 |
89 | if return_numpy and not isinstance(embeddings, ndarray):
90 | return embeddings.numpy()
91 | return embeddings
92 |
93 | def similarity(self, queries: Union[str, List[str]],
94 | keys: Union[str, List[str], ndarray],
95 | device: str = None) -> Union[float, ndarray]:
96 |
97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries
98 |
99 | if not isinstance(keys, ndarray):
100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys
101 | else:
102 | key_vecs = keys
103 |
104 | # check whether N == 1 or M == 1
105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1
106 | if single_query:
107 | query_vecs = query_vecs.reshape(1, -1)
108 | if single_key:
109 | key_vecs = key_vecs.reshape(1, -1)
110 |
111 | # returns an N*M similarity array
112 | similarities = cosine_similarity(query_vecs, key_vecs)
113 |
114 | if single_query:
115 | similarities = similarities[0]
116 | if single_key:
117 | similarities = float(similarities[0])
118 |
119 | return similarities
120 |
121 | def build_index(self, sentences_or_file_path: Union[str, List[str]],
122 | use_faiss: bool = None,
123 | faiss_fast: bool = False,
124 | device: str = None,
125 | batch_size: int = 64):
126 |
127 | if use_faiss is None or use_faiss:
128 | try:
129 | import faiss
130 | assert hasattr(faiss, "IndexFlatIP")
131 | use_faiss = True
132 | except:
133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.")
134 | use_faiss = False
135 |
136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences
137 | if isinstance(sentences_or_file_path, str):
138 | sentences = []
139 | with open(sentences_or_file_path, "r") as f:
140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path))
141 | for line in tqdm(f):
142 | sentences.append(line.rstrip())
143 | sentences_or_file_path = sentences
144 |
145 | logger.info("Encoding embeddings for sentences...")
146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True)
147 |
148 | logger.info("Building index...")
149 | self.index = {"sentences": sentences_or_file_path}
150 |
151 | if use_faiss:
152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1])
153 | if faiss_fast:
154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path)))
155 | else:
156 | index = quantizer
157 |
158 | if (self.device == "cuda" and device != "cpu") or device == "cuda":
159 | if hasattr(faiss, "StandardGpuResources"):
160 | logger.info("Use GPU-version faiss")
161 | res = faiss.StandardGpuResources()
162 | res.setTempMemory(20 * 1024 * 1024 * 1024)
163 | index = faiss.index_cpu_to_gpu(res, 0, index)
164 | else:
165 | logger.info("Use CPU-version faiss")
166 | else:
167 | logger.info("Use CPU-version faiss")
168 |
169 | if faiss_fast:
170 | index.train(embeddings.astype(np.float32))
171 | index.add(embeddings.astype(np.float32))
172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path))
173 | self.is_faiss_index = True
174 | else:
175 | index = embeddings
176 | self.is_faiss_index = False
177 | self.index["index"] = index
178 | logger.info("Finished")
179 |
180 | def search(self, queries: Union[str, List[str]],
181 | device: str = None,
182 | threshold: float = 0.6,
183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
184 |
185 | if not self.is_faiss_index:
186 | if isinstance(queries, list):
187 | combined_results = []
188 | for query in queries:
189 | results = self.search(query, device)
190 | combined_results.append(results)
191 | return combined_results
192 |
193 | similarities = self.similarity(queries, self.index["index"]).tolist()
194 | id_and_score = []
195 | for i, s in enumerate(similarities):
196 | if s >= threshold:
197 | id_and_score.append((i, s))
198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k]
199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score]
200 | return results
201 | else:
202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True)
203 |
204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k)
205 |
206 | def pack_single_result(dist, idx):
207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold]
208 | return results
209 |
210 | if isinstance(queries, list):
211 | combined_results = []
212 | for i in range(len(queries)):
213 | results = pack_single_result(distance[i], idx[i])
214 | combined_results.append(results)
215 | return combined_results
216 | else:
217 | return pack_single_result(distance[0], idx[0])
218 |
219 | if __name__=="__main__":
220 | example_sentences = [
221 | 'An animal is biting a persons finger.',
222 | 'A woman is reading.',
223 | 'A man is lifting weights in a garage.',
224 | 'A man plays the violin.',
225 | 'A man is eating food.',
226 | 'A man plays the piano.',
227 | 'A panda is climbing.',
228 | 'A man plays a guitar.',
229 | 'A woman is slicing a meat.',
230 | 'A woman is taking a picture.'
231 | ]
232 | example_queries = [
233 | 'A man is playing music.',
234 | 'A woman is making a photo.'
235 | ]
236 |
237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased"
238 | simcse = SSCL(model_name)
239 |
240 | print("\n=========Calculate cosine similarities between queries and sentences============\n")
241 | similarities = simcse.similarity(example_queries, example_sentences)
242 | print(similarities)
243 |
244 | print("\n=========Naive brute force search============\n")
245 | simcse.build_index(example_sentences, use_faiss=False)
246 | results = simcse.search(example_queries)
247 | for i, result in enumerate(results):
248 | print("Retrieval results for query: {}".format(example_queries[i]))
249 | for sentence, score in result:
250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score))
251 | print("")
252 |
253 | print("\n=========Search with Faiss backend============\n")
254 | simcse.build_index(example_sentences, use_faiss=True)
255 | results = simcse.search(example_queries)
256 | for i, result in enumerate(results):
257 | print("Retrieval results for query: {}".format(example_queries[i]))
258 | for sentence, score in result:
259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score))
260 | print("")
261 |
262 |
--------------------------------------------------------------------------------
/sscl/.ipynb_checkpoints/tool-checkpoint.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from tqdm import tqdm
3 | import numpy as np
4 | from numpy import ndarray
5 | import torch
6 | from torch import Tensor, device
7 | import transformers
8 | from transformers import AutoModel, AutoTokenizer
9 | from sklearn.metrics.pairwise import cosine_similarity
10 | from sklearn.preprocessing import normalize
11 | from typing import List, Dict, Tuple, Type, Union
12 |
13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
14 | level=logging.INFO)
15 | logger = logging.getLogger(__name__)
16 |
17 | class SimCSE(object):
18 | """
19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE.
20 | """
21 | def __init__(self, model_name_or_path: str,
22 | device: str = None,
23 | num_cells: int = 100,
24 | num_cells_in_search: int = 10,
25 | pooler = None):
26 |
27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
28 | self.model = AutoModel.from_pretrained(model_name_or_path)
29 | if device is None:
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 | self.device = device
32 |
33 | self.index = None
34 | self.is_faiss_index = False
35 | self.num_cells = num_cells
36 | self.num_cells_in_search = num_cells_in_search
37 |
38 | if pooler is not None:
39 | self.pooler = pooler
40 | elif "unsup" in model_name_or_path:
41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.")
42 | self.pooler = "cls_before_pooler"
43 | else:
44 | self.pooler = "cls"
45 |
46 | def encode(self, sentence: Union[str, List[str]],
47 | device: str = None,
48 | return_numpy: bool = False,
49 | normalize_to_unit: bool = True,
50 | keepdim: bool = False,
51 | batch_size: int = 64,
52 | max_length: int = 128) -> Union[ndarray, Tensor]:
53 |
54 | target_device = self.device if device is None else device
55 | self.model = self.model.to(target_device)
56 |
57 | single_sentence = False
58 | if isinstance(sentence, str):
59 | sentence = [sentence]
60 | single_sentence = True
61 |
62 | embedding_list = []
63 | with torch.no_grad():
64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
65 | for batch_id in tqdm(range(total_batch)):
66 | inputs = self.tokenizer(
67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size],
68 | padding=True,
69 | truncation=True,
70 | max_length=max_length,
71 | return_tensors="pt"
72 | )
73 | inputs = {k: v.to(target_device) for k, v in inputs.items()}
74 | outputs = self.model(**inputs, return_dict=True)
75 | if self.pooler == "cls":
76 | embeddings = outputs.pooler_output
77 | elif self.pooler == "cls_before_pooler":
78 | embeddings = outputs.last_hidden_state[:, 0]
79 | else:
80 | raise NotImplementedError
81 | if normalize_to_unit:
82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
83 | embedding_list.append(embeddings.cpu())
84 | embeddings = torch.cat(embedding_list, 0)
85 |
86 | if single_sentence and not keepdim:
87 | embeddings = embeddings[0]
88 |
89 | if return_numpy and not isinstance(embeddings, ndarray):
90 | return embeddings.numpy()
91 | return embeddings
92 |
93 | def similarity(self, queries: Union[str, List[str]],
94 | keys: Union[str, List[str], ndarray],
95 | device: str = None) -> Union[float, ndarray]:
96 |
97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries
98 |
99 | if not isinstance(keys, ndarray):
100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys
101 | else:
102 | key_vecs = keys
103 |
104 | # check whether N == 1 or M == 1
105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1
106 | if single_query:
107 | query_vecs = query_vecs.reshape(1, -1)
108 | if single_key:
109 | key_vecs = key_vecs.reshape(1, -1)
110 |
111 | # returns an N*M similarity array
112 | similarities = cosine_similarity(query_vecs, key_vecs)
113 |
114 | if single_query:
115 | similarities = similarities[0]
116 | if single_key:
117 | similarities = float(similarities[0])
118 |
119 | return similarities
120 |
121 | def build_index(self, sentences_or_file_path: Union[str, List[str]],
122 | use_faiss: bool = None,
123 | faiss_fast: bool = False,
124 | device: str = None,
125 | batch_size: int = 64):
126 |
127 | if use_faiss is None or use_faiss:
128 | try:
129 | import faiss
130 | assert hasattr(faiss, "IndexFlatIP")
131 | use_faiss = True
132 | except:
133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.")
134 | use_faiss = False
135 |
136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences
137 | if isinstance(sentences_or_file_path, str):
138 | sentences = []
139 | with open(sentences_or_file_path, "r") as f:
140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path))
141 | for line in tqdm(f):
142 | sentences.append(line.rstrip())
143 | sentences_or_file_path = sentences
144 |
145 | logger.info("Encoding embeddings for sentences...")
146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True)
147 |
148 | logger.info("Building index...")
149 | self.index = {"sentences": sentences_or_file_path}
150 |
151 | if use_faiss:
152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1])
153 | if faiss_fast:
154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path)))
155 | else:
156 | index = quantizer
157 |
158 | if (self.device == "cuda" and device != "cpu") or device == "cuda":
159 | if hasattr(faiss, "StandardGpuResources"):
160 | logger.info("Use GPU-version faiss")
161 | res = faiss.StandardGpuResources()
162 | res.setTempMemory(20 * 1024 * 1024 * 1024)
163 | index = faiss.index_cpu_to_gpu(res, 0, index)
164 | else:
165 | logger.info("Use CPU-version faiss")
166 | else:
167 | logger.info("Use CPU-version faiss")
168 |
169 | if faiss_fast:
170 | index.train(embeddings.astype(np.float32))
171 | index.add(embeddings.astype(np.float32))
172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path))
173 | self.is_faiss_index = True
174 | else:
175 | index = embeddings
176 | self.is_faiss_index = False
177 | self.index["index"] = index
178 | logger.info("Finished")
179 |
180 | def search(self, queries: Union[str, List[str]],
181 | device: str = None,
182 | threshold: float = 0.6,
183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
184 |
185 | if not self.is_faiss_index:
186 | if isinstance(queries, list):
187 | combined_results = []
188 | for query in queries:
189 | results = self.search(query, device)
190 | combined_results.append(results)
191 | return combined_results
192 |
193 | similarities = self.similarity(queries, self.index["index"]).tolist()
194 | id_and_score = []
195 | for i, s in enumerate(similarities):
196 | if s >= threshold:
197 | id_and_score.append((i, s))
198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k]
199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score]
200 | return results
201 | else:
202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True)
203 |
204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k)
205 |
206 | def pack_single_result(dist, idx):
207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold]
208 | return results
209 |
210 | if isinstance(queries, list):
211 | combined_results = []
212 | for i in range(len(queries)):
213 | results = pack_single_result(distance[i], idx[i])
214 | combined_results.append(results)
215 | return combined_results
216 | else:
217 | return pack_single_result(distance[0], idx[0])
218 |
219 | if __name__=="__main__":
220 | example_sentences = [
221 | 'An animal is biting a persons finger.',
222 | 'A woman is reading.',
223 | 'A man is lifting weights in a garage.',
224 | 'A man plays the violin.',
225 | 'A man is eating food.',
226 | 'A man plays the piano.',
227 | 'A panda is climbing.',
228 | 'A man plays a guitar.',
229 | 'A woman is slicing a meat.',
230 | 'A woman is taking a picture.'
231 | ]
232 | example_queries = [
233 | 'A man is playing music.',
234 | 'A woman is making a photo.'
235 | ]
236 |
237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased"
238 | simcse = SimCSE(model_name)
239 |
240 | print("\n=========Calculate cosine similarities between queries and sentences============\n")
241 | similarities = simcse.similarity(example_queries, example_sentences)
242 | print(similarities)
243 |
244 | print("\n=========Naive brute force search============\n")
245 | simcse.build_index(example_sentences, use_faiss=False)
246 | results = simcse.search(example_queries)
247 | for i, result in enumerate(results):
248 | print("Retrieval results for query: {}".format(example_queries[i]))
249 | for sentence, score in result:
250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score))
251 | print("")
252 |
253 | print("\n=========Search with Faiss backend============\n")
254 | simcse.build_index(example_sentences, use_faiss=True)
255 | results = simcse.search(example_queries)
256 | for i, result in enumerate(results):
257 | print("Retrieval results for query: {}".format(example_queries[i]))
258 | for sentence, score in result:
259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score))
260 | print("")
261 |
262 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/SentEval/README.md:
--------------------------------------------------------------------------------
1 | Our modification to SentEval:
2 |
3 | 1. Add the `all` setting to all STS tasks.
4 | 2. Change STS-B and SICK-R to not use an additional regressor.
5 |
6 | # SentEval: evaluation toolkit for sentence embeddings
7 |
8 | SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. **SentEval currently includes 17 downstream tasks**. We also include a suite of **10 probing tasks** which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations.
9 |
10 |
11 | **(04/22) SentEval new tasks: Added probing tasks for evaluating what linguistic properties are encoded in sentence embeddings**
12 |
13 | **(10/04) SentEval example scripts for three sentence encoders: [SkipThought-LN](https://github.com/ryankiros/layer-norm#skip-thoughts)/[GenSen](https://github.com/Maluuba/gensen)/[Google-USE](https://tfhub.dev/google/universal-sentence-encoder/1)**
14 |
15 | ## Dependencies
16 |
17 | This code is written in python. The dependencies are:
18 |
19 | * Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](http://www.scipy.org/)
20 | * [Pytorch](http://pytorch.org/)>=0.4
21 | * [scikit-learn](http://scikit-learn.org/stable/index.html)>=0.18.0
22 |
23 | ## Transfer tasks
24 |
25 | ### Downstream tasks
26 | SentEval allows you to evaluate your sentence embeddings as features for the following *downstream* tasks:
27 |
28 | | Task | Type | #train | #test | needs_train | set_classifier |
29 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:|
30 | | [MR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | movie review | 11k | 11k | 1 | 1 |
31 | | [CR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | product review | 4k | 4k | 1 | 1 |
32 | | [SUBJ](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | subjectivity status | 10k | 10k | 1 | 1 |
33 | | [MPQA](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | opinion-polarity | 11k | 11k | 1 | 1 |
34 | | [SST](https://nlp.stanford.edu/sentiment/index.html) | binary sentiment analysis | 67k | 1.8k | 1 | 1 |
35 | | **[SST](https://nlp.stanford.edu/sentiment/index.html)** | **fine-grained sentiment analysis** | 8.5k | 2.2k | 1 | 1 |
36 | | [TREC](http://cogcomp.cs.illinois.edu/Data/QA/QC/) | question-type classification | 6k | 0.5k | 1 | 1 |
37 | | [SICK-E](http://clic.cimec.unitn.it/composes/sick.html) | natural language inference | 4.5k | 4.9k | 1 | 1 |
38 | | [SNLI](https://nlp.stanford.edu/projects/snli/) | natural language inference | 550k | 9.8k | 1 | 1 |
39 | | [MRPC](https://aclweb.org/aclwiki/Paraphrase_Identification_(State_of_the_art)) | paraphrase detection | 4.1k | 1.7k | 1 | 1 |
40 | | [STS 2012](https://www.cs.york.ac.uk/semeval-2012/task6/) | semantic textual similarity | N/A | 3.1k | 0 | 0 |
41 | | [STS 2013](http://ixa2.si.ehu.es/sts/) | semantic textual similarity | N/A | 1.5k | 0 | 0 |
42 | | [STS 2014](http://alt.qcri.org/semeval2014/task10/) | semantic textual similarity | N/A | 3.7k | 0 | 0 |
43 | | [STS 2015](http://alt.qcri.org/semeval2015/task2/) | semantic textual similarity | N/A | 8.5k | 0 | 0 |
44 | | [STS 2016](http://alt.qcri.org/semeval2016/task1/) | semantic textual similarity | N/A | 9.2k | 0 | 0 |
45 | | [STS B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) | semantic textual similarity | 5.7k | 1.4k | 1 | 0 |
46 | | [SICK-R](http://clic.cimec.unitn.it/composes/sick.html) | semantic textual similarity | 4.5k | 4.9k | 1 | 0 |
47 | | [COCO](http://mscoco.org/) | image-caption retrieval | 567k | 5*1k | 1 | 0 |
48 |
49 | where **needs_train** means a model with parameters is learned on top of the sentence embeddings, and **set_classifier** means you can define the parameters of the classifier in the case of a classification task (see below).
50 |
51 | Note: COCO comes with ResNet-101 2048d image embeddings. [More details on the tasks.](https://arxiv.org/pdf/1705.02364.pdf)
52 |
53 | ### Probing tasks
54 | SentEval also includes a series of [*probing* tasks](https://github.com/facebookresearch/SentEval/tree/master/data/probing) to evaluate what linguistic properties are encoded in your sentence embeddings:
55 |
56 | | Task | Type | #train | #test | needs_train | set_classifier |
57 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:|
58 | | [SentLen](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Length prediction | 100k | 10k | 1 | 1 |
59 | | [WC](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word Content analysis | 100k | 10k | 1 | 1 |
60 | | [TreeDepth](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Tree depth prediction | 100k | 10k | 1 | 1 |
61 | | [TopConst](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Top Constituents prediction | 100k | 10k | 1 | 1 |
62 | | [BShift](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word order analysis | 100k | 10k | 1 | 1 |
63 | | [Tense](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Verb tense prediction | 100k | 10k | 1 | 1 |
64 | | [SubjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Subject number prediction | 100k | 10k | 1 | 1 |
65 | | [ObjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Object number prediction | 100k | 10k | 1 | 1 |
66 | | [SOMO](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Semantic odd man out | 100k | 10k | 1 | 1 |
67 | | [CoordInv](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Coordination Inversion | 100k | 10k | 1 | 1 |
68 |
69 | ## Download datasets
70 | To get all the transfer tasks datasets, run (in data/downstream/):
71 | ```bash
72 | ./get_transfer_data.bash
73 | ```
74 | This will automatically download and preprocess the downstream datasets, and store them in data/downstream (warning: for MacOS users, you may have to use p7zip instead of unzip). The probing tasks are already in data/probing by default.
75 |
76 | ## How to use SentEval: examples
77 |
78 | ### examples/bow.py
79 |
80 | In examples/bow.py, we evaluate the quality of the average of word embeddings.
81 |
82 | To download state-of-the-art fastText embeddings:
83 |
84 | ```bash
85 | curl -Lo glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip
86 | curl -Lo crawl-300d-2M.vec.zip https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip
87 | ```
88 |
89 | To reproduce the results for bag-of-vectors, run (in examples/):
90 | ```bash
91 | python bow.py
92 | ```
93 |
94 | As required by SentEval, this script implements two functions: **prepare** (optional) and **batcher** (required) that turn text sentences into sentence embeddings. Then SentEval takes care of the evaluation on the transfer tasks using the embeddings as features.
95 |
96 | ### examples/infersent.py
97 |
98 | To get the **[InferSent](https://www.github.com/facebookresearch/InferSent)** model and reproduce our results, download our best models and run infersent.py (in examples/):
99 | ```bash
100 | curl -Lo examples/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl
101 | curl -Lo examples/infersent2.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent2.pkl
102 | ```
103 |
104 | ### examples/skipthought.py - examples/gensen.py - examples/googleuse.py
105 |
106 | We also provide example scripts for three other encoders:
107 |
108 | * [SkipThought with Layer-Normalization](https://github.com/ryankiros/layer-norm#skip-thoughts) in Theano
109 | * [GenSen encoder](https://github.com/Maluuba/gensen) in Pytorch
110 | * [Google encoder](https://tfhub.dev/google/universal-sentence-encoder/1) in TensorFlow
111 |
112 | Note that for SkipThought and GenSen, following the steps of the associated githubs is necessary.
113 | The Google encoder script should work as-is.
114 |
115 | ## How to use SentEval
116 |
117 | To evaluate your sentence embeddings, SentEval requires that you implement two functions:
118 |
119 | 1. **prepare** (sees the whole dataset of each task and can thus construct the word vocabulary, the dictionary of word vectors etc)
120 | 2. **batcher** (transforms a batch of text sentences into sentence embeddings)
121 |
122 |
123 | ### 1.) prepare(params, samples) (optional)
124 |
125 | *batcher* only sees one batch at a time while the *samples* argument of *prepare* contains all the sentences of a task.
126 |
127 | ```
128 | prepare(params, samples)
129 | ```
130 | * *params*: senteval parameters.
131 | * *samples*: list of all sentences from the tranfer task.
132 | * *output*: No output. Arguments stored in "params" can further be used by *batcher*.
133 |
134 | *Example*: in bow.py, prepare is is used to build the vocabulary of words and construct the "params.word_vect* dictionary of word vectors.
135 |
136 |
137 | ### 2.) batcher(params, batch)
138 | ```
139 | batcher(params, batch)
140 | ```
141 | * *params*: senteval parameters.
142 | * *batch*: numpy array of text sentences (of size params.batch_size)
143 | * *output*: numpy array of sentence embeddings (of size params.batch_size)
144 |
145 | *Example*: in bow.py, batcher is used to compute the mean of the word vectors for each sentence in the batch using params.word_vec. Use your own encoder in that function to encode sentences.
146 |
147 | ### 3.) evaluation on transfer tasks
148 |
149 | After having implemented the batch and prepare function for your own sentence encoder,
150 |
151 | 1) to perform the actual evaluation, first import senteval and set its parameters:
152 | ```python
153 | import senteval
154 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
155 | ```
156 |
157 | 2) (optional) set the parameters of the classifier (when applicable):
158 | ```python
159 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
160 | 'tenacity': 5, 'epoch_size': 4}
161 | ```
162 | You can choose **nhid=0** (Logistic Regression) or **nhid>0** (MLP) and define the parameters for training.
163 |
164 | 3) Create an instance of the class SE:
165 | ```python
166 | se = senteval.engine.SE(params, batcher, prepare)
167 | ```
168 |
169 | 4) define the set of transfer tasks and run the evaluation:
170 | ```python
171 | transfer_tasks = ['MR', 'SICKEntailment', 'STS14', 'STSBenchmark']
172 | results = se.eval(transfer_tasks)
173 | ```
174 | The current list of available tasks is:
175 | ```python
176 | ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI',
177 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 'ImageCaptionRetrieval',
178 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16',
179 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense',
180 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion']
181 | ```
182 |
183 | ## SentEval parameters
184 | Global parameters of SentEval:
185 | ```bash
186 | # senteval parameters
187 | task_path # path to SentEval datasets (required)
188 | seed # seed
189 | usepytorch # use cuda-pytorch (else scikit-learn) where possible
190 | kfold # k-fold validation for MR/CR/SUB/MPQA.
191 | ```
192 |
193 | Parameters of the classifier:
194 | ```bash
195 | nhid: # number of hidden units (0: Logistic Regression, >0: MLP); Default nonlinearity: Tanh
196 | optim: # optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
197 | tenacity: # how many times dev acc does not increase before training stops
198 | epoch_size: # each epoch corresponds to epoch_size pass on the train set
199 | max_epoch: # max number of epoches
200 | dropout: # dropout for MLP
201 | ```
202 |
203 | Note that to get a proxy of the results while **dramatically reducing computation time**,
204 | we suggest the **prototyping config**:
205 | ```python
206 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
207 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
208 | 'tenacity': 3, 'epoch_size': 2}
209 | ```
210 | which will results in a 5 times speedup for classification tasks.
211 |
212 | To produce results that are **comparable to the literature**, use the **default config**:
213 | ```python
214 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
215 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
216 | 'tenacity': 5, 'epoch_size': 4}
217 | ```
218 | which takes longer but will produce better and comparable results.
219 |
220 | For probing tasks, we used an MLP with a Sigmoid nonlinearity and and tuned the nhid (in [50, 100, 200]) and dropout (in [0.0, 0.1, 0.2]) on the dev set.
221 |
222 | ## References
223 |
224 | Please considering citing [[1]](https://arxiv.org/abs/1803.05449) if using this code for evaluating sentence embedding methods.
225 |
226 | ### SentEval: An Evaluation Toolkit for Universal Sentence Representations
227 |
228 | [1] A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449)
229 |
230 | ```
231 | @article{conneau2018senteval,
232 | title={SentEval: An Evaluation Toolkit for Universal Sentence Representations},
233 | author={Conneau, Alexis and Kiela, Douwe},
234 | journal={arXiv preprint arXiv:1803.05449},
235 | year={2018}
236 | }
237 | ```
238 |
239 | Contact: [aconneau@fb.com](mailto:aconneau@fb.com), [dkiela@fb.com](mailto:dkiela@fb.com)
240 |
241 | ### Related work
242 | * [J. R Kiros, Y. Zhu, R. Salakhutdinov, R. S. Zemel, A. Torralba, R. Urtasun, S. Fidler - SkipThought Vectors, NIPS 2015](https://arxiv.org/abs/1506.06726)
243 | * [S. Arora, Y. Liang, T. Ma - A Simple but Tough-to-Beat Baseline for Sentence Embeddings, ICLR 2017](https://openreview.net/pdf?id=SyK00v5xx)
244 | * [Y. Adi, E. Kermany, Y. Belinkov, O. Lavi, Y. Goldberg - Fine-grained analysis of sentence embeddings using auxiliary prediction tasks, ICLR 2017](https://arxiv.org/abs/1608.04207)
245 | * [A. Conneau, D. Kiela, L. Barrault, H. Schwenk, A. Bordes - Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, EMNLP 2017](https://arxiv.org/abs/1705.02364)
246 | * [S. Subramanian, A. Trischler, Y. Bengio, C. J Pal - Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, ICLR 2018](https://arxiv.org/abs/1804.00079)
247 | * [A. Nie, E. D. Bennett, N. D. Goodman - DisSent: Sentence Representation Learning from Explicit Discourse Relations, 2018](https://arxiv.org/abs/1710.04334)
248 | * [D. Cer, Y. Yang, S. Kong, N. Hua, N. Limtiaco, R. St. John, N. Constant, M. Guajardo-Cespedes, S. Yuan, C. Tar, Y. Sung, B. Strope, R. Kurzweil - Universal Sentence Encoder, 2018](https://arxiv.org/abs/1803.11175)
249 | * [A. Conneau, G. Kruszewski, G. Lample, L. Barrault, M. Baroni - What you can cram into a single vector: Probing sentence embeddings for linguistic properties, ACL 2018](https://arxiv.org/abs/1805.01070)
250 |
--------------------------------------------------------------------------------