├── SentEval ├── senteval │ ├── tools │ │ ├── __init__.py │ │ ├── relatedness.py │ │ ├── classifier.py │ │ └── validation.py │ ├── __init__.py │ ├── utils.py │ ├── trec.py │ ├── binary.py │ ├── sst.py │ ├── mrpc.py │ ├── snli.py │ ├── rank.py │ ├── engine.py │ ├── probing.py │ ├── sick.py │ └── sts.py ├── data │ └── downstream │ │ └── download_dataset.sh ├── setup.py ├── LICENSE ├── examples │ ├── skipthought.py │ ├── googleuse.py │ ├── gensen.py │ ├── infersent.py │ ├── bow.py │ └── models.py └── README.md ├── simcse ├── __init__.py ├── tool.py ├── models.py ├── models_HSCL.py └── models_aug.py ├── data └── .gitignore ├── requirements.txt ├── scripts ├── eval.sh └── sup_train_mp.sh ├── LICENSE ├── simcse_to_huggingface.py ├── evaluation.py └── README.md /SentEval/senteval/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /simcse/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import SimCSE 2 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.2.1 2 | scipy 3 | datasets 4 | pandas 5 | scikit-learn 6 | prettytable 7 | gradio 8 | setuptools -------------------------------------------------------------------------------- /SentEval/data/downstream/download_dataset.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/senteval.tar 2 | tar xvf senteval.tar 3 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | path=sjtu-lit/SynCSE-partial-RoBERTa-base 2 | python simcse_to_huggingface.py --path ${path} 3 | CUDA_VISIBLE_DEVICES=0 python evaluation.py \ 4 | --model_name_or_path ${path} \ 5 | --pooler cls \ 6 | --task_set sts \ 7 | --mode test 8 | -------------------------------------------------------------------------------- /SentEval/senteval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | 10 | from senteval.engine import SE 11 | -------------------------------------------------------------------------------- /SentEval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import io 9 | from setuptools import setup, find_packages 10 | 11 | with io.open('./README.md', encoding='utf-8') as f: 12 | readme = f.read() 13 | 14 | setup( 15 | name='SentEval', 16 | version='0.1.0', 17 | url='https://github.com/facebookresearch/SentEval', 18 | packages=find_packages(exclude=['examples']), 19 | license='Attribution-NonCommercial 4.0 International', 20 | long_description=readme, 21 | ) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Language Intelligence and Technology group @ SJTU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/sup_train_mp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # In this example, we show how to train SimCSE using multiple GPU cards and PyTorch's distributed data parallel on supervised NLI dataset. 4 | # Set how many GPUs to use 5 | 6 | NUM_GPU=1 7 | 8 | # Randomly set a port number 9 | # If you encounter "address already used" error, just run again or manually set an available port id. 10 | PORT_ID=$(expr $RANDOM + 1000) 11 | 12 | # Allow multiple threads 13 | export OMP_NUM_THREADS=1 14 | 15 | # Use distributed data parallel 16 | # If you only want to use one card, uncomment the following line and comment the line with "torch.distributed.launch" 17 | # python train.py \ 18 | 19 | model=roberta-base 20 | dataset=sjtu-lit/SynCSE-partial-NLI 21 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node $NUM_GPU --master_port $PORT_ID train.py \ 22 | --model_name_or_path ${model} \ 23 | --train_file ${dataset} \ 24 | --output_dir result/my-sup-simcse-${model}_${dataset} \ 25 | --num_train_epochs 3 \ 26 | --per_device_train_batch_size 512 \ 27 | --learning_rate 5e-5 \ 28 | --max_seq_length 32 \ 29 | --evaluation_strategy steps \ 30 | --metric_for_best_model avg_sts \ 31 | --load_best_model_at_end \ 32 | --eval_steps 25 \ 33 | --pooler_type cls \ 34 | --overwrite_output_dir \ 35 | --temp 0.05 \ 36 | --do_train \ 37 | --do_eval \ 38 | --fp16 \ 39 | --seed 42 \ 40 | --do_mlm \ 41 | --hard_negative_weight 0 \ 42 | "$@" 43 | -------------------------------------------------------------------------------- /simcse_to_huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert SimCSE's checkpoints to Huggingface style. 3 | """ 4 | 5 | import argparse 6 | import torch 7 | import os 8 | import json 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--path", type=str, help="Path of SimCSE checkpoint folder") 14 | args = parser.parse_args() 15 | 16 | print("SimCSE checkpoint -> Huggingface checkpoint for {}".format(args.path)) 17 | 18 | state_dict = torch.load(os.path.join(args.path, "pytorch_model.bin"), map_location=torch.device("cpu")) 19 | new_state_dict = {} 20 | for key, param in state_dict.items(): 21 | # Replace "mlp" to "pooler" 22 | if "mlp" in key: 23 | key = key.replace("mlp", "pooler") 24 | 25 | # Delete "bert" or "roberta" prefix 26 | if "bert." in key: 27 | key = key.replace("bert.", "") 28 | if "roberta." in key: 29 | key = key.replace("roberta.", "") 30 | 31 | new_state_dict[key] = param 32 | 33 | torch.save(new_state_dict, os.path.join(args.path, "pytorch_model.bin")) 34 | 35 | # Change architectures in config.json 36 | config = json.load(open(os.path.join(args.path, "config.json"))) 37 | for i in range(len(config["architectures"])): 38 | config["architectures"][i] = config["architectures"][i].replace("ForCL", "Model") 39 | json.dump(config, open(os.path.join(args.path, "config.json"), "w"), indent=2) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /SentEval/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For SentEval software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /SentEval/examples/skipthought.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | """ 11 | Example of file for SkipThought in SentEval 12 | """ 13 | import logging 14 | import sys 15 | sys.setdefaultencoding('utf8') 16 | 17 | 18 | # Set PATHs 19 | PATH_TO_SENTEVAL = '../' 20 | PATH_TO_DATA = '../data/senteval_data/' 21 | PATH_TO_SKIPTHOUGHT = '' 22 | 23 | assert PATH_TO_SKIPTHOUGHT != '', 'Download skipthought and set correct PATH' 24 | 25 | # import skipthought and Senteval 26 | sys.path.insert(0, PATH_TO_SKIPTHOUGHT) 27 | import skipthoughts 28 | sys.path.insert(0, PATH_TO_SENTEVAL) 29 | import senteval 30 | 31 | 32 | def prepare(params, samples): 33 | return 34 | 35 | def batcher(params, batch): 36 | batch = [str(' '.join(sent), errors="ignore") if sent != [] else '.' for sent in batch] 37 | embeddings = skipthoughts.encode(params['encoder'], batch, 38 | verbose=False, use_eos=True) 39 | return embeddings 40 | 41 | 42 | # Set params for SentEval 43 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size': 512} 44 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 45 | 'tenacity': 5, 'epoch_size': 4} 46 | # Set up logger 47 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 48 | 49 | if __name__ == "__main__": 50 | # Load SkipThought model 51 | params_senteval['encoder'] = skipthoughts.load_model() 52 | 53 | se = senteval.engine.SE(params_senteval, batcher, prepare) 54 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 55 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 56 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 57 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 58 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 59 | 'OddManOut', 'CoordinationInversion'] 60 | results = se.eval(transfer_tasks) 61 | print(results) 62 | -------------------------------------------------------------------------------- /SentEval/examples/googleuse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division 9 | 10 | import os 11 | import sys 12 | import logging 13 | import tensorflow as tf 14 | import tensorflow_hub as hub 15 | tf.logging.set_verbosity(0) 16 | 17 | # Set PATHs 18 | PATH_TO_SENTEVAL = '../' 19 | PATH_TO_DATA = '../data' 20 | 21 | # import SentEval 22 | sys.path.insert(0, PATH_TO_SENTEVAL) 23 | import senteval 24 | 25 | # tensorflow session 26 | session = tf.Session() 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 28 | 29 | # SentEval prepare and batcher 30 | def prepare(params, samples): 31 | return 32 | 33 | def batcher(params, batch): 34 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 35 | embeddings = params['google_use'](batch) 36 | return embeddings 37 | 38 | def make_embed_fn(module): 39 | with tf.Graph().as_default(): 40 | sentences = tf.placeholder(tf.string) 41 | embed = hub.Module(module) 42 | embeddings = embed(sentences) 43 | session = tf.train.MonitoredSession() 44 | return lambda x: session.run(embeddings, {sentences: x}) 45 | 46 | # Start TF session and load Google Universal Sentence Encoder 47 | encoder = make_embed_fn("https://tfhub.dev/google/universal-sentence-encoder-large/2") 48 | 49 | # Set params for SentEval 50 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 51 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 52 | 'tenacity': 3, 'epoch_size': 2} 53 | params_senteval['google_use'] = encoder 54 | 55 | # Set up logger 56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 57 | 58 | if __name__ == "__main__": 59 | se = senteval.engine.SE(params_senteval, batcher, prepare) 60 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 61 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 62 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 63 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 64 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 65 | 'OddManOut', 'CoordinationInversion'] 66 | results = se.eval(transfer_tasks) 67 | print(results) 68 | -------------------------------------------------------------------------------- /SentEval/examples/gensen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Clone GenSen repo here: https://github.com/Maluuba/gensen.git 10 | And follow instructions for loading the model used in batcher 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import sys 16 | import logging 17 | # import GenSen package 18 | from gensen import GenSen, GenSenSingle 19 | 20 | # Set PATHs 21 | PATH_TO_SENTEVAL = '../' 22 | PATH_TO_DATA = '../data' 23 | 24 | # import SentEval 25 | sys.path.insert(0, PATH_TO_SENTEVAL) 26 | import senteval 27 | 28 | # SentEval prepare and batcher 29 | def prepare(params, samples): 30 | return 31 | 32 | def batcher(params, batch): 33 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 34 | _, reps_h_t = gensen.get_representation( 35 | sentences, pool='last', return_numpy=True, tokenize=True 36 | ) 37 | embeddings = reps_h_t 38 | return embeddings 39 | 40 | # Load GenSen model 41 | gensen_1 = GenSenSingle( 42 | model_folder='../data/models', 43 | filename_prefix='nli_large_bothskip', 44 | pretrained_emb='../data/embedding/glove.840B.300d.h5' 45 | ) 46 | gensen_2 = GenSenSingle( 47 | model_folder='../data/models', 48 | filename_prefix='nli_large_bothskip_parse', 49 | pretrained_emb='../data/embedding/glove.840B.300d.h5' 50 | ) 51 | gensen_encoder = GenSen(gensen_1, gensen_2) 52 | reps_h, reps_h_t = gensen.get_representation( 53 | sentences, pool='last', return_numpy=True, tokenize=True 54 | ) 55 | 56 | # Set params for SentEval 57 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 58 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 59 | 'tenacity': 3, 'epoch_size': 2} 60 | params_senteval['gensen'] = gensen_encoder 61 | 62 | # Set up logger 63 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 64 | 65 | if __name__ == "__main__": 66 | se = senteval.engine.SE(params_senteval, batcher, prepare) 67 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 68 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 69 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 70 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 71 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 72 | 'OddManOut', 'CoordinationInversion'] 73 | results = se.eval(transfer_tasks) 74 | print(results) 75 | -------------------------------------------------------------------------------- /SentEval/examples/infersent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | InferSent models. See https://github.com/facebookresearch/InferSent. 10 | """ 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import sys 15 | import os 16 | import torch 17 | import logging 18 | 19 | # get models.py from InferSent repo 20 | from models import InferSent 21 | 22 | # Set PATHs 23 | PATH_SENTEVAL = '../' 24 | PATH_TO_DATA = '../data' 25 | PATH_TO_W2V = 'PATH/TO/glove.840B.300d.txt' # or crawl-300d-2M.vec for V2 26 | MODEL_PATH = 'infersent1.pkl' 27 | V = 1 # version of InferSent 28 | 29 | assert os.path.isfile(MODEL_PATH) and os.path.isfile(PATH_TO_W2V), \ 30 | 'Set MODEL and GloVe PATHs' 31 | 32 | # import senteval 33 | sys.path.insert(0, PATH_SENTEVAL) 34 | import senteval 35 | 36 | 37 | def prepare(params, samples): 38 | params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False) 39 | 40 | 41 | def batcher(params, batch): 42 | sentences = [' '.join(s) for s in batch] 43 | embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False) 44 | return embeddings 45 | 46 | 47 | """ 48 | Evaluation of trained model on Transfer Tasks (SentEval) 49 | """ 50 | 51 | # define senteval params 52 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 53 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 54 | 'tenacity': 3, 'epoch_size': 2} 55 | # Set up logger 56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 57 | 58 | if __name__ == "__main__": 59 | # Load InferSent model 60 | params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048, 61 | 'pool_type': 'max', 'dpout_model': 0.0, 'version': V} 62 | model = InferSent(params_model) 63 | model.load_state_dict(torch.load(MODEL_PATH)) 64 | model.set_w2v_path(PATH_TO_W2V) 65 | 66 | params_senteval['infersent'] = model.cuda() 67 | 68 | se = senteval.engine.SE(params_senteval, batcher, prepare) 69 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 70 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 71 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 72 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 73 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 74 | 'OddManOut', 'CoordinationInversion'] 75 | results = se.eval(transfer_tasks) 76 | print(results) 77 | -------------------------------------------------------------------------------- /SentEval/senteval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import numpy as np 11 | import re 12 | import inspect 13 | from torch import optim 14 | 15 | 16 | def create_dictionary(sentences): 17 | words = {} 18 | for s in sentences: 19 | for word in s: 20 | if word in words: 21 | words[word] += 1 22 | else: 23 | words[word] = 1 24 | words[''] = 1e9 + 4 25 | words[''] = 1e9 + 3 26 | words['

'] = 1e9 + 2 27 | # words[''] = 1e9 + 1 28 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 29 | id2word = [] 30 | word2id = {} 31 | for i, (w, _) in enumerate(sorted_words): 32 | id2word.append(w) 33 | word2id[w] = i 34 | 35 | return id2word, word2id 36 | 37 | 38 | def cosine(u, v): 39 | return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)) 40 | 41 | 42 | class dotdict(dict): 43 | """ dot.notation access to dictionary attributes """ 44 | __getattr__ = dict.get 45 | __setattr__ = dict.__setitem__ 46 | __delattr__ = dict.__delitem__ 47 | 48 | 49 | def get_optimizer(s): 50 | """ 51 | Parse optimizer parameters. 52 | Input should be of the form: 53 | - "sgd,lr=0.01" 54 | - "adagrad,lr=0.1,lr_decay=0.05" 55 | """ 56 | if "," in s: 57 | method = s[:s.find(',')] 58 | optim_params = {} 59 | for x in s[s.find(',') + 1:].split(','): 60 | split = x.split('=') 61 | assert len(split) == 2 62 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 63 | optim_params[split[0]] = float(split[1]) 64 | else: 65 | method = s 66 | optim_params = {} 67 | 68 | if method == 'adadelta': 69 | optim_fn = optim.Adadelta 70 | elif method == 'adagrad': 71 | optim_fn = optim.Adagrad 72 | elif method == 'adam': 73 | optim_fn = optim.Adam 74 | elif method == 'adamax': 75 | optim_fn = optim.Adamax 76 | elif method == 'asgd': 77 | optim_fn = optim.ASGD 78 | elif method == 'rmsprop': 79 | optim_fn = optim.RMSprop 80 | elif method == 'rprop': 81 | optim_fn = optim.Rprop 82 | elif method == 'sgd': 83 | optim_fn = optim.SGD 84 | assert 'lr' in optim_params 85 | else: 86 | raise Exception('Unknown optimization method: "%s"' % method) 87 | 88 | # check that we give good parameters to the optimizer 89 | expected_args = inspect.getfullargspec(optim_fn.__init__)[0] 90 | assert expected_args[:2] == ['self', 'params'] 91 | if not all(k in expected_args[2:] for k in optim_params.keys()): 92 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 93 | str(expected_args[2:]), str(optim_params.keys()))) 94 | 95 | return optim_fn, optim_params 96 | -------------------------------------------------------------------------------- /SentEval/senteval/trec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | TREC question-type classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import KFoldClassifier 20 | 21 | 22 | class TRECEval(object): 23 | def __init__(self, task_path, seed=1111): 24 | logging.info('***** Transfer task : TREC *****\n\n') 25 | self.seed = seed 26 | self.train = self.loadFile(os.path.join(task_path, 'train_5500.label')) 27 | self.test = self.loadFile(os.path.join(task_path, 'TREC_10.label')) 28 | 29 | def do_prepare(self, params, prepare): 30 | samples = self.train['X'] + self.test['X'] 31 | return prepare(params, samples) 32 | 33 | def loadFile(self, fpath): 34 | trec_data = {'X': [], 'y': []} 35 | tgt2idx = {'ABBR': 0, 'DESC': 1, 'ENTY': 2, 36 | 'HUM': 3, 'LOC': 4, 'NUM': 5} 37 | with io.open(fpath, 'r', encoding='latin-1') as f: 38 | for line in f: 39 | target, sample = line.strip().split(':', 1) 40 | sample = sample.split(' ', 1)[1].split() 41 | assert target in tgt2idx, target 42 | trec_data['X'].append(sample) 43 | trec_data['y'].append(tgt2idx[target]) 44 | return trec_data 45 | 46 | def run(self, params, batcher): 47 | train_embeddings, test_embeddings = [], [] 48 | 49 | # Sort to reduce padding 50 | sorted_corpus_train = sorted(zip(self.train['X'], self.train['y']), 51 | key=lambda z: (len(z[0]), z[1])) 52 | train_samples = [x for (x, y) in sorted_corpus_train] 53 | train_labels = [y for (x, y) in sorted_corpus_train] 54 | 55 | sorted_corpus_test = sorted(zip(self.test['X'], self.test['y']), 56 | key=lambda z: (len(z[0]), z[1])) 57 | test_samples = [x for (x, y) in sorted_corpus_test] 58 | test_labels = [y for (x, y) in sorted_corpus_test] 59 | 60 | # Get train embeddings 61 | for ii in range(0, len(train_labels), params.batch_size): 62 | batch = train_samples[ii:ii + params.batch_size] 63 | embeddings = batcher(params, batch) 64 | train_embeddings.append(embeddings) 65 | train_embeddings = np.vstack(train_embeddings) 66 | logging.info('Computed train embeddings') 67 | 68 | # Get test embeddings 69 | for ii in range(0, len(test_labels), params.batch_size): 70 | batch = test_samples[ii:ii + params.batch_size] 71 | embeddings = batcher(params, batch) 72 | test_embeddings.append(embeddings) 73 | test_embeddings = np.vstack(test_embeddings) 74 | logging.info('Computed test embeddings') 75 | 76 | config_classifier = {'nclasses': 6, 'seed': self.seed, 77 | 'usepytorch': params.usepytorch, 78 | 'classifier': params.classifier, 79 | 'kfold': params.kfold} 80 | clf = KFoldClassifier({'X': train_embeddings, 81 | 'y': np.array(train_labels)}, 82 | {'X': test_embeddings, 83 | 'y': np.array(test_labels)}, 84 | config_classifier) 85 | devacc, testacc, _ = clf.run() 86 | logging.debug('\nDev acc : {0} Test acc : {1} \ 87 | for TREC\n'.format(devacc, testacc)) 88 | return {'devacc': devacc, 'acc': testacc, 89 | 'ndev': len(self.train['X']), 'ntest': len(self.test['X'])} 90 | -------------------------------------------------------------------------------- /SentEval/examples/bow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import sys 11 | import io 12 | import numpy as np 13 | import logging 14 | 15 | 16 | # Set PATHs 17 | PATH_TO_SENTEVAL = '../' 18 | PATH_TO_DATA = '../data' 19 | # PATH_TO_VEC = 'glove/glove.840B.300d.txt' 20 | PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec' 21 | 22 | # import SentEval 23 | sys.path.insert(0, PATH_TO_SENTEVAL) 24 | import senteval 25 | 26 | 27 | # Create dictionary 28 | def create_dictionary(sentences, threshold=0): 29 | words = {} 30 | for s in sentences: 31 | for word in s: 32 | words[word] = words.get(word, 0) + 1 33 | 34 | if threshold > 0: 35 | newwords = {} 36 | for word in words: 37 | if words[word] >= threshold: 38 | newwords[word] = words[word] 39 | words = newwords 40 | words[''] = 1e9 + 4 41 | words[''] = 1e9 + 3 42 | words['

