├── SentEval ├── senteval │ ├── tools │ │ ├── __init__.py │ │ ├── relatedness.py │ │ ├── classifier.py │ │ └── validation.py │ ├── __init__.py │ ├── utils.py │ ├── trec.py │ ├── binary.py │ ├── sst.py │ ├── mrpc.py │ ├── snli.py │ ├── rank.py │ ├── engine.py │ ├── probing.py │ ├── sick.py │ ├── sts.py │ └── .ipynb_checkpoints │ │ └── sts-checkpoint.py ├── .gitignore ├── setup.py ├── LICENSE ├── examples │ ├── skipthought.py │ ├── googleuse.py │ ├── gensen.py │ ├── infersent.py │ ├── bow.py │ └── models.py └── README.md ├── sscl ├── __init__.py ├── __pycache__ │ ├── tool.cpython-36.pyc │ ├── tool.cpython-39.pyc │ ├── models.cpython-36.pyc │ ├── models.cpython-39.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── trainers.cpython-36.pyc │ └── trainers.cpython-39.pyc ├── tool.py └── .ipynb_checkpoints │ └── tool-checkpoint.py ├── transfer.sh ├── figure └── overview.png ├── demo ├── static │ ├── files │ │ ├── plogo.png │ │ ├── favicon.ico │ │ └── style.css │ ├── example_query.txt │ └── index.html ├── run_demo_example.sh ├── README.md ├── gradiodemo.py └── flaskdemo.py ├── data ├── download_nli.sh └── download_wiki.sh ├── test.sh ├── requirements.txt ├── run_unsup_example.sh ├── simcse_to_huggingface.py ├── sscl_to_huggingface.py ├── README.md ├── evaluation.py └── LICENSE /SentEval/senteval/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sscl/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import SSCL 2 | -------------------------------------------------------------------------------- /transfer.sh: -------------------------------------------------------------------------------- 1 | python sscl_to_huggingface.py --path result/sscl-bert-base-ckpt 2 | -------------------------------------------------------------------------------- /figure/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/figure/overview.png -------------------------------------------------------------------------------- /demo/static/files/plogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/demo/static/files/plogo.png -------------------------------------------------------------------------------- /demo/static/example_query.txt: -------------------------------------------------------------------------------- 1 | a man is playing music 2 | a woman is making a photo 3 | a woman is taking some food -------------------------------------------------------------------------------- /demo/static/files/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/demo/static/files/favicon.ico -------------------------------------------------------------------------------- /sscl/__pycache__/tool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/tool.cpython-36.pyc -------------------------------------------------------------------------------- /sscl/__pycache__/tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/tool.cpython-39.pyc -------------------------------------------------------------------------------- /data/download_nli.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/nli_for_simcse.csv 2 | -------------------------------------------------------------------------------- /sscl/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /sscl/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /data/download_wiki.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt 2 | -------------------------------------------------------------------------------- /sscl/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /sscl/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /sscl/__pycache__/trainers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/trainers.cpython-36.pyc -------------------------------------------------------------------------------- /sscl/__pycache__/trainers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuochenpku/SSCL/HEAD/sscl/__pycache__/trainers.cpython-39.pyc -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python evaluation.py \ 2 | --model_name_or_path result/bert-base_avg_neg2_sts \ 3 | --pooler avg \ 4 | --task_set full \ 5 | --mode test 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.2.1 2 | scipy==1.5.4 3 | datasets==1.2.1 4 | pandas==1.1.5 5 | scikit-learn==0.24.0 6 | prettytable==2.1.0 7 | gradio 8 | setuptools==49.3.0 -------------------------------------------------------------------------------- /SentEval/.gitignore: -------------------------------------------------------------------------------- 1 | # SentEval data and .pyc files 2 | 3 | 4 | 5 | # python 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # log files 11 | *.log 12 | *.txt 13 | 14 | # data files 15 | data/senteval_data* 16 | data/downstream/ 17 | -------------------------------------------------------------------------------- /demo/run_demo_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This example shows how to run the flask demo of SimCSE 4 | 5 | python flaskdemo.py \ 6 | --model_name_or_path princeton-nlp/sup-simcse-bert-base-uncased \ 7 | --sentences_dir ./static/ \ 8 | --example_query example_query.txt \ 9 | --example_sentences example_sentence.txt -------------------------------------------------------------------------------- /SentEval/senteval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | 10 | from senteval.engine import SE 11 | -------------------------------------------------------------------------------- /SentEval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import io 9 | from setuptools import setup, find_packages 10 | 11 | with io.open('./README.md', encoding='utf-8') as f: 12 | readme = f.read() 13 | 14 | setup( 15 | name='SentEval', 16 | version='0.1.0', 17 | url='https://github.com/facebookresearch/SentEval', 18 | packages=find_packages(exclude=['examples']), 19 | license='Attribution-NonCommercial 4.0 International', 20 | long_description=readme, 21 | ) 22 | -------------------------------------------------------------------------------- /run_unsup_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | # In this example, we show how to train SimCSE on unsupervised Wikipedia data. 4 | # about how to use PyTorch's distributed data parallel. 5 | 6 | python train.py \ 7 | --model_name_or_path bert-base-uncased\ 8 | --train_file data/wiki1m_for_simcse.txt \ 9 | --output_dir result/bert-base_avg_neg2_trans \ 10 | --do_neg \ 11 | --num_train_epochs 1 \ 12 | --per_device_train_batch_size 64\ 13 | --learning_rate 3e-5 \ 14 | --max_seq_length 32 \ 15 | --eval_transfer \ 16 | --evaluation_strategy steps \ 17 | --metric_for_best_model eval_avg_transfer \ 18 | --load_best_model_at_end \ 19 | --eval_steps 125 \ 20 | --pooler_type avg \ 21 | --mlp_only_train \ 22 | --overwrite_output_dir \ 23 | --temp 0.05 \ 24 | --do_train \ 25 | --do_eval \ 26 | --fp16 \ 27 | "$@" 28 | # --eval_transfer \ 29 | -------------------------------------------------------------------------------- /demo/static/files/style.css: -------------------------------------------------------------------------------- 1 | html { position: relative; min-height: 100%; } 2 | body { margin-bottom: 60px; font-family: Verdana, sans-serif;} 3 | .footer { position: absolute; bottom: 0; width: 100%; height: 40px; line-height: 15px; background-color: #f5f5f5; padding-top: 5px; font-size: 12px; text-align: center;} 4 | label, footer { user-select: none; } 5 | .score { position:absolute; bottom:0; right:15px;} 6 | 7 | .list-group-item:first-of-type { 8 | background-color: #fff; 9 | border-left-color: #fff; 10 | border-right-color: #fff; 11 | } 12 | .list-group-mine .list-group-item { 13 | background-color: #fff; 14 | border-left-color: #fff; 15 | border-right-color: #fff; 16 | } 17 | 18 | .paper_title { 19 | margin-top: 15px; 20 | margin-left: auto; 21 | margin-right: auto; 22 | margin-bottom: auto; 23 | width: 70%; 24 | text-align: center; 25 | } 26 | .detail { 27 | margin: auto; 28 | width: 50%; 29 | } 30 | .detail2 { 31 | margin-top: 8px; 32 | margin-left: auto; 33 | margin-right: auto; 34 | margin-bottom: auto; 35 | width: 50%; 36 | } 37 | .card { 38 | margin-top: -15px; 39 | } 40 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | ## Demo of SimCSE 2 | Several demos are available for people to play with our pre-trained SimCSE. 3 | 4 | ### Flask Demo 5 |
6 | 7 |
8 | 9 | We provide a simple Web demo based on [flask](https://github.com/pallets/flask) to show how SimCSE can be directly used for information retrieval. The code is based on [DensePhrases](https://arxiv.org/abs/2012.12624)' [repo](https://github.com/princeton-nlp/DensePhrases) and [demo](http://densephrases.korea.ac.kr) (a lot of thanks to the authors of DensePhrases). To run this flask demo locally, make sure the SimCSE inference interfaces are setup: 10 | ```bash 11 | git clone https://github.com/princeton-nlp/SimCSE 12 | cd SimCSE 13 | python setup.py develop 14 | ``` 15 | Then you can use `run_demo_example.sh` to launch the demo. As a default setting, we build the index for 1000 sentences sampled from STS-B dataset. Feel free to build the index of your own corpora. You can also install [faiss](https://github.com/facebookresearch/faiss) to speed up the retrieval process. 16 | 17 | ### Gradio Demo 18 | [AK391](https://github.com/AK391) has provided a [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) of SimCSE to show how the pre-trained models can predict the semantic similarity between two sentences. 19 | -------------------------------------------------------------------------------- /simcse_to_huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert SSCL's checkpoints to Huggingface style. 3 | """ 4 | 5 | import argparse 6 | import torch 7 | import os 8 | import json 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--path", type=str, help="Path of SSCL checkpoint folder") 14 | args = parser.parse_args() 15 | 16 | print("SSCL checkpoint -> Huggingface checkpoint for {}".format(args.path)) 17 | 18 | state_dict = torch.load(os.path.join(args.path, "pytorch_model.bin"), map_location=torch.device("cpu")) 19 | new_state_dict = {} 20 | for key, param in state_dict.items(): 21 | # Replace "mlp" to "pooler" 22 | if "mlp" in key: 23 | key = key.replace("mlp", "pooler") 24 | 25 | # Delete "bert" or "roberta" prefix 26 | if "bert." in key: 27 | key = key.replace("bert.", "") 28 | if "roberta." in key: 29 | key = key.replace("roberta.", "") 30 | 31 | new_state_dict[key] = param 32 | 33 | torch.save(new_state_dict, os.path.join(args.path, "pytorch_model.bin")) 34 | 35 | # Change architectures in config.json 36 | config = json.load(open(os.path.join(args.path, "config.json"))) 37 | for i in range(len(config["architectures"])): 38 | config["architectures"][i] = config["architectures"][i].replace("ForCL", "Model") 39 | json.dump(config, open(os.path.join(args.path, "config.json"), "w"), indent=2) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /sscl_to_huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert SSCL's checkpoints to Huggingface style. 3 | """ 4 | 5 | import argparse 6 | import torch 7 | import os 8 | import json 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--path", type=str, help="Path of SSCL checkpoint folder") 14 | args = parser.parse_args() 15 | 16 | print("SSCL checkpoint -> Huggingface checkpoint for {}".format(args.path)) 17 | 18 | state_dict = torch.load(os.path.join(args.path, "pytorch_model.bin"), map_location=torch.device("cpu")) 19 | new_state_dict = {} 20 | for key, param in state_dict.items(): 21 | # Replace "mlp" to "pooler" 22 | if "mlp" in key: 23 | key = key.replace("mlp", "pooler") 24 | 25 | # Delete "bert" or "roberta" prefix 26 | if "bert." in key: 27 | key = key.replace("bert.", "") 28 | if "roberta." in key: 29 | key = key.replace("roberta.", "") 30 | 31 | new_state_dict[key] = param 32 | 33 | torch.save(new_state_dict, os.path.join(args.path, "pytorch_model.bin")) 34 | 35 | # Change architectures in config.json 36 | config = json.load(open(os.path.join(args.path, "config.json"))) 37 | for i in range(len(config["architectures"])): 38 | config["architectures"][i] = config["architectures"][i].replace("ForCL", "Model") 39 | json.dump(config, open(os.path.join(args.path, "config.json"), "w"), indent=2) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /SentEval/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For SentEval software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /demo/gradiodemo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial.distance import cosine 3 | from transformers import AutoModel, AutoTokenizer 4 | import gradio as gr 5 | 6 | # Import our models. The package will take care of downloading the models automatically 7 | tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased") 8 | model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased") 9 | 10 | def simcse(text1, text2, text3): 11 | # Tokenize input texts 12 | texts = [ 13 | text1, 14 | text2, 15 | text3 16 | ] 17 | inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") 18 | 19 | # Get the embeddings 20 | with torch.no_grad(): 21 | embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output 22 | 23 | # Calculate cosine similarities 24 | # Cosine similarities are in [-1, 1]. Higher means more similar 25 | cosine_sim_0_1 = 1 - cosine(embeddings[0], embeddings[1]) 26 | cosine_sim_0_2 = 1 - cosine(embeddings[0], embeddings[2]) 27 | return {"cosine similarity":cosine_sim_0_1}, {"cosine similarity":cosine_sim_0_2} 28 | 29 | 30 | inputs = [ 31 | gr.inputs.Textbox(lines=5, label="Input Text One"), 32 | gr.inputs.Textbox(lines=5, label="Input Text Two"), 33 | gr.inputs.Textbox(lines=5, label="Input Text Three") 34 | ] 35 | 36 | outputs = [ 37 | gr.outputs.Label(type="confidences",label="Cosine similarity between text one and two"), 38 | gr.outputs.Label(type="confidences", label="Cosine similarity between text one and three") 39 | ] 40 | 41 | 42 | title = "SimCSE" 43 | description = "demo for Princeton-NLP SimCSE. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." 44 | article = "

SimCSE: Simple Contrastive Learning of Sentence Embeddings | Github Repo

" 45 | examples = [ 46 | ["There's a kid on a skateboard.", 47 | "A kid is skateboarding.", 48 | "A kid is inside the house."] 49 | ] 50 | 51 | gr.Interface(simcse, inputs, outputs, title=title, description=description, article=article, examples=examples).launch() -------------------------------------------------------------------------------- /SentEval/examples/skipthought.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | """ 11 | Example of file for SkipThought in SentEval 12 | """ 13 | import logging 14 | import sys 15 | sys.setdefaultencoding('utf8') 16 | 17 | 18 | # Set PATHs 19 | PATH_TO_SENTEVAL = '../' 20 | PATH_TO_DATA = '../data/senteval_data/' 21 | PATH_TO_SKIPTHOUGHT = '' 22 | 23 | assert PATH_TO_SKIPTHOUGHT != '', 'Download skipthought and set correct PATH' 24 | 25 | # import skipthought and Senteval 26 | sys.path.insert(0, PATH_TO_SKIPTHOUGHT) 27 | import skipthoughts 28 | sys.path.insert(0, PATH_TO_SENTEVAL) 29 | import senteval 30 | 31 | 32 | def prepare(params, samples): 33 | return 34 | 35 | def batcher(params, batch): 36 | batch = [str(' '.join(sent), errors="ignore") if sent != [] else '.' for sent in batch] 37 | embeddings = skipthoughts.encode(params['encoder'], batch, 38 | verbose=False, use_eos=True) 39 | return embeddings 40 | 41 | 42 | # Set params for SentEval 43 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size': 512} 44 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 45 | 'tenacity': 5, 'epoch_size': 4} 46 | # Set up logger 47 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 48 | 49 | if __name__ == "__main__": 50 | # Load SkipThought model 51 | params_senteval['encoder'] = skipthoughts.load_model() 52 | 53 | se = senteval.engine.SE(params_senteval, batcher, prepare) 54 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 55 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 56 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 57 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 58 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 59 | 'OddManOut', 'CoordinationInversion'] 60 | results = se.eval(transfer_tasks) 61 | print(results) 62 | -------------------------------------------------------------------------------- /SentEval/examples/googleuse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division 9 | 10 | import os 11 | import sys 12 | import logging 13 | import tensorflow as tf 14 | import tensorflow_hub as hub 15 | tf.logging.set_verbosity(0) 16 | 17 | # Set PATHs 18 | PATH_TO_SENTEVAL = '../' 19 | PATH_TO_DATA = '../data' 20 | 21 | # import SentEval 22 | sys.path.insert(0, PATH_TO_SENTEVAL) 23 | import senteval 24 | 25 | # tensorflow session 26 | session = tf.Session() 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 28 | 29 | # SentEval prepare and batcher 30 | def prepare(params, samples): 31 | return 32 | 33 | def batcher(params, batch): 34 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 35 | embeddings = params['google_use'](batch) 36 | return embeddings 37 | 38 | def make_embed_fn(module): 39 | with tf.Graph().as_default(): 40 | sentences = tf.placeholder(tf.string) 41 | embed = hub.Module(module) 42 | embeddings = embed(sentences) 43 | session = tf.train.MonitoredSession() 44 | return lambda x: session.run(embeddings, {sentences: x}) 45 | 46 | # Start TF session and load Google Universal Sentence Encoder 47 | encoder = make_embed_fn("https://tfhub.dev/google/universal-sentence-encoder-large/2") 48 | 49 | # Set params for SentEval 50 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 51 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 52 | 'tenacity': 3, 'epoch_size': 2} 53 | params_senteval['google_use'] = encoder 54 | 55 | # Set up logger 56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 57 | 58 | if __name__ == "__main__": 59 | se = senteval.engine.SE(params_senteval, batcher, prepare) 60 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 61 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 62 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 63 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 64 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 65 | 'OddManOut', 'CoordinationInversion'] 66 | results = se.eval(transfer_tasks) 67 | print(results) 68 | -------------------------------------------------------------------------------- /SentEval/examples/gensen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Clone GenSen repo here: https://github.com/Maluuba/gensen.git 10 | And follow instructions for loading the model used in batcher 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import sys 16 | import logging 17 | # import GenSen package 18 | from gensen import GenSen, GenSenSingle 19 | 20 | # Set PATHs 21 | PATH_TO_SENTEVAL = '../' 22 | PATH_TO_DATA = '../data' 23 | 24 | # import SentEval 25 | sys.path.insert(0, PATH_TO_SENTEVAL) 26 | import senteval 27 | 28 | # SentEval prepare and batcher 29 | def prepare(params, samples): 30 | return 31 | 32 | def batcher(params, batch): 33 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 34 | _, reps_h_t = gensen.get_representation( 35 | sentences, pool='last', return_numpy=True, tokenize=True 36 | ) 37 | embeddings = reps_h_t 38 | return embeddings 39 | 40 | # Load GenSen model 41 | gensen_1 = GenSenSingle( 42 | model_folder='../data/models', 43 | filename_prefix='nli_large_bothskip', 44 | pretrained_emb='../data/embedding/glove.840B.300d.h5' 45 | ) 46 | gensen_2 = GenSenSingle( 47 | model_folder='../data/models', 48 | filename_prefix='nli_large_bothskip_parse', 49 | pretrained_emb='../data/embedding/glove.840B.300d.h5' 50 | ) 51 | gensen_encoder = GenSen(gensen_1, gensen_2) 52 | reps_h, reps_h_t = gensen.get_representation( 53 | sentences, pool='last', return_numpy=True, tokenize=True 54 | ) 55 | 56 | # Set params for SentEval 57 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 58 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 59 | 'tenacity': 3, 'epoch_size': 2} 60 | params_senteval['gensen'] = gensen_encoder 61 | 62 | # Set up logger 63 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 64 | 65 | if __name__ == "__main__": 66 | se = senteval.engine.SE(params_senteval, batcher, prepare) 67 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 68 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 69 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 70 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 71 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 72 | 'OddManOut', 'CoordinationInversion'] 73 | results = se.eval(transfer_tasks) 74 | print(results) 75 | -------------------------------------------------------------------------------- /SentEval/examples/infersent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | InferSent models. See https://github.com/facebookresearch/InferSent. 10 | """ 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import sys 15 | import os 16 | import torch 17 | import logging 18 | 19 | # get models.py from InferSent repo 20 | from models import InferSent 21 | 22 | # Set PATHs 23 | PATH_SENTEVAL = '../' 24 | PATH_TO_DATA = '../data' 25 | PATH_TO_W2V = 'PATH/TO/glove.840B.300d.txt' # or crawl-300d-2M.vec for V2 26 | MODEL_PATH = 'infersent1.pkl' 27 | V = 1 # version of InferSent 28 | 29 | assert os.path.isfile(MODEL_PATH) and os.path.isfile(PATH_TO_W2V), \ 30 | 'Set MODEL and GloVe PATHs' 31 | 32 | # import senteval 33 | sys.path.insert(0, PATH_SENTEVAL) 34 | import senteval 35 | 36 | 37 | def prepare(params, samples): 38 | params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False) 39 | 40 | 41 | def batcher(params, batch): 42 | sentences = [' '.join(s) for s in batch] 43 | embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False) 44 | return embeddings 45 | 46 | 47 | """ 48 | Evaluation of trained model on Transfer Tasks (SentEval) 49 | """ 50 | 51 | # define senteval params 52 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 53 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 54 | 'tenacity': 3, 'epoch_size': 2} 55 | # Set up logger 56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 57 | 58 | if __name__ == "__main__": 59 | # Load InferSent model 60 | params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048, 61 | 'pool_type': 'max', 'dpout_model': 0.0, 'version': V} 62 | model = InferSent(params_model) 63 | model.load_state_dict(torch.load(MODEL_PATH)) 64 | model.set_w2v_path(PATH_TO_W2V) 65 | 66 | params_senteval['infersent'] = model.cuda() 67 | 68 | se = senteval.engine.SE(params_senteval, batcher, prepare) 69 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 70 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 71 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 72 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 73 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 74 | 'OddManOut', 'CoordinationInversion'] 75 | results = se.eval(transfer_tasks) 76 | print(results) 77 | -------------------------------------------------------------------------------- /SentEval/senteval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import numpy as np 11 | import re 12 | import inspect 13 | from torch import optim 14 | 15 | 16 | def create_dictionary(sentences): 17 | words = {} 18 | for s in sentences: 19 | for word in s: 20 | if word in words: 21 | words[word] += 1 22 | else: 23 | words[word] = 1 24 | words[''] = 1e9 + 4 25 | words[''] = 1e9 + 3 26 | words['

'] = 1e9 + 2 27 | # words[''] = 1e9 + 1 28 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 29 | id2word = [] 30 | word2id = {} 31 | for i, (w, _) in enumerate(sorted_words): 32 | id2word.append(w) 33 | word2id[w] = i 34 | 35 | return id2word, word2id 36 | 37 | 38 | def cosine(u, v): 39 | return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)) 40 | 41 | 42 | class dotdict(dict): 43 | """ dot.notation access to dictionary attributes """ 44 | __getattr__ = dict.get 45 | __setattr__ = dict.__setitem__ 46 | __delattr__ = dict.__delitem__ 47 | 48 | 49 | def get_optimizer(s): 50 | """ 51 | Parse optimizer parameters. 52 | Input should be of the form: 53 | - "sgd,lr=0.01" 54 | - "adagrad,lr=0.1,lr_decay=0.05" 55 | """ 56 | if "," in s: 57 | method = s[:s.find(',')] 58 | optim_params = {} 59 | for x in s[s.find(',') + 1:].split(','): 60 | split = x.split('=') 61 | assert len(split) == 2 62 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 63 | optim_params[split[0]] = float(split[1]) 64 | else: 65 | method = s 66 | optim_params = {} 67 | 68 | if method == 'adadelta': 69 | optim_fn = optim.Adadelta 70 | elif method == 'adagrad': 71 | optim_fn = optim.Adagrad 72 | elif method == 'adam': 73 | optim_fn = optim.Adam 74 | elif method == 'adamax': 75 | optim_fn = optim.Adamax 76 | elif method == 'asgd': 77 | optim_fn = optim.ASGD 78 | elif method == 'rmsprop': 79 | optim_fn = optim.RMSprop 80 | elif method == 'rprop': 81 | optim_fn = optim.Rprop 82 | elif method == 'sgd': 83 | optim_fn = optim.SGD 84 | assert 'lr' in optim_params 85 | else: 86 | raise Exception('Unknown optimization method: "%s"' % method) 87 | 88 | # check that we give good parameters to the optimizer 89 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 90 | assert expected_args[:2] == ['self', 'params'] 91 | if not all(k in expected_args[2:] for k in optim_params.keys()): 92 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 93 | str(expected_args[2:]), str(optim_params.keys()))) 94 | 95 | return optim_fn, optim_params 96 | -------------------------------------------------------------------------------- /demo/flaskdemo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import torch 4 | import os 5 | import random 6 | import numpy as np 7 | import requests 8 | import logging 9 | import math 10 | import copy 11 | import string 12 | 13 | from tqdm import tqdm 14 | from time import time 15 | from flask import Flask, request, jsonify 16 | from flask_cors import CORS 17 | from tornado.wsgi import WSGIContainer 18 | from tornado.httpserver import HTTPServer 19 | from tornado.ioloop import IOLoop 20 | 21 | from simcse import SimCSE 22 | 23 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 24 | level=logging.INFO) 25 | logger = logging.getLogger(__name__) 26 | 27 | def run_simcse_demo(port, args): 28 | app = Flask(__name__, static_folder='./static') 29 | app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False 30 | CORS(app) 31 | 32 | sentence_path = os.path.join(args.sentences_dir, args.example_sentences) 33 | query_path = os.path.join(args.sentences_dir, args.example_query) 34 | embedder = SimCSE(args.model_name_or_path) 35 | embedder.build_index(sentence_path) 36 | @app.route('/') 37 | def index(): 38 | return app.send_static_file('index.html') 39 | 40 | @app.route('/api', methods=['GET']) 41 | def api(): 42 | query = request.args['query'] 43 | top_k = int(request.args['topk']) 44 | threshold = float(request.args['threshold']) 45 | start = time() 46 | results = embedder.search(query, top_k=top_k, threshold=threshold) 47 | ret = [] 48 | out = {} 49 | for sentence, score in results: 50 | ret.append({"sentence": sentence, "score": score}) 51 | span = time() - start 52 | out['ret'] = ret 53 | out['time'] = "{:.4f}".format(span) 54 | return jsonify(out) 55 | 56 | @app.route('/files/') 57 | def static_files(path): 58 | return app.send_static_file('files/' + path) 59 | 60 | @app.route('/get_examples', methods=['GET']) 61 | def get_examples(): 62 | with open(query_path, 'r') as fp: 63 | examples = [line.strip() for line in fp.readlines()] 64 | return jsonify(examples) 65 | 66 | addr = args.ip + ":" + args.port 67 | logger.info(f'Starting Index server at {addr}') 68 | http_server = HTTPServer(WSGIContainer(app)) 69 | http_server.listen(port) 70 | IOLoop.instance().start() 71 | 72 | if __name__=="__main__": 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--model_name_or_path', default=None, type=str) 75 | parser.add_argument('--device', default='cpu', type=str) 76 | parser.add_argument('--sentences_dir', default=None, type=str) 77 | parser.add_argument('--example_query', default=None, type=str) 78 | parser.add_argument('--example_sentences', default=None, type=str) 79 | parser.add_argument('--port', default='8888', type=str) 80 | parser.add_argument('--ip', default='http://127.0.0.1') 81 | parser.add_argument('--load_light', default=False, action='store_true') 82 | args = parser.parse_args() 83 | 84 | run_simcse_demo(args.port, args) -------------------------------------------------------------------------------- /SentEval/senteval/trec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | TREC question-type classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import KFoldClassifier 20 | 21 | 22 | class TRECEval(object): 23 | def __init__(self, task_path, seed=1111): 24 | logging.info('***** Transfer task : TREC *****\n\n') 25 | self.seed = seed 26 | self.train = self.loadFile(os.path.join(task_path, 'train_5500.label')) 27 | self.test = self.loadFile(os.path.join(task_path, 'TREC_10.label')) 28 | 29 | def do_prepare(self, params, prepare): 30 | samples = self.train['X'] + self.test['X'] 31 | return prepare(params, samples) 32 | 33 | def loadFile(self, fpath): 34 | trec_data = {'X': [], 'y': []} 35 | tgt2idx = {'ABBR': 0, 'DESC': 1, 'ENTY': 2, 36 | 'HUM': 3, 'LOC': 4, 'NUM': 5} 37 | with io.open(fpath, 'r', encoding='latin-1') as f: 38 | for line in f: 39 | target, sample = line.strip().split(':', 1) 40 | sample = sample.split(' ', 1)[1].split() 41 | assert target in tgt2idx, target 42 | trec_data['X'].append(sample) 43 | trec_data['y'].append(tgt2idx[target]) 44 | return trec_data 45 | 46 | def run(self, params, batcher): 47 | train_embeddings, test_embeddings = [], [] 48 | 49 | # Sort to reduce padding 50 | sorted_corpus_train = sorted(zip(self.train['X'], self.train['y']), 51 | key=lambda z: (len(z[0]), z[1])) 52 | train_samples = [x for (x, y) in sorted_corpus_train] 53 | train_labels = [y for (x, y) in sorted_corpus_train] 54 | 55 | sorted_corpus_test = sorted(zip(self.test['X'], self.test['y']), 56 | key=lambda z: (len(z[0]), z[1])) 57 | test_samples = [x for (x, y) in sorted_corpus_test] 58 | test_labels = [y for (x, y) in sorted_corpus_test] 59 | 60 | # Get train embeddings 61 | for ii in range(0, len(train_labels), params.batch_size): 62 | batch = train_samples[ii:ii + params.batch_size] 63 | embeddings = batcher(params, batch) 64 | train_embeddings.append(embeddings) 65 | train_embeddings = np.vstack(train_embeddings) 66 | logging.info('Computed train embeddings') 67 | 68 | # Get test embeddings 69 | for ii in range(0, len(test_labels), params.batch_size): 70 | batch = test_samples[ii:ii + params.batch_size] 71 | embeddings = batcher(params, batch) 72 | test_embeddings.append(embeddings) 73 | test_embeddings = np.vstack(test_embeddings) 74 | logging.info('Computed test embeddings') 75 | 76 | config_classifier = {'nclasses': 6, 'seed': self.seed, 77 | 'usepytorch': params.usepytorch, 78 | 'classifier': params.classifier, 79 | 'kfold': params.kfold} 80 | clf = KFoldClassifier({'X': train_embeddings, 81 | 'y': np.array(train_labels)}, 82 | {'X': test_embeddings, 83 | 'y': np.array(test_labels)}, 84 | config_classifier) 85 | devacc, testacc, _ = clf.run() 86 | logging.debug('\nDev acc : {0} Test acc : {1} \ 87 | for TREC\n'.format(devacc, testacc)) 88 | return {'devacc': devacc, 'acc': testacc, 89 | 'ndev': len(self.train['X']), 'ntest': len(self.test['X'])} 90 | -------------------------------------------------------------------------------- /SentEval/examples/bow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import sys 11 | import io 12 | import numpy as np 13 | import logging 14 | 15 | 16 | # Set PATHs 17 | PATH_TO_SENTEVAL = '../' 18 | PATH_TO_DATA = '../data' 19 | # PATH_TO_VEC = 'glove/glove.840B.300d.txt' 20 | PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec' 21 | 22 | # import SentEval 23 | sys.path.insert(0, PATH_TO_SENTEVAL) 24 | import senteval 25 | 26 | 27 | # Create dictionary 28 | def create_dictionary(sentences, threshold=0): 29 | words = {} 30 | for s in sentences: 31 | for word in s: 32 | words[word] = words.get(word, 0) + 1 33 | 34 | if threshold > 0: 35 | newwords = {} 36 | for word in words: 37 | if words[word] >= threshold: 38 | newwords[word] = words[word] 39 | words = newwords 40 | words[''] = 1e9 + 4 41 | words[''] = 1e9 + 3 42 | words['