'] = 1e9 + 2 43 | 44 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 45 | id2word = [] 46 | word2id = {} 47 | for i, (w, _) in enumerate(sorted_words): 48 | id2word.append(w) 49 | word2id[w] = i 50 | 51 | return id2word, word2id 52 | 53 | # Get word vectors from vocabulary (glove, word2vec, fasttext ..) 54 | def get_wordvec(path_to_vec, word2id): 55 | word_vec = {} 56 | 57 | with io.open(path_to_vec, 'r', encoding='utf-8') as f: 58 | # if word2vec or fasttext file : skip first line "next(f)" 59 | for line in f: 60 | word, vec = line.split(' ', 1) 61 | if word in word2id: 62 | word_vec[word] = np.fromstring(vec, sep=' ') 63 | 64 | logging.info('Found {0} words with word vectors, out of \ 65 | {1} words'.format(len(word_vec), len(word2id))) 66 | return word_vec 67 | 68 | 69 | # SentEval prepare and batcher 70 | def prepare(params, samples): 71 | _, params.word2id = create_dictionary(samples) 72 | params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id) 73 | params.wvec_dim = 300 74 | return 75 | 76 | def batcher(params, batch): 77 | batch = [sent if sent != [] else ['.'] for sent in batch] 78 | embeddings = [] 79 | 80 | for sent in batch: 81 | sentvec = [] 82 | for word in sent: 83 | if word in params.word_vec: 84 | sentvec.append(params.word_vec[word]) 85 | if not sentvec: 86 | vec = np.zeros(params.wvec_dim) 87 | sentvec.append(vec) 88 | sentvec = np.mean(sentvec, 0) 89 | embeddings.append(sentvec) 90 | 91 | embeddings = np.vstack(embeddings) 92 | return embeddings 93 | 94 | 95 | # Set params for SentEval 96 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 97 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 98 | 'tenacity': 3, 'epoch_size': 2} 99 | 100 | # Set up logger 101 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 102 | 103 | if __name__ == "__main__": 104 | se = senteval.engine.SE(params_senteval, batcher, prepare) 105 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 106 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 107 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 108 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 109 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 110 | 'OddManOut', 'CoordinationInversion'] 111 | results = se.eval(transfer_tasks) 112 | print(results) 113 | -------------------------------------------------------------------------------- /SentEval/senteval/binary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import io 14 | import os 15 | import numpy as np 16 | import logging 17 | 18 | from senteval.tools.validation import InnerKFoldClassifier 19 | 20 | 21 | class BinaryClassifierEval(object): 22 | def __init__(self, pos, neg, seed=1111): 23 | self.seed = seed 24 | self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg) 25 | self.n_samples = len(self.samples) 26 | 27 | def do_prepare(self, params, prepare): 28 | # prepare is given the whole text 29 | return prepare(params, self.samples) 30 | # prepare puts everything it outputs in "params" : params.word2id etc 31 | # Those output will be further used by "batcher". 32 | 33 | def loadFile(self, fpath): 34 | with io.open(fpath, 'r', encoding='latin-1') as f: 35 | return [line.split() for line in f.read().splitlines()] 36 | 37 | def run(self, params, batcher): 38 | enc_input = [] 39 | # Sort to reduce padding 40 | sorted_corpus = sorted(zip(self.samples, self.labels), 41 | key=lambda z: (len(z[0]), z[1])) 42 | sorted_samples = [x for (x, y) in sorted_corpus] 43 | sorted_labels = [y for (x, y) in sorted_corpus] 44 | logging.info('Generating sentence embeddings') 45 | for ii in range(0, self.n_samples, params.batch_size): 46 | batch = sorted_samples[ii:ii + params.batch_size] 47 | embeddings = batcher(params, batch) 48 | enc_input.append(embeddings) 49 | enc_input = np.vstack(enc_input) 50 | logging.info('Generated sentence embeddings') 51 | 52 | config = {'nclasses': 2, 'seed': self.seed, 53 | 'usepytorch': params.usepytorch, 54 | 'classifier': params.classifier, 55 | 'nhid': params.nhid, 'kfold': params.kfold} 56 | clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config) 57 | devacc, testacc = clf.run() 58 | logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc)) 59 | return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples, 60 | 'ntest': self.n_samples} 61 | 62 | 63 | class CREval(BinaryClassifierEval): 64 | def __init__(self, task_path, seed=1111): 65 | logging.debug('***** Transfer task : CR *****\n\n') 66 | pos = self.loadFile(os.path.join(task_path, 'custrev.pos')) 67 | neg = self.loadFile(os.path.join(task_path, 'custrev.neg')) 68 | super(self.__class__, self).__init__(pos, neg, seed) 69 | 70 | 71 | class MREval(BinaryClassifierEval): 72 | def __init__(self, task_path, seed=1111): 73 | logging.debug('***** Transfer task : MR *****\n\n') 74 | pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos')) 75 | neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg')) 76 | super(self.__class__, self).__init__(pos, neg, seed) 77 | 78 | 79 | class SUBJEval(BinaryClassifierEval): 80 | def __init__(self, task_path, seed=1111): 81 | logging.debug('***** Transfer task : SUBJ *****\n\n') 82 | obj = self.loadFile(os.path.join(task_path, 'subj.objective')) 83 | subj = self.loadFile(os.path.join(task_path, 'subj.subjective')) 84 | super(self.__class__, self).__init__(obj, subj, seed) 85 | 86 | 87 | class MPQAEval(BinaryClassifierEval): 88 | def __init__(self, task_path, seed=1111): 89 | logging.debug('***** Transfer task : MPQA *****\n\n') 90 | pos = self.loadFile(os.path.join(task_path, 'mpqa.pos')) 91 | neg = self.loadFile(os.path.join(task_path, 'mpqa.neg')) 92 | super(self.__class__, self).__init__(pos, neg, seed) 93 | -------------------------------------------------------------------------------- /SentEval/senteval/sst.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SST - binary classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import SplitClassifier 20 | 21 | 22 | class SSTEval(object): 23 | def __init__(self, task_path, nclasses=2, seed=1111): 24 | self.seed = seed 25 | 26 | # binary of fine-grained 27 | assert nclasses in [2, 5] 28 | self.nclasses = nclasses 29 | self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained' 30 | logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name) 31 | 32 | train = self.loadFile(os.path.join(task_path, 'sentiment-train')) 33 | dev = self.loadFile(os.path.join(task_path, 'sentiment-dev')) 34 | test = self.loadFile(os.path.join(task_path, 'sentiment-test')) 35 | self.sst_data = {'train': train, 'dev': dev, 'test': test} 36 | 37 | def do_prepare(self, params, prepare): 38 | samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \ 39 | self.sst_data['test']['X'] 40 | return prepare(params, samples) 41 | 42 | def loadFile(self, fpath): 43 | sst_data = {'X': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if self.nclasses == 2: 47 | sample = line.strip().split('\t') 48 | sst_data['y'].append(int(sample[1])) 49 | sst_data['X'].append(sample[0].split()) 50 | elif self.nclasses == 5: 51 | sample = line.strip().split(' ', 1) 52 | sst_data['y'].append(int(sample[0])) 53 | sst_data['X'].append(sample[1].split()) 54 | assert max(sst_data['y']) == self.nclasses - 1 55 | return sst_data 56 | 57 | def run(self, params, batcher): 58 | sst_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sst_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_data = sorted(zip(self.sst_data[key]['X'], 65 | self.sst_data[key]['y']), 66 | key=lambda z: (len(z[0]), z[1])) 67 | self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data)) 68 | 69 | sst_embed[key]['X'] = [] 70 | for ii in range(0, len(self.sst_data[key]['y']), bsize): 71 | batch = self.sst_data[key]['X'][ii:ii + bsize] 72 | embeddings = batcher(params, batch) 73 | sst_embed[key]['X'].append(embeddings) 74 | sst_embed[key]['X'] = np.vstack(sst_embed[key]['X']) 75 | sst_embed[key]['y'] = np.array(self.sst_data[key]['y']) 76 | logging.info('Computed {0} embeddings'.format(key)) 77 | 78 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 79 | 'usepytorch': params.usepytorch, 80 | 'classifier': params.classifier} 81 | 82 | clf = SplitClassifier(X={'train': sst_embed['train']['X'], 83 | 'valid': sst_embed['dev']['X'], 84 | 'test': sst_embed['test']['X']}, 85 | y={'train': sst_embed['train']['y'], 86 | 'valid': sst_embed['dev']['y'], 87 | 'test': sst_embed['test']['y']}, 88 | config=config_classifier) 89 | 90 | devacc, testacc = clf.run() 91 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 92 | SST {2} classification\n'.format(devacc, testacc, self.task_name)) 93 | 94 | return {'devacc': devacc, 'acc': testacc, 95 | 'ndev': len(sst_embed['dev']['X']), 96 | 'ntest': len(sst_embed['test']['X'])} 97 | -------------------------------------------------------------------------------- /SentEval/senteval/mrpc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | MRPC : Microsoft Research Paraphrase (detection) Corpus 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import logging 15 | import numpy as np 16 | import io 17 | 18 | from senteval.tools.validation import KFoldClassifier 19 | 20 | from sklearn.metrics import f1_score 21 | 22 | 23 | class MRPCEval(object): 24 | def __init__(self, task_path, seed=1111): 25 | logging.info('***** Transfer task : MRPC *****\n\n') 26 | self.seed = seed 27 | train = self.loadFile(os.path.join(task_path, 28 | 'msr_paraphrase_train.txt')) 29 | test = self.loadFile(os.path.join(task_path, 30 | 'msr_paraphrase_test.txt')) 31 | self.mrpc_data = {'train': train, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | # TODO : Should we separate samples in "train, test"? 35 | samples = self.mrpc_data['train']['X_A'] + \ 36 | self.mrpc_data['train']['X_B'] + \ 37 | self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B'] 38 | return prepare(params, samples) 39 | 40 | def loadFile(self, fpath): 41 | mrpc_data = {'X_A': [], 'X_B': [], 'y': []} 42 | with io.open(fpath, 'r', encoding='utf-8') as f: 43 | for line in f: 44 | text = line.strip().split('\t') 45 | mrpc_data['X_A'].append(text[3].split()) 46 | mrpc_data['X_B'].append(text[4].split()) 47 | mrpc_data['y'].append(text[0]) 48 | 49 | mrpc_data['X_A'] = mrpc_data['X_A'][1:] 50 | mrpc_data['X_B'] = mrpc_data['X_B'][1:] 51 | mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]] 52 | return mrpc_data 53 | 54 | def run(self, params, batcher): 55 | mrpc_embed = {'train': {}, 'test': {}} 56 | 57 | for key in self.mrpc_data: 58 | logging.info('Computing embedding for {0}'.format(key)) 59 | # Sort to reduce padding 60 | text_data = {} 61 | sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'], 62 | self.mrpc_data[key]['X_B'], 63 | self.mrpc_data[key]['y']), 64 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 65 | 66 | text_data['A'] = [x for (x, y, z) in sorted_corpus] 67 | text_data['B'] = [y for (x, y, z) in sorted_corpus] 68 | text_data['y'] = [z for (x, y, z) in sorted_corpus] 69 | 70 | for txt_type in ['A', 'B']: 71 | mrpc_embed[key][txt_type] = [] 72 | for ii in range(0, len(text_data['y']), params.batch_size): 73 | batch = text_data[txt_type][ii:ii + params.batch_size] 74 | embeddings = batcher(params, batch) 75 | mrpc_embed[key][txt_type].append(embeddings) 76 | mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type]) 77 | mrpc_embed[key]['y'] = np.array(text_data['y']) 78 | logging.info('Computed {0} embeddings'.format(key)) 79 | 80 | # Train 81 | trainA = mrpc_embed['train']['A'] 82 | trainB = mrpc_embed['train']['B'] 83 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 84 | trainY = mrpc_embed['train']['y'] 85 | 86 | # Test 87 | testA = mrpc_embed['test']['A'] 88 | testB = mrpc_embed['test']['B'] 89 | testF = np.c_[np.abs(testA - testB), testA * testB] 90 | testY = mrpc_embed['test']['y'] 91 | 92 | config = {'nclasses': 2, 'seed': self.seed, 93 | 'usepytorch': params.usepytorch, 94 | 'classifier': params.classifier, 95 | 'nhid': params.nhid, 'kfold': params.kfold} 96 | clf = KFoldClassifier(train={'X': trainF, 'y': trainY}, 97 | test={'X': testF, 'y': testY}, config=config) 98 | 99 | devacc, testacc, yhat = clf.run() 100 | testf1 = round(100*f1_score(testY, yhat), 2) 101 | logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n' 102 | .format(devacc, testacc, testf1)) 103 | return {'devacc': devacc, 'acc': testacc, 'f1': testf1, 104 | 'ndev': len(trainA), 'ntest': len(testA)} 105 | -------------------------------------------------------------------------------- /SentEval/senteval/snli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SNLI - Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import codecs 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class SNLIEval(object): 24 | def __init__(self, taskpath, seed=1111): 25 | logging.debug('***** Transfer task : SNLI Entailment*****\n\n') 26 | self.seed = seed 27 | train1 = self.loadFile(os.path.join(taskpath, 's1.train')) 28 | train2 = self.loadFile(os.path.join(taskpath, 's2.train')) 29 | 30 | trainlabels = io.open(os.path.join(taskpath, 'labels.train'), 31 | encoding='utf-8').read().splitlines() 32 | 33 | valid1 = self.loadFile(os.path.join(taskpath, 's1.dev')) 34 | valid2 = self.loadFile(os.path.join(taskpath, 's2.dev')) 35 | validlabels = io.open(os.path.join(taskpath, 'labels.dev'), 36 | encoding='utf-8').read().splitlines() 37 | 38 | test1 = self.loadFile(os.path.join(taskpath, 's1.test')) 39 | test2 = self.loadFile(os.path.join(taskpath, 's2.test')) 40 | testlabels = io.open(os.path.join(taskpath, 'labels.test'), 41 | encoding='utf-8').read().splitlines() 42 | 43 | # sort data (by s2 first) to reduce padding 44 | sorted_train = sorted(zip(train2, train1, trainlabels), 45 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 46 | train2, train1, trainlabels = map(list, zip(*sorted_train)) 47 | 48 | sorted_valid = sorted(zip(valid2, valid1, validlabels), 49 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 50 | valid2, valid1, validlabels = map(list, zip(*sorted_valid)) 51 | 52 | sorted_test = sorted(zip(test2, test1, testlabels), 53 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 54 | test2, test1, testlabels = map(list, zip(*sorted_test)) 55 | 56 | self.samples = train1 + train2 + valid1 + valid2 + test1 + test2 57 | self.data = {'train': (train1, train2, trainlabels), 58 | 'valid': (valid1, valid2, validlabels), 59 | 'test': (test1, test2, testlabels) 60 | } 61 | 62 | def do_prepare(self, params, prepare): 63 | return prepare(params, self.samples) 64 | 65 | def loadFile(self, fpath): 66 | with codecs.open(fpath, 'rb', 'latin-1') as f: 67 | return [line.split() for line in 68 | f.read().splitlines()] 69 | 70 | def run(self, params, batcher): 71 | self.X, self.y = {}, {} 72 | dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2} 73 | for key in self.data: 74 | if key not in self.X: 75 | self.X[key] = [] 76 | if key not in self.y: 77 | self.y[key] = [] 78 | 79 | input1, input2, mylabels = self.data[key] 80 | enc_input = [] 81 | n_labels = len(mylabels) 82 | for ii in range(0, n_labels, params.batch_size): 83 | batch1 = input1[ii:ii + params.batch_size] 84 | batch2 = input2[ii:ii + params.batch_size] 85 | 86 | if len(batch1) == len(batch2) and len(batch1) > 0: 87 | enc1 = batcher(params, batch1) 88 | enc2 = batcher(params, batch2) 89 | enc_input.append(np.hstack((enc1, enc2, enc1 * enc2, 90 | np.abs(enc1 - enc2)))) 91 | if (ii*params.batch_size) % (20000*params.batch_size) == 0: 92 | logging.info("PROGRESS (encoding): %.2f%%" % 93 | (100 * ii / n_labels)) 94 | self.X[key] = np.vstack(enc_input) 95 | self.y[key] = [dico_label[y] for y in mylabels] 96 | 97 | config = {'nclasses': 3, 'seed': self.seed, 98 | 'usepytorch': params.usepytorch, 99 | 'cudaEfficient': True, 100 | 'nhid': params.nhid, 'noreg': True} 101 | 102 | config_classifier = copy.deepcopy(params.classifier) 103 | config_classifier['max_epoch'] = 15 104 | config_classifier['epoch_size'] = 1 105 | config['classifier'] = config_classifier 106 | 107 | clf = SplitClassifier(self.X, self.y, config) 108 | devacc, testacc = clf.run() 109 | logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n' 110 | .format(devacc, testacc)) 111 | return {'devacc': devacc, 'acc': testacc, 112 | 'ndev': len(self.data['valid'][0]), 113 | 'ntest': len(self.data['test'][0])} 114 | -------------------------------------------------------------------------------- /SentEval/senteval/rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Image-Caption Retrieval with COCO dataset 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import sys 15 | import logging 16 | import numpy as np 17 | 18 | try: 19 | import cPickle as pickle 20 | except ImportError: 21 | import pickle 22 | 23 | from senteval.tools.ranking import ImageSentenceRankingPytorch 24 | 25 | 26 | class ImageCaptionRetrievalEval(object): 27 | def __init__(self, task_path, seed=1111): 28 | logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n') 29 | 30 | # Get captions and image features 31 | self.seed = seed 32 | train, dev, test = self.loadFile(task_path) 33 | self.coco_data = {'train': train, 'dev': dev, 'test': test} 34 | 35 | def do_prepare(self, params, prepare): 36 | samples = self.coco_data['train']['sent'] + \ 37 | self.coco_data['dev']['sent'] + \ 38 | self.coco_data['test']['sent'] 39 | prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | coco = {} 43 | 44 | for split in ['train', 'valid', 'test']: 45 | list_sent = [] 46 | list_img_feat = [] 47 | if sys.version_info < (3, 0): 48 | with open(os.path.join(fpath, split + '.pkl')) as f: 49 | cocodata = pickle.load(f) 50 | else: 51 | with open(os.path.join(fpath, split + '.pkl'), 'rb') as f: 52 | cocodata = pickle.load(f, encoding='latin1') 53 | 54 | for imgkey in range(len(cocodata['features'])): 55 | assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \ 56 | cocodata['image_to_caption_ids'][imgkey] 57 | for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]: 58 | sent = cocodata['captions'][captkey]['cleaned_caption'] 59 | sent += ' .' # add punctuation to end of sentence in COCO 60 | list_sent.append(sent.encode('utf-8').split()) 61 | list_img_feat.append(cocodata['features'][imgkey]) 62 | assert len(list_sent) == len(list_img_feat) and \ 63 | len(list_sent) % 5 == 0 64 | list_img_feat = np.array(list_img_feat).astype('float32') 65 | coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat} 66 | return coco['train'], coco['valid'], coco['test'] 67 | 68 | def run(self, params, batcher): 69 | coco_embed = {'train': {'sentfeat': [], 'imgfeat': []}, 70 | 'dev': {'sentfeat': [], 'imgfeat': []}, 71 | 'test': {'sentfeat': [], 'imgfeat': []}} 72 | 73 | for key in self.coco_data: 74 | logging.info('Computing embedding for {0}'.format(key)) 75 | # Sort to reduce padding 76 | self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent']) 77 | self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent']) 78 | idx_unsort = np.argsort(idx_sort) 79 | 80 | coco_embed[key]['X'] = [] 81 | nsent = len(self.coco_data[key]['sent']) 82 | for ii in range(0, nsent, params.batch_size): 83 | batch = self.coco_data[key]['sent'][ii:ii + params.batch_size] 84 | embeddings = batcher(params, batch) 85 | coco_embed[key]['sentfeat'].append(embeddings) 86 | coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort] 87 | coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat']) 88 | logging.info('Computed {0} embeddings'.format(key)) 89 | 90 | config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2} 91 | clf = ImageSentenceRankingPytorch(train=coco_embed['train'], 92 | valid=coco_embed['dev'], 93 | test=coco_embed['test'], 94 | config=config) 95 | 96 | bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \ 97 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run() 98 | 99 | logging.debug("\nTest scores | Image to text: \ 100 | {0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 101 | logging.debug("Test scores | Text to image: \ 102 | {0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 103 | 104 | return {'devacc': bestdevscore, 105 | 'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t), 106 | (r1_t2i, r5_t2i, r10_t2i, medr_t2i)], 107 | 'ndev': len(coco_embed['dev']['sentfeat']), 108 | 'ntest': len(coco_embed['test']['sentfeat'])} 109 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/relatedness.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Semantic Relatedness (supervised) with Pytorch 10 | """ 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import copy 14 | import numpy as np 15 | 16 | import torch 17 | from torch import nn 18 | import torch.optim as optim 19 | 20 | from scipy.stats import pearsonr, spearmanr 21 | 22 | 23 | class RelatednessPytorch(object): 24 | # Can be used for SICK-Relatedness, and STS14 25 | def __init__(self, train, valid, test, devscores, config): 26 | # fix seed 27 | np.random.seed(config['seed']) 28 | torch.manual_seed(config['seed']) 29 | assert torch.cuda.is_available(), 'torch.cuda required for Relatedness' 30 | torch.cuda.manual_seed(config['seed']) 31 | 32 | self.train = train 33 | self.valid = valid 34 | self.test = test 35 | self.devscores = devscores 36 | 37 | self.inputdim = train['X'].shape[1] 38 | self.nclasses = config['nclasses'] 39 | self.seed = config['seed'] 40 | self.l2reg = 0. 41 | self.batch_size = 64 42 | self.maxepoch = 1000 43 | self.early_stop = True 44 | 45 | self.model = nn.Sequential( 46 | nn.Linear(self.inputdim, self.nclasses), 47 | nn.Softmax(dim=-1), 48 | ) 49 | self.loss_fn = nn.MSELoss() 50 | 51 | if torch.cuda.is_available(): 52 | self.model = self.model.cuda() 53 | self.loss_fn = self.loss_fn.cuda() 54 | 55 | self.loss_fn.size_average = False 56 | self.optimizer = optim.Adam(self.model.parameters(), 57 | weight_decay=self.l2reg) 58 | 59 | def prepare_data(self, trainX, trainy, devX, devy, testX, testy): 60 | # Transform probs to log-probs for KL-divergence 61 | trainX = torch.from_numpy(trainX).float().cuda() 62 | trainy = torch.from_numpy(trainy).float().cuda() 63 | devX = torch.from_numpy(devX).float().cuda() 64 | devy = torch.from_numpy(devy).float().cuda() 65 | testX = torch.from_numpy(testX).float().cuda() 66 | testY = torch.from_numpy(testy).float().cuda() 67 | 68 | return trainX, trainy, devX, devy, testX, testy 69 | 70 | def run(self): 71 | self.nepoch = 0 72 | bestpr = -1 73 | early_stop_count = 0 74 | r = np.arange(1, 6) 75 | stop_train = False 76 | 77 | # Preparing data 78 | trainX, trainy, devX, devy, testX, testy = self.prepare_data( 79 | self.train['X'], self.train['y'], 80 | self.valid['X'], self.valid['y'], 81 | self.test['X'], self.test['y']) 82 | 83 | # Training 84 | while not stop_train and self.nepoch <= self.maxepoch: 85 | self.trainepoch(trainX, trainy, nepoches=50) 86 | yhat = np.dot(self.predict_proba(devX), r) 87 | pr = spearmanr(yhat, self.devscores)[0] 88 | pr = 0 if pr != pr else pr # if NaN bc std=0 89 | # early stop on Pearson 90 | if pr > bestpr: 91 | bestpr = pr 92 | bestmodel = copy.deepcopy(self.model) 93 | elif self.early_stop: 94 | if early_stop_count >= 3: 95 | stop_train = True 96 | early_stop_count += 1 97 | self.model = bestmodel 98 | 99 | yhat = np.dot(self.predict_proba(testX), r) 100 | 101 | return bestpr, yhat 102 | 103 | def trainepoch(self, X, y, nepoches=1): 104 | self.model.train() 105 | for _ in range(self.nepoch, self.nepoch + nepoches): 106 | permutation = np.random.permutation(len(X)) 107 | all_costs = [] 108 | for i in range(0, len(X), self.batch_size): 109 | # forward 110 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda() 111 | Xbatch = X[idx] 112 | ybatch = y[idx] 113 | output = self.model(Xbatch) 114 | # loss 115 | loss = self.loss_fn(output, ybatch) 116 | all_costs.append(loss.item()) 117 | # backward 118 | self.optimizer.zero_grad() 119 | loss.backward() 120 | # Update parameters 121 | self.optimizer.step() 122 | self.nepoch += nepoches 123 | 124 | def predict_proba(self, devX): 125 | self.model.eval() 126 | probas = [] 127 | with torch.no_grad(): 128 | for i in range(0, len(devX), self.batch_size): 129 | Xbatch = devX[i:i + self.batch_size] 130 | if len(probas) == 0: 131 | probas = self.model(Xbatch).data.cpu().numpy() 132 | else: 133 | probas = np.concatenate((probas, self.model(Xbatch).data.cpu().numpy()), axis=0) 134 | return probas 135 | -------------------------------------------------------------------------------- /SentEval/senteval/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | 10 | Generic sentence evaluation scripts wrapper 11 | 12 | ''' 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | from senteval import utils 16 | from senteval.binary import CREval, MREval, MPQAEval, SUBJEval 17 | from senteval.snli import SNLIEval 18 | from senteval.trec import TRECEval 19 | from senteval.sick import SICKEntailmentEval, SICKEval 20 | from senteval.mrpc import MRPCEval 21 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune 22 | from senteval.sst import SSTEval 23 | from senteval.rank import ImageCaptionRetrievalEval 24 | from senteval.probing import * 25 | 26 | class SE(object): 27 | def __init__(self, params, batcher, prepare=None): 28 | # parameters 29 | params = utils.dotdict(params) 30 | params.usepytorch = True if 'usepytorch' not in params else params.usepytorch 31 | params.seed = 1111 if 'seed' not in params else params.seed 32 | 33 | params.batch_size = 128 if 'batch_size' not in params else params.batch_size 34 | params.nhid = 0 if 'nhid' not in params else params.nhid 35 | params.kfold = 5 if 'kfold' not in params else params.kfold 36 | 37 | if 'classifier' not in params or not params['classifier']: 38 | params.classifier = {'nhid': 0} 39 | 40 | assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!' 41 | 42 | self.params = params 43 | 44 | # batcher and prepare 45 | self.batcher = batcher 46 | self.prepare = prepare if prepare else lambda x, y: None 47 | 48 | self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 49 | 'SICKRelatedness', 'SICKEntailment', 'STSBenchmark', 50 | 'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13', 51 | 'STS14', 'STS15', 'STS16', 52 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 53 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 54 | 'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix'] 55 | 56 | def eval(self, name): 57 | # evaluate on evaluation [name], either takes string or list of strings 58 | if (isinstance(name, list)): 59 | self.results = {x: self.eval(x) for x in name} 60 | return self.results 61 | 62 | tpath = self.params.task_path 63 | assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks) 64 | 65 | # Original SentEval tasks 66 | if name == 'CR': 67 | self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed) 68 | elif name == 'MR': 69 | self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed) 70 | elif name == 'MPQA': 71 | self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed) 72 | elif name == 'SUBJ': 73 | self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed) 74 | elif name == 'SST2': 75 | self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed) 76 | elif name == 'SST5': 77 | self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed) 78 | elif name == 'TREC': 79 | self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed) 80 | elif name == 'MRPC': 81 | self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed) 82 | elif name == 'SICKRelatedness': 83 | self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed) 84 | elif name == 'STSBenchmark': 85 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 86 | elif name == 'STSBenchmark-fix': 87 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed) 88 | elif name == 'STSBenchmark-finetune': 89 | self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 90 | elif name == 'SICKRelatedness-finetune': 91 | self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed) 92 | elif name == 'SICKEntailment': 93 | self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed) 94 | elif name == 'SNLI': 95 | self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed) 96 | elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 97 | fpath = name + '-en-test' 98 | self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed) 99 | elif name == 'ImageCaptionRetrieval': 100 | self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed) 101 | 102 | # Probing Tasks 103 | elif name == 'Length': 104 | self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed) 105 | elif name == 'WordContent': 106 | self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed) 107 | elif name == 'Depth': 108 | self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed) 109 | elif name == 'TopConstituents': 110 | self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed) 111 | elif name == 'BigramShift': 112 | self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed) 113 | elif name == 'Tense': 114 | self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed) 115 | elif name == 'SubjNumber': 116 | self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed) 117 | elif name == 'ObjNumber': 118 | self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed) 119 | elif name == 'OddManOut': 120 | self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed) 121 | elif name == 'CoordinationInversion': 122 | self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed) 123 | 124 | self.params.current_task = name 125 | self.evaluation.do_prepare(self.params, self.prepare) 126 | 127 | self.results = self.evaluation.run(self.params, self.batcher) 128 | 129 | return self.results 130 | -------------------------------------------------------------------------------- /SentEval/senteval/probing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | probing tasks 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class PROBINGEval(object): 24 | def __init__(self, task, task_path, seed=1111): 25 | self.seed = seed 26 | self.task = task 27 | logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper()) 28 | self.task_data = {'train': {'X': [], 'y': []}, 29 | 'dev': {'X': [], 'y': []}, 30 | 'test': {'X': [], 'y': []}} 31 | self.loadFile(task_path) 32 | logging.info('Loaded %s train - %s dev - %s test for %s' % 33 | (len(self.task_data['train']['y']), len(self.task_data['dev']['y']), 34 | len(self.task_data['test']['y']), self.task)) 35 | 36 | def do_prepare(self, params, prepare): 37 | samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \ 38 | self.task_data['test']['X'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'} 43 | with io.open(fpath, 'r', encoding='utf-8') as f: 44 | for line in f: 45 | line = line.rstrip().split('\t') 46 | self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split()) 47 | self.task_data[self.tok2split[line[0]]]['y'].append(line[1]) 48 | 49 | labels = sorted(np.unique(self.task_data['train']['y'])) 50 | self.tok2label = dict(zip(labels, range(len(labels)))) 51 | self.nclasses = len(self.tok2label) 52 | 53 | for split in self.task_data: 54 | for i, y in enumerate(self.task_data[split]['y']): 55 | self.task_data[split]['y'][i] = self.tok2label[y] 56 | 57 | def run(self, params, batcher): 58 | task_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | logging.info('Computing embeddings for train/dev/test') 61 | for key in self.task_data: 62 | # Sort to reduce padding 63 | sorted_data = sorted(zip(self.task_data[key]['X'], 64 | self.task_data[key]['y']), 65 | key=lambda z: (len(z[0]), z[1])) 66 | self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data)) 67 | 68 | task_embed[key]['X'] = [] 69 | for ii in range(0, len(self.task_data[key]['y']), bsize): 70 | batch = self.task_data[key]['X'][ii:ii + bsize] 71 | embeddings = batcher(params, batch) 72 | task_embed[key]['X'].append(embeddings) 73 | task_embed[key]['X'] = np.vstack(task_embed[key]['X']) 74 | task_embed[key]['y'] = np.array(self.task_data[key]['y']) 75 | logging.info('Computed embeddings') 76 | 77 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 78 | 'usepytorch': params.usepytorch, 79 | 'classifier': params.classifier} 80 | 81 | if self.task == "WordContent" and params.classifier['nhid'] > 0: 82 | config_classifier = copy.deepcopy(config_classifier) 83 | config_classifier['classifier']['nhid'] = 0 84 | print(params.classifier['nhid']) 85 | 86 | clf = SplitClassifier(X={'train': task_embed['train']['X'], 87 | 'valid': task_embed['dev']['X'], 88 | 'test': task_embed['test']['X']}, 89 | y={'train': task_embed['train']['y'], 90 | 'valid': task_embed['dev']['y'], 91 | 'test': task_embed['test']['y']}, 92 | config=config_classifier) 93 | 94 | devacc, testacc = clf.run() 95 | logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper())) 96 | 97 | return {'devacc': devacc, 'acc': testacc, 98 | 'ndev': len(task_embed['dev']['X']), 99 | 'ntest': len(task_embed['test']['X'])} 100 | 101 | """ 102 | Surface Information 103 | """ 104 | class LengthEval(PROBINGEval): 105 | def __init__(self, task_path, seed=1111): 106 | task_path = os.path.join(task_path, 'sentence_length.txt') 107 | # labels: bins 108 | PROBINGEval.__init__(self, 'Length', task_path, seed) 109 | 110 | class WordContentEval(PROBINGEval): 111 | def __init__(self, task_path, seed=1111): 112 | task_path = os.path.join(task_path, 'word_content.txt') 113 | # labels: 200 target words 114 | PROBINGEval.__init__(self, 'WordContent', task_path, seed) 115 | 116 | """ 117 | Latent Structural Information 118 | """ 119 | class DepthEval(PROBINGEval): 120 | def __init__(self, task_path, seed=1111): 121 | task_path = os.path.join(task_path, 'tree_depth.txt') 122 | # labels: bins 123 | PROBINGEval.__init__(self, 'Depth', task_path, seed) 124 | 125 | class TopConstituentsEval(PROBINGEval): 126 | def __init__(self, task_path, seed=1111): 127 | task_path = os.path.join(task_path, 'top_constituents.txt') 128 | # labels: 'PP_NP_VP_.' .. (20 classes) 129 | PROBINGEval.__init__(self, 'TopConstituents', task_path, seed) 130 | 131 | class BigramShiftEval(PROBINGEval): 132 | def __init__(self, task_path, seed=1111): 133 | task_path = os.path.join(task_path, 'bigram_shift.txt') 134 | # labels: 0 or 1 135 | PROBINGEval.__init__(self, 'BigramShift', task_path, seed) 136 | 137 | # TODO: Voice? 138 | 139 | """ 140 | Latent Semantic Information 141 | """ 142 | 143 | class TenseEval(PROBINGEval): 144 | def __init__(self, task_path, seed=1111): 145 | task_path = os.path.join(task_path, 'past_present.txt') 146 | # labels: 'PRES', 'PAST' 147 | PROBINGEval.__init__(self, 'Tense', task_path, seed) 148 | 149 | class SubjNumberEval(PROBINGEval): 150 | def __init__(self, task_path, seed=1111): 151 | task_path = os.path.join(task_path, 'subj_number.txt') 152 | # labels: 'NN', 'NNS' 153 | PROBINGEval.__init__(self, 'SubjNumber', task_path, seed) 154 | 155 | class ObjNumberEval(PROBINGEval): 156 | def __init__(self, task_path, seed=1111): 157 | task_path = os.path.join(task_path, 'obj_number.txt') 158 | # labels: 'NN', 'NNS' 159 | PROBINGEval.__init__(self, 'ObjNumber', task_path, seed) 160 | 161 | class OddManOutEval(PROBINGEval): 162 | def __init__(self, task_path, seed=1111): 163 | task_path = os.path.join(task_path, 'odd_man_out.txt') 164 | # labels: 'O', 'C' 165 | PROBINGEval.__init__(self, 'OddManOut', task_path, seed) 166 | 167 | class CoordinationInversionEval(PROBINGEval): 168 | def __init__(self, task_path, seed=1111): 169 | task_path = os.path.join(task_path, 'coordination_inversion.txt') 170 | # labels: 'O', 'I' 171 | PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed) 172 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Pytorch Classifier class in the style of scikit-learn 10 | Classifiers include Logistic Regression and MLP 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import numpy as np 16 | import copy 17 | from senteval import utils 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | 23 | 24 | class PyTorchClassifier(object): 25 | def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111, 26 | cudaEfficient=False): 27 | # fix seed 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | 32 | self.inputdim = inputdim 33 | self.nclasses = nclasses 34 | self.l2reg = l2reg 35 | self.batch_size = batch_size 36 | self.cudaEfficient = cudaEfficient 37 | 38 | def prepare_split(self, X, y, validation_data=None, validation_split=None): 39 | # Preparing validation data 40 | assert validation_split or validation_data 41 | if validation_data is not None: 42 | trainX, trainy = X, y 43 | devX, devy = validation_data 44 | else: 45 | permutation = np.random.permutation(len(X)) 46 | trainidx = permutation[int(validation_split * len(X)):] 47 | devidx = permutation[0:int(validation_split * len(X))] 48 | trainX, trainy = X[trainidx], y[trainidx] 49 | devX, devy = X[devidx], y[devidx] 50 | 51 | device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda') 52 | 53 | trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32) 54 | trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64) 55 | devX = torch.from_numpy(devX).to(device, dtype=torch.float32) 56 | devy = torch.from_numpy(devy).to(device, dtype=torch.int64) 57 | 58 | return trainX, trainy, devX, devy 59 | 60 | def fit(self, X, y, validation_data=None, validation_split=None, 61 | early_stop=True): 62 | self.nepoch = 0 63 | bestaccuracy = -1 64 | stop_train = False 65 | early_stop_count = 0 66 | 67 | # Preparing validation data 68 | trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data, 69 | validation_split) 70 | 71 | # Training 72 | while not stop_train and self.nepoch <= self.max_epoch: 73 | self.trainepoch(trainX, trainy, epoch_size=self.epoch_size) 74 | accuracy = self.score(devX, devy) 75 | if accuracy > bestaccuracy: 76 | bestaccuracy = accuracy 77 | bestmodel = copy.deepcopy(self.model) 78 | elif early_stop: 79 | if early_stop_count >= self.tenacity: 80 | stop_train = True 81 | early_stop_count += 1 82 | self.model = bestmodel 83 | return bestaccuracy 84 | 85 | def trainepoch(self, X, y, epoch_size=1): 86 | self.model.train() 87 | for _ in range(self.nepoch, self.nepoch + epoch_size): 88 | permutation = np.random.permutation(len(X)) 89 | all_costs = [] 90 | for i in range(0, len(X), self.batch_size): 91 | # forward 92 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device) 93 | 94 | Xbatch = X[idx] 95 | ybatch = y[idx] 96 | 97 | if self.cudaEfficient: 98 | Xbatch = Xbatch.cuda() 99 | ybatch = ybatch.cuda() 100 | output = self.model(Xbatch) 101 | # loss 102 | loss = self.loss_fn(output, ybatch) 103 | all_costs.append(loss.data.item()) 104 | # backward 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | # Update parameters 108 | self.optimizer.step() 109 | self.nepoch += epoch_size 110 | 111 | def score(self, devX, devy): 112 | self.model.eval() 113 | correct = 0 114 | if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient: 115 | devX = torch.FloatTensor(devX).cuda() 116 | devy = torch.LongTensor(devy).cuda() 117 | with torch.no_grad(): 118 | for i in range(0, len(devX), self.batch_size): 119 | Xbatch = devX[i:i + self.batch_size] 120 | ybatch = devy[i:i + self.batch_size] 121 | if self.cudaEfficient: 122 | Xbatch = Xbatch.cuda() 123 | ybatch = ybatch.cuda() 124 | output = self.model(Xbatch) 125 | pred = output.data.max(1)[1] 126 | correct += pred.long().eq(ybatch.data.long()).sum().item() 127 | accuracy = 1.0 * correct / len(devX) 128 | return accuracy 129 | 130 | def predict(self, devX): 131 | self.model.eval() 132 | if not isinstance(devX, torch.cuda.FloatTensor): 133 | devX = torch.FloatTensor(devX).cuda() 134 | yhat = np.array([]) 135 | with torch.no_grad(): 136 | for i in range(0, len(devX), self.batch_size): 137 | Xbatch = devX[i:i + self.batch_size] 138 | output = self.model(Xbatch) 139 | yhat = np.append(yhat, 140 | output.data.max(1)[1].cpu().numpy()) 141 | yhat = np.vstack(yhat) 142 | return yhat 143 | 144 | def predict_proba(self, devX): 145 | self.model.eval() 146 | probas = [] 147 | with torch.no_grad(): 148 | for i in range(0, len(devX), self.batch_size): 149 | Xbatch = devX[i:i + self.batch_size] 150 | vals = F.softmax(self.model(Xbatch).data.cpu().numpy()) 151 | if not probas: 152 | probas = vals 153 | else: 154 | probas = np.concatenate(probas, vals, axis=0) 155 | return probas 156 | 157 | 158 | """ 159 | MLP with Pytorch (nhid=0 --> Logistic Regression) 160 | """ 161 | 162 | class MLP(PyTorchClassifier): 163 | def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64, 164 | seed=1111, cudaEfficient=False): 165 | super(self.__class__, self).__init__(inputdim, nclasses, l2reg, 166 | batch_size, seed, cudaEfficient) 167 | """ 168 | PARAMETERS: 169 | -nhid: number of hidden units (0: Logistic Regression) 170 | -optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 171 | -tenacity: how many times dev acc does not increase before stopping 172 | -epoch_size: each epoch corresponds to epoch_size pass on the train set 173 | -max_epoch: max number of epoches 174 | -dropout: dropout for MLP 175 | """ 176 | 177 | self.nhid = 0 if "nhid" not in params else params["nhid"] 178 | self.optim = "adam" if "optim" not in params else params["optim"] 179 | self.tenacity = 5 if "tenacity" not in params else params["tenacity"] 180 | self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"] 181 | self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"] 182 | self.dropout = 0. if "dropout" not in params else params["dropout"] 183 | self.batch_size = 64 if "batch_size" not in params else params["batch_size"] 184 | 185 | if params["nhid"] == 0: 186 | self.model = nn.Sequential( 187 | nn.Linear(self.inputdim, self.nclasses), 188 | ).cuda() 189 | else: 190 | self.model = nn.Sequential( 191 | nn.Linear(self.inputdim, params["nhid"]), 192 | nn.Dropout(p=self.dropout), 193 | nn.Sigmoid(), 194 | nn.Linear(params["nhid"], self.nclasses), 195 | ).cuda() 196 | 197 | self.loss_fn = nn.CrossEntropyLoss().cuda() 198 | self.loss_fn.size_average = False 199 | 200 | optim_fn, optim_params = utils.get_optimizer(self.optim) 201 | self.optimizer = optim_fn(self.model.parameters(), **optim_params) 202 | self.optimizer.param_groups[0]['weight_decay'] = self.l2reg 203 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io, os 3 | import numpy as np 4 | import logging 5 | import argparse 6 | from prettytable import PrettyTable 7 | import torch 8 | import transformers 9 | from transformers import AutoModel, AutoTokenizer 10 | 11 | # Set up logger 12 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 13 | 14 | # Set PATHs 15 | PATH_TO_SENTEVAL = './SentEval' 16 | PATH_TO_DATA = './SentEval/data' 17 | 18 | # Import SentEval 19 | sys.path.insert(0, PATH_TO_SENTEVAL) 20 | import senteval 21 | 22 | 23 | def print_table(task_names, scores): 24 | tb = PrettyTable() 25 | tb.field_names = task_names 26 | tb.add_row(scores) 27 | print(tb) 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--model_name_or_path", type=str, 33 | help="Transformers' model name or path") 34 | parser.add_argument("--pooler", type=str, 35 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 36 | default='cls', 37 | help="Which pooler to use") 38 | parser.add_argument("--mode", type=str, 39 | choices=['dev', 'test', 'fasttest'], 40 | default='test', 41 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 42 | parser.add_argument("--task_set", type=str, 43 | choices=['sts', 'transfer', 'full', 'na'], 44 | default='sts', 45 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 46 | parser.add_argument("--tasks", type=str, nargs='+', 47 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 48 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC', 49 | 'SICKRelatedness', 'STSBenchmark'], 50 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden") 51 | 52 | args = parser.parse_args() 53 | 54 | # Load transformers' model checkpoint 55 | model = AutoModel.from_pretrained(args.model_name_or_path) 56 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 58 | model = model.to(device) 59 | 60 | # Set up the tasks 61 | if args.task_set == 'sts': 62 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 63 | elif args.task_set == 'transfer': 64 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 65 | elif args.task_set == 'full': 66 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 67 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 68 | 69 | # Set params for SentEval 70 | if args.mode == 'dev' or args.mode == 'fasttest': 71 | # Fast mode 72 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 73 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 74 | 'tenacity': 3, 'epoch_size': 2} 75 | elif args.mode == 'test': 76 | # Full mode 77 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 78 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 79 | 'tenacity': 5, 'epoch_size': 4} 80 | else: 81 | raise NotImplementedError 82 | 83 | # SentEval prepare and batcher 84 | def prepare(params, samples): 85 | return 86 | 87 | def batcher(params, batch, max_length=None): 88 | # Handle rare token encoding issues in the dataset 89 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 90 | batch = [[word.decode('utf-8') for word in s] for s in batch] 91 | 92 | sentences = [' '.join(s) for s in batch] 93 | 94 | # Tokenization 95 | if max_length is not None: 96 | batch = tokenizer.batch_encode_plus( 97 | sentences, 98 | return_tensors='pt', 99 | padding=True, 100 | max_length=max_length, 101 | truncation=True 102 | ) 103 | else: 104 | batch = tokenizer.batch_encode_plus( 105 | sentences, 106 | return_tensors='pt', 107 | padding=True, 108 | ) 109 | 110 | # Move to the correct device 111 | for k in batch: 112 | batch[k] = batch[k].to(device) 113 | 114 | # Get raw embeddings 115 | with torch.no_grad(): 116 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 117 | last_hidden = outputs.last_hidden_state 118 | pooler_output = outputs.pooler_output 119 | hidden_states = outputs.hidden_states 120 | 121 | # Apply different poolers 122 | if args.pooler == 'cls': 123 | # There is a linear+activation layer after CLS representation 124 | return pooler_output.cpu() 125 | elif args.pooler == 'cls_before_pooler': 126 | return last_hidden[:, 0].cpu() 127 | elif args.pooler == "avg": 128 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum( 129 | -1).unsqueeze(-1)).cpu() 130 | elif args.pooler == "avg_first_last": 131 | first_hidden = hidden_states[1] 132 | last_hidden = hidden_states[-1] 133 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch[ 134 | 'attention_mask'].sum(-1).unsqueeze(-1) 135 | return pooled_result.cpu() 136 | elif args.pooler == "avg_top2": 137 | second_last_hidden = hidden_states[-2] 138 | last_hidden = hidden_states[-1] 139 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / \ 140 | batch['attention_mask'].sum(-1).unsqueeze(-1) 141 | return pooled_result.cpu() 142 | else: 143 | raise NotImplementedError 144 | 145 | results = {} 146 | 147 | for task in args.tasks: 148 | se = senteval.engine.SE(params, batcher, prepare) 149 | result = se.eval(task) 150 | results[task] = result 151 | 152 | # Print evaluation results 153 | if args.mode == 'dev': 154 | print("------ %s ------" % (args.mode)) 155 | 156 | task_names = [] 157 | scores = [] 158 | for task in ['STSBenchmark', 'SICKRelatedness']: 159 | task_names.append(task) 160 | if task in results: 161 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100)) 162 | else: 163 | scores.append("0.00") 164 | print_table(task_names, scores) 165 | 166 | task_names = [] 167 | scores = [] 168 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 169 | task_names.append(task) 170 | if task in results: 171 | scores.append("%.2f" % (results[task]['devacc'])) 172 | else: 173 | scores.append("0.00") 174 | task_names.append("Avg.") 175 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 176 | print_table(task_names, scores) 177 | 178 | elif args.mode == 'test' or args.mode == 'fasttest': 179 | print("------ %s ------" % (args.mode)) 180 | 181 | task_names = [] 182 | scores = [] 183 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 184 | task_names.append(task) 185 | if task in results: 186 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 187 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 188 | else: 189 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 190 | else: 191 | scores.append("0.00") 192 | task_names.append("Avg.") 193 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 194 | print_table(task_names, scores) 195 | 196 | task_names = [] 197 | scores = [] 198 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 199 | task_names.append(task) 200 | if task in results: 201 | scores.append("%.2f" % (results[task]['acc'])) 202 | else: 203 | scores.append("0.00") 204 | task_names.append("Avg.") 205 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 206 | print_table(task_names, scores) 207 | 208 | 209 | if __name__ == "__main__": 210 | main() -------------------------------------------------------------------------------- /SentEval/senteval/sick.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SICK Relatedness and Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import io 15 | import logging 16 | import numpy as np 17 | 18 | from sklearn.metrics import mean_squared_error 19 | from scipy.stats import pearsonr, spearmanr 20 | 21 | from senteval.tools.relatedness import RelatednessPytorch 22 | from senteval.tools.validation import SplitClassifier 23 | 24 | class SICKEval(object): 25 | def __init__(self, task_path, seed=1111): 26 | logging.debug('***** Transfer task : SICK-Relatedness*****\n\n') 27 | self.seed = seed 28 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 29 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 30 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 31 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | samples = self.sick_data['train']['X_A'] + \ 35 | self.sick_data['train']['X_B'] + \ 36 | self.sick_data['dev']['X_A'] + \ 37 | self.sick_data['dev']['X_B'] + \ 38 | self.sick_data['test']['X_A'] + self.sick_data['test']['X_B'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | skipFirstLine = True 43 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if skipFirstLine: 47 | skipFirstLine = False 48 | else: 49 | text = line.strip().split('\t') 50 | sick_data['X_A'].append(text[1].split()) 51 | sick_data['X_B'].append(text[2].split()) 52 | sick_data['y'].append(text[3]) 53 | 54 | sick_data['y'] = [float(s) for s in sick_data['y']] 55 | return sick_data 56 | 57 | def run(self, params, batcher): 58 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sick_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 65 | self.sick_data[key]['X_B'], 66 | self.sick_data[key]['y']), 67 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 68 | 69 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 70 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 71 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 72 | 73 | for txt_type in ['X_A', 'X_B']: 74 | sick_embed[key][txt_type] = [] 75 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 76 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 77 | embeddings = batcher(params, batch) 78 | sick_embed[key][txt_type].append(embeddings) 79 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 80 | sick_embed[key]['y'] = np.array(self.sick_data[key]['y']) 81 | logging.info('Computed {0} embeddings'.format(key)) 82 | 83 | # Train 84 | trainA = sick_embed['train']['X_A'] 85 | trainB = sick_embed['train']['X_B'] 86 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 87 | trainY = self.encode_labels(self.sick_data['train']['y']) 88 | 89 | # Dev 90 | devA = sick_embed['dev']['X_A'] 91 | devB = sick_embed['dev']['X_B'] 92 | devF = np.c_[np.abs(devA - devB), devA * devB] 93 | devY = self.encode_labels(self.sick_data['dev']['y']) 94 | 95 | # Test 96 | testA = sick_embed['test']['X_A'] 97 | testB = sick_embed['test']['X_B'] 98 | testF = np.c_[np.abs(testA - testB), testA * testB] 99 | testY = self.encode_labels(self.sick_data['test']['y']) 100 | 101 | config = {'seed': self.seed, 'nclasses': 5} 102 | clf = RelatednessPytorch(train={'X': trainF, 'y': trainY}, 103 | valid={'X': devF, 'y': devY}, 104 | test={'X': testF, 'y': testY}, 105 | devscores=self.sick_data['dev']['y'], 106 | config=config) 107 | 108 | devspr, yhat = clf.run() 109 | 110 | pr = pearsonr(yhat, self.sick_data['test']['y'])[0] 111 | sr = spearmanr(yhat, self.sick_data['test']['y'])[0] 112 | pr = 0 if pr != pr else pr 113 | sr = 0 if sr != sr else sr 114 | se = mean_squared_error(yhat, self.sick_data['test']['y']) 115 | logging.debug('Dev : Spearman {0}'.format(devspr)) 116 | logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \ 117 | for SICK Relatedness\n'.format(pr, sr, se)) 118 | 119 | return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se, 120 | 'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)} 121 | 122 | def encode_labels(self, labels, nclass=5): 123 | """ 124 | Label encoding from Tree LSTM paper (Tai, Socher, Manning) 125 | """ 126 | Y = np.zeros((len(labels), nclass)).astype('float32') 127 | for j, y in enumerate(labels): 128 | for i in range(nclass): 129 | if i+1 == np.floor(y) + 1: 130 | Y[j, i] = y - np.floor(y) 131 | if i+1 == np.floor(y): 132 | Y[j, i] = np.floor(y) - y + 1 133 | return Y 134 | 135 | 136 | class SICKEntailmentEval(SICKEval): 137 | def __init__(self, task_path, seed=1111): 138 | logging.debug('***** Transfer task : SICK-Entailment*****\n\n') 139 | self.seed = seed 140 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 141 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 142 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 143 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 144 | 145 | def loadFile(self, fpath): 146 | label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2} 147 | skipFirstLine = True 148 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 149 | with io.open(fpath, 'r', encoding='utf-8') as f: 150 | for line in f: 151 | if skipFirstLine: 152 | skipFirstLine = False 153 | else: 154 | text = line.strip().split('\t') 155 | sick_data['X_A'].append(text[1].split()) 156 | sick_data['X_B'].append(text[2].split()) 157 | sick_data['y'].append(text[4]) 158 | sick_data['y'] = [label2id[s] for s in sick_data['y']] 159 | return sick_data 160 | 161 | def run(self, params, batcher): 162 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 163 | bsize = params.batch_size 164 | for key in self.sick_data: 165 | logging.info('Computing embedding for {0}'.format(key)) 166 | # Sort to reduce padding 167 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 168 | self.sick_data[key]['X_B'], 169 | self.sick_data[key]['y']), 170 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 171 | 172 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 173 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 174 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 175 | 176 | for txt_type in ['X_A', 'X_B']: 177 | sick_embed[key][txt_type] = [] 178 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 179 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 180 | embeddings = batcher(params, batch) 181 | sick_embed[key][txt_type].append(embeddings) 182 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 183 | logging.info('Computed {0} embeddings'.format(key)) 184 | 185 | # Train 186 | trainA = sick_embed['train']['X_A'] 187 | trainB = sick_embed['train']['X_B'] 188 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 189 | trainY = np.array(self.sick_data['train']['y']) 190 | 191 | # Dev 192 | devA = sick_embed['dev']['X_A'] 193 | devB = sick_embed['dev']['X_B'] 194 | devF = np.c_[np.abs(devA - devB), devA * devB] 195 | devY = np.array(self.sick_data['dev']['y']) 196 | 197 | # Test 198 | testA = sick_embed['test']['X_A'] 199 | testB = sick_embed['test']['X_B'] 200 | testF = np.c_[np.abs(testA - testB), testA * testB] 201 | testY = np.array(self.sick_data['test']['y']) 202 | 203 | config = {'nclasses': 3, 'seed': self.seed, 204 | 'usepytorch': params.usepytorch, 205 | 'classifier': params.classifier, 206 | 'nhid': params.nhid} 207 | clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF}, 208 | y={'train': trainY, 'valid': devY, 'test': testY}, 209 | config=config) 210 | 211 | devacc, testacc = clf.run() 212 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 213 | SICK entailment\n'.format(devacc, testacc)) 214 | return {'devacc': devacc, 'acc': testacc, 215 | 'ndev': len(devA), 'ntest': len(testA)} 216 | -------------------------------------------------------------------------------- /SentEval/senteval/sts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and 10 | STS-benchmark (supervised) tasks 11 | ''' 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import os 16 | import io 17 | import numpy as np 18 | import logging 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | 22 | from senteval.utils import cosine 23 | from senteval.sick import SICKEval 24 | 25 | 26 | class STSEval(object): 27 | def loadFile(self, fpath): 28 | self.data = {} 29 | self.samples = [] 30 | 31 | for dataset in self.datasets: 32 | sent1, sent2 = zip(*[l.split("\t") for l in 33 | io.open(fpath + '/STS.input.%s.txt' % dataset, 34 | encoding='utf8').read().splitlines()]) 35 | raw_scores = np.array([x for x in 36 | io.open(fpath + '/STS.gs.%s.txt' % dataset, 37 | encoding='utf8') 38 | .read().splitlines()]) 39 | not_empty_idx = raw_scores != '' 40 | 41 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]] 42 | sent1 = np.array([s.split() for s in sent1])[not_empty_idx] 43 | sent2 = np.array([s.split() for s in sent2])[not_empty_idx] 44 | # sort data by length to minimize padding in batcher 45 | sorted_data = sorted(zip(sent1, sent2, gs_scores), 46 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 47 | sent1, sent2, gs_scores = map(list, zip(*sorted_data)) 48 | 49 | self.data[dataset] = (sent1, sent2, gs_scores) 50 | self.samples += sent1 + sent2 51 | 52 | def do_prepare(self, params, prepare): 53 | if 'similarity' in params: 54 | self.similarity = params.similarity 55 | else: # Default similarity is cosine 56 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2))) 57 | return prepare(params, self.samples) 58 | 59 | def run(self, params, batcher): 60 | results = {} 61 | all_sys_scores = [] 62 | all_gs_scores = [] 63 | for dataset in self.datasets: 64 | sys_scores = [] 65 | input1, input2, gs_scores = self.data[dataset] 66 | for ii in range(0, len(gs_scores), params.batch_size): 67 | batch1 = input1[ii:ii + params.batch_size] 68 | batch2 = input2[ii:ii + params.batch_size] 69 | 70 | # we assume get_batch already throws out the faulty ones 71 | if len(batch1) == len(batch2) and len(batch1) > 0: 72 | enc1 = batcher(params, batch1) 73 | enc2 = batcher(params, batch2) 74 | 75 | for kk in range(enc2.shape[0]): 76 | sys_score = self.similarity(enc1[kk], enc2[kk]) 77 | sys_scores.append(sys_score) 78 | all_sys_scores.extend(sys_scores) 79 | all_gs_scores.extend(gs_scores) 80 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores), 81 | 'spearman': spearmanr(sys_scores, gs_scores), 82 | 'nsamples': len(sys_scores)} 83 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 84 | (dataset, results[dataset]['pearson'][0], 85 | results[dataset]['spearman'][0])) 86 | 87 | weights = [results[dset]['nsamples'] for dset in results.keys()] 88 | list_prs = np.array([results[dset]['pearson'][0] for 89 | dset in results.keys()]) 90 | list_spr = np.array([results[dset]['spearman'][0] for 91 | dset in results.keys()]) 92 | 93 | avg_pearson = np.average(list_prs) 94 | avg_spearman = np.average(list_spr) 95 | wavg_pearson = np.average(list_prs, weights=weights) 96 | wavg_spearman = np.average(list_spr, weights=weights) 97 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 98 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 99 | results['all'] = {'pearson': {'all': all_pearson[0], 100 | 'mean': avg_pearson, 101 | 'wmean': wavg_pearson}, 102 | 'spearman': {'all': all_spearman[0], 103 | 'mean': avg_spearman, 104 | 'wmean': wavg_spearman}} 105 | logging.debug('ALL : Pearson = %.4f, \ 106 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 107 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 108 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 109 | logging.debug('ALL (average) : Pearson = %.4f, \ 110 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 111 | 112 | return results 113 | 114 | 115 | class STS12Eval(STSEval): 116 | def __init__(self, taskpath, seed=1111): 117 | logging.debug('***** Transfer task : STS12 *****\n\n') 118 | self.seed = seed 119 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl', 120 | 'surprise.OnWN', 'surprise.SMTnews'] 121 | self.loadFile(taskpath) 122 | 123 | 124 | class STS13Eval(STSEval): 125 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue 126 | def __init__(self, taskpath, seed=1111): 127 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n') 128 | self.seed = seed 129 | self.datasets = ['FNWN', 'headlines', 'OnWN'] 130 | self.loadFile(taskpath) 131 | 132 | 133 | class STS14Eval(STSEval): 134 | def __init__(self, taskpath, seed=1111): 135 | logging.debug('***** Transfer task : STS14 *****\n\n') 136 | self.seed = seed 137 | self.datasets = ['deft-forum', 'deft-news', 'headlines', 138 | 'images', 'OnWN', 'tweet-news'] 139 | self.loadFile(taskpath) 140 | 141 | 142 | class STS15Eval(STSEval): 143 | def __init__(self, taskpath, seed=1111): 144 | logging.debug('***** Transfer task : STS15 *****\n\n') 145 | self.seed = seed 146 | self.datasets = ['answers-forums', 'answers-students', 147 | 'belief', 'headlines', 'images'] 148 | self.loadFile(taskpath) 149 | 150 | 151 | class STS16Eval(STSEval): 152 | def __init__(self, taskpath, seed=1111): 153 | logging.debug('***** Transfer task : STS16 *****\n\n') 154 | self.seed = seed 155 | self.datasets = ['answer-answer', 'headlines', 'plagiarism', 156 | 'postediting', 'question-question'] 157 | self.loadFile(taskpath) 158 | 159 | 160 | class STSBenchmarkEval(STSEval): 161 | def __init__(self, task_path, seed=1111): 162 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 163 | self.seed = seed 164 | self.samples = [] 165 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 166 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 167 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 168 | self.datasets = ['train', 'dev', 'test'] 169 | self.data = {'train': train, 'dev': dev, 'test': test} 170 | 171 | def loadFile(self, fpath): 172 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 173 | with io.open(fpath, 'r', encoding='utf-8') as f: 174 | for line in f: 175 | text = line.strip().split('\t') 176 | sick_data['X_A'].append(text[5].split()) 177 | sick_data['X_B'].append(text[6].split()) 178 | sick_data['y'].append(text[4]) 179 | 180 | sick_data['y'] = [float(s) for s in sick_data['y']] 181 | self.samples += sick_data['X_A'] + sick_data["X_B"] 182 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 183 | 184 | class STSBenchmarkFinetune(SICKEval): 185 | def __init__(self, task_path, seed=1111): 186 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 187 | self.seed = seed 188 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 189 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 190 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 191 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 192 | 193 | def loadFile(self, fpath): 194 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 195 | with io.open(fpath, 'r', encoding='utf-8') as f: 196 | for line in f: 197 | text = line.strip().split('\t') 198 | sick_data['X_A'].append(text[5].split()) 199 | sick_data['X_B'].append(text[6].split()) 200 | sick_data['y'].append(text[4]) 201 | 202 | sick_data['y'] = [float(s) for s in sick_data['y']] 203 | return sick_data 204 | 205 | class SICKRelatednessEval(STSEval): 206 | def __init__(self, task_path, seed=1111): 207 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n') 208 | self.seed = seed 209 | self.samples = [] 210 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 211 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 212 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 213 | self.datasets = ['train', 'dev', 'test'] 214 | self.data = {'train': train, 'dev': dev, 'test': test} 215 | 216 | def loadFile(self, fpath): 217 | skipFirstLine = True 218 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 219 | with io.open(fpath, 'r', encoding='utf-8') as f: 220 | for line in f: 221 | if skipFirstLine: 222 | skipFirstLine = False 223 | else: 224 | text = line.strip().split('\t') 225 | sick_data['X_A'].append(text[1].split()) 226 | sick_data['X_B'].append(text[2].split()) 227 | sick_data['y'].append(text[3]) 228 | 229 | sick_data['y'] = [float(s) for s in sick_data['y']] 230 | self.samples += sick_data['X_A'] + sick_data["X_B"] 231 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 232 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Contrastive Learning of Sentence Embeddings from Scratch 2 | This is the official repo for the [paper](https://arxiv.org/abs/2305.15077) 3 | 4 | ``` 5 | Contrastive Learning of Sentence embeddings from scratch 6 | Junlei Zhang, Zhenzhong Lan, Junxian He 7 | Preprint 2023 8 | ``` 9 | 10 | We propose SynCSE, an unsupervised sentence embedding learning approach that trains sentence embeddings from scratch, without any (unlabeled) data samples. Specifically, we use ChatGPT to synthesize the positive and hard negative samples (SynCSE-partial) given unlabeled sentences, or synthesize the unlabeled sentences, positive, and hard negative samples altogether (SynCSE-scratch). We release the synthetic SynCSE-partial and SynCSE-scratch datasets along with the model checkpoints. 11 | 12 | ## Updates 13 | 14 | * [2023-06-02]: We released our model checkpoints and datasets 15 | * [2023-05-23]: We released [our paper](https://arxiv.org/abs/2305.15077). Check it out! 16 | 17 | 18 | ## Quick Links 19 | 20 | - [Model Checkpoints](#model-checkpoints) 21 | - [Datasets](#datasets) 22 | - [Train SynCSE](#train-SynCSE) 23 | - [Requirements](#requirements) 24 | - [Training](#training) 25 | - [Evaluation](#evaluation) 26 | - [Acknowledgement](#acknowledgement) 27 | - [Citation](#citation) 28 | 29 | ## Model Checkpoints 30 | 31 | We release our model checkpoints in huggingface as listed below: 32 | | Model | Avg. STS | 33 | |:-------------------------------|:--------:| 34 | | [sjtu-lit/SynCSE-partial-RoBERTa-base](https://huggingface.co/sjtu-lit/SynCSE-partial-RoBERTa-base) | 81.84 | 35 | | [sjtu-lit/SynCSE-partial-RoBERTa-large](https://huggingface.co/sjtu-lit/SynCSE-partial-RoBERTa-large) | 82.66 | 36 | | [sjtu-lit/SynCSE-scratch-RoBERTa-base](https://huggingface.co/sjtu-lit/SynCSE-scratch-RoBERTa-base) | 80.66 | 37 | | [sjtu-lit/SynCSE-partial-RoBERTa-large](https://huggingface.co/sjtu-lit/SynCSE-partial-RoBERTa-large) |81.84| 38 | 39 | The results slightly differ from what we report in the paper, because we clean the dataset to remove failure generations such as: "I can not generate a paraphrased sentence because the input is ambiguous." 40 | 41 | ### Load and Use the checkpoints 42 | #### encoding sentences into embeddings 43 | ```python 44 | from transformers import AutoTokenizer, AutoModel 45 | tokenizer = AutoTokenizer.from_pretrained("sjtu-lit/SynCSE-partial-RoBERTa-large") 46 | model = AutoModel.from_pretrained("sjtu-lit/SynCSE-partial-RoBERTa-large") 47 | embeddings = model.encode("A woman is reading.") 48 | ``` 49 | 50 | #### Compute the cosine similarities between two groups of sentences 51 | ```python 52 | sentences_a = ['A woman is reading.', 'A man is playing a guitar.'] 53 | sentences_b = ['He plays guitar.', 'A woman is making a photo.'] 54 | similarities = model.similarity(sentences_a, sentences_b) 55 | ``` 56 | 57 | #### Build index for a group of sentences and search among them 58 | ```python 59 | sentences = ['A woman is reading.', 'A man is playing a guitar.'] 60 | model.build_index(sentences) 61 | results = model.search("He plays guitar.") 62 | ``` 63 | If you encounter any problem when directly loading the models by HuggingFace's API, you can also download the models manually from the above table and use `model = AutoModel.from_pretrained({PATH TO THE DOWNLOAD MODEL})`. 64 | 65 | 66 | ## Datasets 67 | | Dataset | 68 | |:-------------------------------| 69 | | [sjtu-lit/SynCSE-partial-NLI](https://huggingface.co/datasets/sjtu-lit/SynCSE-partial-NLI) | 70 | | [sjtu-lit/SynCSE-scratch-NLI](https://huggingface.co/datasets/sjtu-lit/SynCSE-scratch-NLI) | 71 | 72 | These two synthetic datasets are respectively used for the SynCSE-partial and SynCSE-scratch experimental setups. For SynCSE-partial, we use the unlabeled data from the NLI dataset used by SimCSE and generate labels for them. For SynCSE-scratch, we generate unlabeled data and their corresponding labels. 73 | 74 | To download the data, take SynCSE-partial for an example: 75 | ``` 76 | wget https://huggingface.co/datasets/sjtu-lit/SynCSE-partial-NLI/resolve/resolve/train.csv 77 | ``` 78 | 79 | ## Train SynCSE 80 | 81 | ### Requirements 82 | 83 | First, install PyTorch by following the instructions from [the official website](https://pytorch.org). We use the `1.13.0+cu116` pytorch version. We train our model on a single A100-80G card. 84 | 85 | Then run the following script to install the remaining dependencies, 86 | 87 | ```bash 88 | pip install -r requirements.txt 89 | ``` 90 | 91 | ### Training 92 | 93 | #### Data 94 | You can specify `sjtu-lit/SynCSE-partial-NLI` or `sjtu-lit/SynCSE-scratch-NLI` in the scripts/sup_train_mp.sh. It will download the dataset automatically. You can also download the SynCSE-partial-NLI and the SynCSE-scratch-NLI [datasets](#datasets), and put them into the data folder. 95 | 96 | #### Training scripts 97 | 98 | We provide example training scripts for both training SynCSE in `scripts/sup_train_mp.sh`. Below are explanations of some arguments: 99 | * `--model_name_or_path`: Pre-trained checkpoints to start with. For now we support BERT-based models (`bert-base-uncased`, `bert-large-uncased`, etc.) and RoBERTa-based models (`RoBERTa-base`, `RoBERTa-large`, etc.). 100 | * `--temp`: Temperature for the contrastive loss. 101 | * `--pooler_type`: Pooling method. It's the same as the `--pooler_type` in the [evaluation part](#evaluation). 102 | * `--hard_negative_weight`: If using hard negatives (i.e., there are 3 columns in the training file), this is the logarithm of the weight. For example, if the weight is 1, then this argument should be set as 0 (default value). 103 | * `--do_mlm`: Whether to use the MLM auxiliary objective. If True: 104 | * `--mlm_weight`: Weight for the MLM objective. 105 | * `--mlm_probability`: Masking rate for the MLM objective. 106 | 107 | All the other arguments are standard Huggingface's `transformers` training arguments. Some often-used arguments are: `--output_dir`, `--learning_rate`, `--per_device_train_batch_size`. 108 | 109 | For results in the paper, we use Nvidia A100 (80G) GPUs with CUDA 11.6 Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance. 110 | 111 | #### Hyperparameters 112 | 113 | We use the following hyperparamters for training SynCSE: 114 | - Batch size: 512 115 | - Learning rate (base): 5e-5 116 | - Learning rate (large): 1e-5 117 | 118 | #### Convert models 119 | 120 | Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. 121 | 122 | ### Evaluation 123 | Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation. 124 | 125 | Before evaluation, please download the evaluation datasets by running 126 | ```bash 127 | cd SentEval/data/downstream/ 128 | bash download_dataset.sh 129 | ``` 130 | Then come back to the root directory, you can evaluate any `transformers`-based pre-trained models using our evaluation code. For example, 131 | ``` 132 | bash ./scripts/eval.sh 133 | ``` 134 | which is expected to output the results in a tabular format: 135 | ``` 136 | ------ test ------ 137 | +-------+-------+-------+-------+-------+--------------+-----------------+-------+ 138 | | STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | Avg. | 139 | +-------+-------+-------+-------+-------+--------------+-----------------+-------+ 140 | | 76.14 | 84.41 | 79.23 | 84.85 | 82.87 | 83.95 | 81.41 | 81.84 | 141 | +-------+-------+-------+-------+-------+--------------+-----------------+-------+ 142 | ``` 143 | 144 | Arguments for the evaluation script are as follows, 145 | 146 | * `--model_name_or_path`: The name or path of a `transformers`-based pre-trained checkpoint. You can directly use the models in the above table, e.g., `sjtu-lit/SynCSE-scratch-RoBERTa-base`. 147 | * `--pooler`: Pooling method. Now we support 148 | * `cls` (default): Use the representation of `[CLS]` token. 149 | * `avg`: Average embeddings of the last layer. If you use checkpoints of SBERT/SRoBERTa ([paper](https://arxiv.org/abs/1908.10084)), you should use this option. 150 | * `avg_top2`: Average embeddings of the last two layers. 151 | * `avg_first_last`: Average embeddings of the first and last layers. If you use vanilla BERT or RoBERTa, this works the best. 152 | * `--mode`: Evaluation mode 153 | * `test` (default): The default test mode. To faithfully reproduce our results, you should use this option. 154 | * `dev`: Report the development set results. Note that in STS tasks, only `STS-B` and `SICK-R` have development sets, so we only report their numbers. It also takes a fast mode for transfer tasks, so the running time is much shorter than the `test` mode (though numbers are slightly lower). 155 | * `fasttest`: It is the same as `test`, but with a fast mode so the running time is much shorter, but the reported numbers may be lower (only for transfer tasks). 156 | * `--task_set`: What set of tasks to evaluate on (if set, it will override `--tasks`) 157 | * `sts` (default): Evaluate on STS tasks, including `STS 12~16`, `STS-B` and `SICK-R`. This is the most commonly-used set of tasks to evaluate the quality of sentence embeddings. 158 | * `transfer`: Evaluate on transfer tasks. 159 | * `full`: Evaluate on both STS and transfer tasks. 160 | * `na`: Manually set tasks by `--tasks`. 161 | * `--tasks`: Specify which dataset(s) to evaluate on. Will be overridden if `--task_set` is not `na`. See the code for a full list of tasks. 162 | 163 | ## Acknowledgement 164 | Our training code is based on the [SimCSE repo](https://github.com/princeton-nlp/SimCSE), and the evaluatio code is based on the [SentEval repo](https://github.com/facebookresearch/SentEval) 165 | 166 | ## Bugs or questions? 167 | 168 | If you have any questions related to the code or the paper, feel free to email Junlei (`zhangjunlei@westlake.edu.cn`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 169 | 170 | ## Citation 171 | 172 | Please cite our paper if you use SynCSE: 173 | 174 | ```bibtex 175 | @article{zhang2023contrastive, 176 | title={Contrastive Learning of Sentence Embeddings from Scratch}, 177 | author={Zhang, Junlei and Lan, Zhenzhong and He, Junxian}, 178 | journal={arXiv preprint arXiv:2305.15077}, 179 | year={2023} 180 | } 181 | ``` 182 | -------------------------------------------------------------------------------- /SentEval/examples/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf 10 | """ 11 | 12 | import numpy as np 13 | import time 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class InferSent(nn.Module): 20 | 21 | def __init__(self, config): 22 | super(InferSent, self).__init__() 23 | self.bsize = config['bsize'] 24 | self.word_emb_dim = config['word_emb_dim'] 25 | self.enc_lstm_dim = config['enc_lstm_dim'] 26 | self.pool_type = config['pool_type'] 27 | self.dpout_model = config['dpout_model'] 28 | self.version = 1 if 'version' not in config else config['version'] 29 | 30 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 31 | bidirectional=True, dropout=self.dpout_model) 32 | 33 | assert self.version in [1, 2] 34 | if self.version == 1: 35 | self.bos = '' 36 | self.eos = '' 37 | self.max_pad = True 38 | self.moses_tok = False 39 | elif self.version == 2: 40 | self.bos = '