'] = 1e9 + 2 43 | 44 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 45 | id2word = [] 46 | word2id = {} 47 | for i, (w, _) in enumerate(sorted_words): 48 | id2word.append(w) 49 | word2id[w] = i 50 | 51 | return id2word, word2id 52 | 53 | # Get word vectors from vocabulary (glove, word2vec, fasttext ..) 54 | def get_wordvec(path_to_vec, word2id): 55 | word_vec = {} 56 | 57 | with io.open(path_to_vec, 'r', encoding='utf-8') as f: 58 | # if word2vec or fasttext file : skip first line "next(f)" 59 | for line in f: 60 | word, vec = line.split(' ', 1) 61 | if word in word2id: 62 | word_vec[word] = np.fromstring(vec, sep=' ') 63 | 64 | logging.info('Found {0} words with word vectors, out of \ 65 | {1} words'.format(len(word_vec), len(word2id))) 66 | return word_vec 67 | 68 | 69 | # SentEval prepare and batcher 70 | def prepare(params, samples): 71 | _, params.word2id = create_dictionary(samples) 72 | params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id) 73 | params.wvec_dim = 300 74 | return 75 | 76 | def batcher(params, batch): 77 | batch = [sent if sent != [] else ['.'] for sent in batch] 78 | embeddings = [] 79 | 80 | for sent in batch: 81 | sentvec = [] 82 | for word in sent: 83 | if word in params.word_vec: 84 | sentvec.append(params.word_vec[word]) 85 | if not sentvec: 86 | vec = np.zeros(params.wvec_dim) 87 | sentvec.append(vec) 88 | sentvec = np.mean(sentvec, 0) 89 | embeddings.append(sentvec) 90 | 91 | embeddings = np.vstack(embeddings) 92 | return embeddings 93 | 94 | 95 | # Set params for SentEval 96 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 97 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 98 | 'tenacity': 3, 'epoch_size': 2} 99 | 100 | # Set up logger 101 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 102 | 103 | if __name__ == "__main__": 104 | se = senteval.engine.SE(params_senteval, batcher, prepare) 105 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 106 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 107 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 108 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 109 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 110 | 'OddManOut', 'CoordinationInversion'] 111 | results = se.eval(transfer_tasks) 112 | print(results) 113 | -------------------------------------------------------------------------------- /SentEval/senteval/binary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import io 14 | import os 15 | import numpy as np 16 | import logging 17 | 18 | from senteval.tools.validation import InnerKFoldClassifier 19 | 20 | 21 | class BinaryClassifierEval(object): 22 | def __init__(self, pos, neg, seed=1111): 23 | self.seed = seed 24 | self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg) 25 | self.n_samples = len(self.samples) 26 | 27 | def do_prepare(self, params, prepare): 28 | # prepare is given the whole text 29 | return prepare(params, self.samples) 30 | # prepare puts everything it outputs in "params" : params.word2id etc 31 | # Those output will be further used by "batcher". 32 | 33 | def loadFile(self, fpath): 34 | with io.open(fpath, 'r', encoding='latin-1') as f: 35 | return [line.split() for line in f.read().splitlines()] 36 | 37 | def run(self, params, batcher): 38 | enc_input = [] 39 | # Sort to reduce padding 40 | sorted_corpus = sorted(zip(self.samples, self.labels), 41 | key=lambda z: (len(z[0]), z[1])) 42 | sorted_samples = [x for (x, y) in sorted_corpus] 43 | sorted_labels = [y for (x, y) in sorted_corpus] 44 | logging.info('Generating sentence embeddings') 45 | for ii in range(0, self.n_samples, params.batch_size): 46 | batch = sorted_samples[ii:ii + params.batch_size] 47 | embeddings = batcher(params, batch) 48 | enc_input.append(embeddings) 49 | enc_input = np.vstack(enc_input) 50 | logging.info('Generated sentence embeddings') 51 | 52 | config = {'nclasses': 2, 'seed': self.seed, 53 | 'usepytorch': params.usepytorch, 54 | 'classifier': params.classifier, 55 | 'nhid': params.nhid, 'kfold': params.kfold} 56 | clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config) 57 | devacc, testacc = clf.run() 58 | logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc)) 59 | return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples, 60 | 'ntest': self.n_samples} 61 | 62 | 63 | class CREval(BinaryClassifierEval): 64 | def __init__(self, task_path, seed=1111): 65 | logging.debug('***** Transfer task : CR *****\n\n') 66 | pos = self.loadFile(os.path.join(task_path, 'custrev.pos')) 67 | neg = self.loadFile(os.path.join(task_path, 'custrev.neg')) 68 | super(self.__class__, self).__init__(pos, neg, seed) 69 | 70 | 71 | class MREval(BinaryClassifierEval): 72 | def __init__(self, task_path, seed=1111): 73 | logging.debug('***** Transfer task : MR *****\n\n') 74 | pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos')) 75 | neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg')) 76 | super(self.__class__, self).__init__(pos, neg, seed) 77 | 78 | 79 | class SUBJEval(BinaryClassifierEval): 80 | def __init__(self, task_path, seed=1111): 81 | logging.debug('***** Transfer task : SUBJ *****\n\n') 82 | obj = self.loadFile(os.path.join(task_path, 'subj.objective')) 83 | subj = self.loadFile(os.path.join(task_path, 'subj.subjective')) 84 | super(self.__class__, self).__init__(obj, subj, seed) 85 | 86 | 87 | class MPQAEval(BinaryClassifierEval): 88 | def __init__(self, task_path, seed=1111): 89 | logging.debug('***** Transfer task : MPQA *****\n\n') 90 | pos = self.loadFile(os.path.join(task_path, 'mpqa.pos')) 91 | neg = self.loadFile(os.path.join(task_path, 'mpqa.neg')) 92 | super(self.__class__, self).__init__(pos, neg, seed) 93 | -------------------------------------------------------------------------------- /SentEval/senteval/sst.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SST - binary classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import SplitClassifier 20 | 21 | 22 | class SSTEval(object): 23 | def __init__(self, task_path, nclasses=2, seed=1111): 24 | self.seed = seed 25 | 26 | # binary of fine-grained 27 | assert nclasses in [2, 5] 28 | self.nclasses = nclasses 29 | self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained' 30 | logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name) 31 | 32 | train = self.loadFile(os.path.join(task_path, 'sentiment-train')) 33 | dev = self.loadFile(os.path.join(task_path, 'sentiment-dev')) 34 | test = self.loadFile(os.path.join(task_path, 'sentiment-test')) 35 | self.sst_data = {'train': train, 'dev': dev, 'test': test} 36 | 37 | def do_prepare(self, params, prepare): 38 | samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \ 39 | self.sst_data['test']['X'] 40 | return prepare(params, samples) 41 | 42 | def loadFile(self, fpath): 43 | sst_data = {'X': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if self.nclasses == 2: 47 | sample = line.strip().split('\t') 48 | sst_data['y'].append(int(sample[1])) 49 | sst_data['X'].append(sample[0].split()) 50 | elif self.nclasses == 5: 51 | sample = line.strip().split(' ', 1) 52 | sst_data['y'].append(int(sample[0])) 53 | sst_data['X'].append(sample[1].split()) 54 | assert max(sst_data['y']) == self.nclasses - 1 55 | return sst_data 56 | 57 | def run(self, params, batcher): 58 | sst_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sst_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_data = sorted(zip(self.sst_data[key]['X'], 65 | self.sst_data[key]['y']), 66 | key=lambda z: (len(z[0]), z[1])) 67 | self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data)) 68 | 69 | sst_embed[key]['X'] = [] 70 | for ii in range(0, len(self.sst_data[key]['y']), bsize): 71 | batch = self.sst_data[key]['X'][ii:ii + bsize] 72 | embeddings = batcher(params, batch) 73 | sst_embed[key]['X'].append(embeddings) 74 | sst_embed[key]['X'] = np.vstack(sst_embed[key]['X']) 75 | sst_embed[key]['y'] = np.array(self.sst_data[key]['y']) 76 | logging.info('Computed {0} embeddings'.format(key)) 77 | 78 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 79 | 'usepytorch': params.usepytorch, 80 | 'classifier': params.classifier} 81 | 82 | clf = SplitClassifier(X={'train': sst_embed['train']['X'], 83 | 'valid': sst_embed['dev']['X'], 84 | 'test': sst_embed['test']['X']}, 85 | y={'train': sst_embed['train']['y'], 86 | 'valid': sst_embed['dev']['y'], 87 | 'test': sst_embed['test']['y']}, 88 | config=config_classifier) 89 | 90 | devacc, testacc = clf.run() 91 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 92 | SST {2} classification\n'.format(devacc, testacc, self.task_name)) 93 | 94 | return {'devacc': devacc, 'acc': testacc, 95 | 'ndev': len(sst_embed['dev']['X']), 96 | 'ntest': len(sst_embed['test']['X'])} 97 | -------------------------------------------------------------------------------- /SentEval/senteval/mrpc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | MRPC : Microsoft Research Paraphrase (detection) Corpus 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import logging 15 | import numpy as np 16 | import io 17 | 18 | from senteval.tools.validation import KFoldClassifier 19 | 20 | from sklearn.metrics import f1_score 21 | 22 | 23 | class MRPCEval(object): 24 | def __init__(self, task_path, seed=1111): 25 | logging.info('***** Transfer task : MRPC *****\n\n') 26 | self.seed = seed 27 | train = self.loadFile(os.path.join(task_path, 28 | 'msr_paraphrase_train.txt')) 29 | test = self.loadFile(os.path.join(task_path, 30 | 'msr_paraphrase_test.txt')) 31 | self.mrpc_data = {'train': train, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | # TODO : Should we separate samples in "train, test"? 35 | samples = self.mrpc_data['train']['X_A'] + \ 36 | self.mrpc_data['train']['X_B'] + \ 37 | self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B'] 38 | return prepare(params, samples) 39 | 40 | def loadFile(self, fpath): 41 | mrpc_data = {'X_A': [], 'X_B': [], 'y': []} 42 | with io.open(fpath, 'r', encoding='utf-8') as f: 43 | for line in f: 44 | text = line.strip().split('\t') 45 | mrpc_data['X_A'].append(text[3].split()) 46 | mrpc_data['X_B'].append(text[4].split()) 47 | mrpc_data['y'].append(text[0]) 48 | 49 | mrpc_data['X_A'] = mrpc_data['X_A'][1:] 50 | mrpc_data['X_B'] = mrpc_data['X_B'][1:] 51 | mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]] 52 | return mrpc_data 53 | 54 | def run(self, params, batcher): 55 | mrpc_embed = {'train': {}, 'test': {}} 56 | 57 | for key in self.mrpc_data: 58 | logging.info('Computing embedding for {0}'.format(key)) 59 | # Sort to reduce padding 60 | text_data = {} 61 | sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'], 62 | self.mrpc_data[key]['X_B'], 63 | self.mrpc_data[key]['y']), 64 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 65 | 66 | text_data['A'] = [x for (x, y, z) in sorted_corpus] 67 | text_data['B'] = [y for (x, y, z) in sorted_corpus] 68 | text_data['y'] = [z for (x, y, z) in sorted_corpus] 69 | 70 | for txt_type in ['A', 'B']: 71 | mrpc_embed[key][txt_type] = [] 72 | for ii in range(0, len(text_data['y']), params.batch_size): 73 | batch = text_data[txt_type][ii:ii + params.batch_size] 74 | embeddings = batcher(params, batch) 75 | mrpc_embed[key][txt_type].append(embeddings) 76 | mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type]) 77 | mrpc_embed[key]['y'] = np.array(text_data['y']) 78 | logging.info('Computed {0} embeddings'.format(key)) 79 | 80 | # Train 81 | trainA = mrpc_embed['train']['A'] 82 | trainB = mrpc_embed['train']['B'] 83 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 84 | trainY = mrpc_embed['train']['y'] 85 | 86 | # Test 87 | testA = mrpc_embed['test']['A'] 88 | testB = mrpc_embed['test']['B'] 89 | testF = np.c_[np.abs(testA - testB), testA * testB] 90 | testY = mrpc_embed['test']['y'] 91 | 92 | config = {'nclasses': 2, 'seed': self.seed, 93 | 'usepytorch': params.usepytorch, 94 | 'classifier': params.classifier, 95 | 'nhid': params.nhid, 'kfold': params.kfold} 96 | clf = KFoldClassifier(train={'X': trainF, 'y': trainY}, 97 | test={'X': testF, 'y': testY}, config=config) 98 | 99 | devacc, testacc, yhat = clf.run() 100 | testf1 = round(100*f1_score(testY, yhat), 2) 101 | logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n' 102 | .format(devacc, testacc, testf1)) 103 | return {'devacc': devacc, 'acc': testacc, 'f1': testf1, 104 | 'ndev': len(trainA), 'ntest': len(testA)} 105 | -------------------------------------------------------------------------------- /SentEval/senteval/snli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SNLI - Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import codecs 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class SNLIEval(object): 24 | def __init__(self, taskpath, seed=1111): 25 | logging.debug('***** Transfer task : SNLI Entailment*****\n\n') 26 | self.seed = seed 27 | train1 = self.loadFile(os.path.join(taskpath, 's1.train')) 28 | train2 = self.loadFile(os.path.join(taskpath, 's2.train')) 29 | 30 | trainlabels = io.open(os.path.join(taskpath, 'labels.train'), 31 | encoding='utf-8').read().splitlines() 32 | 33 | valid1 = self.loadFile(os.path.join(taskpath, 's1.dev')) 34 | valid2 = self.loadFile(os.path.join(taskpath, 's2.dev')) 35 | validlabels = io.open(os.path.join(taskpath, 'labels.dev'), 36 | encoding='utf-8').read().splitlines() 37 | 38 | test1 = self.loadFile(os.path.join(taskpath, 's1.test')) 39 | test2 = self.loadFile(os.path.join(taskpath, 's2.test')) 40 | testlabels = io.open(os.path.join(taskpath, 'labels.test'), 41 | encoding='utf-8').read().splitlines() 42 | 43 | # sort data (by s2 first) to reduce padding 44 | sorted_train = sorted(zip(train2, train1, trainlabels), 45 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 46 | train2, train1, trainlabels = map(list, zip(*sorted_train)) 47 | 48 | sorted_valid = sorted(zip(valid2, valid1, validlabels), 49 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 50 | valid2, valid1, validlabels = map(list, zip(*sorted_valid)) 51 | 52 | sorted_test = sorted(zip(test2, test1, testlabels), 53 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 54 | test2, test1, testlabels = map(list, zip(*sorted_test)) 55 | 56 | self.samples = train1 + train2 + valid1 + valid2 + test1 + test2 57 | self.data = {'train': (train1, train2, trainlabels), 58 | 'valid': (valid1, valid2, validlabels), 59 | 'test': (test1, test2, testlabels) 60 | } 61 | 62 | def do_prepare(self, params, prepare): 63 | return prepare(params, self.samples) 64 | 65 | def loadFile(self, fpath): 66 | with codecs.open(fpath, 'rb', 'latin-1') as f: 67 | return [line.split() for line in 68 | f.read().splitlines()] 69 | 70 | def run(self, params, batcher): 71 | self.X, self.y = {}, {} 72 | dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2} 73 | for key in self.data: 74 | if key not in self.X: 75 | self.X[key] = [] 76 | if key not in self.y: 77 | self.y[key] = [] 78 | 79 | input1, input2, mylabels = self.data[key] 80 | enc_input = [] 81 | n_labels = len(mylabels) 82 | for ii in range(0, n_labels, params.batch_size): 83 | batch1 = input1[ii:ii + params.batch_size] 84 | batch2 = input2[ii:ii + params.batch_size] 85 | 86 | if len(batch1) == len(batch2) and len(batch1) > 0: 87 | enc1 = batcher(params, batch1) 88 | enc2 = batcher(params, batch2) 89 | enc_input.append(np.hstack((enc1, enc2, enc1 * enc2, 90 | np.abs(enc1 - enc2)))) 91 | if (ii*params.batch_size) % (20000*params.batch_size) == 0: 92 | logging.info("PROGRESS (encoding): %.2f%%" % 93 | (100 * ii / n_labels)) 94 | self.X[key] = np.vstack(enc_input) 95 | self.y[key] = [dico_label[y] for y in mylabels] 96 | 97 | config = {'nclasses': 3, 'seed': self.seed, 98 | 'usepytorch': params.usepytorch, 99 | 'cudaEfficient': True, 100 | 'nhid': params.nhid, 'noreg': True} 101 | 102 | config_classifier = copy.deepcopy(params.classifier) 103 | config_classifier['max_epoch'] = 15 104 | config_classifier['epoch_size'] = 1 105 | config['classifier'] = config_classifier 106 | 107 | clf = SplitClassifier(self.X, self.y, config) 108 | devacc, testacc = clf.run() 109 | logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n' 110 | .format(devacc, testacc)) 111 | return {'devacc': devacc, 'acc': testacc, 112 | 'ndev': len(self.data['valid'][0]), 113 | 'ntest': len(self.data['test'][0])} 114 | -------------------------------------------------------------------------------- /SentEval/senteval/rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Image-Caption Retrieval with COCO dataset 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import sys 15 | import logging 16 | import numpy as np 17 | 18 | try: 19 | import cPickle as pickle 20 | except ImportError: 21 | import pickle 22 | 23 | from senteval.tools.ranking import ImageSentenceRankingPytorch 24 | 25 | 26 | class ImageCaptionRetrievalEval(object): 27 | def __init__(self, task_path, seed=1111): 28 | logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n') 29 | 30 | # Get captions and image features 31 | self.seed = seed 32 | train, dev, test = self.loadFile(task_path) 33 | self.coco_data = {'train': train, 'dev': dev, 'test': test} 34 | 35 | def do_prepare(self, params, prepare): 36 | samples = self.coco_data['train']['sent'] + \ 37 | self.coco_data['dev']['sent'] + \ 38 | self.coco_data['test']['sent'] 39 | prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | coco = {} 43 | 44 | for split in ['train', 'valid', 'test']: 45 | list_sent = [] 46 | list_img_feat = [] 47 | if sys.version_info < (3, 0): 48 | with open(os.path.join(fpath, split + '.pkl')) as f: 49 | cocodata = pickle.load(f) 50 | else: 51 | with open(os.path.join(fpath, split + '.pkl'), 'rb') as f: 52 | cocodata = pickle.load(f, encoding='latin1') 53 | 54 | for imgkey in range(len(cocodata['features'])): 55 | assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \ 56 | cocodata['image_to_caption_ids'][imgkey] 57 | for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]: 58 | sent = cocodata['captions'][captkey]['cleaned_caption'] 59 | sent += ' .' # add punctuation to end of sentence in COCO 60 | list_sent.append(sent.encode('utf-8').split()) 61 | list_img_feat.append(cocodata['features'][imgkey]) 62 | assert len(list_sent) == len(list_img_feat) and \ 63 | len(list_sent) % 5 == 0 64 | list_img_feat = np.array(list_img_feat).astype('float32') 65 | coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat} 66 | return coco['train'], coco['valid'], coco['test'] 67 | 68 | def run(self, params, batcher): 69 | coco_embed = {'train': {'sentfeat': [], 'imgfeat': []}, 70 | 'dev': {'sentfeat': [], 'imgfeat': []}, 71 | 'test': {'sentfeat': [], 'imgfeat': []}} 72 | 73 | for key in self.coco_data: 74 | logging.info('Computing embedding for {0}'.format(key)) 75 | # Sort to reduce padding 76 | self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent']) 77 | self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent']) 78 | idx_unsort = np.argsort(idx_sort) 79 | 80 | coco_embed[key]['X'] = [] 81 | nsent = len(self.coco_data[key]['sent']) 82 | for ii in range(0, nsent, params.batch_size): 83 | batch = self.coco_data[key]['sent'][ii:ii + params.batch_size] 84 | embeddings = batcher(params, batch) 85 | coco_embed[key]['sentfeat'].append(embeddings) 86 | coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort] 87 | coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat']) 88 | logging.info('Computed {0} embeddings'.format(key)) 89 | 90 | config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2} 91 | clf = ImageSentenceRankingPytorch(train=coco_embed['train'], 92 | valid=coco_embed['dev'], 93 | test=coco_embed['test'], 94 | config=config) 95 | 96 | bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \ 97 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run() 98 | 99 | logging.debug("\nTest scores | Image to text: \ 100 | {0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 101 | logging.debug("Test scores | Text to image: \ 102 | {0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 103 | 104 | return {'devacc': bestdevscore, 105 | 'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t), 106 | (r1_t2i, r5_t2i, r10_t2i, medr_t2i)], 107 | 'ndev': len(coco_embed['dev']['sentfeat']), 108 | 'ntest': len(coco_embed['test']['sentfeat'])} 109 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/relatedness.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Semantic Relatedness (supervised) with Pytorch 10 | """ 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import copy 14 | import numpy as np 15 | 16 | import torch 17 | from torch import nn 18 | import torch.optim as optim 19 | 20 | from scipy.stats import pearsonr, spearmanr 21 | 22 | 23 | class RelatednessPytorch(object): 24 | # Can be used for SICK-Relatedness, and STS14 25 | def __init__(self, train, valid, test, devscores, config): 26 | # fix seed 27 | np.random.seed(config['seed']) 28 | torch.manual_seed(config['seed']) 29 | assert torch.cuda.is_available(), 'torch.cuda required for Relatedness' 30 | torch.cuda.manual_seed(config['seed']) 31 | 32 | self.train = train 33 | self.valid = valid 34 | self.test = test 35 | self.devscores = devscores 36 | 37 | self.inputdim = train['X'].shape[1] 38 | self.nclasses = config['nclasses'] 39 | self.seed = config['seed'] 40 | self.l2reg = 0. 41 | self.batch_size = 64 42 | self.maxepoch = 1000 43 | self.early_stop = True 44 | 45 | self.model = nn.Sequential( 46 | nn.Linear(self.inputdim, self.nclasses), 47 | nn.Softmax(dim=-1), 48 | ) 49 | self.loss_fn = nn.MSELoss() 50 | 51 | if torch.cuda.is_available(): 52 | self.model = self.model.cuda() 53 | self.loss_fn = self.loss_fn.cuda() 54 | 55 | self.loss_fn.size_average = False 56 | self.optimizer = optim.Adam(self.model.parameters(), 57 | weight_decay=self.l2reg) 58 | 59 | def prepare_data(self, trainX, trainy, devX, devy, testX, testy): 60 | # Transform probs to log-probs for KL-divergence 61 | trainX = torch.from_numpy(trainX).float().cuda() 62 | trainy = torch.from_numpy(trainy).float().cuda() 63 | devX = torch.from_numpy(devX).float().cuda() 64 | devy = torch.from_numpy(devy).float().cuda() 65 | testX = torch.from_numpy(testX).float().cuda() 66 | testY = torch.from_numpy(testy).float().cuda() 67 | 68 | return trainX, trainy, devX, devy, testX, testy 69 | 70 | def run(self): 71 | self.nepoch = 0 72 | bestpr = -1 73 | early_stop_count = 0 74 | r = np.arange(1, 6) 75 | stop_train = False 76 | 77 | # Preparing data 78 | trainX, trainy, devX, devy, testX, testy = self.prepare_data( 79 | self.train['X'], self.train['y'], 80 | self.valid['X'], self.valid['y'], 81 | self.test['X'], self.test['y']) 82 | 83 | # Training 84 | while not stop_train and self.nepoch <= self.maxepoch: 85 | self.trainepoch(trainX, trainy, nepoches=50) 86 | yhat = np.dot(self.predict_proba(devX), r) 87 | pr = spearmanr(yhat, self.devscores)[0] 88 | pr = 0 if pr != pr else pr # if NaN bc std=0 89 | # early stop on Pearson 90 | if pr > bestpr: 91 | bestpr = pr 92 | bestmodel = copy.deepcopy(self.model) 93 | elif self.early_stop: 94 | if early_stop_count >= 3: 95 | stop_train = True 96 | early_stop_count += 1 97 | self.model = bestmodel 98 | 99 | yhat = np.dot(self.predict_proba(testX), r) 100 | 101 | return bestpr, yhat 102 | 103 | def trainepoch(self, X, y, nepoches=1): 104 | self.model.train() 105 | for _ in range(self.nepoch, self.nepoch + nepoches): 106 | permutation = np.random.permutation(len(X)) 107 | all_costs = [] 108 | for i in range(0, len(X), self.batch_size): 109 | # forward 110 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda() 111 | Xbatch = X[idx] 112 | ybatch = y[idx] 113 | output = self.model(Xbatch) 114 | # loss 115 | loss = self.loss_fn(output, ybatch) 116 | all_costs.append(loss.item()) 117 | # backward 118 | self.optimizer.zero_grad() 119 | loss.backward() 120 | # Update parameters 121 | self.optimizer.step() 122 | self.nepoch += nepoches 123 | 124 | def predict_proba(self, devX): 125 | self.model.eval() 126 | probas = [] 127 | with torch.no_grad(): 128 | for i in range(0, len(devX), self.batch_size): 129 | Xbatch = devX[i:i + self.batch_size] 130 | if len(probas) == 0: 131 | probas = self.model(Xbatch).data.cpu().numpy() 132 | else: 133 | probas = np.concatenate((probas, self.model(Xbatch).data.cpu().numpy()), axis=0) 134 | return probas 135 | -------------------------------------------------------------------------------- /SentEval/senteval/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | 10 | Generic sentence evaluation scripts wrapper 11 | 12 | ''' 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | from senteval import utils 16 | from senteval.binary import CREval, MREval, MPQAEval, SUBJEval 17 | from senteval.snli import SNLIEval 18 | from senteval.trec import TRECEval 19 | from senteval.sick import SICKEntailmentEval, SICKEval 20 | from senteval.mrpc import MRPCEval 21 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune 22 | from senteval.sst import SSTEval 23 | from senteval.rank import ImageCaptionRetrievalEval 24 | from senteval.probing import * 25 | 26 | class SE(object): 27 | def __init__(self, params, batcher, prepare=None): 28 | # parameters 29 | params = utils.dotdict(params) 30 | params.usepytorch = True if 'usepytorch' not in params else params.usepytorch 31 | params.seed = 1111 if 'seed' not in params else params.seed 32 | 33 | params.batch_size = 128 if 'batch_size' not in params else params.batch_size 34 | params.nhid = 0 if 'nhid' not in params else params.nhid 35 | params.kfold = 5 if 'kfold' not in params else params.kfold 36 | 37 | if 'classifier' not in params or not params['classifier']: 38 | params.classifier = {'nhid': 0} 39 | 40 | assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!' 41 | 42 | self.params = params 43 | 44 | # batcher and prepare 45 | self.batcher = batcher 46 | self.prepare = prepare if prepare else lambda x, y: None 47 | 48 | self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 49 | 'SICKRelatedness', 'SICKEntailment', 'STSBenchmark', 50 | 'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13', 51 | 'STS14', 'STS15', 'STS16', 52 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 53 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 54 | 'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix'] 55 | 56 | def eval(self, name): 57 | # evaluate on evaluation [name], either takes string or list of strings 58 | if (isinstance(name, list)): 59 | self.results = {x: self.eval(x) for x in name} 60 | return self.results 61 | 62 | tpath = self.params.task_path 63 | assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks) 64 | 65 | # Original SentEval tasks 66 | if name == 'CR': 67 | self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed) 68 | elif name == 'MR': 69 | self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed) 70 | elif name == 'MPQA': 71 | self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed) 72 | elif name == 'SUBJ': 73 | self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed) 74 | elif name == 'SST2': 75 | self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed) 76 | elif name == 'SST5': 77 | self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed) 78 | elif name == 'TREC': 79 | self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed) 80 | elif name == 'MRPC': 81 | self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed) 82 | elif name == 'SICKRelatedness': 83 | self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed) 84 | elif name == 'STSBenchmark': 85 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 86 | elif name == 'STSBenchmark-fix': 87 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed) 88 | elif name == 'STSBenchmark-finetune': 89 | self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 90 | elif name == 'SICKRelatedness-finetune': 91 | self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed) 92 | elif name == 'SICKEntailment': 93 | self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed) 94 | elif name == 'SNLI': 95 | self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed) 96 | elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 97 | fpath = name + '-en-test' 98 | self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed) 99 | elif name == 'ImageCaptionRetrieval': 100 | self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed) 101 | 102 | # Probing Tasks 103 | elif name == 'Length': 104 | self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed) 105 | elif name == 'WordContent': 106 | self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed) 107 | elif name == 'Depth': 108 | self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed) 109 | elif name == 'TopConstituents': 110 | self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed) 111 | elif name == 'BigramShift': 112 | self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed) 113 | elif name == 'Tense': 114 | self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed) 115 | elif name == 'SubjNumber': 116 | self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed) 117 | elif name == 'ObjNumber': 118 | self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed) 119 | elif name == 'OddManOut': 120 | self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed) 121 | elif name == 'CoordinationInversion': 122 | self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed) 123 | 124 | self.params.current_task = name 125 | self.evaluation.do_prepare(self.params, self.prepare) 126 | 127 | self.results = self.evaluation.run(self.params, self.batcher) 128 | 129 | return self.results 130 | -------------------------------------------------------------------------------- /demo/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | SimCSE 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |

28 | 29 |
30 |
31 |
32 |
33 |

Simple Contrastive Learning of Sentence Embeddings

34 | Tianyu Gao  Xingcheng Yao  Danqi Chen
35 | Princeton University  Tsinghua University
36 |
37 |
38 |
39 | SimCSE is a novel framework for contrastive learning of sentence embeddings. This demo shows how our pre-trained sentence embeddings can be directly applied to sentence retrieval tasks. You can type any natural language sentences and click the search button to see which sentences in the example database are semantically similar to the provided sentence. Here are some details about this demo: 40 |
41 |
42 | 47 |
48 |
49 | 50 |
51 | 52 | 53 |
54 |
55 | 58 | 60 |
61 | 62 | 65 |
66 | 69 |
70 |
71 | 72 |
73 |
74 |
75 |
Top K: 5
76 | 77 |
78 |
79 |
Threshold: 0.6
80 | 81 |
82 |
83 | 84 |
85 | 86 | 87 |
88 | 91 |
92 | 93 |
94 | 95 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /SentEval/senteval/probing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | probing tasks 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class PROBINGEval(object): 24 | def __init__(self, task, task_path, seed=1111): 25 | self.seed = seed 26 | self.task = task 27 | logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper()) 28 | self.task_data = {'train': {'X': [], 'y': []}, 29 | 'dev': {'X': [], 'y': []}, 30 | 'test': {'X': [], 'y': []}} 31 | self.loadFile(task_path) 32 | logging.info('Loaded %s train - %s dev - %s test for %s' % 33 | (len(self.task_data['train']['y']), len(self.task_data['dev']['y']), 34 | len(self.task_data['test']['y']), self.task)) 35 | 36 | def do_prepare(self, params, prepare): 37 | samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \ 38 | self.task_data['test']['X'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'} 43 | with io.open(fpath, 'r', encoding='utf-8') as f: 44 | for line in f: 45 | line = line.rstrip().split('\t') 46 | self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split()) 47 | self.task_data[self.tok2split[line[0]]]['y'].append(line[1]) 48 | 49 | labels = sorted(np.unique(self.task_data['train']['y'])) 50 | self.tok2label = dict(zip(labels, range(len(labels)))) 51 | self.nclasses = len(self.tok2label) 52 | 53 | for split in self.task_data: 54 | for i, y in enumerate(self.task_data[split]['y']): 55 | self.task_data[split]['y'][i] = self.tok2label[y] 56 | 57 | def run(self, params, batcher): 58 | task_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | logging.info('Computing embeddings for train/dev/test') 61 | for key in self.task_data: 62 | # Sort to reduce padding 63 | sorted_data = sorted(zip(self.task_data[key]['X'], 64 | self.task_data[key]['y']), 65 | key=lambda z: (len(z[0]), z[1])) 66 | self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data)) 67 | 68 | task_embed[key]['X'] = [] 69 | for ii in range(0, len(self.task_data[key]['y']), bsize): 70 | batch = self.task_data[key]['X'][ii:ii + bsize] 71 | embeddings = batcher(params, batch) 72 | task_embed[key]['X'].append(embeddings) 73 | task_embed[key]['X'] = np.vstack(task_embed[key]['X']) 74 | task_embed[key]['y'] = np.array(self.task_data[key]['y']) 75 | logging.info('Computed embeddings') 76 | 77 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 78 | 'usepytorch': params.usepytorch, 79 | 'classifier': params.classifier} 80 | 81 | if self.task == "WordContent" and params.classifier['nhid'] > 0: 82 | config_classifier = copy.deepcopy(config_classifier) 83 | config_classifier['classifier']['nhid'] = 0 84 | print(params.classifier['nhid']) 85 | 86 | clf = SplitClassifier(X={'train': task_embed['train']['X'], 87 | 'valid': task_embed['dev']['X'], 88 | 'test': task_embed['test']['X']}, 89 | y={'train': task_embed['train']['y'], 90 | 'valid': task_embed['dev']['y'], 91 | 'test': task_embed['test']['y']}, 92 | config=config_classifier) 93 | 94 | devacc, testacc = clf.run() 95 | logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper())) 96 | 97 | return {'devacc': devacc, 'acc': testacc, 98 | 'ndev': len(task_embed['dev']['X']), 99 | 'ntest': len(task_embed['test']['X'])} 100 | 101 | """ 102 | Surface Information 103 | """ 104 | class LengthEval(PROBINGEval): 105 | def __init__(self, task_path, seed=1111): 106 | task_path = os.path.join(task_path, 'sentence_length.txt') 107 | # labels: bins 108 | PROBINGEval.__init__(self, 'Length', task_path, seed) 109 | 110 | class WordContentEval(PROBINGEval): 111 | def __init__(self, task_path, seed=1111): 112 | task_path = os.path.join(task_path, 'word_content.txt') 113 | # labels: 200 target words 114 | PROBINGEval.__init__(self, 'WordContent', task_path, seed) 115 | 116 | """ 117 | Latent Structural Information 118 | """ 119 | class DepthEval(PROBINGEval): 120 | def __init__(self, task_path, seed=1111): 121 | task_path = os.path.join(task_path, 'tree_depth.txt') 122 | # labels: bins 123 | PROBINGEval.__init__(self, 'Depth', task_path, seed) 124 | 125 | class TopConstituentsEval(PROBINGEval): 126 | def __init__(self, task_path, seed=1111): 127 | task_path = os.path.join(task_path, 'top_constituents.txt') 128 | # labels: 'PP_NP_VP_.' .. (20 classes) 129 | PROBINGEval.__init__(self, 'TopConstituents', task_path, seed) 130 | 131 | class BigramShiftEval(PROBINGEval): 132 | def __init__(self, task_path, seed=1111): 133 | task_path = os.path.join(task_path, 'bigram_shift.txt') 134 | # labels: 0 or 1 135 | PROBINGEval.__init__(self, 'BigramShift', task_path, seed) 136 | 137 | # TODO: Voice? 138 | 139 | """ 140 | Latent Semantic Information 141 | """ 142 | 143 | class TenseEval(PROBINGEval): 144 | def __init__(self, task_path, seed=1111): 145 | task_path = os.path.join(task_path, 'past_present.txt') 146 | # labels: 'PRES', 'PAST' 147 | PROBINGEval.__init__(self, 'Tense', task_path, seed) 148 | 149 | class SubjNumberEval(PROBINGEval): 150 | def __init__(self, task_path, seed=1111): 151 | task_path = os.path.join(task_path, 'subj_number.txt') 152 | # labels: 'NN', 'NNS' 153 | PROBINGEval.__init__(self, 'SubjNumber', task_path, seed) 154 | 155 | class ObjNumberEval(PROBINGEval): 156 | def __init__(self, task_path, seed=1111): 157 | task_path = os.path.join(task_path, 'obj_number.txt') 158 | # labels: 'NN', 'NNS' 159 | PROBINGEval.__init__(self, 'ObjNumber', task_path, seed) 160 | 161 | class OddManOutEval(PROBINGEval): 162 | def __init__(self, task_path, seed=1111): 163 | task_path = os.path.join(task_path, 'odd_man_out.txt') 164 | # labels: 'O', 'C' 165 | PROBINGEval.__init__(self, 'OddManOut', task_path, seed) 166 | 167 | class CoordinationInversionEval(PROBINGEval): 168 | def __init__(self, task_path, seed=1111): 169 | task_path = os.path.join(task_path, 'coordination_inversion.txt') 170 | # labels: 'O', 'I' 171 | PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed) 172 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Pytorch Classifier class in the style of scikit-learn 10 | Classifiers include Logistic Regression and MLP 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import numpy as np 16 | import copy 17 | from senteval import utils 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | 23 | 24 | class PyTorchClassifier(object): 25 | def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111, 26 | cudaEfficient=False): 27 | # fix seed 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | 32 | self.inputdim = inputdim 33 | self.nclasses = nclasses 34 | self.l2reg = l2reg 35 | self.batch_size = batch_size 36 | self.cudaEfficient = cudaEfficient 37 | 38 | def prepare_split(self, X, y, validation_data=None, validation_split=None): 39 | # Preparing validation data 40 | assert validation_split or validation_data 41 | if validation_data is not None: 42 | trainX, trainy = X, y 43 | devX, devy = validation_data 44 | else: 45 | permutation = np.random.permutation(len(X)) 46 | trainidx = permutation[int(validation_split * len(X)):] 47 | devidx = permutation[0:int(validation_split * len(X))] 48 | trainX, trainy = X[trainidx], y[trainidx] 49 | devX, devy = X[devidx], y[devidx] 50 | 51 | device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda') 52 | 53 | trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32) 54 | trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64) 55 | devX = torch.from_numpy(devX).to(device, dtype=torch.float32) 56 | devy = torch.from_numpy(devy).to(device, dtype=torch.int64) 57 | 58 | return trainX, trainy, devX, devy 59 | 60 | def fit(self, X, y, validation_data=None, validation_split=None, 61 | early_stop=True): 62 | self.nepoch = 0 63 | bestaccuracy = -1 64 | stop_train = False 65 | early_stop_count = 0 66 | 67 | # Preparing validation data 68 | trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data, 69 | validation_split) 70 | 71 | # Training 72 | while not stop_train and self.nepoch <= self.max_epoch: 73 | self.trainepoch(trainX, trainy, epoch_size=self.epoch_size) 74 | accuracy = self.score(devX, devy) 75 | if accuracy > bestaccuracy: 76 | bestaccuracy = accuracy 77 | bestmodel = copy.deepcopy(self.model) 78 | elif early_stop: 79 | if early_stop_count >= self.tenacity: 80 | stop_train = True 81 | early_stop_count += 1 82 | self.model = bestmodel 83 | return bestaccuracy 84 | 85 | def trainepoch(self, X, y, epoch_size=1): 86 | self.model.train() 87 | for _ in range(self.nepoch, self.nepoch + epoch_size): 88 | permutation = np.random.permutation(len(X)) 89 | all_costs = [] 90 | for i in range(0, len(X), self.batch_size): 91 | # forward 92 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device) 93 | 94 | Xbatch = X[idx] 95 | ybatch = y[idx] 96 | 97 | if self.cudaEfficient: 98 | Xbatch = Xbatch.cuda() 99 | ybatch = ybatch.cuda() 100 | output = self.model(Xbatch) 101 | # loss 102 | loss = self.loss_fn(output, ybatch) 103 | all_costs.append(loss.data.item()) 104 | # backward 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | # Update parameters 108 | self.optimizer.step() 109 | self.nepoch += epoch_size 110 | 111 | def score(self, devX, devy): 112 | self.model.eval() 113 | correct = 0 114 | if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient: 115 | devX = torch.FloatTensor(devX).cuda() 116 | devy = torch.LongTensor(devy).cuda() 117 | with torch.no_grad(): 118 | for i in range(0, len(devX), self.batch_size): 119 | Xbatch = devX[i:i + self.batch_size] 120 | ybatch = devy[i:i + self.batch_size] 121 | if self.cudaEfficient: 122 | Xbatch = Xbatch.cuda() 123 | ybatch = ybatch.cuda() 124 | output = self.model(Xbatch) 125 | pred = output.data.max(1)[1] 126 | correct += pred.long().eq(ybatch.data.long()).sum().item() 127 | accuracy = 1.0 * correct / len(devX) 128 | return accuracy 129 | 130 | def predict(self, devX): 131 | self.model.eval() 132 | if not isinstance(devX, torch.cuda.FloatTensor): 133 | devX = torch.FloatTensor(devX).cuda() 134 | yhat = np.array([]) 135 | with torch.no_grad(): 136 | for i in range(0, len(devX), self.batch_size): 137 | Xbatch = devX[i:i + self.batch_size] 138 | output = self.model(Xbatch) 139 | yhat = np.append(yhat, 140 | output.data.max(1)[1].cpu().numpy()) 141 | yhat = np.vstack(yhat) 142 | return yhat 143 | 144 | def predict_proba(self, devX): 145 | self.model.eval() 146 | probas = [] 147 | with torch.no_grad(): 148 | for i in range(0, len(devX), self.batch_size): 149 | Xbatch = devX[i:i + self.batch_size] 150 | vals = F.softmax(self.model(Xbatch).data.cpu().numpy()) 151 | if not probas: 152 | probas = vals 153 | else: 154 | probas = np.concatenate(probas, vals, axis=0) 155 | return probas 156 | 157 | 158 | """ 159 | MLP with Pytorch (nhid=0 --> Logistic Regression) 160 | """ 161 | 162 | class MLP(PyTorchClassifier): 163 | def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64, 164 | seed=1111, cudaEfficient=False): 165 | super(self.__class__, self).__init__(inputdim, nclasses, l2reg, 166 | batch_size, seed, cudaEfficient) 167 | """ 168 | PARAMETERS: 169 | -nhid: number of hidden units (0: Logistic Regression) 170 | -optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 171 | -tenacity: how many times dev acc does not increase before stopping 172 | -epoch_size: each epoch corresponds to epoch_size pass on the train set 173 | -max_epoch: max number of epoches 174 | -dropout: dropout for MLP 175 | """ 176 | 177 | self.nhid = 0 if "nhid" not in params else params["nhid"] 178 | self.optim = "adam" if "optim" not in params else params["optim"] 179 | self.tenacity = 5 if "tenacity" not in params else params["tenacity"] 180 | self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"] 181 | self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"] 182 | self.dropout = 0. if "dropout" not in params else params["dropout"] 183 | self.batch_size = 64 if "batch_size" not in params else params["batch_size"] 184 | 185 | if params["nhid"] == 0: 186 | self.model = nn.Sequential( 187 | nn.Linear(self.inputdim, self.nclasses), 188 | ).cuda() 189 | else: 190 | self.model = nn.Sequential( 191 | nn.Linear(self.inputdim, params["nhid"]), 192 | nn.Dropout(p=self.dropout), 193 | nn.Sigmoid(), 194 | nn.Linear(params["nhid"], self.nclasses), 195 | ).cuda() 196 | 197 | self.loss_fn = nn.CrossEntropyLoss().cuda() 198 | self.loss_fn.size_average = False 199 | 200 | optim_fn, optim_params = utils.get_optimizer(self.optim) 201 | self.optimizer = optim_fn(self.model.parameters(), **optim_params) 202 | self.optimizer.param_groups[0]['weight_decay'] = self.l2reg 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Alleviating Over-smoothing for Unsupervised Sentence Representation 2 | This is the project of our ACL 2023 paper: [Alleviating Over-smoothing for Unsupervised Sentence Representation](https://arxiv.org/pdf/2305.06154). 3 | 4 | Our work mainly based on [SimCSE project](https://github.com/princeton-nlp/SimCSE), thanks to SimCSE! 5 | 6 | ## Quick Links 7 | 8 | - [Overview](#overview) 9 | - [Getting Started](##setup) 10 | - [Train SSCL](#training) 11 | - [Requirements](#requirements) 12 | - [Evaluation](#evaluation) 13 | - [Training](#training) 14 | - [Citation](#citation) 15 | 16 | ## Overview 17 | 18 | we present a new training paradigm based on contrastive learning: Simple contrastive method named Self-Contrastive Learning (SSCL), which can significantly improve the performance of learned sentence representations while alleviating the over-smoothing issue. Simply Said, we utilize hidden representations from intermediate PLMs layers as negative samples which the final sentence representations should be away from. Generally, our SSCL has several advantages: (1) It is fairly straightforward and does not require complex data augmentation techniques; (2) It can be seen as a contrastive framework that focuses on mining negatives effectively, and can be easily extended into different sentence encoders that aim for building positive pairs; (3) It can further be viewed as a plug-and-play framework for enhancing sentence representations. 19 | ![](figure/overview.png) 20 | 21 | ## Setup 22 | First, install PyTorch by following the instructions from the [official website](https://pytorch.org/get-started/previous-versions/). To faithfully reproduce our results, please use the correct 1.9.1 version corresponding to your platforms/CUDA versions. Install PyTorch by the following command, 23 | 24 | ``` 25 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 26 | ``` 27 | Then run the following script to install the remaining dependencies, 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Evaluation 34 | Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation. 35 | 36 | Before evaluation, please download the evaluation datasets by running 37 | ```bash 38 | cd SentEval/data/downstream/ 39 | bash download_dataset.sh 40 | ``` 41 | 42 | Then come back to the root directory, you can evaluate any `transformers`-based pre-trained models using our evaluation code. For example, 43 | ```bash 44 | python evaluation.py \ 45 | --model_name_or_path sscl-bert-base-uncased \ 46 | --pooler cls \ 47 | --task_set sts \ 48 | --mode test 49 | ``` 50 | 51 | Arguments for the evaluation script are as follows, 52 | 53 | * `--model_name_or_path`: The name or path of a `transformers`-based pre-trained checkpoint. You can directly use the models in the above table. 54 | * `--pooler`: Pooling method. Now we support 55 | * `cls` (default): Use the representation of `[CLS]` token. A linear+activation layer is applied after the representation (it's in the standard BERT implementation). 56 | * `cls_before_pooler`: Use the representation of `[CLS]` token without the extra linear+activation. 57 | * `avg`: Average embeddings of the last layer. If you use checkpoints of SBERT/SRoBERTa ([paper](https://arxiv.org/abs/1908.10084)), you should use this option. 58 | * `avg_top2`: Average embeddings of the last two layers. 59 | * `avg_first_last`: Average embeddings of the first and last layers. If you use vanilla BERT or RoBERTa, this works the best. 60 | * `--mode`: Evaluation mode 61 | * `test` (default): The default test mode. To faithfully reproduce our results, you should use this option. 62 | * `dev`: Report the development set results. Note that in STS tasks, only `STS-B` and `SICK-R` have development sets, so we only report their numbers. It also takes a fast mode for transfer tasks, so the running time is much shorter than the `test` mode (though numbers are slightly lower). 63 | * `fasttest`: It is the same as `test`, but with a fast mode so the running time is much shorter, but the reported numbers may be lower (only for transfer tasks). 64 | * `--task_set`: What set of tasks to evaluate on (if set, it will override `--tasks`) 65 | * `sts` (default): Evaluate on STS tasks, including `STS 12~16`, `STS-B` and `SICK-R`. This is the most commonly-used set of tasks to evaluate the quality of sentence embeddings. 66 | * `transfer`: Evaluate on transfer tasks. 67 | * `full`: Evaluate on both STS and transfer tasks. 68 | * `na`: Manually set tasks by `--tasks`. 69 | * `--tasks`: Specify which dataset(s) to evaluate on. Will be overridden if `--task_set` is not `na`. See the code for a full list of tasks. 70 | 71 | ### Training 72 | 73 | **Data** 74 | 75 | You can run `data/download_wiki.sh` and `data/download_nli.sh` to download the two datasets. 76 | 77 | **Training scripts** 78 | In `run_unsup_example.sh`, we provide a single-GPU (or CPU) example for the unsupervised version. We explain the arguments in following: 79 | * `--train_file`: Training file path. We support "txt" files (one line for one sentence) and "csv" files (2-column: pair data with no hard negative; 3-column: pair data with one corresponding hard negative instance). You can use our provided Wikipedia or NLI data, or you can use your own data with the same format. 80 | * `--model_name_or_path`: Pre-trained checkpoints to start with. For now we support BERT-based models (`bert-base-uncased`, `bert-large-uncased`, etc.). 81 | * `--temp`: Temperature for the contrastive loss. 82 | * `--pooler_type`: Pooling method. It's the same as the `--pooler_type` in the [evaluation part](#evaluation). 83 | * `--mlp_only_train`: You should use this argument when training SSCL models. 84 | * `--hard_negative_weight`: If using hard negatives (i.e., there are 3 columns in the training file), this is the logarithm of the weight. For example, if the weight is 1, then this argument should be set as 0 (default value). 85 | * `--do_mlm`: Whether to use the MLM auxiliary objective. If True: 86 | * `--mlm_weight`: Weight for the MLM objective. 87 | * `--mlm_probability`: Masking rate for the MLM objective. 88 | * `--do_neg`: Whether to use negatives in SSCL. 89 | * `--hard_negative_layers`: How many previous layers to construct negative layers. 90 | 91 | All the other arguments are standard Huggingface's `transformers` training arguments. Some of the often-used arguments are: `--output_dir`, `--learning_rate`, `--per_device_train_batch_size`. In our example scripts, we also set to evaluate the model on the STS-B development set (need to download the dataset following the [evaluation](#evaluation) section) and save the best checkpoint. 92 | 93 | For results in the paper, we use Nvidia A100 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance. 94 | 95 | 96 | 97 | **Convert models** 98 | 99 | Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python sscl_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. After that, you can evaluate it by our [evaluation](#evaluation) code or directly use it [out of the box](#use-our-models-out-of-the-box). 100 | 101 | ## Citation 102 | 103 | Please cite our paper if you use SSCL in your work: 104 | 105 | ```bibtex 106 | @inproceedings{chen-etal-2023-alleviating, 107 | title = "Alleviating Over-smoothing for Unsupervised Sentence Representation", 108 | author = "Chen, Nuo and 109 | Shou, Linjun and 110 | Pei, Jian and 111 | Gong, Ming and 112 | Cao, Bowen and 113 | Chang, Jianhui and 114 | Li, Jia and 115 | Jiang, Daxin", 116 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 117 | month = jul, 118 | year = "2023", 119 | address = "Toronto, Canada", 120 | publisher = "Association for Computational Linguistics", 121 | url = "https://aclanthology.org/2023.acl-long.197", 122 | pages = "3552--3566", 123 | abstract = "Currently, learning better unsupervised sentence representations is the pursuit of many natural language processing communities. Lots of approaches based on pre-trained language models (PLMs) and contrastive learning have achieved promising results on this task. Experimentally, we observe that the over-smoothing problem reduces the capacity of these powerful PLMs, leading to sub-optimal sentence representations. In this paper, we present a Simple method named Self-Contrastive Learning (SSCL) to alleviate this issue, which samples negatives from PLMs intermediate layers, improving the quality of the sentence representation. Our proposed method is quite simple and can be easily extended to various state-of-the-art models for performance boosting, which can be seen as a plug-and-play contrastive framework for learning unsupervised sentence representation. Extensive results prove that SSCL brings the superior performance improvements of different strong baselines (e.g., BERT and SimCSE) on Semantic Textual Similarity and Transfer datasets", 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io, os 3 | import numpy as np 4 | import logging 5 | import argparse 6 | from prettytable import PrettyTable 7 | import torch 8 | import transformers 9 | from transformers import AutoModel, AutoTokenizer 10 | import xlrd 11 | from openpyxl import Workbook 12 | 13 | # Set up logger 14 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 15 | 16 | # Set PATHs 17 | PATH_TO_SENTEVAL = './SentEval' 18 | PATH_TO_DATA = './SentEval/data' 19 | workbook = Workbook() 20 | 21 | # Import SentEval 22 | sys.path.insert(0, PATH_TO_SENTEVAL) 23 | import senteval 24 | 25 | def print_table(task_names, scores): 26 | tb = PrettyTable() 27 | tb.field_names = task_names 28 | tb.add_row(scores) 29 | print(tb) 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--model_name_or_path", type=str, 34 | help="Transformers' model name or path") 35 | parser.add_argument("--pooler", type=str, 36 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 37 | default='cls', 38 | help="Which pooler to use") 39 | parser.add_argument("--mode", type=str, 40 | choices=['dev', 'test', 'fasttest'], 41 | default='test', 42 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 43 | parser.add_argument("--task_set", type=str, 44 | choices=['sts', 'transfer', 'full', 'na'], 45 | default='sts', 46 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 47 | parser.add_argument("--tasks", type=str, nargs='+', 48 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 49 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC', 50 | 'SICKRelatedness', 'STSBenchmark'], 51 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden") 52 | 53 | args = parser.parse_args() 54 | 55 | # Load transformers' model checkpoint 56 | model = AutoModel.from_pretrained(args.model_name_or_path) 57 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 58 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 59 | model = model.to(device) 60 | 61 | # Set up the tasks 62 | if args.task_set == 'sts': 63 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 64 | elif args.task_set == 'transfer': 65 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 66 | elif args.task_set == 'full': 67 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 68 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 69 | 70 | # Set params for SentEval 71 | if args.mode == 'dev' or args.mode == 'fasttest': 72 | # Fast mode 73 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 74 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 75 | 'tenacity': 3, 'epoch_size': 2} 76 | elif args.mode == 'test': 77 | # Full mode 78 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 79 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 80 | 'tenacity': 5, 'epoch_size': 4} 81 | else: 82 | raise NotImplementedError 83 | 84 | # SentEval prepare and batcher 85 | def prepare(params, samples): 86 | return 87 | 88 | def batcher(params, batch, max_length=None): 89 | # Handle rare token encoding issues in the dataset 90 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 91 | batch = [[word.decode('utf-8') for word in s] for s in batch] 92 | 93 | sentences = [' '.join(s) for s in batch] 94 | 95 | # Tokenization 96 | if max_length is not None: 97 | batch = tokenizer.batch_encode_plus( 98 | sentences, 99 | return_tensors='pt', 100 | padding=True, 101 | max_length=max_length, 102 | truncation=True 103 | ) 104 | else: 105 | batch = tokenizer.batch_encode_plus( 106 | sentences, 107 | return_tensors='pt', 108 | padding=True, 109 | ) 110 | 111 | # Move to the correct device 112 | for k in batch: 113 | batch[k] = batch[k].to(device) 114 | 115 | # Get raw embeddings 116 | with torch.no_grad(): 117 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 118 | last_hidden = outputs.last_hidden_state 119 | pooler_output = outputs.pooler_output 120 | hidden_states = outputs.hidden_states 121 | 122 | # Apply different poolers 123 | if args.pooler == 'cls': 124 | # There is a linear+activation layer after CLS representation 125 | return pooler_output.cpu() 126 | elif args.pooler == 'cls_before_pooler': 127 | return last_hidden[:, 0].cpu() 128 | elif args.pooler == "avg": 129 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu() 130 | elif args.pooler == "avg_first_last": 131 | first_hidden = hidden_states[0] 132 | last_hidden = hidden_states[-1] 133 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 134 | return pooled_result.cpu() 135 | elif args.pooler == "avg_top2": 136 | second_last_hidden = hidden_states[-2] 137 | last_hidden = hidden_states[-1] 138 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 139 | return pooled_result.cpu() 140 | else: 141 | raise NotImplementedError 142 | 143 | results = {} 144 | 145 | for task in args.tasks: 146 | se = senteval.engine.SE(params, batcher, prepare) 147 | result = se.eval(task) 148 | results[task] = result 149 | 150 | # Print evaluation results 151 | if args.mode == 'dev': 152 | print("------ %s ------" % (args.mode)) 153 | 154 | task_names = [] 155 | scores = [] 156 | for task in ['STSBenchmark', 'SICKRelatedness']: 157 | task_names.append(task) 158 | if task in results: 159 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100)) 160 | else: 161 | scores.append("0.00") 162 | print_table(task_names, scores) 163 | 164 | task_names = [] 165 | scores = [] 166 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 167 | task_names.append(task) 168 | if task in results: 169 | scores.append("%.2f" % (results[task]['devacc'])) 170 | else: 171 | scores.append("0.00") 172 | task_names.append("Avg.") 173 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 174 | print_table(task_names, scores) 175 | 176 | elif args.mode == 'test' or args.mode == 'fasttest': 177 | print("------ %s ------" % (args.mode)) 178 | save_file = os.path.join(args.model_name_or_path, 'results.xlsx') 179 | task_names = [] 180 | scores = [] 181 | sheet_name = '' 182 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 183 | task_names.append(task) 184 | if task in results: 185 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 186 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 187 | else: 188 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 189 | else: 190 | scores.append("0.00") 191 | sheet_name = 'STS' 192 | task_names.append("Avg.") 193 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 194 | # print(task_names) 195 | # print(scores) 196 | worksheet = workbook.create_sheet(sheet_name) 197 | worksheet.append(task_names) 198 | worksheet.append(scores) 199 | print_table(task_names, scores) 200 | 201 | task_names = [] 202 | scores = [] 203 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 204 | task_names.append(task) 205 | if task in results: 206 | scores.append("%.2f" % (results[task]['acc'])) 207 | else: 208 | scores.append("0.00") 209 | sheet_name = 'Transfer' 210 | task_names.append("Avg.") 211 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 212 | print_table(task_names, scores) 213 | worksheet = workbook.create_sheet(sheet_name) 214 | worksheet.append(task_names) 215 | worksheet.append(scores) 216 | workbook.save(filename=save_file) 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /SentEval/senteval/sick.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SICK Relatedness and Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import io 15 | import logging 16 | import numpy as np 17 | 18 | from sklearn.metrics import mean_squared_error 19 | from scipy.stats import pearsonr, spearmanr 20 | 21 | from senteval.tools.relatedness import RelatednessPytorch 22 | from senteval.tools.validation import SplitClassifier 23 | 24 | class SICKEval(object): 25 | def __init__(self, task_path, seed=1111): 26 | logging.debug('***** Transfer task : SICK-Relatedness*****\n\n') 27 | self.seed = seed 28 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 29 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 30 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 31 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | samples = self.sick_data['train']['X_A'] + \ 35 | self.sick_data['train']['X_B'] + \ 36 | self.sick_data['dev']['X_A'] + \ 37 | self.sick_data['dev']['X_B'] + \ 38 | self.sick_data['test']['X_A'] + self.sick_data['test']['X_B'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | skipFirstLine = True 43 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if skipFirstLine: 47 | skipFirstLine = False 48 | else: 49 | text = line.strip().split('\t') 50 | sick_data['X_A'].append(text[1].split()) 51 | sick_data['X_B'].append(text[2].split()) 52 | sick_data['y'].append(text[3]) 53 | 54 | sick_data['y'] = [float(s) for s in sick_data['y']] 55 | return sick_data 56 | 57 | def run(self, params, batcher): 58 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sick_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 65 | self.sick_data[key]['X_B'], 66 | self.sick_data[key]['y']), 67 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 68 | 69 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 70 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 71 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 72 | 73 | for txt_type in ['X_A', 'X_B']: 74 | sick_embed[key][txt_type] = [] 75 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 76 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 77 | embeddings = batcher(params, batch) 78 | sick_embed[key][txt_type].append(embeddings) 79 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 80 | sick_embed[key]['y'] = np.array(self.sick_data[key]['y']) 81 | logging.info('Computed {0} embeddings'.format(key)) 82 | 83 | # Train 84 | trainA = sick_embed['train']['X_A'] 85 | trainB = sick_embed['train']['X_B'] 86 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 87 | trainY = self.encode_labels(self.sick_data['train']['y']) 88 | 89 | # Dev 90 | devA = sick_embed['dev']['X_A'] 91 | devB = sick_embed['dev']['X_B'] 92 | devF = np.c_[np.abs(devA - devB), devA * devB] 93 | devY = self.encode_labels(self.sick_data['dev']['y']) 94 | 95 | # Test 96 | testA = sick_embed['test']['X_A'] 97 | testB = sick_embed['test']['X_B'] 98 | testF = np.c_[np.abs(testA - testB), testA * testB] 99 | testY = self.encode_labels(self.sick_data['test']['y']) 100 | 101 | config = {'seed': self.seed, 'nclasses': 5} 102 | clf = RelatednessPytorch(train={'X': trainF, 'y': trainY}, 103 | valid={'X': devF, 'y': devY}, 104 | test={'X': testF, 'y': testY}, 105 | devscores=self.sick_data['dev']['y'], 106 | config=config) 107 | 108 | devspr, yhat = clf.run() 109 | 110 | pr = pearsonr(yhat, self.sick_data['test']['y'])[0] 111 | sr = spearmanr(yhat, self.sick_data['test']['y'])[0] 112 | pr = 0 if pr != pr else pr 113 | sr = 0 if sr != sr else sr 114 | se = mean_squared_error(yhat, self.sick_data['test']['y']) 115 | logging.debug('Dev : Spearman {0}'.format(devspr)) 116 | logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \ 117 | for SICK Relatedness\n'.format(pr, sr, se)) 118 | 119 | return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se, 120 | 'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)} 121 | 122 | def encode_labels(self, labels, nclass=5): 123 | """ 124 | Label encoding from Tree LSTM paper (Tai, Socher, Manning) 125 | """ 126 | Y = np.zeros((len(labels), nclass)).astype('float32') 127 | for j, y in enumerate(labels): 128 | for i in range(nclass): 129 | if i+1 == np.floor(y) + 1: 130 | Y[j, i] = y - np.floor(y) 131 | if i+1 == np.floor(y): 132 | Y[j, i] = np.floor(y) - y + 1 133 | return Y 134 | 135 | 136 | class SICKEntailmentEval(SICKEval): 137 | def __init__(self, task_path, seed=1111): 138 | logging.debug('***** Transfer task : SICK-Entailment*****\n\n') 139 | self.seed = seed 140 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 141 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 142 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 143 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 144 | 145 | def loadFile(self, fpath): 146 | label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2} 147 | skipFirstLine = True 148 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 149 | with io.open(fpath, 'r', encoding='utf-8') as f: 150 | for line in f: 151 | if skipFirstLine: 152 | skipFirstLine = False 153 | else: 154 | text = line.strip().split('\t') 155 | sick_data['X_A'].append(text[1].split()) 156 | sick_data['X_B'].append(text[2].split()) 157 | sick_data['y'].append(text[4]) 158 | sick_data['y'] = [label2id[s] for s in sick_data['y']] 159 | return sick_data 160 | 161 | def run(self, params, batcher): 162 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 163 | bsize = params.batch_size 164 | 165 | for key in self.sick_data: 166 | logging.info('Computing embedding for {0}'.format(key)) 167 | # Sort to reduce padding 168 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 169 | self.sick_data[key]['X_B'], 170 | self.sick_data[key]['y']), 171 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 172 | 173 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 174 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 175 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 176 | 177 | for txt_type in ['X_A', 'X_B']: 178 | sick_embed[key][txt_type] = [] 179 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 180 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 181 | embeddings = batcher(params, batch) 182 | sick_embed[key][txt_type].append(embeddings) 183 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 184 | logging.info('Computed {0} embeddings'.format(key)) 185 | 186 | # Train 187 | trainA = sick_embed['train']['X_A'] 188 | trainB = sick_embed['train']['X_B'] 189 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 190 | trainY = np.array(self.sick_data['train']['y']) 191 | 192 | # Dev 193 | devA = sick_embed['dev']['X_A'] 194 | devB = sick_embed['dev']['X_B'] 195 | devF = np.c_[np.abs(devA - devB), devA * devB] 196 | devY = np.array(self.sick_data['dev']['y']) 197 | 198 | # Test 199 | testA = sick_embed['test']['X_A'] 200 | testB = sick_embed['test']['X_B'] 201 | testF = np.c_[np.abs(testA - testB), testA * testB] 202 | testY = np.array(self.sick_data['test']['y']) 203 | 204 | config = {'nclasses': 3, 'seed': self.seed, 205 | 'usepytorch': params.usepytorch, 206 | 'classifier': params.classifier, 207 | 'nhid': params.nhid} 208 | clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF}, 209 | y={'train': trainY, 'valid': devY, 'test': testY}, 210 | config=config) 211 | 212 | devacc, testacc = clf.run() 213 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 214 | SICK entailment\n'.format(devacc, testacc)) 215 | return {'devacc': devacc, 'acc': testacc, 216 | 'ndev': len(devA), 'ntest': len(testA)} 217 | -------------------------------------------------------------------------------- /SentEval/senteval/sts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and 10 | STS-benchmark (supervised) tasks 11 | ''' 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import os 16 | import io 17 | import numpy as np 18 | import logging 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | 22 | from senteval.utils import cosine 23 | from senteval.sick import SICKEval 24 | 25 | 26 | class STSEval(object): 27 | def loadFile(self, fpath): 28 | self.data = {} 29 | self.samples = [] 30 | 31 | for dataset in self.datasets: 32 | sent1, sent2 = zip(*[l.split("\t") for l in 33 | io.open(fpath + '/STS.input.%s.txt' % dataset, 34 | encoding='utf8').read().splitlines()]) 35 | raw_scores = np.array([x for x in 36 | io.open(fpath + '/STS.gs.%s.txt' % dataset, 37 | encoding='utf8') 38 | .read().splitlines()]) 39 | not_empty_idx = raw_scores != '' 40 | 41 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]] 42 | sent1 = np.array([s.split() for s in sent1])[not_empty_idx] 43 | sent2 = np.array([s.split() for s in sent2])[not_empty_idx] 44 | # sort data by length to minimize padding in batcher 45 | sorted_data = sorted(zip(sent1, sent2, gs_scores), 46 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 47 | sent1, sent2, gs_scores = map(list, zip(*sorted_data)) 48 | 49 | self.data[dataset] = (sent1, sent2, gs_scores) 50 | self.samples += sent1 + sent2 51 | 52 | def do_prepare(self, params, prepare): 53 | if 'similarity' in params: 54 | self.similarity = params.similarity 55 | else: # Default similarity is cosine 56 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2))) 57 | return prepare(params, self.samples) 58 | 59 | def run(self, params, batcher): 60 | results = {} 61 | all_sys_scores = [] 62 | all_gs_scores = [] 63 | for dataset in self.datasets: 64 | sys_scores = [] 65 | input1, input2, gs_scores = self.data[dataset] 66 | for ii in range(0, len(gs_scores), params.batch_size): 67 | batch1 = input1[ii:ii + params.batch_size] 68 | batch2 = input2[ii:ii + params.batch_size] 69 | 70 | # we assume get_batch already throws out the faulty ones 71 | if len(batch1) == len(batch2) and len(batch1) > 0: 72 | enc1 = batcher(params, batch1) 73 | enc2 = batcher(params, batch2) 74 | 75 | for kk in range(enc2.shape[0]): 76 | sys_score = self.similarity(enc1[kk], enc2[kk]) 77 | sys_scores.append(sys_score) 78 | all_sys_scores.extend(sys_scores) 79 | all_gs_scores.extend(gs_scores) 80 | 81 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores), 82 | 'spearman': spearmanr(sys_scores, gs_scores), 83 | 'nsamples': len(sys_scores)} 84 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 85 | (dataset, results[dataset]['pearson'][0], 86 | results[dataset]['spearman'][0])) 87 | 88 | weights = [results[dset]['nsamples'] for dset in results.keys()] 89 | list_prs = np.array([results[dset]['pearson'][0] for 90 | dset in results.keys()]) 91 | list_spr = np.array([results[dset]['spearman'][0] for 92 | dset in results.keys()]) 93 | 94 | avg_pearson = np.average(list_prs) 95 | avg_spearman = np.average(list_spr) 96 | wavg_pearson = np.average(list_prs, weights=weights) 97 | wavg_spearman = np.average(list_spr, weights=weights) 98 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 99 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 100 | results['all'] = {'pearson': {'all': all_pearson[0], 101 | 'mean': avg_pearson, 102 | 'wmean': wavg_pearson}, 103 | 'spearman': {'all': all_spearman[0], 104 | 'mean': avg_spearman, 105 | 'wmean': wavg_spearman}} 106 | logging.debug('ALL : Pearson = %.4f, \ 107 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 108 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 109 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 110 | logging.debug('ALL (average) : Pearson = %.4f, \ 111 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 112 | 113 | return results 114 | 115 | 116 | class STS12Eval(STSEval): 117 | def __init__(self, taskpath, seed=1111): 118 | logging.debug('***** Transfer task : STS12 *****\n\n') 119 | self.seed = seed 120 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl', 121 | 'surprise.OnWN', 'surprise.SMTnews'] 122 | self.loadFile(taskpath) 123 | 124 | 125 | class STS13Eval(STSEval): 126 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue 127 | def __init__(self, taskpath, seed=1111): 128 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n') 129 | self.seed = seed 130 | self.datasets = ['FNWN', 'headlines', 'OnWN'] 131 | self.loadFile(taskpath) 132 | 133 | 134 | class STS14Eval(STSEval): 135 | def __init__(self, taskpath, seed=1111): 136 | logging.debug('***** Transfer task : STS14 *****\n\n') 137 | self.seed = seed 138 | self.datasets = ['deft-forum', 'deft-news', 'headlines', 139 | 'images', 'OnWN', 'tweet-news'] 140 | self.loadFile(taskpath) 141 | 142 | 143 | class STS15Eval(STSEval): 144 | def __init__(self, taskpath, seed=1111): 145 | logging.debug('***** Transfer task : STS15 *****\n\n') 146 | self.seed = seed 147 | self.datasets = ['answers-forums', 'answers-students', 148 | 'belief', 'headlines', 'images'] 149 | self.loadFile(taskpath) 150 | 151 | 152 | class STS16Eval(STSEval): 153 | def __init__(self, taskpath, seed=1111): 154 | logging.debug('***** Transfer task : STS16 *****\n\n') 155 | self.seed = seed 156 | self.datasets = ['answer-answer', 'headlines', 'plagiarism', 157 | 'postediting', 'question-question'] 158 | self.loadFile(taskpath) 159 | 160 | 161 | class STSBenchmarkEval(STSEval): 162 | def __init__(self, task_path, seed=1111): 163 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 164 | self.seed = seed 165 | self.samples = [] 166 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 167 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 168 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 169 | self.datasets = ['train', 'dev', 'test'] 170 | self.data = {'train': train, 'dev': dev, 'test': test} 171 | 172 | def loadFile(self, fpath): 173 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 174 | with io.open(fpath, 'r', encoding='utf-8') as f: 175 | for line in f: 176 | text = line.strip().split('\t') 177 | sick_data['X_A'].append(text[5].split()) 178 | sick_data['X_B'].append(text[6].split()) 179 | sick_data['y'].append(text[4]) 180 | 181 | sick_data['y'] = [float(s) for s in sick_data['y']] 182 | self.samples += sick_data['X_A'] + sick_data["X_B"] 183 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 184 | 185 | class STSBenchmarkFinetune(SICKEval): 186 | def __init__(self, task_path, seed=1111): 187 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 188 | self.seed = seed 189 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 190 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 191 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 192 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 193 | 194 | def loadFile(self, fpath): 195 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 196 | with io.open(fpath, 'r', encoding='utf-8') as f: 197 | for line in f: 198 | text = line.strip().split('\t') 199 | sick_data['X_A'].append(text[5].split()) 200 | sick_data['X_B'].append(text[6].split()) 201 | sick_data['y'].append(text[4]) 202 | 203 | sick_data['y'] = [float(s) for s in sick_data['y']] 204 | return sick_data 205 | 206 | class SICKRelatednessEval(STSEval): 207 | def __init__(self, task_path, seed=1111): 208 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n') 209 | self.seed = seed 210 | self.samples = [] 211 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 212 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 213 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 214 | self.datasets = ['train', 'dev', 'test'] 215 | self.data = {'train': train, 'dev': dev, 'test': test} 216 | 217 | def loadFile(self, fpath): 218 | skipFirstLine = True 219 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 220 | with io.open(fpath, 'r', encoding='utf-8') as f: 221 | for line in f: 222 | if skipFirstLine: 223 | skipFirstLine = False 224 | else: 225 | text = line.strip().split('\t') 226 | sick_data['X_A'].append(text[1].split()) 227 | sick_data['X_B'].append(text[2].split()) 228 | sick_data['y'].append(text[3]) 229 | 230 | sick_data['y'] = [float(s) for s in sick_data['y']] 231 | self.samples += sick_data['X_A'] + sick_data["X_B"] 232 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 233 | -------------------------------------------------------------------------------- /SentEval/senteval/.ipynb_checkpoints/sts-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and 10 | STS-benchmark (supervised) tasks 11 | ''' 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import os 16 | import io 17 | import numpy as np 18 | import logging 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | 22 | from senteval.utils import cosine 23 | from senteval.sick import SICKEval 24 | 25 | 26 | class STSEval(object): 27 | def loadFile(self, fpath): 28 | self.data = {} 29 | self.samples = [] 30 | 31 | for dataset in self.datasets: 32 | sent1, sent2 = zip(*[l.split("\t") for l in 33 | io.open(fpath + '/STS.input.%s.txt' % dataset, 34 | encoding='utf8').read().splitlines()]) 35 | raw_scores = np.array([x for x in 36 | io.open(fpath + '/STS.gs.%s.txt' % dataset, 37 | encoding='utf8') 38 | .read().splitlines()]) 39 | not_empty_idx = raw_scores != '' 40 | 41 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]] 42 | sent1 = np.array([s.split() for s in sent1])[not_empty_idx] 43 | sent2 = np.array([s.split() for s in sent2])[not_empty_idx] 44 | # sort data by length to minimize padding in batcher 45 | sorted_data = sorted(zip(sent1, sent2, gs_scores), 46 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 47 | sent1, sent2, gs_scores = map(list, zip(*sorted_data)) 48 | 49 | self.data[dataset] = (sent1, sent2, gs_scores) 50 | self.samples += sent1 + sent2 51 | 52 | def do_prepare(self, params, prepare): 53 | if 'similarity' in params: 54 | self.similarity = params.similarity 55 | else: # Default similarity is cosine 56 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2))) 57 | return prepare(params, self.samples) 58 | 59 | def run(self, params, batcher): 60 | results = {} 61 | all_sys_scores = [] 62 | all_gs_scores = [] 63 | for dataset in self.datasets: 64 | sys_scores = [] 65 | input1, input2, gs_scores = self.data[dataset] 66 | for ii in range(0, len(gs_scores), params.batch_size): 67 | batch1 = input1[ii:ii + params.batch_size] 68 | batch2 = input2[ii:ii + params.batch_size] 69 | 70 | # we assume get_batch already throws out the faulty ones 71 | if len(batch1) == len(batch2) and len(batch1) > 0: 72 | enc1 = batcher(params, batch1) 73 | enc2 = batcher(params, batch2) 74 | 75 | for kk in range(enc2.shape[0]): 76 | sys_score = self.similarity(enc1[kk], enc2[kk]) 77 | sys_scores.append(sys_score) 78 | all_sys_scores.extend(sys_scores) 79 | all_gs_scores.extend(gs_scores) 80 | 81 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores), 82 | 'spearman': spearmanr(sys_scores, gs_scores), 83 | 'nsamples': len(sys_scores)} 84 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 85 | (dataset, results[dataset]['pearson'][0], 86 | results[dataset]['spearman'][0])) 87 | 88 | weights = [results[dset]['nsamples'] for dset in results.keys()] 89 | list_prs = np.array([results[dset]['pearson'][0] for 90 | dset in results.keys()]) 91 | list_spr = np.array([results[dset]['spearman'][0] for 92 | dset in results.keys()]) 93 | 94 | avg_pearson = np.average(list_prs) 95 | avg_spearman = np.average(list_spr) 96 | wavg_pearson = np.average(list_prs, weights=weights) 97 | wavg_spearman = np.average(list_spr, weights=weights) 98 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 99 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 100 | results['all'] = {'pearson': {'all': all_pearson[0], 101 | 'mean': avg_pearson, 102 | 'wmean': wavg_pearson}, 103 | 'spearman': {'all': all_spearman[0], 104 | 'mean': avg_spearman, 105 | 'wmean': wavg_spearman}} 106 | logging.debug('ALL : Pearson = %.4f, \ 107 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 108 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 109 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 110 | logging.debug('ALL (average) : Pearson = %.4f, \ 111 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 112 | 113 | return results 114 | 115 | 116 | class STS12Eval(STSEval): 117 | def __init__(self, taskpath, seed=1111): 118 | logging.debug('***** Transfer task : STS12 *****\n\n') 119 | self.seed = seed 120 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl', 121 | 'surprise.OnWN', 'surprise.SMTnews'] 122 | self.loadFile(taskpath) 123 | 124 | 125 | class STS13Eval(STSEval): 126 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue 127 | def __init__(self, taskpath, seed=1111): 128 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n') 129 | self.seed = seed 130 | self.datasets = ['FNWN', 'headlines', 'OnWN'] 131 | self.loadFile(taskpath) 132 | 133 | 134 | class STS14Eval(STSEval): 135 | def __init__(self, taskpath, seed=1111): 136 | logging.debug('***** Transfer task : STS14 *****\n\n') 137 | self.seed = seed 138 | self.datasets = ['deft-forum', 'deft-news', 'headlines', 139 | 'images', 'OnWN', 'tweet-news'] 140 | self.loadFile(taskpath) 141 | 142 | 143 | class STS15Eval(STSEval): 144 | def __init__(self, taskpath, seed=1111): 145 | logging.debug('***** Transfer task : STS15 *****\n\n') 146 | self.seed = seed 147 | self.datasets = ['answers-forums', 'answers-students', 148 | 'belief', 'headlines', 'images'] 149 | self.loadFile(taskpath) 150 | 151 | 152 | class STS16Eval(STSEval): 153 | def __init__(self, taskpath, seed=1111): 154 | logging.debug('***** Transfer task : STS16 *****\n\n') 155 | self.seed = seed 156 | self.datasets = ['answer-answer', 'headlines', 'plagiarism', 157 | 'postediting', 'question-question'] 158 | self.loadFile(taskpath) 159 | 160 | 161 | class STSBenchmarkEval(STSEval): 162 | def __init__(self, task_path, seed=1111): 163 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 164 | self.seed = seed 165 | self.samples = [] 166 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 167 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 168 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 169 | self.datasets = ['train', 'dev', 'test'] 170 | self.data = {'train': train, 'dev': dev, 'test': test} 171 | 172 | def loadFile(self, fpath): 173 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 174 | with io.open(fpath, 'r', encoding='utf-8') as f: 175 | for line in f: 176 | text = line.strip().split('\t') 177 | sick_data['X_A'].append(text[5].split()) 178 | sick_data['X_B'].append(text[6].split()) 179 | sick_data['y'].append(text[4]) 180 | 181 | sick_data['y'] = [float(s) for s in sick_data['y']] 182 | self.samples += sick_data['X_A'] + sick_data["X_B"] 183 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 184 | 185 | class STSBenchmarkFinetune(SICKEval): 186 | def __init__(self, task_path, seed=1111): 187 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 188 | self.seed = seed 189 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 190 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 191 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 192 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 193 | 194 | def loadFile(self, fpath): 195 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 196 | with io.open(fpath, 'r', encoding='utf-8') as f: 197 | for line in f: 198 | text = line.strip().split('\t') 199 | sick_data['X_A'].append(text[5].split()) 200 | sick_data['X_B'].append(text[6].split()) 201 | sick_data['y'].append(text[4]) 202 | 203 | sick_data['y'] = [float(s) for s in sick_data['y']] 204 | return sick_data 205 | 206 | class SICKRelatednessEval(STSEval): 207 | def __init__(self, task_path, seed=1111): 208 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n') 209 | self.seed = seed 210 | self.samples = [] 211 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 212 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 213 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 214 | self.datasets = ['train', 'dev', 'test'] 215 | self.data = {'train': train, 'dev': dev, 'test': test} 216 | 217 | def loadFile(self, fpath): 218 | skipFirstLine = True 219 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 220 | with io.open(fpath, 'r', encoding='utf-8') as f: 221 | for line in f: 222 | if skipFirstLine: 223 | skipFirstLine = False 224 | else: 225 | text = line.strip().split('\t') 226 | sick_data['X_A'].append(text[1].split()) 227 | sick_data['X_B'].append(text[2].split()) 228 | sick_data['y'].append(text[3]) 229 | 230 | sick_data['y'] = [float(s) for s in sick_data['y']] 231 | self.samples += sick_data['X_A'] + sick_data["X_B"] 232 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 233 | -------------------------------------------------------------------------------- /SentEval/examples/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf 10 | """ 11 | 12 | import numpy as np 13 | import time 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class InferSent(nn.Module): 20 | 21 | def __init__(self, config): 22 | super(InferSent, self).__init__() 23 | self.bsize = config['bsize'] 24 | self.word_emb_dim = config['word_emb_dim'] 25 | self.enc_lstm_dim = config['enc_lstm_dim'] 26 | self.pool_type = config['pool_type'] 27 | self.dpout_model = config['dpout_model'] 28 | self.version = 1 if 'version' not in config else config['version'] 29 | 30 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 31 | bidirectional=True, dropout=self.dpout_model) 32 | 33 | assert self.version in [1, 2] 34 | if self.version == 1: 35 | self.bos = '' 36 | self.eos = '' 37 | self.max_pad = True 38 | self.moses_tok = False 39 | elif self.version == 2: 40 | self.bos = '