' 41 | self.eos = '

' 42 | self.max_pad = False 43 | self.moses_tok = True 44 | 45 | def is_cuda(self): 46 | # either all weights are on cpu or they are on gpu 47 | return self.enc_lstm.bias_hh_l0.data.is_cuda 48 | 49 | def forward(self, sent_tuple): 50 | # sent_len: [max_len, ..., min_len] (bsize) 51 | # sent: (seqlen x bsize x worddim) 52 | sent, sent_len = sent_tuple 53 | 54 | # Sort by length (keep idx) 55 | sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 56 | sent_len_sorted = sent_len_sorted.copy() 57 | idx_unsort = np.argsort(idx_sort) 58 | 59 | idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \ 60 | else torch.from_numpy(idx_sort) 61 | sent = sent.index_select(1, idx_sort) 62 | 63 | # Handling padding in Recurrent Networks 64 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) 65 | sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid 66 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 67 | 68 | # Un-sort by length 69 | idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \ 70 | else torch.from_numpy(idx_unsort) 71 | sent_output = sent_output.index_select(1, idx_unsort) 72 | 73 | # Pooling 74 | if self.pool_type == "mean": 75 | sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda() 76 | emb = torch.sum(sent_output, 0).squeeze(0) 77 | emb = emb / sent_len.expand_as(emb) 78 | elif self.pool_type == "max": 79 | if not self.max_pad: 80 | sent_output[sent_output == 0] = -1e9 81 | emb = torch.max(sent_output, 0)[0] 82 | if emb.ndimension() == 3: 83 | emb = emb.squeeze(0) 84 | assert emb.ndimension() == 2 85 | 86 | return emb 87 | 88 | def set_w2v_path(self, w2v_path): 89 | self.w2v_path = w2v_path 90 | 91 | def get_word_dict(self, sentences, tokenize=True): 92 | # create vocab of words 93 | word_dict = {} 94 | sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] 95 | for sent in sentences: 96 | for word in sent: 97 | if word not in word_dict: 98 | word_dict[word] = '' 99 | word_dict[self.bos] = '' 100 | word_dict[self.eos] = '' 101 | return word_dict 102 | 103 | def get_w2v(self, word_dict): 104 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 105 | # create word_vec with w2v vectors 106 | word_vec = {} 107 | with open(self.w2v_path, encoding='utf-8') as f: 108 | for line in f: 109 | word, vec = line.split(' ', 1) 110 | if word in word_dict: 111 | word_vec[word] = np.fromstring(vec, sep=' ') 112 | print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict))) 113 | return word_vec 114 | 115 | def get_w2v_k(self, K): 116 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 117 | # create word_vec with k first w2v vectors 118 | k = 0 119 | word_vec = {} 120 | with open(self.w2v_path, encoding='utf-8') as f: 121 | for line in f: 122 | word, vec = line.split(' ', 1) 123 | if k <= K: 124 | word_vec[word] = np.fromstring(vec, sep=' ') 125 | k += 1 126 | if k > K: 127 | if word in [self.bos, self.eos]: 128 | word_vec[word] = np.fromstring(vec, sep=' ') 129 | 130 | if k > K and all([w in word_vec for w in [self.bos, self.eos]]): 131 | break 132 | return word_vec 133 | 134 | def build_vocab(self, sentences, tokenize=True): 135 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 136 | word_dict = self.get_word_dict(sentences, tokenize) 137 | self.word_vec = self.get_w2v(word_dict) 138 | print('Vocab size : %s' % (len(self.word_vec))) 139 | 140 | # build w2v vocab with k most frequent words 141 | def build_vocab_k_words(self, K): 142 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 143 | self.word_vec = self.get_w2v_k(K) 144 | print('Vocab size : %s' % (K)) 145 | 146 | def update_vocab(self, sentences, tokenize=True): 147 | assert hasattr(self, 'w2v_path'), 'warning : w2v path not set' 148 | assert hasattr(self, 'word_vec'), 'build_vocab before updating it' 149 | word_dict = self.get_word_dict(sentences, tokenize) 150 | 151 | # keep only new words 152 | for word in self.word_vec: 153 | if word in word_dict: 154 | del word_dict[word] 155 | 156 | # udpate vocabulary 157 | if word_dict: 158 | new_word_vec = self.get_w2v(word_dict) 159 | self.word_vec.update(new_word_vec) 160 | else: 161 | new_word_vec = [] 162 | print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec))) 163 | 164 | def get_batch(self, batch): 165 | # sent in batch in decreasing order of lengths 166 | # batch: (bsize, max_len, word_dim) 167 | embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) 168 | 169 | for i in range(len(batch)): 170 | for j in range(len(batch[i])): 171 | embed[j, i, :] = self.word_vec[batch[i][j]] 172 | 173 | return torch.FloatTensor(embed) 174 | 175 | def tokenize(self, s): 176 | from nltk.tokenize import word_tokenize 177 | if self.moses_tok: 178 | s = ' '.join(word_tokenize(s)) 179 | s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization 180 | return s.split() 181 | else: 182 | return word_tokenize(s) 183 | 184 | def prepare_samples(self, sentences, bsize, tokenize, verbose): 185 | sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else 186 | [self.bos] + self.tokenize(s) + [self.eos] for s in sentences] 187 | n_w = np.sum([len(x) for x in sentences]) 188 | 189 | # filters words without w2v vectors 190 | for i in range(len(sentences)): 191 | s_f = [word for word in sentences[i] if word in self.word_vec] 192 | if not s_f: 193 | import warnings 194 | warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \ 195 | Replacing by ""..' % (sentences[i], i)) 196 | s_f = [self.eos] 197 | sentences[i] = s_f 198 | 199 | lengths = np.array([len(s) for s in sentences]) 200 | n_wk = np.sum(lengths) 201 | if verbose: 202 | print('Nb words kept : %s/%s (%.1f%s)' % ( 203 | n_wk, n_w, 100.0 * n_wk / n_w, '%')) 204 | 205 | # sort by decreasing length 206 | lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) 207 | sentences = np.array(sentences)[idx_sort] 208 | 209 | return sentences, lengths, idx_sort 210 | 211 | def encode(self, sentences, bsize=64, tokenize=True, verbose=False): 212 | tic = time.time() 213 | sentences, lengths, idx_sort = self.prepare_samples( 214 | sentences, bsize, tokenize, verbose) 215 | 216 | embeddings = [] 217 | for stidx in range(0, len(sentences), bsize): 218 | batch = self.get_batch(sentences[stidx:stidx + bsize]) 219 | if self.is_cuda(): 220 | batch = batch.cuda() 221 | with torch.no_grad(): 222 | batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy() 223 | embeddings.append(batch) 224 | embeddings = np.vstack(embeddings) 225 | 226 | # unsort 227 | idx_unsort = np.argsort(idx_sort) 228 | embeddings = embeddings[idx_unsort] 229 | 230 | if verbose: 231 | print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % ( 232 | len(embeddings)/(time.time()-tic), 233 | 'gpu' if self.is_cuda() else 'cpu', bsize)) 234 | return embeddings 235 | 236 | def visualize(self, sent, tokenize=True): 237 | 238 | sent = sent.split() if not tokenize else self.tokenize(sent) 239 | sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]] 240 | 241 | if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos): 242 | import warnings 243 | warnings.warn('No words in "%s" have w2v vectors. Replacing \ 244 | by "%s %s"..' % (sent, self.bos, self.eos)) 245 | batch = self.get_batch(sent) 246 | 247 | if self.is_cuda(): 248 | batch = batch.cuda() 249 | output = self.enc_lstm(batch)[0] 250 | output, idxs = torch.max(output, 0) 251 | # output, idxs = output.squeeze(), idxs.squeeze() 252 | idxs = idxs.data.cpu().numpy() 253 | argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))] 254 | 255 | # visualize model 256 | import matplotlib.pyplot as plt 257 | x = range(len(sent[0])) 258 | y = [100.0 * n / np.sum(argmaxs) for n in argmaxs] 259 | plt.xticks(x, sent[0], rotation=45) 260 | plt.bar(x, y) 261 | plt.ylabel('%') 262 | plt.title('Visualisation of words importance') 263 | plt.show() 264 | 265 | return output, idxs 266 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Validation and classification 10 | (train) : inner-kfold classifier 11 | (train, test) : kfold classifier 12 | (train, dev, test) : split classifier 13 | 14 | """ 15 | from __future__ import absolute_import, division, unicode_literals 16 | 17 | import logging 18 | import numpy as np 19 | from senteval.tools.classifier import MLP 20 | 21 | import sklearn 22 | assert(sklearn.__version__ >= "0.18.0"), \ 23 | "need to update sklearn to version >= 0.18.0" 24 | from sklearn.linear_model import LogisticRegression 25 | from sklearn.model_selection import StratifiedKFold 26 | 27 | 28 | def get_classif_name(classifier_config, usepytorch): 29 | if not usepytorch: 30 | modelname = 'sklearn-LogReg' 31 | else: 32 | nhid = classifier_config['nhid'] 33 | optim = 'adam' if 'optim' not in classifier_config else classifier_config['optim'] 34 | bs = 64 if 'batch_size' not in classifier_config else classifier_config['batch_size'] 35 | modelname = 'pytorch-MLP-nhid%s-%s-bs%s' % (nhid, optim, bs) 36 | return modelname 37 | 38 | # Pytorch version 39 | class InnerKFoldClassifier(object): 40 | """ 41 | (train) split classifier : InnerKfold. 42 | """ 43 | def __init__(self, X, y, config): 44 | self.X = X 45 | self.y = y 46 | self.featdim = X.shape[1] 47 | self.nclasses = config['nclasses'] 48 | self.seed = config['seed'] 49 | self.devresults = [] 50 | self.testresults = [] 51 | self.usepytorch = config['usepytorch'] 52 | self.classifier_config = config['classifier'] 53 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 54 | 55 | self.k = 5 if 'kfold' not in config else config['kfold'] 56 | 57 | def run(self): 58 | logging.info('Training {0} with (inner) {1}-fold cross-validation' 59 | .format(self.modelname, self.k)) 60 | 61 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 62 | [2**t for t in range(-2, 4, 1)] 63 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=1111) 64 | innerskf = StratifiedKFold(n_splits=self.k, shuffle=True, 65 | random_state=1111) 66 | count = 0 67 | for train_idx, test_idx in skf.split(self.X, self.y): 68 | count += 1 69 | X_train, X_test = self.X[train_idx], self.X[test_idx] 70 | y_train, y_test = self.y[train_idx], self.y[test_idx] 71 | scores = [] 72 | for reg in regs: 73 | regscores = [] 74 | for inner_train_idx, inner_test_idx in innerskf.split(X_train, y_train): 75 | X_in_train, X_in_test = X_train[inner_train_idx], X_train[inner_test_idx] 76 | y_in_train, y_in_test = y_train[inner_train_idx], y_train[inner_test_idx] 77 | if self.usepytorch: 78 | clf = MLP(self.classifier_config, inputdim=self.featdim, 79 | nclasses=self.nclasses, l2reg=reg, 80 | seed=self.seed) 81 | clf.fit(X_in_train, y_in_train, 82 | validation_data=(X_in_test, y_in_test)) 83 | else: 84 | clf = LogisticRegression(C=reg, random_state=self.seed) 85 | clf.fit(X_in_train, y_in_train) 86 | regscores.append(clf.score(X_in_test, y_in_test)) 87 | scores.append(round(100*np.mean(regscores), 2)) 88 | optreg = regs[np.argmax(scores)] 89 | logging.info('Best param found at split {0}: l2reg = {1} \ 90 | with score {2}'.format(count, optreg, np.max(scores))) 91 | self.devresults.append(np.max(scores)) 92 | 93 | if self.usepytorch: 94 | clf = MLP(self.classifier_config, inputdim=self.featdim, 95 | nclasses=self.nclasses, l2reg=optreg, 96 | seed=self.seed) 97 | 98 | clf.fit(X_train, y_train, validation_split=0.05) 99 | else: 100 | clf = LogisticRegression(C=optreg, random_state=self.seed) 101 | clf.fit(X_train, y_train) 102 | 103 | self.testresults.append(round(100*clf.score(X_test, y_test), 2)) 104 | 105 | devaccuracy = round(np.mean(self.devresults), 2) 106 | testaccuracy = round(np.mean(self.testresults), 2) 107 | return devaccuracy, testaccuracy 108 | 109 | 110 | class KFoldClassifier(object): 111 | """ 112 | (train, test) split classifier : cross-validation on train. 113 | """ 114 | def __init__(self, train, test, config): 115 | self.train = train 116 | self.test = test 117 | self.featdim = self.train['X'].shape[1] 118 | self.nclasses = config['nclasses'] 119 | self.seed = config['seed'] 120 | self.usepytorch = config['usepytorch'] 121 | self.classifier_config = config['classifier'] 122 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 123 | 124 | self.k = 5 if 'kfold' not in config else config['kfold'] 125 | 126 | def run(self): 127 | # cross-validation 128 | logging.info('Training {0} with {1}-fold cross-validation' 129 | .format(self.modelname, self.k)) 130 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 131 | [2**t for t in range(-1, 6, 1)] 132 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, 133 | random_state=self.seed) 134 | scores = [] 135 | 136 | for reg in regs: 137 | scanscores = [] 138 | for train_idx, test_idx in skf.split(self.train['X'], 139 | self.train['y']): 140 | # Split data 141 | X_train, y_train = self.train['X'][train_idx], self.train['y'][train_idx] 142 | 143 | X_test, y_test = self.train['X'][test_idx], self.train['y'][test_idx] 144 | 145 | # Train classifier 146 | if self.usepytorch: 147 | clf = MLP(self.classifier_config, inputdim=self.featdim, 148 | nclasses=self.nclasses, l2reg=reg, 149 | seed=self.seed) 150 | clf.fit(X_train, y_train, validation_data=(X_test, y_test)) 151 | else: 152 | clf = LogisticRegression(C=reg, random_state=self.seed) 153 | clf.fit(X_train, y_train) 154 | score = clf.score(X_test, y_test) 155 | scanscores.append(score) 156 | # Append mean score 157 | scores.append(round(100*np.mean(scanscores), 2)) 158 | 159 | # evaluation 160 | logging.info([('reg:' + str(regs[idx]), scores[idx]) 161 | for idx in range(len(scores))]) 162 | optreg = regs[np.argmax(scores)] 163 | devaccuracy = np.max(scores) 164 | logging.info('Cross-validation : best param found is reg = {0} \ 165 | with score {1}'.format(optreg, devaccuracy)) 166 | 167 | logging.info('Evaluating...') 168 | if self.usepytorch: 169 | clf = MLP(self.classifier_config, inputdim=self.featdim, 170 | nclasses=self.nclasses, l2reg=optreg, 171 | seed=self.seed) 172 | clf.fit(self.train['X'], self.train['y'], validation_split=0.05) 173 | else: 174 | clf = LogisticRegression(C=optreg, random_state=self.seed) 175 | clf.fit(self.train['X'], self.train['y']) 176 | yhat = clf.predict(self.test['X']) 177 | 178 | testaccuracy = clf.score(self.test['X'], self.test['y']) 179 | testaccuracy = round(100*testaccuracy, 2) 180 | 181 | return devaccuracy, testaccuracy, yhat 182 | 183 | 184 | class SplitClassifier(object): 185 | """ 186 | (train, valid, test) split classifier. 187 | """ 188 | def __init__(self, X, y, config): 189 | self.X = X 190 | self.y = y 191 | self.nclasses = config['nclasses'] 192 | self.featdim = self.X['train'].shape[1] 193 | self.seed = config['seed'] 194 | self.usepytorch = config['usepytorch'] 195 | self.classifier_config = config['classifier'] 196 | self.cudaEfficient = False if 'cudaEfficient' not in config else \ 197 | config['cudaEfficient'] 198 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 199 | self.noreg = False if 'noreg' not in config else config['noreg'] 200 | self.config = config 201 | 202 | def run(self): 203 | logging.info('Training {0} with standard validation..' 204 | .format(self.modelname)) 205 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 206 | [2**t for t in range(-2, 4, 1)] 207 | if self.noreg: 208 | regs = [1e-9 if self.usepytorch else 1e9] 209 | scores = [] 210 | for reg in regs: 211 | if self.usepytorch: 212 | clf = MLP(self.classifier_config, inputdim=self.featdim, 213 | nclasses=self.nclasses, l2reg=reg, 214 | seed=self.seed, cudaEfficient=self.cudaEfficient) 215 | 216 | # TODO: Find a hack for reducing nb epoches in SNLI 217 | clf.fit(self.X['train'], self.y['train'], 218 | validation_data=(self.X['valid'], self.y['valid'])) 219 | else: 220 | clf = LogisticRegression(C=reg, random_state=self.seed) 221 | clf.fit(self.X['train'], self.y['train']) 222 | scores.append(round(100*clf.score(self.X['valid'], 223 | self.y['valid']), 2)) 224 | logging.info([('reg:'+str(regs[idx]), scores[idx]) 225 | for idx in range(len(scores))]) 226 | optreg = regs[np.argmax(scores)] 227 | devaccuracy = np.max(scores) 228 | logging.info('Validation : best param found is reg = {0} with score \ 229 | {1}'.format(optreg, devaccuracy)) 230 | clf = LogisticRegression(C=optreg, random_state=self.seed) 231 | logging.info('Evaluating...') 232 | if self.usepytorch: 233 | clf = MLP(self.classifier_config, inputdim=self.featdim, 234 | nclasses=self.nclasses, l2reg=optreg, 235 | seed=self.seed, cudaEfficient=self.cudaEfficient) 236 | 237 | # TODO: Find a hack for reducing nb epoches in SNLI 238 | clf.fit(self.X['train'], self.y['train'], 239 | validation_data=(self.X['valid'], self.y['valid'])) 240 | else: 241 | clf = LogisticRegression(C=optreg, random_state=self.seed) 242 | clf.fit(self.X['train'], self.y['train']) 243 | 244 | testaccuracy = clf.score(self.X['test'], self.y['test']) 245 | testaccuracy = round(100*testaccuracy, 2) 246 | return devaccuracy, testaccuracy 247 | -------------------------------------------------------------------------------- /simcse/tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SimCSE(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path)), faiss.METRIC_INNER_PRODUCT) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def add_to_index(self, sentences_or_file_path: Union[str, List[str]], 181 | device: str = None, 182 | batch_size: int = 64): 183 | 184 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 185 | if isinstance(sentences_or_file_path, str): 186 | sentences = [] 187 | with open(sentences_or_file_path, "r") as f: 188 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 189 | for line in tqdm(f): 190 | sentences.append(line.rstrip()) 191 | sentences_or_file_path = sentences 192 | 193 | logger.info("Encoding embeddings for sentences...") 194 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 195 | 196 | if self.is_faiss_index: 197 | self.index["index"].add(embeddings.astype(np.float32)) 198 | else: 199 | self.index["index"] = np.concatenate((self.index["index"], embeddings)) 200 | self.index["sentences"] += sentences_or_file_path 201 | logger.info("Finished") 202 | 203 | 204 | 205 | def search(self, queries: Union[str, List[str]], 206 | device: str = None, 207 | threshold: float = 0.6, 208 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 209 | 210 | if not self.is_faiss_index: 211 | if isinstance(queries, list): 212 | combined_results = [] 213 | for query in queries: 214 | results = self.search(query, device, threshold, top_k) 215 | combined_results.append(results) 216 | return combined_results 217 | 218 | similarities = self.similarity(queries, self.index["index"]).tolist() 219 | id_and_score = [] 220 | for i, s in enumerate(similarities): 221 | if s >= threshold: 222 | id_and_score.append((i, s)) 223 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 224 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 225 | return results 226 | else: 227 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 228 | 229 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 230 | 231 | def pack_single_result(dist, idx): 232 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 233 | return results 234 | 235 | if isinstance(queries, list): 236 | combined_results = [] 237 | for i in range(len(queries)): 238 | results = pack_single_result(distance[i], idx[i]) 239 | combined_results.append(results) 240 | return combined_results 241 | else: 242 | return pack_single_result(distance[0], idx[0]) 243 | 244 | if __name__=="__main__": 245 | example_sentences = [ 246 | 'An animal is biting a persons finger.', 247 | 'A woman is reading.', 248 | 'A man is lifting weights in a garage.', 249 | 'A man plays the violin.', 250 | 'A man is eating food.', 251 | 'A man plays the piano.', 252 | 'A panda is climbing.', 253 | 'A man plays a guitar.', 254 | 'A woman is slicing a meat.', 255 | 'A woman is taking a picture.' 256 | ] 257 | example_queries = [ 258 | 'A man is playing music.', 259 | 'A woman is making a photo.' 260 | ] 261 | 262 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 263 | simcse = SimCSE(model_name) 264 | 265 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 266 | similarities = simcse.similarity(example_queries, example_sentences) 267 | print(similarities) 268 | 269 | print("\n=========Naive brute force search============\n") 270 | simcse.build_index(example_sentences, use_faiss=False) 271 | results = simcse.search(example_queries) 272 | for i, result in enumerate(results): 273 | print("Retrieval results for query: {}".format(example_queries[i])) 274 | for sentence, score in result: 275 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 276 | print("") 277 | 278 | print("\n=========Search with Faiss backend============\n") 279 | simcse.build_index(example_sentences, use_faiss=True) 280 | results = simcse.search(example_queries) 281 | for i, result in enumerate(results): 282 | print("Retrieval results for query: {}".format(example_queries[i])) 283 | for sentence, score in result: 284 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 285 | print("") 286 | 287 | -------------------------------------------------------------------------------- /SentEval/README.md: -------------------------------------------------------------------------------- 1 | Our modification to SentEval: 2 | 3 | 1. Add the `all` setting to all STS tasks. 4 | 2. Change STS-B and SICK-R to not use an additional regressor. 5 | 6 | # SentEval: evaluation toolkit for sentence embeddings 7 | 8 | SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. **SentEval currently includes 17 downstream tasks**. We also include a suite of **10 probing tasks** which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations. 9 | 10 | 11 | **(04/22) SentEval new tasks: Added probing tasks for evaluating what linguistic properties are encoded in sentence embeddings** 12 | 13 | **(10/04) SentEval example scripts for three sentence encoders: [SkipThought-LN](https://github.com/ryankiros/layer-norm#skip-thoughts)/[GenSen](https://github.com/Maluuba/gensen)/[Google-USE](https://tfhub.dev/google/universal-sentence-encoder/1)** 14 | 15 | ## Dependencies 16 | 17 | This code is written in python. The dependencies are: 18 | 19 | * Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](http://www.scipy.org/) 20 | * [Pytorch](http://pytorch.org/)>=0.4 21 | * [scikit-learn](http://scikit-learn.org/stable/index.html)>=0.18.0 22 | 23 | ## Transfer tasks 24 | 25 | ### Downstream tasks 26 | SentEval allows you to evaluate your sentence embeddings as features for the following *downstream* tasks: 27 | 28 | | Task | Type | #train | #test | needs_train | set_classifier | 29 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:| 30 | | [MR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | movie review | 11k | 11k | 1 | 1 | 31 | | [CR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | product review | 4k | 4k | 1 | 1 | 32 | | [SUBJ](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | subjectivity status | 10k | 10k | 1 | 1 | 33 | | [MPQA](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | opinion-polarity | 11k | 11k | 1 | 1 | 34 | | [SST](https://nlp.stanford.edu/sentiment/index.html) | binary sentiment analysis | 67k | 1.8k | 1 | 1 | 35 | | **[SST](https://nlp.stanford.edu/sentiment/index.html)** | **fine-grained sentiment analysis** | 8.5k | 2.2k | 1 | 1 | 36 | | [TREC](http://cogcomp.cs.illinois.edu/Data/QA/QC/) | question-type classification | 6k | 0.5k | 1 | 1 | 37 | | [SICK-E](http://clic.cimec.unitn.it/composes/sick.html) | natural language inference | 4.5k | 4.9k | 1 | 1 | 38 | | [SNLI](https://nlp.stanford.edu/projects/snli/) | natural language inference | 550k | 9.8k | 1 | 1 | 39 | | [MRPC](https://aclweb.org/aclwiki/Paraphrase_Identification_(State_of_the_art)) | paraphrase detection | 4.1k | 1.7k | 1 | 1 | 40 | | [STS 2012](https://www.cs.york.ac.uk/semeval-2012/task6/) | semantic textual similarity | N/A | 3.1k | 0 | 0 | 41 | | [STS 2013](http://ixa2.si.ehu.es/sts/) | semantic textual similarity | N/A | 1.5k | 0 | 0 | 42 | | [STS 2014](http://alt.qcri.org/semeval2014/task10/) | semantic textual similarity | N/A | 3.7k | 0 | 0 | 43 | | [STS 2015](http://alt.qcri.org/semeval2015/task2/) | semantic textual similarity | N/A | 8.5k | 0 | 0 | 44 | | [STS 2016](http://alt.qcri.org/semeval2016/task1/) | semantic textual similarity | N/A | 9.2k | 0 | 0 | 45 | | [STS B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) | semantic textual similarity | 5.7k | 1.4k | 1 | 0 | 46 | | [SICK-R](http://clic.cimec.unitn.it/composes/sick.html) | semantic textual similarity | 4.5k | 4.9k | 1 | 0 | 47 | | [COCO](http://mscoco.org/) | image-caption retrieval | 567k | 5*1k | 1 | 0 | 48 | 49 | where **needs_train** means a model with parameters is learned on top of the sentence embeddings, and **set_classifier** means you can define the parameters of the classifier in the case of a classification task (see below). 50 | 51 | Note: COCO comes with ResNet-101 2048d image embeddings. [More details on the tasks.](https://arxiv.org/pdf/1705.02364.pdf) 52 | 53 | ### Probing tasks 54 | SentEval also includes a series of [*probing* tasks](https://github.com/facebookresearch/SentEval/tree/master/data/probing) to evaluate what linguistic properties are encoded in your sentence embeddings: 55 | 56 | | Task | Type | #train | #test | needs_train | set_classifier | 57 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:| 58 | | [SentLen](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Length prediction | 100k | 10k | 1 | 1 | 59 | | [WC](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word Content analysis | 100k | 10k | 1 | 1 | 60 | | [TreeDepth](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Tree depth prediction | 100k | 10k | 1 | 1 | 61 | | [TopConst](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Top Constituents prediction | 100k | 10k | 1 | 1 | 62 | | [BShift](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word order analysis | 100k | 10k | 1 | 1 | 63 | | [Tense](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Verb tense prediction | 100k | 10k | 1 | 1 | 64 | | [SubjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Subject number prediction | 100k | 10k | 1 | 1 | 65 | | [ObjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Object number prediction | 100k | 10k | 1 | 1 | 66 | | [SOMO](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Semantic odd man out | 100k | 10k | 1 | 1 | 67 | | [CoordInv](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Coordination Inversion | 100k | 10k | 1 | 1 | 68 | 69 | ## Download datasets 70 | To get all the transfer tasks datasets, run (in data/downstream/): 71 | ```bash 72 | ./get_transfer_data.bash 73 | ``` 74 | This will automatically download and preprocess the downstream datasets, and store them in data/downstream (warning: for MacOS users, you may have to use p7zip instead of unzip). The probing tasks are already in data/probing by default. 75 | 76 | ## How to use SentEval: examples 77 | 78 | ### examples/bow.py 79 | 80 | In examples/bow.py, we evaluate the quality of the average of word embeddings. 81 | 82 | To download state-of-the-art fastText embeddings: 83 | 84 | ```bash 85 | curl -Lo glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip 86 | curl -Lo crawl-300d-2M.vec.zip https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip 87 | ``` 88 | 89 | To reproduce the results for bag-of-vectors, run (in examples/): 90 | ```bash 91 | python bow.py 92 | ``` 93 | 94 | As required by SentEval, this script implements two functions: **prepare** (optional) and **batcher** (required) that turn text sentences into sentence embeddings. Then SentEval takes care of the evaluation on the transfer tasks using the embeddings as features. 95 | 96 | ### examples/infersent.py 97 | 98 | To get the **[InferSent](https://www.github.com/facebookresearch/InferSent)** model and reproduce our results, download our best models and run infersent.py (in examples/): 99 | ```bash 100 | curl -Lo examples/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl 101 | curl -Lo examples/infersent2.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent2.pkl 102 | ``` 103 | 104 | ### examples/skipthought.py - examples/gensen.py - examples/googleuse.py 105 | 106 | We also provide example scripts for three other encoders: 107 | 108 | * [SkipThought with Layer-Normalization](https://github.com/ryankiros/layer-norm#skip-thoughts) in Theano 109 | * [GenSen encoder](https://github.com/Maluuba/gensen) in Pytorch 110 | * [Google encoder](https://tfhub.dev/google/universal-sentence-encoder/1) in TensorFlow 111 | 112 | Note that for SkipThought and GenSen, following the steps of the associated githubs is necessary. 113 | The Google encoder script should work as-is. 114 | 115 | ## How to use SentEval 116 | 117 | To evaluate your sentence embeddings, SentEval requires that you implement two functions: 118 | 119 | 1. **prepare** (sees the whole dataset of each task and can thus construct the word vocabulary, the dictionary of word vectors etc) 120 | 2. **batcher** (transforms a batch of text sentences into sentence embeddings) 121 | 122 | 123 | ### 1.) prepare(params, samples) (optional) 124 | 125 | *batcher* only sees one batch at a time while the *samples* argument of *prepare* contains all the sentences of a task. 126 | 127 | ``` 128 | prepare(params, samples) 129 | ``` 130 | * *params*: senteval parameters. 131 | * *samples*: list of all sentences from the tranfer task. 132 | * *output*: No output. Arguments stored in "params" can further be used by *batcher*. 133 | 134 | *Example*: in bow.py, prepare is is used to build the vocabulary of words and construct the "params.word_vect* dictionary of word vectors. 135 | 136 | 137 | ### 2.) batcher(params, batch) 138 | ``` 139 | batcher(params, batch) 140 | ``` 141 | * *params*: senteval parameters. 142 | * *batch*: numpy array of text sentences (of size params.batch_size) 143 | * *output*: numpy array of sentence embeddings (of size params.batch_size) 144 | 145 | *Example*: in bow.py, batcher is used to compute the mean of the word vectors for each sentence in the batch using params.word_vec. Use your own encoder in that function to encode sentences. 146 | 147 | ### 3.) evaluation on transfer tasks 148 | 149 | After having implemented the batch and prepare function for your own sentence encoder, 150 | 151 | 1) to perform the actual evaluation, first import senteval and set its parameters: 152 | ```python 153 | import senteval 154 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 155 | ``` 156 | 157 | 2) (optional) set the parameters of the classifier (when applicable): 158 | ```python 159 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 160 | 'tenacity': 5, 'epoch_size': 4} 161 | ``` 162 | You can choose **nhid=0** (Logistic Regression) or **nhid>0** (MLP) and define the parameters for training. 163 | 164 | 3) Create an instance of the class SE: 165 | ```python 166 | se = senteval.engine.SE(params, batcher, prepare) 167 | ``` 168 | 169 | 4) define the set of transfer tasks and run the evaluation: 170 | ```python 171 | transfer_tasks = ['MR', 'SICKEntailment', 'STS14', 'STSBenchmark'] 172 | results = se.eval(transfer_tasks) 173 | ``` 174 | The current list of available tasks is: 175 | ```python 176 | ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI', 177 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 'ImageCaptionRetrieval', 178 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 179 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense', 180 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion'] 181 | ``` 182 | 183 | ## SentEval parameters 184 | Global parameters of SentEval: 185 | ```bash 186 | # senteval parameters 187 | task_path # path to SentEval datasets (required) 188 | seed # seed 189 | usepytorch # use cuda-pytorch (else scikit-learn) where possible 190 | kfold # k-fold validation for MR/CR/SUB/MPQA. 191 | ``` 192 | 193 | Parameters of the classifier: 194 | ```bash 195 | nhid: # number of hidden units (0: Logistic Regression, >0: MLP); Default nonlinearity: Tanh 196 | optim: # optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 197 | tenacity: # how many times dev acc does not increase before training stops 198 | epoch_size: # each epoch corresponds to epoch_size pass on the train set 199 | max_epoch: # max number of epoches 200 | dropout: # dropout for MLP 201 | ``` 202 | 203 | Note that to get a proxy of the results while **dramatically reducing computation time**, 204 | we suggest the **prototyping config**: 205 | ```python 206 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 207 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 208 | 'tenacity': 3, 'epoch_size': 2} 209 | ``` 210 | which will results in a 5 times speedup for classification tasks. 211 | 212 | To produce results that are **comparable to the literature**, use the **default config**: 213 | ```python 214 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 215 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 216 | 'tenacity': 5, 'epoch_size': 4} 217 | ``` 218 | which takes longer but will produce better and comparable results. 219 | 220 | For probing tasks, we used an MLP with a Sigmoid nonlinearity and and tuned the nhid (in [50, 100, 200]) and dropout (in [0.0, 0.1, 0.2]) on the dev set. 221 | 222 | ## References 223 | 224 | Please considering citing [[1]](https://arxiv.org/abs/1803.05449) if using this code for evaluating sentence embedding methods. 225 | 226 | ### SentEval: An Evaluation Toolkit for Universal Sentence Representations 227 | 228 | [1] A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449) 229 | 230 | ``` 231 | @article{conneau2018senteval, 232 | title={SentEval: An Evaluation Toolkit for Universal Sentence Representations}, 233 | author={Conneau, Alexis and Kiela, Douwe}, 234 | journal={arXiv preprint arXiv:1803.05449}, 235 | year={2018} 236 | } 237 | ``` 238 | 239 | Contact: [aconneau@fb.com](mailto:aconneau@fb.com), [dkiela@fb.com](mailto:dkiela@fb.com) 240 | 241 | ### Related work 242 | * [J. R Kiros, Y. Zhu, R. Salakhutdinov, R. S. Zemel, A. Torralba, R. Urtasun, S. Fidler - SkipThought Vectors, NIPS 2015](https://arxiv.org/abs/1506.06726) 243 | * [S. Arora, Y. Liang, T. Ma - A Simple but Tough-to-Beat Baseline for Sentence Embeddings, ICLR 2017](https://openreview.net/pdf?id=SyK00v5xx) 244 | * [Y. Adi, E. Kermany, Y. Belinkov, O. Lavi, Y. Goldberg - Fine-grained analysis of sentence embeddings using auxiliary prediction tasks, ICLR 2017](https://arxiv.org/abs/1608.04207) 245 | * [A. Conneau, D. Kiela, L. Barrault, H. Schwenk, A. Bordes - Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, EMNLP 2017](https://arxiv.org/abs/1705.02364) 246 | * [S. Subramanian, A. Trischler, Y. Bengio, C. J Pal - Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, ICLR 2018](https://arxiv.org/abs/1804.00079) 247 | * [A. Nie, E. D. Bennett, N. D. Goodman - DisSent: Sentence Representation Learning from Explicit Discourse Relations, 2018](https://arxiv.org/abs/1710.04334) 248 | * [D. Cer, Y. Yang, S. Kong, N. Hua, N. Limtiaco, R. St. John, N. Constant, M. Guajardo-Cespedes, S. Yuan, C. Tar, Y. Sung, B. Strope, R. Kurzweil - Universal Sentence Encoder, 2018](https://arxiv.org/abs/1803.11175) 249 | * [A. Conneau, G. Kruszewski, G. Lample, L. Barrault, M. Baroni - What you can cram into a single vector: Probing sentence embeddings for linguistic properties, ACL 2018](https://arxiv.org/abs/1805.01070) 250 | -------------------------------------------------------------------------------- /simcse/models.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | 8 | import transformers 9 | from transformers import RobertaTokenizer 10 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 11 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 12 | from transformers.activations import gelu 13 | from transformers.file_utils import ( 14 | add_code_sample_docstrings, 15 | add_start_docstrings, 16 | add_start_docstrings_to_model_forward, 17 | replace_return_docstrings, 18 | ) 19 | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions 20 | 21 | class MLPLayer(nn.Module): 22 | """ 23 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 24 | """ 25 | 26 | def __init__(self, config): 27 | super().__init__() 28 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 29 | self.activation = nn.Tanh() 30 | 31 | def forward(self, features, **kwargs): 32 | x = self.dense(features) 33 | x = self.activation(x) 34 | 35 | return x 36 | 37 | class Similarity(nn.Module): 38 | """ 39 | Dot product or cosine similarity 40 | """ 41 | 42 | def __init__(self, temp): 43 | super().__init__() 44 | self.temp = temp 45 | self.cos = nn.CosineSimilarity(dim=-1) 46 | 47 | def forward(self, x, y): 48 | return self.cos(x, y) / self.temp 49 | 50 | 51 | class Pooler(nn.Module): 52 | """ 53 | Parameter-free poolers to get the sentence embedding 54 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 55 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 56 | 'avg': average of the last layers' hidden states at each token. 57 | 'avg_top2': average of the last two layers. 58 | 'avg_first_last': average of the first and the last layers. 59 | """ 60 | def __init__(self, pooler_type): 61 | super().__init__() 62 | self.pooler_type = pooler_type 63 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 64 | 65 | def forward(self, attention_mask, outputs): 66 | last_hidden = outputs.last_hidden_state 67 | pooler_output = outputs.pooler_output 68 | hidden_states = outputs.hidden_states 69 | 70 | if self.pooler_type in ['cls_before_pooler', 'cls']: 71 | return last_hidden[:, 0] 72 | elif self.pooler_type == "avg": 73 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 74 | elif self.pooler_type == "avg_first_last": 75 | first_hidden = hidden_states[1] 76 | last_hidden = hidden_states[-1] 77 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 78 | return pooled_result 79 | elif self.pooler_type == "avg_top2": 80 | second_last_hidden = hidden_states[-2] 81 | last_hidden = hidden_states[-1] 82 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 83 | return pooled_result 84 | else: 85 | raise NotImplementedError 86 | 87 | 88 | def cl_init(cls, config): 89 | """ 90 | Contrastive learning class init function. 91 | """ 92 | cls.pooler_type = cls.model_args.pooler_type 93 | cls.pooler = Pooler(cls.model_args.pooler_type) 94 | if cls.model_args.pooler_type == "cls": 95 | cls.mlp = MLPLayer(config) 96 | cls.sim = Similarity(temp=cls.model_args.temp) 97 | cls.init_weights() 98 | 99 | def cl_forward(cls, 100 | encoder, 101 | input_ids=None, 102 | attention_mask=None, 103 | token_type_ids=None, 104 | position_ids=None, 105 | head_mask=None, 106 | inputs_embeds=None, 107 | labels=None, 108 | output_attentions=None, 109 | output_hidden_states=None, 110 | return_dict=None, 111 | mlm_input_ids=None, 112 | mlm_labels=None, 113 | ): 114 | 115 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 116 | ori_input_ids = input_ids 117 | batch_size = input_ids.size(0) 118 | # Number of sentences in one instance 119 | # 2: pair instance; 3: pair instance with a hard negative 120 | num_sent = input_ids.size(1) 121 | 122 | mlm_outputs = None 123 | # Flatten input for encoding 124 | input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len) 125 | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len) 126 | if token_type_ids is not None: 127 | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len) 128 | 129 | # Get raw embeddings 130 | outputs = encoder( 131 | input_ids, 132 | attention_mask=attention_mask, 133 | token_type_ids=token_type_ids, 134 | position_ids=position_ids, 135 | head_mask=head_mask, 136 | inputs_embeds=inputs_embeds, 137 | output_attentions=output_attentions, 138 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 139 | return_dict=True, 140 | ) 141 | 142 | # MLM auxiliary objective 143 | if mlm_input_ids is not None: 144 | mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) 145 | mlm_outputs = encoder( 146 | mlm_input_ids, 147 | attention_mask=attention_mask, 148 | token_type_ids=token_type_ids, 149 | position_ids=position_ids, 150 | head_mask=head_mask, 151 | inputs_embeds=inputs_embeds, 152 | output_attentions=output_attentions, 153 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 154 | return_dict=True, 155 | ) 156 | 157 | # Pooling 158 | pooler_output = cls.pooler(attention_mask, outputs) 159 | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden) 160 | 161 | # If using "cls", we add an extra MLP layer 162 | # (same as BERT's original implementation) over the representation. 163 | if cls.pooler_type == "cls": 164 | pooler_output = cls.mlp(pooler_output) 165 | 166 | # Separate representation 167 | z1, z2 = pooler_output[:,0], pooler_output[:,1] 168 | 169 | # Hard negative 170 | if num_sent == 3: 171 | z3 = pooler_output[:, 2] 172 | 173 | # Gather all embeddings if using distributed training 174 | if dist.is_initialized() and cls.training: 175 | # Gather hard negative 176 | if num_sent >= 3: 177 | z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())] 178 | dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous()) 179 | z3_list[dist.get_rank()] = z3 180 | z3 = torch.cat(z3_list, 0) 181 | 182 | # Dummy vectors for allgather 183 | z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())] 184 | z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())] 185 | # Allgather 186 | dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous()) 187 | dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous()) 188 | 189 | # Since allgather results do not have gradients, we replace the 190 | # current process's corresponding embeddings with original tensors 191 | z1_list[dist.get_rank()] = z1 192 | z2_list[dist.get_rank()] = z2 193 | # Get full batch embeddings: (bs x N, hidden) 194 | z1 = torch.cat(z1_list, 0) 195 | z2 = torch.cat(z2_list, 0) 196 | 197 | cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0)) 198 | # Hard negative 199 | 200 | if num_sent >= 3: 201 | z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0)) 202 | cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) 203 | 204 | labels = torch.arange(cos_sim.size(0)).long().to(cls.device) 205 | loss_fct = nn.CrossEntropyLoss() 206 | 207 | # Calculate loss with hard negatives 208 | if num_sent == 3: 209 | # Note that weights are actually logits of weights 210 | z3_weight = cls.model_args.hard_negative_weight 211 | weights = torch.tensor( 212 | [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))] 213 | ).to(cls.device) 214 | cos_sim = cos_sim + weights 215 | 216 | loss = loss_fct(cos_sim, labels) 217 | 218 | # Calculate loss for MLM 219 | if mlm_outputs is not None and mlm_labels is not None: 220 | mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) 221 | prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state) 222 | masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1)) 223 | loss = loss + cls.model_args.mlm_weight * masked_lm_loss 224 | 225 | if not return_dict: 226 | output = (cos_sim,) + outputs[2:] 227 | return ((loss,) + output) if loss is not None else output 228 | return SequenceClassifierOutput( 229 | loss=loss, 230 | logits=cos_sim, 231 | hidden_states=outputs.hidden_states, 232 | attentions=outputs.attentions, 233 | ) 234 | 235 | 236 | def sentemb_forward( 237 | cls, 238 | encoder, 239 | input_ids=None, 240 | attention_mask=None, 241 | token_type_ids=None, 242 | position_ids=None, 243 | head_mask=None, 244 | inputs_embeds=None, 245 | labels=None, 246 | output_attentions=None, 247 | output_hidden_states=None, 248 | return_dict=None, 249 | ): 250 | 251 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 252 | 253 | outputs = encoder( 254 | input_ids, 255 | attention_mask=attention_mask, 256 | token_type_ids=token_type_ids, 257 | position_ids=position_ids, 258 | head_mask=head_mask, 259 | inputs_embeds=inputs_embeds, 260 | output_attentions=output_attentions, 261 | output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False, 262 | return_dict=True, 263 | ) 264 | 265 | pooler_output = cls.pooler(attention_mask, outputs) 266 | if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train: 267 | pooler_output = cls.mlp(pooler_output) 268 | 269 | if not return_dict: 270 | return (outputs[0], pooler_output) + outputs[2:] 271 | 272 | return BaseModelOutputWithPoolingAndCrossAttentions( 273 | pooler_output=pooler_output, 274 | last_hidden_state=outputs.last_hidden_state, 275 | hidden_states=outputs.hidden_states, 276 | ) 277 | 278 | 279 | class BertForCL(BertPreTrainedModel): 280 | _keys_to_ignore_on_load_missing = [r"position_ids"] 281 | 282 | def __init__(self, config, *model_args, **model_kargs): 283 | super().__init__(config) 284 | self.model_args = model_kargs["model_args"] 285 | self.bert = BertModel(config, add_pooling_layer=False) 286 | 287 | if self.model_args.do_mlm: 288 | self.lm_head = BertLMPredictionHead(config) 289 | 290 | cl_init(self, config) 291 | 292 | def forward(self, 293 | input_ids=None, 294 | attention_mask=None, 295 | token_type_ids=None, 296 | position_ids=None, 297 | head_mask=None, 298 | inputs_embeds=None, 299 | labels=None, 300 | output_attentions=None, 301 | output_hidden_states=None, 302 | return_dict=None, 303 | sent_emb=False, 304 | mlm_input_ids=None, 305 | mlm_labels=None, 306 | ): 307 | if sent_emb: 308 | return sentemb_forward(self, self.bert, 309 | input_ids=input_ids, 310 | attention_mask=attention_mask, 311 | token_type_ids=token_type_ids, 312 | position_ids=position_ids, 313 | head_mask=head_mask, 314 | inputs_embeds=inputs_embeds, 315 | labels=labels, 316 | output_attentions=output_attentions, 317 | output_hidden_states=output_hidden_states, 318 | return_dict=return_dict, 319 | ) 320 | else: 321 | return cl_forward(self, self.bert, 322 | input_ids=input_ids, 323 | attention_mask=attention_mask, 324 | token_type_ids=token_type_ids, 325 | position_ids=position_ids, 326 | head_mask=head_mask, 327 | inputs_embeds=inputs_embeds, 328 | labels=labels, 329 | output_attentions=output_attentions, 330 | output_hidden_states=output_hidden_states, 331 | return_dict=return_dict, 332 | mlm_input_ids=mlm_input_ids, 333 | mlm_labels=mlm_labels, 334 | ) 335 | 336 | 337 | 338 | class RobertaForCL(RobertaPreTrainedModel): 339 | _keys_to_ignore_on_load_missing = [r"position_ids"] 340 | 341 | def __init__(self, config, *model_args, **model_kargs): 342 | super().__init__(config) 343 | self.model_args = model_kargs["model_args"] 344 | self.roberta = RobertaModel(config, add_pooling_layer=False) 345 | 346 | if self.model_args.do_mlm: 347 | self.lm_head = RobertaLMHead(config) 348 | 349 | cl_init(self, config) 350 | 351 | def forward(self, 352 | input_ids=None, 353 | attention_mask=None, 354 | token_type_ids=None, 355 | position_ids=None, 356 | head_mask=None, 357 | inputs_embeds=None, 358 | labels=None, 359 | output_attentions=None, 360 | output_hidden_states=None, 361 | return_dict=None, 362 | sent_emb=False, 363 | mlm_input_ids=None, 364 | mlm_labels=None, 365 | ): 366 | if sent_emb: 367 | return sentemb_forward(self, self.roberta, 368 | input_ids=input_ids, 369 | attention_mask=attention_mask, 370 | token_type_ids=token_type_ids, 371 | position_ids=position_ids, 372 | head_mask=head_mask, 373 | inputs_embeds=inputs_embeds, 374 | labels=labels, 375 | output_attentions=output_attentions, 376 | output_hidden_states=output_hidden_states, 377 | return_dict=return_dict, 378 | ) 379 | else: 380 | return cl_forward(self, self.roberta, 381 | input_ids=input_ids, 382 | attention_mask=attention_mask, 383 | token_type_ids=token_type_ids, 384 | position_ids=position_ids, 385 | head_mask=head_mask, 386 | inputs_embeds=inputs_embeds, 387 | labels=labels, 388 | output_attentions=output_attentions, 389 | output_hidden_states=output_hidden_states, 390 | return_dict=return_dict, 391 | mlm_input_ids=mlm_input_ids, 392 | mlm_labels=mlm_labels, 393 | ) 394 | -------------------------------------------------------------------------------- /simcse/models_HSCL.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | 8 | import transformers 9 | from transformers import RobertaTokenizer 10 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 11 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 12 | from transformers.activations import gelu 13 | from transformers.file_utils import ( 14 | add_code_sample_docstrings, 15 | add_start_docstrings, 16 | add_start_docstrings_to_model_forward, 17 | replace_return_docstrings, 18 | ) 19 | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions 20 | 21 | class MLPLayer(nn.Module): 22 | """ 23 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 24 | """ 25 | 26 | def __init__(self, config): 27 | super().__init__() 28 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 29 | self.activation = nn.Tanh() 30 | 31 | def forward(self, features, **kwargs): 32 | x = self.dense(features) 33 | x = self.activation(x) 34 | 35 | return x 36 | 37 | class Similarity(nn.Module): 38 | """ 39 | Dot product or cosine similarity 40 | """ 41 | 42 | def __init__(self, temp): 43 | super().__init__() 44 | self.temp = temp 45 | self.cos = nn.CosineSimilarity(dim=-1) 46 | 47 | def forward(self, x, y): 48 | return self.cos(x, y) / self.temp 49 | 50 | 51 | class Pooler(nn.Module): 52 | """ 53 | Parameter-free poolers to get the sentence embedding 54 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 55 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 56 | 'avg': average of the last layers' hidden states at each token. 57 | 'avg_top2': average of the last two layers. 58 | 'avg_first_last': average of the first and the last layers. 59 | """ 60 | def __init__(self, pooler_type): 61 | super().__init__() 62 | self.pooler_type = pooler_type 63 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 64 | 65 | def forward(self, attention_mask, outputs): 66 | last_hidden = outputs.last_hidden_state 67 | pooler_output = outputs.pooler_output 68 | hidden_states = outputs.hidden_states 69 | 70 | if self.pooler_type in ['cls_before_pooler', 'cls']: 71 | return last_hidden[:, 0] 72 | elif self.pooler_type == "avg": 73 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 74 | elif self.pooler_type == "avg_first_last": 75 | first_hidden = hidden_states[1] 76 | last_hidden = hidden_states[-1] 77 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 78 | return pooled_result 79 | elif self.pooler_type == "avg_top2": 80 | second_last_hidden = hidden_states[-2] 81 | last_hidden = hidden_states[-1] 82 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 83 | return pooled_result 84 | else: 85 | raise NotImplementedError 86 | 87 | 88 | def cl_init(cls, config): 89 | """ 90 | Contrastive learning class init function. 91 | """ 92 | cls.pooler_type = cls.model_args.pooler_type 93 | cls.pooler = Pooler(cls.model_args.pooler_type) 94 | if cls.model_args.pooler_type == "cls": 95 | cls.mlp = MLPLayer(config) 96 | cls.sim = Similarity(temp=cls.model_args.temp) 97 | cls.init_weights() 98 | 99 | def cl_forward(cls, 100 | encoder, 101 | input_ids=None, 102 | attention_mask=None, 103 | token_type_ids=None, 104 | position_ids=None, 105 | head_mask=None, 106 | inputs_embeds=None, 107 | labels=None, 108 | output_attentions=None, 109 | output_hidden_states=None, 110 | return_dict=None, 111 | mlm_input_ids=None, 112 | mlm_labels=None, 113 | ): 114 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 115 | ori_input_ids = input_ids 116 | batch_size = input_ids.size(0) 117 | # Number of sentences in one instance 118 | # 2: pair instance; 3: pair instance with a hard negative 119 | num_sent = input_ids.size(1) 120 | 121 | mlm_outputs = None 122 | # Flatten input for encoding 123 | input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len) 124 | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len) 125 | if token_type_ids is not None: 126 | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len) 127 | 128 | # Get raw embeddings 129 | outputs = encoder( 130 | input_ids, 131 | attention_mask=attention_mask, 132 | token_type_ids=token_type_ids, 133 | position_ids=position_ids, 134 | head_mask=head_mask, 135 | inputs_embeds=inputs_embeds, 136 | output_attentions=output_attentions, 137 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 138 | return_dict=True, 139 | ) 140 | 141 | # MLM auxiliary objective 142 | if mlm_input_ids is not None: 143 | mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) 144 | mlm_outputs = encoder( 145 | mlm_input_ids, 146 | attention_mask=attention_mask, 147 | token_type_ids=token_type_ids, 148 | position_ids=position_ids, 149 | head_mask=head_mask, 150 | inputs_embeds=inputs_embeds, 151 | output_attentions=output_attentions, 152 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 153 | return_dict=True, 154 | ) 155 | 156 | # Pooling 157 | pooler_output = cls.pooler(attention_mask, outputs) 158 | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden) 159 | 160 | # If using "cls", we add an extra MLP layer 161 | # (same as BERT's original implementation) over the representation. 162 | if cls.pooler_type == "cls": 163 | pooler_output = cls.mlp(pooler_output) 164 | 165 | # Separate representation 166 | z1, z2 = pooler_output[:,0], pooler_output[:,1] 167 | 168 | # Hard negative 169 | if num_sent == 3: 170 | z3 = pooler_output[:, 2] 171 | 172 | # Gather all embeddings if using distributed training 173 | if dist.is_initialized() and cls.training: 174 | # Gather hard negative 175 | if num_sent >= 3: 176 | z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())] 177 | dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous()) 178 | z3_list[dist.get_rank()] = z3 179 | z3 = torch.cat(z3_list, 0) 180 | 181 | # Dummy vectors for allgather 182 | z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())] 183 | z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())] 184 | # Allgather 185 | dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous()) 186 | dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous()) 187 | 188 | # Since allgather results do not have gradients, we replace the 189 | # current process's corresponding embeddings with original tensors 190 | z1_list[dist.get_rank()] = z1 191 | z2_list[dist.get_rank()] = z2 192 | # Get full batch embeddings: (bs x N, hidden) 193 | z1 = torch.cat(z1_list, 0) 194 | z2 = torch.cat(z2_list, 0) 195 | 196 | cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0)) 197 | # Hard negative 198 | if num_sent >= 3: 199 | 200 | z1_z3_cos = cls.sim(z1, z3) 201 | cos_sim = torch.cat([cos_sim, z1_z3_cos.unsqueeze(1)], 1) 202 | 203 | labels = torch.arange(cos_sim.size(0)).long().to(cls.device) 204 | loss_fct = nn.CrossEntropyLoss() 205 | # Calculate loss with hard negatives 206 | if num_sent == 3: 207 | # Note that weights are actually logits of weights 208 | z3_weight = cls.model_args.hard_negative_weight 209 | weights = torch.tensor( 210 | [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))] 211 | ).to(cls.device) 212 | cos_sim = cos_sim + weights 213 | 214 | loss = loss_fct(cos_sim, labels) 215 | 216 | # Calculate loss for MLM 217 | if mlm_outputs is not None and mlm_labels is not None: 218 | mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) 219 | prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state) 220 | masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1)) 221 | loss = loss + cls.model_args.mlm_weight * masked_lm_loss 222 | 223 | if not return_dict: 224 | output = (cos_sim,) + outputs[2:] 225 | return ((loss,) + output) if loss is not None else output 226 | return SequenceClassifierOutput( 227 | loss=loss, 228 | logits=cos_sim, 229 | hidden_states=outputs.hidden_states, 230 | attentions=outputs.attentions, 231 | ) 232 | 233 | 234 | def sentemb_forward( 235 | cls, 236 | encoder, 237 | input_ids=None, 238 | attention_mask=None, 239 | token_type_ids=None, 240 | position_ids=None, 241 | head_mask=None, 242 | inputs_embeds=None, 243 | labels=None, 244 | output_attentions=None, 245 | output_hidden_states=None, 246 | return_dict=None, 247 | ): 248 | 249 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 250 | 251 | outputs = encoder( 252 | input_ids, 253 | attention_mask=attention_mask, 254 | token_type_ids=token_type_ids, 255 | position_ids=position_ids, 256 | head_mask=head_mask, 257 | inputs_embeds=inputs_embeds, 258 | output_attentions=output_attentions, 259 | output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False, 260 | return_dict=True, 261 | ) 262 | 263 | pooler_output = cls.pooler(attention_mask, outputs) 264 | if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train: 265 | pooler_output = cls.mlp(pooler_output) 266 | 267 | if not return_dict: 268 | return (outputs[0], pooler_output) + outputs[2:] 269 | 270 | return BaseModelOutputWithPoolingAndCrossAttentions( 271 | pooler_output=pooler_output, 272 | last_hidden_state=outputs.last_hidden_state, 273 | hidden_states=outputs.hidden_states, 274 | ) 275 | 276 | 277 | class BertForCL(BertPreTrainedModel): 278 | _keys_to_ignore_on_load_missing = [r"position_ids"] 279 | 280 | def __init__(self, config, *model_args, **model_kargs): 281 | super().__init__(config) 282 | self.model_args = model_kargs["model_args"] 283 | self.bert = BertModel(config, add_pooling_layer=False) 284 | 285 | if self.model_args.do_mlm: 286 | self.lm_head = BertLMPredictionHead(config) 287 | 288 | cl_init(self, config) 289 | 290 | def forward(self, 291 | input_ids=None, 292 | attention_mask=None, 293 | token_type_ids=None, 294 | position_ids=None, 295 | head_mask=None, 296 | inputs_embeds=None, 297 | labels=None, 298 | output_attentions=None, 299 | output_hidden_states=None, 300 | return_dict=None, 301 | sent_emb=False, 302 | mlm_input_ids=None, 303 | mlm_labels=None, 304 | ): 305 | if sent_emb: 306 | return sentemb_forward(self, self.bert, 307 | input_ids=input_ids, 308 | attention_mask=attention_mask, 309 | token_type_ids=token_type_ids, 310 | position_ids=position_ids, 311 | head_mask=head_mask, 312 | inputs_embeds=inputs_embeds, 313 | labels=labels, 314 | output_attentions=output_attentions, 315 | output_hidden_states=output_hidden_states, 316 | return_dict=return_dict, 317 | ) 318 | else: 319 | return cl_forward(self, self.bert, 320 | input_ids=input_ids, 321 | attention_mask=attention_mask, 322 | token_type_ids=token_type_ids, 323 | position_ids=position_ids, 324 | head_mask=head_mask, 325 | inputs_embeds=inputs_embeds, 326 | labels=labels, 327 | output_attentions=output_attentions, 328 | output_hidden_states=output_hidden_states, 329 | return_dict=return_dict, 330 | mlm_input_ids=mlm_input_ids, 331 | mlm_labels=mlm_labels, 332 | ) 333 | 334 | 335 | 336 | class RobertaForCL(RobertaPreTrainedModel): 337 | _keys_to_ignore_on_load_missing = [r"position_ids"] 338 | 339 | def __init__(self, config, *model_args, **model_kargs): 340 | super().__init__(config) 341 | self.model_args = model_kargs["model_args"] 342 | self.roberta = RobertaModel(config, add_pooling_layer=False) 343 | 344 | if self.model_args.do_mlm: 345 | self.lm_head = RobertaLMHead(config) 346 | 347 | cl_init(self, config) 348 | 349 | def forward(self, 350 | input_ids=None, 351 | attention_mask=None, 352 | token_type_ids=None, 353 | position_ids=None, 354 | head_mask=None, 355 | inputs_embeds=None, 356 | labels=None, 357 | output_attentions=None, 358 | output_hidden_states=None, 359 | return_dict=None, 360 | sent_emb=False, 361 | mlm_input_ids=None, 362 | mlm_labels=None, 363 | ): 364 | if sent_emb: 365 | return sentemb_forward(self, self.roberta, 366 | input_ids=input_ids, 367 | attention_mask=attention_mask, 368 | token_type_ids=token_type_ids, 369 | position_ids=position_ids, 370 | head_mask=head_mask, 371 | inputs_embeds=inputs_embeds, 372 | labels=labels, 373 | output_attentions=output_attentions, 374 | output_hidden_states=output_hidden_states, 375 | return_dict=return_dict, 376 | ) 377 | else: 378 | return cl_forward(self, self.roberta, 379 | input_ids=input_ids, 380 | attention_mask=attention_mask, 381 | token_type_ids=token_type_ids, 382 | position_ids=position_ids, 383 | head_mask=head_mask, 384 | inputs_embeds=inputs_embeds, 385 | labels=labels, 386 | output_attentions=output_attentions, 387 | output_hidden_states=output_hidden_states, 388 | return_dict=return_dict, 389 | mlm_input_ids=mlm_input_ids, 390 | mlm_labels=mlm_labels, 391 | ) 392 | -------------------------------------------------------------------------------- /simcse/models_aug.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | 8 | import transformers 9 | from transformers import RobertaTokenizer 10 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 11 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 12 | from transformers.activations import gelu 13 | from transformers.file_utils import ( 14 | add_code_sample_docstrings, 15 | add_start_docstrings, 16 | add_start_docstrings_to_model_forward, 17 | replace_return_docstrings, 18 | ) 19 | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions 20 | 21 | class MLPLayer(nn.Module): 22 | """ 23 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 24 | """ 25 | 26 | def __init__(self, config): 27 | super().__init__() 28 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 29 | self.activation = nn.Tanh() 30 | 31 | def forward(self, features, **kwargs): 32 | x = self.dense(features) 33 | x = self.activation(x) 34 | 35 | return x 36 | 37 | class Similarity(nn.Module): 38 | """ 39 | Dot product or cosine similarity 40 | """ 41 | 42 | def __init__(self, temp): 43 | super().__init__() 44 | self.temp = temp 45 | self.cos = nn.CosineSimilarity(dim=-1) 46 | 47 | def forward(self, x, y): 48 | return self.cos(x, y) / self.temp 49 | 50 | 51 | class Pooler(nn.Module): 52 | """ 53 | Parameter-free poolers to get the sentence embedding 54 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 55 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 56 | 'avg': average of the last layers' hidden states at each token. 57 | 'avg_top2': average of the last two layers. 58 | 'avg_first_last': average of the first and the last layers. 59 | """ 60 | def __init__(self, pooler_type): 61 | super().__init__() 62 | self.pooler_type = pooler_type 63 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 64 | 65 | def forward(self, attention_mask, outputs): 66 | last_hidden = outputs.last_hidden_state 67 | pooler_output = outputs.pooler_output 68 | hidden_states = outputs.hidden_states 69 | 70 | if self.pooler_type in ['cls_before_pooler', 'cls']: 71 | return last_hidden[:, 0] 72 | elif self.pooler_type == "avg": 73 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 74 | elif self.pooler_type == "avg_first_last": 75 | first_hidden = hidden_states[1] 76 | last_hidden = hidden_states[-1] 77 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 78 | return pooled_result 79 | elif self.pooler_type == "avg_top2": 80 | second_last_hidden = hidden_states[-2] 81 | last_hidden = hidden_states[-1] 82 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 83 | return pooled_result 84 | else: 85 | raise NotImplementedError 86 | 87 | 88 | def cl_init(cls, config): 89 | """ 90 | Contrastive learning class init function. 91 | """ 92 | cls.pooler_type = cls.model_args.pooler_type 93 | cls.pooler = Pooler(cls.model_args.pooler_type) 94 | if cls.model_args.pooler_type == "cls": 95 | cls.mlp = MLPLayer(config) 96 | cls.sim = Similarity(temp=cls.model_args.temp) 97 | cls.init_weights() 98 | 99 | def cl_forward(cls, 100 | encoder, 101 | input_ids=None, 102 | attention_mask=None, 103 | token_type_ids=None, 104 | position_ids=None, 105 | head_mask=None, 106 | inputs_embeds=None, 107 | labels=None, 108 | output_attentions=None, 109 | output_hidden_states=None, 110 | return_dict=None, 111 | mlm_input_ids=None, 112 | mlm_labels=None, 113 | ): 114 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 115 | ori_input_ids = input_ids 116 | batch_size = input_ids.size(0) 117 | # Number of sentences in one instance 118 | # 2: pair instance; 3: pair instance with a hard negative 119 | num_sent = input_ids.size(1) 120 | 121 | mlm_outputs = None 122 | # Flatten input for encoding 123 | input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len) 124 | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len) 125 | if token_type_ids is not None: 126 | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len) 127 | 128 | # Get raw embeddings 129 | outputs = encoder( 130 | input_ids, 131 | attention_mask=attention_mask, 132 | token_type_ids=token_type_ids, 133 | position_ids=position_ids, 134 | head_mask=head_mask, 135 | inputs_embeds=inputs_embeds, 136 | output_attentions=output_attentions, 137 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 138 | return_dict=True, 139 | ) 140 | 141 | # MLM auxiliary objective 142 | if mlm_input_ids is not None: 143 | mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) 144 | mlm_outputs = encoder( 145 | mlm_input_ids, 146 | attention_mask=attention_mask, 147 | token_type_ids=token_type_ids, 148 | position_ids=position_ids, 149 | head_mask=head_mask, 150 | inputs_embeds=inputs_embeds, 151 | output_attentions=output_attentions, 152 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 153 | return_dict=True, 154 | ) 155 | 156 | # Pooling 157 | pooler_output = cls.pooler(attention_mask, outputs) 158 | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden) 159 | 160 | # If using "cls", we add an extra MLP layer 161 | # (same as BERT's original implementation) over the representation. 162 | if cls.pooler_type == "cls": 163 | pooler_output = cls.mlp(pooler_output) 164 | 165 | # Separate representation 166 | z1, z2 = pooler_output[:,0], pooler_output[:,1] 167 | 168 | # Hard negative 169 | if num_sent >= 3: 170 | z3 = pooler_output[:, 2] 171 | z4 = pooler_output[:, 3] 172 | z5 = pooler_output[:, 4] 173 | 174 | # Gather all embeddings if using distributed training 175 | if dist.is_initialized() and cls.training: 176 | # Gather hard negative 177 | if num_sent >= 3: 178 | z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())] 179 | dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous()) 180 | z3_list[dist.get_rank()] = z3 181 | z3 = torch.cat(z3_list, 0) 182 | 183 | z4_list = [torch.zeros_like(z4) for _ in range(dist.get_world_size())] 184 | dist.all_gather(tensor_list=z4_list, tensor=z4.contiguous()) 185 | z4_list[dist.get_rank()] = z4 186 | z4 = torch.cat(z4_list, 0) 187 | 188 | z5_list = [torch.zeros_like(z5) for _ in range(dist.get_world_size())] 189 | dist.all_gather(tensor_list=z5_list, tensor=z5.contiguous()) 190 | z5_list[dist.get_rank()] = z5 191 | z5 = torch.cat(z5_list, 0) 192 | 193 | # Dummy vectors for allgather 194 | z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())] 195 | z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())] 196 | # Allgather 197 | dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous()) 198 | dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous()) 199 | 200 | # Since allgather results do not have gradients, we replace the 201 | # current process's corresponding embeddings with original tensors 202 | z1_list[dist.get_rank()] = z1 203 | z2_list[dist.get_rank()] = z2 204 | # Get full batch embeddings: (bs x N, hidden) 205 | z1 = torch.cat(z1_list, 0) 206 | z2 = torch.cat(z2_list, 0) 207 | 208 | cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0)) 209 | 210 | # Hard negative 211 | if num_sent >= 3: 212 | z1_z3_cos = cls.sim(z1.unsqueeze(1), z5.unsqueeze(0)) 213 | z1_z4_cos = cls.sim(z1, z4) 214 | # z1_z5_cos = cls.sim(z1, z4) 215 | cos_sim = torch.cat([cos_sim, z1_z3_cos, z1_z4_cos.unsqueeze(1)], 1) 216 | 217 | labels = torch.arange(cos_sim.size(0)).long().to(cls.device) 218 | loss_fct = nn.CrossEntropyLoss() 219 | 220 | # Calculate loss with hard negatives 221 | if num_sent >= 3: 222 | # Note that weights are actually logits of weights 223 | z3_weight = cls.model_args.hard_negative_weight 224 | weights = torch.tensor( 225 | [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1) -1) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) + [z3_weight] for i in range(z1_z3_cos.size(-1))] 226 | ).to(cls.device) 227 | cos_sim = cos_sim + weights 228 | loss = loss_fct(cos_sim, labels) 229 | 230 | # Calculate loss for MLM 231 | if mlm_outputs is not None and mlm_labels is not None: 232 | mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) 233 | prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state) 234 | masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1)) 235 | loss = loss + cls.model_args.mlm_weight * masked_lm_loss 236 | 237 | if not return_dict: 238 | output = (cos_sim,) + outputs[2:] 239 | return ((loss,) + output) if loss is not None else output 240 | return SequenceClassifierOutput( 241 | loss=loss, 242 | logits=cos_sim, 243 | hidden_states=outputs.hidden_states, 244 | attentions=outputs.attentions, 245 | ) 246 | 247 | 248 | def sentemb_forward( 249 | cls, 250 | encoder, 251 | input_ids=None, 252 | attention_mask=None, 253 | token_type_ids=None, 254 | position_ids=None, 255 | head_mask=None, 256 | inputs_embeds=None, 257 | labels=None, 258 | output_attentions=None, 259 | output_hidden_states=None, 260 | return_dict=None, 261 | ): 262 | 263 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 264 | 265 | outputs = encoder( 266 | input_ids, 267 | attention_mask=attention_mask, 268 | token_type_ids=token_type_ids, 269 | position_ids=position_ids, 270 | head_mask=head_mask, 271 | inputs_embeds=inputs_embeds, 272 | output_attentions=output_attentions, 273 | output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False, 274 | return_dict=True, 275 | ) 276 | 277 | pooler_output = cls.pooler(attention_mask, outputs) 278 | if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train: 279 | pooler_output = cls.mlp(pooler_output) 280 | 281 | if not return_dict: 282 | return (outputs[0], pooler_output) + outputs[2:] 283 | 284 | return BaseModelOutputWithPoolingAndCrossAttentions( 285 | pooler_output=pooler_output, 286 | last_hidden_state=outputs.last_hidden_state, 287 | hidden_states=outputs.hidden_states, 288 | ) 289 | 290 | 291 | class BertForCL(BertPreTrainedModel): 292 | _keys_to_ignore_on_load_missing = [r"position_ids"] 293 | 294 | def __init__(self, config, *model_args, **model_kargs): 295 | super().__init__(config) 296 | self.model_args = model_kargs["model_args"] 297 | self.bert = BertModel(config, add_pooling_layer=False) 298 | 299 | if self.model_args.do_mlm: 300 | self.lm_head = BertLMPredictionHead(config) 301 | 302 | cl_init(self, config) 303 | 304 | def forward(self, 305 | input_ids=None, 306 | attention_mask=None, 307 | token_type_ids=None, 308 | position_ids=None, 309 | head_mask=None, 310 | inputs_embeds=None, 311 | labels=None, 312 | output_attentions=None, 313 | output_hidden_states=None, 314 | return_dict=None, 315 | sent_emb=False, 316 | mlm_input_ids=None, 317 | mlm_labels=None, 318 | ): 319 | if sent_emb: 320 | return sentemb_forward(self, self.bert, 321 | input_ids=input_ids, 322 | attention_mask=attention_mask, 323 | token_type_ids=token_type_ids, 324 | position_ids=position_ids, 325 | head_mask=head_mask, 326 | inputs_embeds=inputs_embeds, 327 | labels=labels, 328 | output_attentions=output_attentions, 329 | output_hidden_states=output_hidden_states, 330 | return_dict=return_dict, 331 | ) 332 | else: 333 | return cl_forward(self, self.bert, 334 | input_ids=input_ids, 335 | attention_mask=attention_mask, 336 | token_type_ids=token_type_ids, 337 | position_ids=position_ids, 338 | head_mask=head_mask, 339 | inputs_embeds=inputs_embeds, 340 | labels=labels, 341 | output_attentions=output_attentions, 342 | output_hidden_states=output_hidden_states, 343 | return_dict=return_dict, 344 | mlm_input_ids=mlm_input_ids, 345 | mlm_labels=mlm_labels, 346 | ) 347 | 348 | 349 | 350 | class RobertaForCL(RobertaPreTrainedModel): 351 | _keys_to_ignore_on_load_missing = [r"position_ids"] 352 | 353 | def __init__(self, config, *model_args, **model_kargs): 354 | super().__init__(config) 355 | self.model_args = model_kargs["model_args"] 356 | self.roberta = RobertaModel(config, add_pooling_layer=False) 357 | 358 | if self.model_args.do_mlm: 359 | self.lm_head = RobertaLMHead(config) 360 | 361 | cl_init(self, config) 362 | 363 | def forward(self, 364 | input_ids=None, 365 | attention_mask=None, 366 | token_type_ids=None, 367 | position_ids=None, 368 | head_mask=None, 369 | inputs_embeds=None, 370 | labels=None, 371 | output_attentions=None, 372 | output_hidden_states=None, 373 | return_dict=None, 374 | sent_emb=False, 375 | mlm_input_ids=None, 376 | mlm_labels=None, 377 | ): 378 | if sent_emb: 379 | return sentemb_forward(self, self.roberta, 380 | input_ids=input_ids, 381 | attention_mask=attention_mask, 382 | token_type_ids=token_type_ids, 383 | position_ids=position_ids, 384 | head_mask=head_mask, 385 | inputs_embeds=inputs_embeds, 386 | labels=labels, 387 | output_attentions=output_attentions, 388 | output_hidden_states=output_hidden_states, 389 | return_dict=return_dict, 390 | ) 391 | else: 392 | return cl_forward(self, self.roberta, 393 | input_ids=input_ids, 394 | attention_mask=attention_mask, 395 | token_type_ids=token_type_ids, 396 | position_ids=position_ids, 397 | head_mask=head_mask, 398 | inputs_embeds=inputs_embeds, 399 | labels=labels, 400 | output_attentions=output_attentions, 401 | output_hidden_states=output_hidden_states, 402 | return_dict=return_dict, 403 | mlm_input_ids=mlm_input_ids, 404 | mlm_labels=mlm_labels, 405 | ) 406 | --------------------------------------------------------------------------------