' 41 | self.eos = '

' 42 | self.max_pad = False 43 | self.moses_tok = True 44 | 45 | def is_cuda(self): 46 | # either all weights are on cpu or they are on gpu 47 | return self.enc_lstm.bias_hh_l0.data.is_cuda 48 | 49 | def forward(self, sent_tuple): 50 | # sent_len: [max_len, ..., min_len] (bsize) 51 | # sent: (seqlen x bsize x worddim) 52 | sent, sent_len = sent_tuple 53 | 54 | # Sort by length (keep idx) 55 | sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 56 | sent_len_sorted = sent_len_sorted.copy() 57 | idx_unsort = np.argsort(idx_sort) 58 | 59 | idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \ 60 | else torch.from_numpy(idx_sort) 61 | sent = sent.index_select(1, idx_sort) 62 | 63 | # Handling padding in Recurrent Networks 64 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) 65 | sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid 66 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 67 | 68 | # Un-sort by length 69 | idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \ 70 | else torch.from_numpy(idx_unsort) 71 | sent_output = sent_output.index_select(1, idx_unsort) 72 | 73 | # Pooling 74 | if self.pool_type == "mean": 75 | sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda() 76 | emb = torch.sum(sent_output, 0).squeeze(0) 77 | emb = emb / sent_len.expand_as(emb) 78 | elif self.pool_type == "max": 79 | if not self.max_pad: 80 | sent_output[sent_output == 0] = -1e9 81 | emb = torch.max(sent_output, 0)[0] 82 | if emb.ndimension() == 3: 83 | emb = emb.squeeze(0) 84 | assert emb.ndimension() == 2 85 | 86 | return emb 87 | 88 | def set_w2v_path(self, w2v_path): 89 | self.w2v_path = w2v_path 90 | 91 | def get_word_dict(self, sentences, tokenize=True): 92 | # create vocab of words 93 | word_dict = {} 94 | sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] 95 | for sent in sentences: 96 | for word in sent: 97 | if word not in word_dict: 98 | word_dict[word] = '' 99 | word_dict[self.bos] = '' 100 | word_dict[self.eos] = '' 101 | return word_dict 102 | 103 | def get_w2v(self, word_dict): 104 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 105 | # create word_vec with w2v vectors 106 | word_vec = {} 107 | with open(self.w2v_path, encoding='utf-8') as f: 108 | for line in f: 109 | word, vec = line.split(' ', 1) 110 | if word in word_dict: 111 | word_vec[word] = np.fromstring(vec, sep=' ') 112 | print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict))) 113 | return word_vec 114 | 115 | def get_w2v_k(self, K): 116 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 117 | # create word_vec with k first w2v vectors 118 | k = 0 119 | word_vec = {} 120 | with open(self.w2v_path, encoding='utf-8') as f: 121 | for line in f: 122 | word, vec = line.split(' ', 1) 123 | if k <= K: 124 | word_vec[word] = np.fromstring(vec, sep=' ') 125 | k += 1 126 | if k > K: 127 | if word in [self.bos, self.eos]: 128 | word_vec[word] = np.fromstring(vec, sep=' ') 129 | 130 | if k > K and all([w in word_vec for w in [self.bos, self.eos]]): 131 | break 132 | return word_vec 133 | 134 | def build_vocab(self, sentences, tokenize=True): 135 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 136 | word_dict = self.get_word_dict(sentences, tokenize) 137 | self.word_vec = self.get_w2v(word_dict) 138 | print('Vocab size : %s' % (len(self.word_vec))) 139 | 140 | # build w2v vocab with k most frequent words 141 | def build_vocab_k_words(self, K): 142 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 143 | self.word_vec = self.get_w2v_k(K) 144 | print('Vocab size : %s' % (K)) 145 | 146 | def update_vocab(self, sentences, tokenize=True): 147 | assert hasattr(self, 'w2v_path'), 'warning : w2v path not set' 148 | assert hasattr(self, 'word_vec'), 'build_vocab before updating it' 149 | word_dict = self.get_word_dict(sentences, tokenize) 150 | 151 | # keep only new words 152 | for word in self.word_vec: 153 | if word in word_dict: 154 | del word_dict[word] 155 | 156 | # udpate vocabulary 157 | if word_dict: 158 | new_word_vec = self.get_w2v(word_dict) 159 | self.word_vec.update(new_word_vec) 160 | else: 161 | new_word_vec = [] 162 | print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec))) 163 | 164 | def get_batch(self, batch): 165 | # sent in batch in decreasing order of lengths 166 | # batch: (bsize, max_len, word_dim) 167 | embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) 168 | 169 | for i in range(len(batch)): 170 | for j in range(len(batch[i])): 171 | embed[j, i, :] = self.word_vec[batch[i][j]] 172 | 173 | return torch.FloatTensor(embed) 174 | 175 | def tokenize(self, s): 176 | from nltk.tokenize import word_tokenize 177 | if self.moses_tok: 178 | s = ' '.join(word_tokenize(s)) 179 | s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization 180 | return s.split() 181 | else: 182 | return word_tokenize(s) 183 | 184 | def prepare_samples(self, sentences, bsize, tokenize, verbose): 185 | sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else 186 | [self.bos] + self.tokenize(s) + [self.eos] for s in sentences] 187 | n_w = np.sum([len(x) for x in sentences]) 188 | 189 | # filters words without w2v vectors 190 | for i in range(len(sentences)): 191 | s_f = [word for word in sentences[i] if word in self.word_vec] 192 | if not s_f: 193 | import warnings 194 | warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \ 195 | Replacing by ""..' % (sentences[i], i)) 196 | s_f = [self.eos] 197 | sentences[i] = s_f 198 | 199 | lengths = np.array([len(s) for s in sentences]) 200 | n_wk = np.sum(lengths) 201 | if verbose: 202 | print('Nb words kept : %s/%s (%.1f%s)' % ( 203 | n_wk, n_w, 100.0 * n_wk / n_w, '%')) 204 | 205 | # sort by decreasing length 206 | lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) 207 | sentences = np.array(sentences)[idx_sort] 208 | 209 | return sentences, lengths, idx_sort 210 | 211 | def encode(self, sentences, bsize=64, tokenize=True, verbose=False): 212 | tic = time.time() 213 | sentences, lengths, idx_sort = self.prepare_samples( 214 | sentences, bsize, tokenize, verbose) 215 | 216 | embeddings = [] 217 | for stidx in range(0, len(sentences), bsize): 218 | batch = self.get_batch(sentences[stidx:stidx + bsize]) 219 | if self.is_cuda(): 220 | batch = batch.cuda() 221 | with torch.no_grad(): 222 | batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy() 223 | embeddings.append(batch) 224 | embeddings = np.vstack(embeddings) 225 | 226 | # unsort 227 | idx_unsort = np.argsort(idx_sort) 228 | embeddings = embeddings[idx_unsort] 229 | 230 | if verbose: 231 | print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % ( 232 | len(embeddings)/(time.time()-tic), 233 | 'gpu' if self.is_cuda() else 'cpu', bsize)) 234 | return embeddings 235 | 236 | def visualize(self, sent, tokenize=True): 237 | 238 | sent = sent.split() if not tokenize else self.tokenize(sent) 239 | sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]] 240 | 241 | if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos): 242 | import warnings 243 | warnings.warn('No words in "%s" have w2v vectors. Replacing \ 244 | by "%s %s"..' % (sent, self.bos, self.eos)) 245 | batch = self.get_batch(sent) 246 | 247 | if self.is_cuda(): 248 | batch = batch.cuda() 249 | output = self.enc_lstm(batch)[0] 250 | output, idxs = torch.max(output, 0) 251 | # output, idxs = output.squeeze(), idxs.squeeze() 252 | idxs = idxs.data.cpu().numpy() 253 | argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))] 254 | 255 | # visualize model 256 | import matplotlib.pyplot as plt 257 | x = range(len(sent[0])) 258 | y = [100.0 * n / np.sum(argmaxs) for n in argmaxs] 259 | plt.xticks(x, sent[0], rotation=45) 260 | plt.bar(x, y) 261 | plt.ylabel('%') 262 | plt.title('Visualisation of words importance') 263 | plt.show() 264 | 265 | return output, idxs 266 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Validation and classification 10 | (train) : inner-kfold classifier 11 | (train, test) : kfold classifier 12 | (train, dev, test) : split classifier 13 | 14 | """ 15 | from __future__ import absolute_import, division, unicode_literals 16 | 17 | import logging 18 | import numpy as np 19 | from senteval.tools.classifier import MLP 20 | 21 | import sklearn 22 | assert(sklearn.__version__ >= "0.18.0"), \ 23 | "need to update sklearn to version >= 0.18.0" 24 | from sklearn.linear_model import LogisticRegression 25 | from sklearn.model_selection import StratifiedKFold 26 | 27 | 28 | def get_classif_name(classifier_config, usepytorch): 29 | if not usepytorch: 30 | modelname = 'sklearn-LogReg' 31 | else: 32 | nhid = classifier_config['nhid'] 33 | optim = 'adam' if 'optim' not in classifier_config else classifier_config['optim'] 34 | bs = 64 if 'batch_size' not in classifier_config else classifier_config['batch_size'] 35 | modelname = 'pytorch-MLP-nhid%s-%s-bs%s' % (nhid, optim, bs) 36 | return modelname 37 | 38 | # Pytorch version 39 | class InnerKFoldClassifier(object): 40 | """ 41 | (train) split classifier : InnerKfold. 42 | """ 43 | def __init__(self, X, y, config): 44 | self.X = X 45 | self.y = y 46 | self.featdim = X.shape[1] 47 | self.nclasses = config['nclasses'] 48 | self.seed = config['seed'] 49 | self.devresults = [] 50 | self.testresults = [] 51 | self.usepytorch = config['usepytorch'] 52 | self.classifier_config = config['classifier'] 53 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 54 | 55 | self.k = 5 if 'kfold' not in config else config['kfold'] 56 | 57 | def run(self): 58 | logging.info('Training {0} with (inner) {1}-fold cross-validation' 59 | .format(self.modelname, self.k)) 60 | 61 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 62 | [2**t for t in range(-2, 4, 1)] 63 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=1111) 64 | innerskf = StratifiedKFold(n_splits=self.k, shuffle=True, 65 | random_state=1111) 66 | count = 0 67 | for train_idx, test_idx in skf.split(self.X, self.y): 68 | count += 1 69 | X_train, X_test = self.X[train_idx], self.X[test_idx] 70 | y_train, y_test = self.y[train_idx], self.y[test_idx] 71 | scores = [] 72 | for reg in regs: 73 | regscores = [] 74 | for inner_train_idx, inner_test_idx in innerskf.split(X_train, y_train): 75 | X_in_train, X_in_test = X_train[inner_train_idx], X_train[inner_test_idx] 76 | y_in_train, y_in_test = y_train[inner_train_idx], y_train[inner_test_idx] 77 | if self.usepytorch: 78 | clf = MLP(self.classifier_config, inputdim=self.featdim, 79 | nclasses=self.nclasses, l2reg=reg, 80 | seed=self.seed) 81 | clf.fit(X_in_train, y_in_train, 82 | validation_data=(X_in_test, y_in_test)) 83 | else: 84 | clf = LogisticRegression(C=reg, random_state=self.seed) 85 | clf.fit(X_in_train, y_in_train) 86 | regscores.append(clf.score(X_in_test, y_in_test)) 87 | scores.append(round(100*np.mean(regscores), 2)) 88 | optreg = regs[np.argmax(scores)] 89 | logging.info('Best param found at split {0}: l2reg = {1} \ 90 | with score {2}'.format(count, optreg, np.max(scores))) 91 | self.devresults.append(np.max(scores)) 92 | 93 | if self.usepytorch: 94 | clf = MLP(self.classifier_config, inputdim=self.featdim, 95 | nclasses=self.nclasses, l2reg=optreg, 96 | seed=self.seed) 97 | 98 | clf.fit(X_train, y_train, validation_split=0.05) 99 | else: 100 | clf = LogisticRegression(C=optreg, random_state=self.seed) 101 | clf.fit(X_train, y_train) 102 | 103 | self.testresults.append(round(100*clf.score(X_test, y_test), 2)) 104 | 105 | devaccuracy = round(np.mean(self.devresults), 2) 106 | testaccuracy = round(np.mean(self.testresults), 2) 107 | return devaccuracy, testaccuracy 108 | 109 | 110 | class KFoldClassifier(object): 111 | """ 112 | (train, test) split classifier : cross-validation on train. 113 | """ 114 | def __init__(self, train, test, config): 115 | self.train = train 116 | self.test = test 117 | self.featdim = self.train['X'].shape[1] 118 | self.nclasses = config['nclasses'] 119 | self.seed = config['seed'] 120 | self.usepytorch = config['usepytorch'] 121 | self.classifier_config = config['classifier'] 122 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 123 | 124 | self.k = 5 if 'kfold' not in config else config['kfold'] 125 | 126 | def run(self): 127 | # cross-validation 128 | logging.info('Training {0} with {1}-fold cross-validation' 129 | .format(self.modelname, self.k)) 130 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 131 | [2**t for t in range(-1, 6, 1)] 132 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, 133 | random_state=self.seed) 134 | scores = [] 135 | 136 | for reg in regs: 137 | scanscores = [] 138 | for train_idx, test_idx in skf.split(self.train['X'], 139 | self.train['y']): 140 | # Split data 141 | X_train, y_train = self.train['X'][train_idx], self.train['y'][train_idx] 142 | 143 | X_test, y_test = self.train['X'][test_idx], self.train['y'][test_idx] 144 | 145 | # Train classifier 146 | if self.usepytorch: 147 | clf = MLP(self.classifier_config, inputdim=self.featdim, 148 | nclasses=self.nclasses, l2reg=reg, 149 | seed=self.seed) 150 | clf.fit(X_train, y_train, validation_data=(X_test, y_test)) 151 | else: 152 | clf = LogisticRegression(C=reg, random_state=self.seed) 153 | clf.fit(X_train, y_train) 154 | score = clf.score(X_test, y_test) 155 | scanscores.append(score) 156 | # Append mean score 157 | scores.append(round(100*np.mean(scanscores), 2)) 158 | 159 | # evaluation 160 | logging.info([('reg:' + str(regs[idx]), scores[idx]) 161 | for idx in range(len(scores))]) 162 | optreg = regs[np.argmax(scores)] 163 | devaccuracy = np.max(scores) 164 | logging.info('Cross-validation : best param found is reg = {0} \ 165 | with score {1}'.format(optreg, devaccuracy)) 166 | 167 | logging.info('Evaluating...') 168 | if self.usepytorch: 169 | clf = MLP(self.classifier_config, inputdim=self.featdim, 170 | nclasses=self.nclasses, l2reg=optreg, 171 | seed=self.seed) 172 | clf.fit(self.train['X'], self.train['y'], validation_split=0.05) 173 | else: 174 | clf = LogisticRegression(C=optreg, random_state=self.seed) 175 | clf.fit(self.train['X'], self.train['y']) 176 | yhat = clf.predict(self.test['X']) 177 | 178 | testaccuracy = clf.score(self.test['X'], self.test['y']) 179 | testaccuracy = round(100*testaccuracy, 2) 180 | 181 | return devaccuracy, testaccuracy, yhat 182 | 183 | 184 | class SplitClassifier(object): 185 | """ 186 | (train, valid, test) split classifier. 187 | """ 188 | def __init__(self, X, y, config): 189 | self.X = X 190 | self.y = y 191 | self.nclasses = config['nclasses'] 192 | self.featdim = self.X['train'].shape[1] 193 | self.seed = config['seed'] 194 | self.usepytorch = config['usepytorch'] 195 | self.classifier_config = config['classifier'] 196 | self.cudaEfficient = False if 'cudaEfficient' not in config else \ 197 | config['cudaEfficient'] 198 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 199 | self.noreg = False if 'noreg' not in config else config['noreg'] 200 | self.config = config 201 | 202 | def run(self): 203 | logging.info('Training {0} with standard validation..' 204 | .format(self.modelname)) 205 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 206 | [2**t for t in range(-2, 4, 1)] 207 | if self.noreg: 208 | regs = [1e-9 if self.usepytorch else 1e9] 209 | scores = [] 210 | for reg in regs: 211 | if self.usepytorch: 212 | clf = MLP(self.classifier_config, inputdim=self.featdim, 213 | nclasses=self.nclasses, l2reg=reg, 214 | seed=self.seed, cudaEfficient=self.cudaEfficient) 215 | 216 | # TODO: Find a hack for reducing nb epoches in SNLI 217 | clf.fit(self.X['train'], self.y['train'], 218 | validation_data=(self.X['valid'], self.y['valid'])) 219 | else: 220 | clf = LogisticRegression(C=reg, random_state=self.seed) 221 | clf.fit(self.X['train'], self.y['train']) 222 | scores.append(round(100*clf.score(self.X['valid'], 223 | self.y['valid']), 2)) 224 | logging.info([('reg:'+str(regs[idx]), scores[idx]) 225 | for idx in range(len(scores))]) 226 | optreg = regs[np.argmax(scores)] 227 | devaccuracy = np.max(scores) 228 | logging.info('Validation : best param found is reg = {0} with score \ 229 | {1}'.format(optreg, devaccuracy)) 230 | clf = LogisticRegression(C=optreg, random_state=self.seed) 231 | logging.info('Evaluating...') 232 | if self.usepytorch: 233 | clf = MLP(self.classifier_config, inputdim=self.featdim, 234 | nclasses=self.nclasses, l2reg=optreg, 235 | seed=self.seed, cudaEfficient=self.cudaEfficient) 236 | 237 | # TODO: Find a hack for reducing nb epoches in SNLI 238 | clf.fit(self.X['train'], self.y['train'], 239 | validation_data=(self.X['valid'], self.y['valid'])) 240 | else: 241 | clf = LogisticRegression(C=optreg, random_state=self.seed) 242 | clf.fit(self.X['train'], self.y['train']) 243 | 244 | testaccuracy = clf.score(self.X['test'], self.y['test']) 245 | testaccuracy = round(100*testaccuracy, 2) 246 | return devaccuracy, testaccuracy 247 | -------------------------------------------------------------------------------- /sscl/tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SSCL(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path))) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def search(self, queries: Union[str, List[str]], 181 | device: str = None, 182 | threshold: float = 0.6, 183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 184 | 185 | if not self.is_faiss_index: 186 | if isinstance(queries, list): 187 | combined_results = [] 188 | for query in queries: 189 | results = self.search(query, device) 190 | combined_results.append(results) 191 | return combined_results 192 | 193 | similarities = self.similarity(queries, self.index["index"]).tolist() 194 | id_and_score = [] 195 | for i, s in enumerate(similarities): 196 | if s >= threshold: 197 | id_and_score.append((i, s)) 198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 200 | return results 201 | else: 202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 203 | 204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 205 | 206 | def pack_single_result(dist, idx): 207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 208 | return results 209 | 210 | if isinstance(queries, list): 211 | combined_results = [] 212 | for i in range(len(queries)): 213 | results = pack_single_result(distance[i], idx[i]) 214 | combined_results.append(results) 215 | return combined_results 216 | else: 217 | return pack_single_result(distance[0], idx[0]) 218 | 219 | if __name__=="__main__": 220 | example_sentences = [ 221 | 'An animal is biting a persons finger.', 222 | 'A woman is reading.', 223 | 'A man is lifting weights in a garage.', 224 | 'A man plays the violin.', 225 | 'A man is eating food.', 226 | 'A man plays the piano.', 227 | 'A panda is climbing.', 228 | 'A man plays a guitar.', 229 | 'A woman is slicing a meat.', 230 | 'A woman is taking a picture.' 231 | ] 232 | example_queries = [ 233 | 'A man is playing music.', 234 | 'A woman is making a photo.' 235 | ] 236 | 237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 238 | simcse = SSCL(model_name) 239 | 240 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 241 | similarities = simcse.similarity(example_queries, example_sentences) 242 | print(similarities) 243 | 244 | print("\n=========Naive brute force search============\n") 245 | simcse.build_index(example_sentences, use_faiss=False) 246 | results = simcse.search(example_queries) 247 | for i, result in enumerate(results): 248 | print("Retrieval results for query: {}".format(example_queries[i])) 249 | for sentence, score in result: 250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 251 | print("") 252 | 253 | print("\n=========Search with Faiss backend============\n") 254 | simcse.build_index(example_sentences, use_faiss=True) 255 | results = simcse.search(example_queries) 256 | for i, result in enumerate(results): 257 | print("Retrieval results for query: {}".format(example_queries[i])) 258 | for sentence, score in result: 259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 260 | print("") 261 | 262 | -------------------------------------------------------------------------------- /sscl/.ipynb_checkpoints/tool-checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SimCSE(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path))) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def search(self, queries: Union[str, List[str]], 181 | device: str = None, 182 | threshold: float = 0.6, 183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 184 | 185 | if not self.is_faiss_index: 186 | if isinstance(queries, list): 187 | combined_results = [] 188 | for query in queries: 189 | results = self.search(query, device) 190 | combined_results.append(results) 191 | return combined_results 192 | 193 | similarities = self.similarity(queries, self.index["index"]).tolist() 194 | id_and_score = [] 195 | for i, s in enumerate(similarities): 196 | if s >= threshold: 197 | id_and_score.append((i, s)) 198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 200 | return results 201 | else: 202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 203 | 204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 205 | 206 | def pack_single_result(dist, idx): 207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 208 | return results 209 | 210 | if isinstance(queries, list): 211 | combined_results = [] 212 | for i in range(len(queries)): 213 | results = pack_single_result(distance[i], idx[i]) 214 | combined_results.append(results) 215 | return combined_results 216 | else: 217 | return pack_single_result(distance[0], idx[0]) 218 | 219 | if __name__=="__main__": 220 | example_sentences = [ 221 | 'An animal is biting a persons finger.', 222 | 'A woman is reading.', 223 | 'A man is lifting weights in a garage.', 224 | 'A man plays the violin.', 225 | 'A man is eating food.', 226 | 'A man plays the piano.', 227 | 'A panda is climbing.', 228 | 'A man plays a guitar.', 229 | 'A woman is slicing a meat.', 230 | 'A woman is taking a picture.' 231 | ] 232 | example_queries = [ 233 | 'A man is playing music.', 234 | 'A woman is making a photo.' 235 | ] 236 | 237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 238 | simcse = SimCSE(model_name) 239 | 240 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 241 | similarities = simcse.similarity(example_queries, example_sentences) 242 | print(similarities) 243 | 244 | print("\n=========Naive brute force search============\n") 245 | simcse.build_index(example_sentences, use_faiss=False) 246 | results = simcse.search(example_queries) 247 | for i, result in enumerate(results): 248 | print("Retrieval results for query: {}".format(example_queries[i])) 249 | for sentence, score in result: 250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 251 | print("") 252 | 253 | print("\n=========Search with Faiss backend============\n") 254 | simcse.build_index(example_sentences, use_faiss=True) 255 | results = simcse.search(example_queries) 256 | for i, result in enumerate(results): 257 | print("Retrieval results for query: {}".format(example_queries[i])) 258 | for sentence, score in result: 259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 260 | print("") 261 | 262 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /SentEval/README.md: -------------------------------------------------------------------------------- 1 | Our modification to SentEval: 2 | 3 | 1. Add the `all` setting to all STS tasks. 4 | 2. Change STS-B and SICK-R to not use an additional regressor. 5 | 6 | # SentEval: evaluation toolkit for sentence embeddings 7 | 8 | SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. **SentEval currently includes 17 downstream tasks**. We also include a suite of **10 probing tasks** which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations. 9 | 10 | 11 | **(04/22) SentEval new tasks: Added probing tasks for evaluating what linguistic properties are encoded in sentence embeddings** 12 | 13 | **(10/04) SentEval example scripts for three sentence encoders: [SkipThought-LN](https://github.com/ryankiros/layer-norm#skip-thoughts)/[GenSen](https://github.com/Maluuba/gensen)/[Google-USE](https://tfhub.dev/google/universal-sentence-encoder/1)** 14 | 15 | ## Dependencies 16 | 17 | This code is written in python. The dependencies are: 18 | 19 | * Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](http://www.scipy.org/) 20 | * [Pytorch](http://pytorch.org/)>=0.4 21 | * [scikit-learn](http://scikit-learn.org/stable/index.html)>=0.18.0 22 | 23 | ## Transfer tasks 24 | 25 | ### Downstream tasks 26 | SentEval allows you to evaluate your sentence embeddings as features for the following *downstream* tasks: 27 | 28 | | Task | Type | #train | #test | needs_train | set_classifier | 29 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:| 30 | | [MR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | movie review | 11k | 11k | 1 | 1 | 31 | | [CR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | product review | 4k | 4k | 1 | 1 | 32 | | [SUBJ](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | subjectivity status | 10k | 10k | 1 | 1 | 33 | | [MPQA](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | opinion-polarity | 11k | 11k | 1 | 1 | 34 | | [SST](https://nlp.stanford.edu/sentiment/index.html) | binary sentiment analysis | 67k | 1.8k | 1 | 1 | 35 | | **[SST](https://nlp.stanford.edu/sentiment/index.html)** | **fine-grained sentiment analysis** | 8.5k | 2.2k | 1 | 1 | 36 | | [TREC](http://cogcomp.cs.illinois.edu/Data/QA/QC/) | question-type classification | 6k | 0.5k | 1 | 1 | 37 | | [SICK-E](http://clic.cimec.unitn.it/composes/sick.html) | natural language inference | 4.5k | 4.9k | 1 | 1 | 38 | | [SNLI](https://nlp.stanford.edu/projects/snli/) | natural language inference | 550k | 9.8k | 1 | 1 | 39 | | [MRPC](https://aclweb.org/aclwiki/Paraphrase_Identification_(State_of_the_art)) | paraphrase detection | 4.1k | 1.7k | 1 | 1 | 40 | | [STS 2012](https://www.cs.york.ac.uk/semeval-2012/task6/) | semantic textual similarity | N/A | 3.1k | 0 | 0 | 41 | | [STS 2013](http://ixa2.si.ehu.es/sts/) | semantic textual similarity | N/A | 1.5k | 0 | 0 | 42 | | [STS 2014](http://alt.qcri.org/semeval2014/task10/) | semantic textual similarity | N/A | 3.7k | 0 | 0 | 43 | | [STS 2015](http://alt.qcri.org/semeval2015/task2/) | semantic textual similarity | N/A | 8.5k | 0 | 0 | 44 | | [STS 2016](http://alt.qcri.org/semeval2016/task1/) | semantic textual similarity | N/A | 9.2k | 0 | 0 | 45 | | [STS B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) | semantic textual similarity | 5.7k | 1.4k | 1 | 0 | 46 | | [SICK-R](http://clic.cimec.unitn.it/composes/sick.html) | semantic textual similarity | 4.5k | 4.9k | 1 | 0 | 47 | | [COCO](http://mscoco.org/) | image-caption retrieval | 567k | 5*1k | 1 | 0 | 48 | 49 | where **needs_train** means a model with parameters is learned on top of the sentence embeddings, and **set_classifier** means you can define the parameters of the classifier in the case of a classification task (see below). 50 | 51 | Note: COCO comes with ResNet-101 2048d image embeddings. [More details on the tasks.](https://arxiv.org/pdf/1705.02364.pdf) 52 | 53 | ### Probing tasks 54 | SentEval also includes a series of [*probing* tasks](https://github.com/facebookresearch/SentEval/tree/master/data/probing) to evaluate what linguistic properties are encoded in your sentence embeddings: 55 | 56 | | Task | Type | #train | #test | needs_train | set_classifier | 57 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:| 58 | | [SentLen](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Length prediction | 100k | 10k | 1 | 1 | 59 | | [WC](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word Content analysis | 100k | 10k | 1 | 1 | 60 | | [TreeDepth](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Tree depth prediction | 100k | 10k | 1 | 1 | 61 | | [TopConst](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Top Constituents prediction | 100k | 10k | 1 | 1 | 62 | | [BShift](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word order analysis | 100k | 10k | 1 | 1 | 63 | | [Tense](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Verb tense prediction | 100k | 10k | 1 | 1 | 64 | | [SubjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Subject number prediction | 100k | 10k | 1 | 1 | 65 | | [ObjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Object number prediction | 100k | 10k | 1 | 1 | 66 | | [SOMO](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Semantic odd man out | 100k | 10k | 1 | 1 | 67 | | [CoordInv](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Coordination Inversion | 100k | 10k | 1 | 1 | 68 | 69 | ## Download datasets 70 | To get all the transfer tasks datasets, run (in data/downstream/): 71 | ```bash 72 | ./get_transfer_data.bash 73 | ``` 74 | This will automatically download and preprocess the downstream datasets, and store them in data/downstream (warning: for MacOS users, you may have to use p7zip instead of unzip). The probing tasks are already in data/probing by default. 75 | 76 | ## How to use SentEval: examples 77 | 78 | ### examples/bow.py 79 | 80 | In examples/bow.py, we evaluate the quality of the average of word embeddings. 81 | 82 | To download state-of-the-art fastText embeddings: 83 | 84 | ```bash 85 | curl -Lo glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip 86 | curl -Lo crawl-300d-2M.vec.zip https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip 87 | ``` 88 | 89 | To reproduce the results for bag-of-vectors, run (in examples/): 90 | ```bash 91 | python bow.py 92 | ``` 93 | 94 | As required by SentEval, this script implements two functions: **prepare** (optional) and **batcher** (required) that turn text sentences into sentence embeddings. Then SentEval takes care of the evaluation on the transfer tasks using the embeddings as features. 95 | 96 | ### examples/infersent.py 97 | 98 | To get the **[InferSent](https://www.github.com/facebookresearch/InferSent)** model and reproduce our results, download our best models and run infersent.py (in examples/): 99 | ```bash 100 | curl -Lo examples/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl 101 | curl -Lo examples/infersent2.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent2.pkl 102 | ``` 103 | 104 | ### examples/skipthought.py - examples/gensen.py - examples/googleuse.py 105 | 106 | We also provide example scripts for three other encoders: 107 | 108 | * [SkipThought with Layer-Normalization](https://github.com/ryankiros/layer-norm#skip-thoughts) in Theano 109 | * [GenSen encoder](https://github.com/Maluuba/gensen) in Pytorch 110 | * [Google encoder](https://tfhub.dev/google/universal-sentence-encoder/1) in TensorFlow 111 | 112 | Note that for SkipThought and GenSen, following the steps of the associated githubs is necessary. 113 | The Google encoder script should work as-is. 114 | 115 | ## How to use SentEval 116 | 117 | To evaluate your sentence embeddings, SentEval requires that you implement two functions: 118 | 119 | 1. **prepare** (sees the whole dataset of each task and can thus construct the word vocabulary, the dictionary of word vectors etc) 120 | 2. **batcher** (transforms a batch of text sentences into sentence embeddings) 121 | 122 | 123 | ### 1.) prepare(params, samples) (optional) 124 | 125 | *batcher* only sees one batch at a time while the *samples* argument of *prepare* contains all the sentences of a task. 126 | 127 | ``` 128 | prepare(params, samples) 129 | ``` 130 | * *params*: senteval parameters. 131 | * *samples*: list of all sentences from the tranfer task. 132 | * *output*: No output. Arguments stored in "params" can further be used by *batcher*. 133 | 134 | *Example*: in bow.py, prepare is is used to build the vocabulary of words and construct the "params.word_vect* dictionary of word vectors. 135 | 136 | 137 | ### 2.) batcher(params, batch) 138 | ``` 139 | batcher(params, batch) 140 | ``` 141 | * *params*: senteval parameters. 142 | * *batch*: numpy array of text sentences (of size params.batch_size) 143 | * *output*: numpy array of sentence embeddings (of size params.batch_size) 144 | 145 | *Example*: in bow.py, batcher is used to compute the mean of the word vectors for each sentence in the batch using params.word_vec. Use your own encoder in that function to encode sentences. 146 | 147 | ### 3.) evaluation on transfer tasks 148 | 149 | After having implemented the batch and prepare function for your own sentence encoder, 150 | 151 | 1) to perform the actual evaluation, first import senteval and set its parameters: 152 | ```python 153 | import senteval 154 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 155 | ``` 156 | 157 | 2) (optional) set the parameters of the classifier (when applicable): 158 | ```python 159 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 160 | 'tenacity': 5, 'epoch_size': 4} 161 | ``` 162 | You can choose **nhid=0** (Logistic Regression) or **nhid>0** (MLP) and define the parameters for training. 163 | 164 | 3) Create an instance of the class SE: 165 | ```python 166 | se = senteval.engine.SE(params, batcher, prepare) 167 | ``` 168 | 169 | 4) define the set of transfer tasks and run the evaluation: 170 | ```python 171 | transfer_tasks = ['MR', 'SICKEntailment', 'STS14', 'STSBenchmark'] 172 | results = se.eval(transfer_tasks) 173 | ``` 174 | The current list of available tasks is: 175 | ```python 176 | ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI', 177 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 'ImageCaptionRetrieval', 178 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 179 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense', 180 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion'] 181 | ``` 182 | 183 | ## SentEval parameters 184 | Global parameters of SentEval: 185 | ```bash 186 | # senteval parameters 187 | task_path # path to SentEval datasets (required) 188 | seed # seed 189 | usepytorch # use cuda-pytorch (else scikit-learn) where possible 190 | kfold # k-fold validation for MR/CR/SUB/MPQA. 191 | ``` 192 | 193 | Parameters of the classifier: 194 | ```bash 195 | nhid: # number of hidden units (0: Logistic Regression, >0: MLP); Default nonlinearity: Tanh 196 | optim: # optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 197 | tenacity: # how many times dev acc does not increase before training stops 198 | epoch_size: # each epoch corresponds to epoch_size pass on the train set 199 | max_epoch: # max number of epoches 200 | dropout: # dropout for MLP 201 | ``` 202 | 203 | Note that to get a proxy of the results while **dramatically reducing computation time**, 204 | we suggest the **prototyping config**: 205 | ```python 206 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 207 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 208 | 'tenacity': 3, 'epoch_size': 2} 209 | ``` 210 | which will results in a 5 times speedup for classification tasks. 211 | 212 | To produce results that are **comparable to the literature**, use the **default config**: 213 | ```python 214 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 215 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 216 | 'tenacity': 5, 'epoch_size': 4} 217 | ``` 218 | which takes longer but will produce better and comparable results. 219 | 220 | For probing tasks, we used an MLP with a Sigmoid nonlinearity and and tuned the nhid (in [50, 100, 200]) and dropout (in [0.0, 0.1, 0.2]) on the dev set. 221 | 222 | ## References 223 | 224 | Please considering citing [[1]](https://arxiv.org/abs/1803.05449) if using this code for evaluating sentence embedding methods. 225 | 226 | ### SentEval: An Evaluation Toolkit for Universal Sentence Representations 227 | 228 | [1] A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449) 229 | 230 | ``` 231 | @article{conneau2018senteval, 232 | title={SentEval: An Evaluation Toolkit for Universal Sentence Representations}, 233 | author={Conneau, Alexis and Kiela, Douwe}, 234 | journal={arXiv preprint arXiv:1803.05449}, 235 | year={2018} 236 | } 237 | ``` 238 | 239 | Contact: [aconneau@fb.com](mailto:aconneau@fb.com), [dkiela@fb.com](mailto:dkiela@fb.com) 240 | 241 | ### Related work 242 | * [J. R Kiros, Y. Zhu, R. Salakhutdinov, R. S. Zemel, A. Torralba, R. Urtasun, S. Fidler - SkipThought Vectors, NIPS 2015](https://arxiv.org/abs/1506.06726) 243 | * [S. Arora, Y. Liang, T. Ma - A Simple but Tough-to-Beat Baseline for Sentence Embeddings, ICLR 2017](https://openreview.net/pdf?id=SyK00v5xx) 244 | * [Y. Adi, E. Kermany, Y. Belinkov, O. Lavi, Y. Goldberg - Fine-grained analysis of sentence embeddings using auxiliary prediction tasks, ICLR 2017](https://arxiv.org/abs/1608.04207) 245 | * [A. Conneau, D. Kiela, L. Barrault, H. Schwenk, A. Bordes - Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, EMNLP 2017](https://arxiv.org/abs/1705.02364) 246 | * [S. Subramanian, A. Trischler, Y. Bengio, C. J Pal - Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, ICLR 2018](https://arxiv.org/abs/1804.00079) 247 | * [A. Nie, E. D. Bennett, N. D. Goodman - DisSent: Sentence Representation Learning from Explicit Discourse Relations, 2018](https://arxiv.org/abs/1710.04334) 248 | * [D. Cer, Y. Yang, S. Kong, N. Hua, N. Limtiaco, R. St. John, N. Constant, M. Guajardo-Cespedes, S. Yuan, C. Tar, Y. Sung, B. Strope, R. Kurzweil - Universal Sentence Encoder, 2018](https://arxiv.org/abs/1803.11175) 249 | * [A. Conneau, G. Kruszewski, G. Lample, L. Barrault, M. Baroni - What you can cram into a single vector: Probing sentence embeddings for linguistic properties, ACL 2018](https://arxiv.org/abs/1805.01070) 250 | --------------------------------------------------------------------------------