├── simcse ├── __init__.py ├── tool.py ├── train.py └── trainers.py ├── figure └── model.png ├── data └── download_wiki.sh ├── requirements.txt ├── senteval ├── binary.py ├── sst.py ├── mrpc.py ├── snli.py ├── rank.py ├── tools │ ├── relatedness.py │ ├── classifier.py │ ├── validation.py │ └── ranking.py ├── engine.py ├── probing.py ├── sick.py └── sts.py ├── README.md ├── run.sh ├── evaluation.py └── train.py /simcse/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import SimCSE 2 | -------------------------------------------------------------------------------- /figure/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCAIBox/DCLR/HEAD/figure/model.png -------------------------------------------------------------------------------- /data/download_wiki.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.2.1 2 | scipy==1.5.4 3 | datasets==1.2.1 4 | pandas==1.1.5 5 | scikit-learn==0.24.0 6 | prettytable==2.1.0 7 | gradio 8 | torch 9 | setuptools==49.3.0 -------------------------------------------------------------------------------- /senteval/binary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import io 14 | import os 15 | import numpy as np 16 | import logging 17 | 18 | from senteval.tools.validation import InnerKFoldClassifier 19 | 20 | 21 | class BinaryClassifierEval(object): 22 | def __init__(self, pos, neg, seed=1111): 23 | self.seed = seed 24 | self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg) 25 | self.n_samples = len(self.samples) 26 | 27 | def do_prepare(self, params, prepare): 28 | # prepare is given the whole text 29 | return prepare(params, self.samples) 30 | # prepare puts everything it outputs in "params" : params.word2id etc 31 | # Those output will be further used by "batcher". 32 | 33 | def loadFile(self, fpath): 34 | with io.open(fpath, 'r', encoding='latin-1') as f: 35 | return [line.split() for line in f.read().splitlines()] 36 | 37 | def run(self, params, batcher): 38 | enc_input = [] 39 | # Sort to reduce padding 40 | sorted_corpus = sorted(zip(self.samples, self.labels), 41 | key=lambda z: (len(z[0]), z[1])) 42 | sorted_samples = [x for (x, y) in sorted_corpus] 43 | sorted_labels = [y for (x, y) in sorted_corpus] 44 | logging.info('Generating sentence embeddings') 45 | for ii in range(0, self.n_samples, params.batch_size): 46 | batch = sorted_samples[ii:ii + params.batch_size] 47 | embeddings = batcher(params, batch) 48 | enc_input.append(embeddings) 49 | enc_input = np.vstack(enc_input) 50 | logging.info('Generated sentence embeddings') 51 | 52 | config = {'nclasses': 2, 'seed': self.seed, 53 | 'usepytorch': params.usepytorch, 54 | 'classifier': params.classifier, 55 | 'nhid': params.nhid, 'kfold': params.kfold} 56 | clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config) 57 | devacc, testacc = clf.run() 58 | logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc)) 59 | return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples, 60 | 'ntest': self.n_samples} 61 | 62 | 63 | class CREval(BinaryClassifierEval): 64 | def __init__(self, task_path, seed=1111): 65 | logging.debug('***** Transfer task : CR *****\n\n') 66 | pos = self.loadFile(os.path.join(task_path, 'custrev.pos')) 67 | neg = self.loadFile(os.path.join(task_path, 'custrev.neg')) 68 | super(self.__class__, self).__init__(pos, neg, seed) 69 | 70 | 71 | class MREval(BinaryClassifierEval): 72 | def __init__(self, task_path, seed=1111): 73 | logging.debug('***** Transfer task : MR *****\n\n') 74 | pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos')) 75 | neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg')) 76 | super(self.__class__, self).__init__(pos, neg, seed) 77 | 78 | 79 | class SUBJEval(BinaryClassifierEval): 80 | def __init__(self, task_path, seed=1111): 81 | logging.debug('***** Transfer task : SUBJ *****\n\n') 82 | obj = self.loadFile(os.path.join(task_path, 'subj.objective')) 83 | subj = self.loadFile(os.path.join(task_path, 'subj.subjective')) 84 | super(self.__class__, self).__init__(obj, subj, seed) 85 | 86 | 87 | class MPQAEval(BinaryClassifierEval): 88 | def __init__(self, task_path, seed=1111): 89 | logging.debug('***** Transfer task : MPQA *****\n\n') 90 | pos = self.loadFile(os.path.join(task_path, 'mpqa.pos')) 91 | neg = self.loadFile(os.path.join(task_path, 'mpqa.neg')) 92 | super(self.__class__, self).__init__(pos, neg, seed) 93 | -------------------------------------------------------------------------------- /senteval/sst.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SST - binary classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import SplitClassifier 20 | 21 | 22 | class SSTEval(object): 23 | def __init__(self, task_path, nclasses=2, seed=1111): 24 | self.seed = seed 25 | 26 | # binary of fine-grained 27 | assert nclasses in [2, 5] 28 | self.nclasses = nclasses 29 | self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained' 30 | logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name) 31 | 32 | train = self.loadFile(os.path.join(task_path, 'sentiment-train')) 33 | dev = self.loadFile(os.path.join(task_path, 'sentiment-dev')) 34 | test = self.loadFile(os.path.join(task_path, 'sentiment-test')) 35 | self.sst_data = {'train': train, 'dev': dev, 'test': test} 36 | 37 | def do_prepare(self, params, prepare): 38 | samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \ 39 | self.sst_data['test']['X'] 40 | return prepare(params, samples) 41 | 42 | def loadFile(self, fpath): 43 | sst_data = {'X': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if self.nclasses == 2: 47 | sample = line.strip().split('\t') 48 | sst_data['y'].append(int(sample[1])) 49 | sst_data['X'].append(sample[0].split()) 50 | elif self.nclasses == 5: 51 | sample = line.strip().split(' ', 1) 52 | sst_data['y'].append(int(sample[0])) 53 | sst_data['X'].append(sample[1].split()) 54 | assert max(sst_data['y']) == self.nclasses - 1 55 | return sst_data 56 | 57 | def run(self, params, batcher): 58 | sst_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sst_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_data = sorted(zip(self.sst_data[key]['X'], 65 | self.sst_data[key]['y']), 66 | key=lambda z: (len(z[0]), z[1])) 67 | self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data)) 68 | 69 | sst_embed[key]['X'] = [] 70 | for ii in range(0, len(self.sst_data[key]['y']), bsize): 71 | batch = self.sst_data[key]['X'][ii:ii + bsize] 72 | embeddings = batcher(params, batch) 73 | sst_embed[key]['X'].append(embeddings) 74 | sst_embed[key]['X'] = np.vstack(sst_embed[key]['X']) 75 | sst_embed[key]['y'] = np.array(self.sst_data[key]['y']) 76 | logging.info('Computed {0} embeddings'.format(key)) 77 | 78 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 79 | 'usepytorch': params.usepytorch, 80 | 'classifier': params.classifier} 81 | 82 | clf = SplitClassifier(X={'train': sst_embed['train']['X'], 83 | 'valid': sst_embed['dev']['X'], 84 | 'test': sst_embed['test']['X']}, 85 | y={'train': sst_embed['train']['y'], 86 | 'valid': sst_embed['dev']['y'], 87 | 'test': sst_embed['test']['y']}, 88 | config=config_classifier) 89 | 90 | devacc, testacc = clf.run() 91 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 92 | SST {2} classification\n'.format(devacc, testacc, self.task_name)) 93 | 94 | return {'devacc': devacc, 'acc': testacc, 95 | 'ndev': len(sst_embed['dev']['X']), 96 | 'ntest': len(sst_embed['test']['X'])} 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Debiased Contrastive Learning of Unsupervised Sentence Representations 2 | 3 | This repository contains the code for our paper ***Debiased Contrastive Learning of Unsupervised Sentence Representations***. 4 | 5 | ## Overview 6 | 7 | We propose ***DCLR***, a debiased contrastive learning framework for unsupervised sentence representation learning. Based on SimCSE, we mainly consider two biases caused by the randomly negative sampling, namely the false negatives and the anistropy representation problem. For the two problems, we incorporate an instance weighting method and noise-based negatives to alleviate their influence during contrastive learning. 8 | 9 | ![](figure/model.png) 10 | 11 | ## Train DCLR 12 | 13 | In the following section, we describe how to train a DCLR model by using our code. 14 | 15 | ### Evaluation 16 | 17 | Our evaluation code for sentence embeddings is following the released code of [SimCSE](https://github.com/princeton-nlp/SimCSE), it is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation. 18 | 19 | Before evaluation, please download the evaluation datasets by running 20 | ```bash 21 | cd SentEval/data/downstream/ 22 | bash download_dataset.sh 23 | ``` 24 | 25 | ### Training 26 | 27 | **Environment** 28 | 29 | To faithfully reproduce our results, please use the correct `1.8.1` pytorch version corresponding to your platforms/CUDA versions. 30 | 31 | ```bash 32 | pip install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 33 | ``` 34 | 35 | 36 | Then run the following script to install the remaining dependencies, 37 | 38 | ```bash 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | **Data** 43 | 44 | We utilize the released data from SimCSE that samples 1 million sentences from English Wikipedia. You can run `data/download_wiki.sh` to download it. 45 | 46 | **Required Checkpoints from SimCSE** 47 | 48 | In our approach, we require to use a fixed SimCSE on BERT-base and RoBERTa-base as the complementary model for instance weighting. You can download their checkpoints from these links: [SimCSE-BERT-base](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) and [SimCSE-RoBERTa-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base). 49 | 50 | Besides, we also need the checkpoints of SimCSE on BERT-large and RoBERTa-large to initialize our model for stabilizing the training process. You can download them from these links: [SimCSE-BERT-large](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) and [SimCSE-RoBERTa-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large). 51 | 52 | **Training scripts** 53 | 54 | We provide the training scripts for BERT/RoBERTa-base/large and have set up the best hyperparameters for training. You can run it to automatically finish the training on BERT/RoBERTa-base/large backbone models. 55 | ```bash 56 | bash run.sh 57 | ``` 58 | 59 | For BERT/RoBERTa-base models, we provide a single-GPU (or CPU) example, and for BERT/RoBERTa-large models we give a **multiple-GPU** example. We explain some important arguments in following: 60 | * `--model_name_or_path`: Pre-trained checkpoints to start with. We support BERT-based models (`bert-base-uncased`, `bert-large-uncased`) and RoBERTa-based models (`RoBERTa-base`, `RoBERTa-large`). 61 | * `--c_model_name_or_path`: The checkpoints of Complementary model. We support SimCSE-BERT/RoBERTa-base models (`unsup-simcse-bert-base-uncased`, `unsup-simcse-roberta-base`). 62 | 63 | For results in the paper, we use 8 * Nvidia 3090 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance. 64 | 65 | **Hyperparameter Sensitivity** 66 | 67 | Note that the performance of DCLR is also sensitive to the environment and hyperparameter settings. If you get different performance, we suggest a necessary hyperparameter search about ***phi***, ***noise_times*** around our provided values. 68 | 69 | ## Citation 70 | 71 | Please cite our paper if you use DCLR in your work: 72 | 73 | ```bibtex 74 | @article{zhou2021dclr, 75 | title={Debiased Contrastive Learning of Unsupervised Sentence Representations}, 76 | author={Zhou, Kun and Zhang, Beichen and Zhao, Xin and Wen, Ji-Rong}, 77 | booktitle = {{ACL}}, 78 | year={2022} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /senteval/mrpc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | MRPC : Microsoft Research Paraphrase (detection) Corpus 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import logging 15 | import numpy as np 16 | import io 17 | 18 | from senteval.tools.validation import KFoldClassifier 19 | 20 | from sklearn.metrics import f1_score 21 | 22 | 23 | class MRPCEval(object): 24 | def __init__(self, task_path, seed=1111): 25 | logging.info('***** Transfer task : MRPC *****\n\n') 26 | self.seed = seed 27 | train = self.loadFile(os.path.join(task_path, 28 | 'msr_paraphrase_train.txt')) 29 | test = self.loadFile(os.path.join(task_path, 30 | 'msr_paraphrase_test.txt')) 31 | self.mrpc_data = {'train': train, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | # TODO : Should we separate samples in "train, test"? 35 | samples = self.mrpc_data['train']['X_A'] + \ 36 | self.mrpc_data['train']['X_B'] + \ 37 | self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B'] 38 | return prepare(params, samples) 39 | 40 | def loadFile(self, fpath): 41 | mrpc_data = {'X_A': [], 'X_B': [], 'y': []} 42 | with io.open(fpath, 'r', encoding='utf-8') as f: 43 | for line in f: 44 | text = line.strip().split('\t') 45 | mrpc_data['X_A'].append(text[3].split()) 46 | mrpc_data['X_B'].append(text[4].split()) 47 | mrpc_data['y'].append(text[0]) 48 | 49 | mrpc_data['X_A'] = mrpc_data['X_A'][1:] 50 | mrpc_data['X_B'] = mrpc_data['X_B'][1:] 51 | mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]] 52 | return mrpc_data 53 | 54 | def run(self, params, batcher): 55 | mrpc_embed = {'train': {}, 'test': {}} 56 | 57 | for key in self.mrpc_data: 58 | logging.info('Computing embedding for {0}'.format(key)) 59 | # Sort to reduce padding 60 | text_data = {} 61 | sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'], 62 | self.mrpc_data[key]['X_B'], 63 | self.mrpc_data[key]['y']), 64 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 65 | 66 | text_data['A'] = [x for (x, y, z) in sorted_corpus] 67 | text_data['B'] = [y for (x, y, z) in sorted_corpus] 68 | text_data['y'] = [z for (x, y, z) in sorted_corpus] 69 | 70 | for txt_type in ['A', 'B']: 71 | mrpc_embed[key][txt_type] = [] 72 | for ii in range(0, len(text_data['y']), params.batch_size): 73 | batch = text_data[txt_type][ii:ii + params.batch_size] 74 | embeddings = batcher(params, batch) 75 | mrpc_embed[key][txt_type].append(embeddings) 76 | mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type]) 77 | mrpc_embed[key]['y'] = np.array(text_data['y']) 78 | logging.info('Computed {0} embeddings'.format(key)) 79 | 80 | # Train 81 | trainA = mrpc_embed['train']['A'] 82 | trainB = mrpc_embed['train']['B'] 83 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 84 | trainY = mrpc_embed['train']['y'] 85 | 86 | # Test 87 | testA = mrpc_embed['test']['A'] 88 | testB = mrpc_embed['test']['B'] 89 | testF = np.c_[np.abs(testA - testB), testA * testB] 90 | testY = mrpc_embed['test']['y'] 91 | 92 | config = {'nclasses': 2, 'seed': self.seed, 93 | 'usepytorch': params.usepytorch, 94 | 'classifier': params.classifier, 95 | 'nhid': params.nhid, 'kfold': params.kfold} 96 | clf = KFoldClassifier(train={'X': trainF, 'y': trainY}, 97 | test={'X': testF, 'y': testY}, config=config) 98 | 99 | devacc, testacc, yhat = clf.run() 100 | testf1 = round(100*f1_score(testY, yhat), 2) 101 | logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n' 102 | .format(devacc, testacc, testf1)) 103 | return {'devacc': devacc, 'acc': testacc, 'f1': testf1, 104 | 'ndev': len(trainA), 'ntest': len(testA)} 105 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Start to training on BERT-base" 4 | CUDA_VISIBLE_DEVICES=0 python train.py \ 5 | --model_name_or_path bert-base-uncased \ 6 | --c_model_name_or_path unsup-simcse-bert-base-uncased \ 7 | --train_file data/wiki1m_for_simcse.txt \ 8 | --output_dir result/my-unsup-simcse-bert-base-uncased \ 9 | --num_train_epochs 3 \ 10 | --per_device_train_batch_size 128 \ 11 | --learning_rate 3e-5 \ 12 | --max_seq_length 32 \ 13 | --evaluation_strategy steps \ 14 | --metric_for_best_model avg_sts \ 15 | --load_best_model_at_end \ 16 | --eval_steps 125 \ 17 | --pooler_type cls \ 18 | --mlp_only_train \ 19 | --overwrite_output_dir \ 20 | --temp 0.05 \ 21 | --phi 0.9 \ 22 | --noise_times 1 \ 23 | --gradient_accumulation_steps 1 \ 24 | --do_train \ 25 | --do_eval \ 26 | --is_base True \ 27 | #--fp16 \ 28 | "$@" 29 | 30 | CUDA_VISIBLE_DEVICES=0 python evaluation.py \ 31 | --model_name_or_path result/my-unsup-simcse-bert-base-uncased \ 32 | --pooler cls_before_pooler \ 33 | --task_set sts \ 34 | --mode test 35 | 36 | echo "Start to training on RoBERTa-base" 37 | CUDA_VISIBLE_DEVICES=0 python train.py \ 38 | --model_name_or_path roberta-base \ 39 | --c_model_name_or_path unsup-simcse-roberta-base \ 40 | --train_file data/wiki1m_for_simcse.txt \ 41 | --output_dir result/my-unsup-simcse-roberta-base \ 42 | --num_train_epochs 3 \ 43 | --per_device_train_batch_size 128 \ 44 | --learning_rate 3e-5 \ 45 | --max_seq_length 32 \ 46 | --evaluation_strategy steps \ 47 | --metric_for_best_model avg_sts \ 48 | --load_best_model_at_end \ 49 | --eval_steps 125 \ 50 | --pooler_type cls \ 51 | --mlp_only_train \ 52 | --overwrite_output_dir \ 53 | --temp 0.05 \ 54 | --phi 0.85 \ 55 | --noise_times 2.5 \ 56 | --gradient_accumulation_steps 1 \ 57 | --do_train \ 58 | --do_eval \ 59 | --is_base True \ 60 | #--fp16 \ 61 | "$@" 62 | 63 | CUDA_VISIBLE_DEVICES=0 python evaluation.py \ 64 | --model_name_or_path result/my-unsup-simcse-roberta-base \ 65 | --pooler cls_before_pooler \ 66 | --task_set sts \ 67 | --mode test 68 | 69 | 70 | echo "Start to training on BERT-large" 71 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 55555 train.py \ 72 | --model_name_or_path unsup-simcse-bert-large-uncased \ 73 | --c_model_name_or_path unsup-simcse-bert-base-uncased \ 74 | --train_file data/wiki1m_for_simcse.txt \ 75 | --output_dir result/my-unsup-simcse-bert-large-uncased \ 76 | --num_train_epochs 1 \ 77 | --per_device_train_batch_size 64 \ 78 | --learning_rate 3e-5 \ 79 | --max_seq_length 32 \ 80 | --evaluation_strategy steps \ 81 | --metric_for_best_model avg_sts \ 82 | --load_best_model_at_end \ 83 | --eval_steps 125 \ 84 | --pooler_type cls \ 85 | --mlp_only_train \ 86 | --overwrite_output_dir \ 87 | --temp 0.05 \ 88 | --phi 0.9 \ 89 | --noise_times 3 \ 90 | --gradient_accumulation_steps 1 \ 91 | --do_train \ 92 | --do_eval \ 93 | --is_base False \ 94 | #--fp16 \ 95 | "$@" 96 | 97 | CUDA_VISIBLE_DEVICES=0 python evaluation.py \ 98 | --model_name_or_path result/my-unsup-simcse-bert-large-uncased \ 99 | --pooler cls_before_pooler \ 100 | --task_set sts \ 101 | --mode test 102 | 103 | echo "Start to training on RoBERTa-large" 104 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 55555 train.py \ 105 | --model_name_or_path unsup-simcse-roberta-large \ 106 | --c_model_name_or_path unsup-simcse-roberta-base \ 107 | --train_file data/wiki1m_for_simcse.txt \ 108 | --output_dir result/my-unsup-simcse-roberta-large \ 109 | --num_train_epochs 1 \ 110 | --per_device_train_batch_size 64 \ 111 | --learning_rate 1e-5 \ 112 | --max_seq_length 32 \ 113 | --evaluation_strategy steps \ 114 | --metric_for_best_model avg_sts \ 115 | --load_best_model_at_end \ 116 | --eval_steps 125 \ 117 | --pooler_type cls \ 118 | --mlp_only_train \ 119 | --overwrite_output_dir \ 120 | --temp 0.05 \ 121 | --phi 0.85 \ 122 | --noise_times 5 \ 123 | --gradient_accumulation_steps 1 \ 124 | --do_train \ 125 | --do_eval \ 126 | --is_base False \ 127 | #--fp16 \ 128 | "$@" 129 | 130 | CUDA_VISIBLE_DEVICES=0 python evaluation.py \ 131 | --model_name_or_path result/my-unsup-simcse-roberta-large \ 132 | --pooler cls_before_pooler \ 133 | --task_set sts \ 134 | --mode test 135 | -------------------------------------------------------------------------------- /senteval/snli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SNLI - Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import codecs 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class SNLIEval(object): 24 | def __init__(self, taskpath, seed=1111): 25 | logging.debug('***** Transfer task : SNLI Entailment*****\n\n') 26 | self.seed = seed 27 | train1 = self.loadFile(os.path.join(taskpath, 's1.train')) 28 | train2 = self.loadFile(os.path.join(taskpath, 's2.train')) 29 | 30 | trainlabels = io.open(os.path.join(taskpath, 'labels.train'), 31 | encoding='utf-8').read().splitlines() 32 | 33 | valid1 = self.loadFile(os.path.join(taskpath, 's1.dev')) 34 | valid2 = self.loadFile(os.path.join(taskpath, 's2.dev')) 35 | validlabels = io.open(os.path.join(taskpath, 'labels.dev'), 36 | encoding='utf-8').read().splitlines() 37 | 38 | test1 = self.loadFile(os.path.join(taskpath, 's1.test')) 39 | test2 = self.loadFile(os.path.join(taskpath, 's2.test')) 40 | testlabels = io.open(os.path.join(taskpath, 'labels.test'), 41 | encoding='utf-8').read().splitlines() 42 | 43 | # sort data (by s2 first) to reduce padding 44 | sorted_train = sorted(zip(train2, train1, trainlabels), 45 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 46 | train2, train1, trainlabels = map(list, zip(*sorted_train)) 47 | 48 | sorted_valid = sorted(zip(valid2, valid1, validlabels), 49 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 50 | valid2, valid1, validlabels = map(list, zip(*sorted_valid)) 51 | 52 | sorted_test = sorted(zip(test2, test1, testlabels), 53 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 54 | test2, test1, testlabels = map(list, zip(*sorted_test)) 55 | 56 | self.samples = train1 + train2 + valid1 + valid2 + test1 + test2 57 | self.data = {'train': (train1, train2, trainlabels), 58 | 'valid': (valid1, valid2, validlabels), 59 | 'test': (test1, test2, testlabels) 60 | } 61 | 62 | def do_prepare(self, params, prepare): 63 | return prepare(params, self.samples) 64 | 65 | def loadFile(self, fpath): 66 | with codecs.open(fpath, 'rb', 'latin-1') as f: 67 | return [line.split() for line in 68 | f.read().splitlines()] 69 | 70 | def run(self, params, batcher): 71 | self.X, self.y = {}, {} 72 | dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2} 73 | for key in self.data: 74 | if key not in self.X: 75 | self.X[key] = [] 76 | if key not in self.y: 77 | self.y[key] = [] 78 | 79 | input1, input2, mylabels = self.data[key] 80 | enc_input = [] 81 | n_labels = len(mylabels) 82 | for ii in range(0, n_labels, params.batch_size): 83 | batch1 = input1[ii:ii + params.batch_size] 84 | batch2 = input2[ii:ii + params.batch_size] 85 | 86 | if len(batch1) == len(batch2) and len(batch1) > 0: 87 | enc1 = batcher(params, batch1) 88 | enc2 = batcher(params, batch2) 89 | enc_input.append(np.hstack((enc1, enc2, enc1 * enc2, 90 | np.abs(enc1 - enc2)))) 91 | if (ii*params.batch_size) % (20000*params.batch_size) == 0: 92 | logging.info("PROGRESS (encoding): %.2f%%" % 93 | (100 * ii / n_labels)) 94 | self.X[key] = np.vstack(enc_input) 95 | self.y[key] = [dico_label[y] for y in mylabels] 96 | 97 | config = {'nclasses': 3, 'seed': self.seed, 98 | 'usepytorch': params.usepytorch, 99 | 'cudaEfficient': True, 100 | 'nhid': params.nhid, 'noreg': True} 101 | 102 | config_classifier = copy.deepcopy(params.classifier) 103 | config_classifier['max_epoch'] = 15 104 | config_classifier['epoch_size'] = 1 105 | config['classifier'] = config_classifier 106 | 107 | clf = SplitClassifier(self.X, self.y, config) 108 | devacc, testacc = clf.run() 109 | logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n' 110 | .format(devacc, testacc)) 111 | return {'devacc': devacc, 'acc': testacc, 112 | 'ndev': len(self.data['valid'][0]), 113 | 'ntest': len(self.data['test'][0])} 114 | -------------------------------------------------------------------------------- /senteval/rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Image-Caption Retrieval with COCO dataset 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import sys 15 | import logging 16 | import numpy as np 17 | 18 | try: 19 | import cPickle as pickle 20 | except ImportError: 21 | import pickle 22 | 23 | from senteval.tools.ranking import ImageSentenceRankingPytorch 24 | 25 | 26 | class ImageCaptionRetrievalEval(object): 27 | def __init__(self, task_path, seed=1111): 28 | logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n') 29 | 30 | # Get captions and image features 31 | self.seed = seed 32 | train, dev, test = self.loadFile(task_path) 33 | self.coco_data = {'train': train, 'dev': dev, 'test': test} 34 | 35 | def do_prepare(self, params, prepare): 36 | samples = self.coco_data['train']['sent'] + \ 37 | self.coco_data['dev']['sent'] + \ 38 | self.coco_data['test']['sent'] 39 | prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | coco = {} 43 | 44 | for split in ['train', 'valid', 'test']: 45 | list_sent = [] 46 | list_img_feat = [] 47 | if sys.version_info < (3, 0): 48 | with open(os.path.join(fpath, split + '.pkl')) as f: 49 | cocodata = pickle.load(f) 50 | else: 51 | with open(os.path.join(fpath, split + '.pkl'), 'rb') as f: 52 | cocodata = pickle.load(f, encoding='latin1') 53 | 54 | for imgkey in range(len(cocodata['features'])): 55 | assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \ 56 | cocodata['image_to_caption_ids'][imgkey] 57 | for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]: 58 | sent = cocodata['captions'][captkey]['cleaned_caption'] 59 | sent += ' .' # add punctuation to end of sentence in COCO 60 | list_sent.append(sent.encode('utf-8').split()) 61 | list_img_feat.append(cocodata['features'][imgkey]) 62 | assert len(list_sent) == len(list_img_feat) and \ 63 | len(list_sent) % 5 == 0 64 | list_img_feat = np.array(list_img_feat).astype('float32') 65 | coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat} 66 | return coco['train'], coco['valid'], coco['test'] 67 | 68 | def run(self, params, batcher): 69 | coco_embed = {'train': {'sentfeat': [], 'imgfeat': []}, 70 | 'dev': {'sentfeat': [], 'imgfeat': []}, 71 | 'test': {'sentfeat': [], 'imgfeat': []}} 72 | 73 | for key in self.coco_data: 74 | logging.info('Computing embedding for {0}'.format(key)) 75 | # Sort to reduce padding 76 | self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent']) 77 | self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent']) 78 | idx_unsort = np.argsort(idx_sort) 79 | 80 | coco_embed[key]['X'] = [] 81 | nsent = len(self.coco_data[key]['sent']) 82 | for ii in range(0, nsent, params.batch_size): 83 | batch = self.coco_data[key]['sent'][ii:ii + params.batch_size] 84 | embeddings = batcher(params, batch) 85 | coco_embed[key]['sentfeat'].append(embeddings) 86 | coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort] 87 | coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat']) 88 | logging.info('Computed {0} embeddings'.format(key)) 89 | 90 | config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2} 91 | clf = ImageSentenceRankingPytorch(train=coco_embed['train'], 92 | valid=coco_embed['dev'], 93 | test=coco_embed['test'], 94 | config=config) 95 | 96 | bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \ 97 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run() 98 | 99 | logging.debug("\nTest scores | Image to text: \ 100 | {0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 101 | logging.debug("Test scores | Text to image: \ 102 | {0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 103 | 104 | return {'devacc': bestdevscore, 105 | 'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t), 106 | (r1_t2i, r5_t2i, r10_t2i, medr_t2i)], 107 | 'ndev': len(coco_embed['dev']['sentfeat']), 108 | 'ntest': len(coco_embed['test']['sentfeat'])} 109 | -------------------------------------------------------------------------------- /senteval/tools/relatedness.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Semantic Relatedness (supervised) with Pytorch 10 | """ 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import copy 14 | import numpy as np 15 | 16 | import torch 17 | from torch import nn 18 | import torch.optim as optim 19 | 20 | from scipy.stats import pearsonr, spearmanr 21 | 22 | 23 | class RelatednessPytorch(object): 24 | # Can be used for SICK-Relatedness, and STS14 25 | def __init__(self, train, valid, test, devscores, config): 26 | # fix seed 27 | np.random.seed(config['seed']) 28 | torch.manual_seed(config['seed']) 29 | assert torch.cuda.is_available(), 'torch.cuda required for Relatedness' 30 | torch.cuda.manual_seed(config['seed']) 31 | 32 | self.train = train 33 | self.valid = valid 34 | self.test = test 35 | self.devscores = devscores 36 | 37 | self.inputdim = train['X'].shape[1] 38 | self.nclasses = config['nclasses'] 39 | self.seed = config['seed'] 40 | self.l2reg = 0. 41 | self.batch_size = 64 42 | self.maxepoch = 1000 43 | self.early_stop = True 44 | 45 | self.model = nn.Sequential( 46 | nn.Linear(self.inputdim, self.nclasses), 47 | nn.Softmax(dim=-1), 48 | ) 49 | self.loss_fn = nn.MSELoss() 50 | 51 | if torch.cuda.is_available(): 52 | self.model = self.model.cuda() 53 | self.loss_fn = self.loss_fn.cuda() 54 | 55 | self.loss_fn.size_average = False 56 | self.optimizer = optim.Adam(self.model.parameters(), 57 | weight_decay=self.l2reg) 58 | 59 | def prepare_data(self, trainX, trainy, devX, devy, testX, testy): 60 | # Transform probs to log-probs for KL-divergence 61 | trainX = torch.from_numpy(trainX).float().cuda() 62 | trainy = torch.from_numpy(trainy).float().cuda() 63 | devX = torch.from_numpy(devX).float().cuda() 64 | devy = torch.from_numpy(devy).float().cuda() 65 | testX = torch.from_numpy(testX).float().cuda() 66 | testY = torch.from_numpy(testy).float().cuda() 67 | 68 | return trainX, trainy, devX, devy, testX, testy 69 | 70 | def run(self): 71 | self.nepoch = 0 72 | bestpr = -1 73 | early_stop_count = 0 74 | r = np.arange(1, 6) 75 | stop_train = False 76 | 77 | # Preparing data 78 | trainX, trainy, devX, devy, testX, testy = self.prepare_data( 79 | self.train['X'], self.train['y'], 80 | self.valid['X'], self.valid['y'], 81 | self.test['X'], self.test['y']) 82 | 83 | # Training 84 | while not stop_train and self.nepoch <= self.maxepoch: 85 | self.trainepoch(trainX, trainy, nepoches=50) 86 | yhat = np.dot(self.predict_proba(devX), r) 87 | pr = spearmanr(yhat, self.devscores)[0] 88 | pr = 0 if pr != pr else pr # if NaN bc std=0 89 | # early stop on Pearson 90 | if pr > bestpr: 91 | bestpr = pr 92 | bestmodel = copy.deepcopy(self.model) 93 | elif self.early_stop: 94 | if early_stop_count >= 3: 95 | stop_train = True 96 | early_stop_count += 1 97 | self.model = bestmodel 98 | 99 | yhat = np.dot(self.predict_proba(testX), r) 100 | 101 | return bestpr, yhat 102 | 103 | def trainepoch(self, X, y, nepoches=1): 104 | self.model.train() 105 | for _ in range(self.nepoch, self.nepoch + nepoches): 106 | permutation = np.random.permutation(len(X)) 107 | all_costs = [] 108 | for i in range(0, len(X), self.batch_size): 109 | # forward 110 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda() 111 | Xbatch = X[idx] 112 | ybatch = y[idx] 113 | output = self.model(Xbatch) 114 | # loss 115 | loss = self.loss_fn(output, ybatch) 116 | all_costs.append(loss.item()) 117 | # backward 118 | self.optimizer.zero_grad() 119 | loss.backward() 120 | # Update parameters 121 | self.optimizer.step() 122 | self.nepoch += nepoches 123 | 124 | def predict_proba(self, devX): 125 | self.model.eval() 126 | probas = [] 127 | with torch.no_grad(): 128 | for i in range(0, len(devX), self.batch_size): 129 | Xbatch = devX[i:i + self.batch_size] 130 | if len(probas) == 0: 131 | probas = self.model(Xbatch).data.cpu().numpy() 132 | else: 133 | probas = np.concatenate((probas, self.model(Xbatch).data.cpu().numpy()), axis=0) 134 | return probas 135 | -------------------------------------------------------------------------------- /senteval/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | 10 | Generic sentence evaluation scripts wrapper 11 | 12 | ''' 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | from senteval import utils 16 | from senteval.binary import CREval, MREval, MPQAEval, SUBJEval 17 | from senteval.snli import SNLIEval 18 | from senteval.trec import TRECEval 19 | from senteval.sick import SICKEntailmentEval, SICKEval 20 | from senteval.mrpc import MRPCEval 21 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune 22 | from senteval.sst import SSTEval 23 | from senteval.rank import ImageCaptionRetrievalEval 24 | from senteval.probing import * 25 | 26 | class SE(object): 27 | def __init__(self, params, batcher, prepare=None): 28 | # parameters 29 | params = utils.dotdict(params) 30 | params.usepytorch = True if 'usepytorch' not in params else params.usepytorch 31 | params.seed = 1111 if 'seed' not in params else params.seed 32 | 33 | params.batch_size = 128 if 'batch_size' not in params else params.batch_size 34 | params.nhid = 0 if 'nhid' not in params else params.nhid 35 | params.kfold = 5 if 'kfold' not in params else params.kfold 36 | 37 | if 'classifier' not in params or not params['classifier']: 38 | params.classifier = {'nhid': 0} 39 | 40 | assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!' 41 | 42 | self.params = params 43 | 44 | # batcher and prepare 45 | self.batcher = batcher 46 | self.prepare = prepare if prepare else lambda x, y: None 47 | 48 | self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 49 | 'SICKRelatedness', 'SICKEntailment', 'STSBenchmark', 50 | 'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13', 51 | 'STS14', 'STS15', 'STS16', 52 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 53 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 54 | 'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix'] 55 | 56 | def eval(self, name): 57 | # evaluate on evaluation [name], either takes string or list of strings 58 | if (isinstance(name, list)): 59 | self.results = {x: self.eval(x) for x in name} 60 | return self.results 61 | 62 | tpath = self.params.task_path 63 | assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks) 64 | 65 | # Original SentEval tasks 66 | if name == 'CR': 67 | self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed) 68 | elif name == 'MR': 69 | self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed) 70 | elif name == 'MPQA': 71 | self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed) 72 | elif name == 'SUBJ': 73 | self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed) 74 | elif name == 'SST2': 75 | self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed) 76 | elif name == 'SST5': 77 | self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed) 78 | elif name == 'TREC': 79 | self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed) 80 | elif name == 'MRPC': 81 | self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed) 82 | elif name == 'SICKRelatedness': 83 | self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed) 84 | elif name == 'STSBenchmark': 85 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 86 | elif name == 'STSBenchmark-fix': 87 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed) 88 | elif name == 'STSBenchmark-finetune': 89 | self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 90 | elif name == 'SICKRelatedness-finetune': 91 | self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed) 92 | elif name == 'SICKEntailment': 93 | self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed) 94 | elif name == 'SNLI': 95 | self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed) 96 | elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 97 | fpath = name + '-en-test' 98 | self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed) 99 | elif name == 'ImageCaptionRetrieval': 100 | self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed) 101 | 102 | # Probing Tasks 103 | elif name == 'Length': 104 | self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed) 105 | elif name == 'WordContent': 106 | self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed) 107 | elif name == 'Depth': 108 | self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed) 109 | elif name == 'TopConstituents': 110 | self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed) 111 | elif name == 'BigramShift': 112 | self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed) 113 | elif name == 'Tense': 114 | self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed) 115 | elif name == 'SubjNumber': 116 | self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed) 117 | elif name == 'ObjNumber': 118 | self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed) 119 | elif name == 'OddManOut': 120 | self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed) 121 | elif name == 'CoordinationInversion': 122 | self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed) 123 | 124 | self.params.current_task = name 125 | self.evaluation.do_prepare(self.params, self.prepare) 126 | 127 | self.results = self.evaluation.run(self.params, self.batcher) 128 | 129 | return self.results 130 | -------------------------------------------------------------------------------- /senteval/probing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | probing tasks 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class PROBINGEval(object): 24 | def __init__(self, task, task_path, seed=1111): 25 | self.seed = seed 26 | self.task = task 27 | logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper()) 28 | self.task_data = {'train': {'X': [], 'y': []}, 29 | 'dev': {'X': [], 'y': []}, 30 | 'test': {'X': [], 'y': []}} 31 | self.loadFile(task_path) 32 | logging.info('Loaded %s train - %s dev - %s test for %s' % 33 | (len(self.task_data['train']['y']), len(self.task_data['dev']['y']), 34 | len(self.task_data['test']['y']), self.task)) 35 | 36 | def do_prepare(self, params, prepare): 37 | samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \ 38 | self.task_data['test']['X'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'} 43 | with io.open(fpath, 'r', encoding='utf-8') as f: 44 | for line in f: 45 | line = line.rstrip().split('\t') 46 | self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split()) 47 | self.task_data[self.tok2split[line[0]]]['y'].append(line[1]) 48 | 49 | labels = sorted(np.unique(self.task_data['train']['y'])) 50 | self.tok2label = dict(zip(labels, range(len(labels)))) 51 | self.nclasses = len(self.tok2label) 52 | 53 | for split in self.task_data: 54 | for i, y in enumerate(self.task_data[split]['y']): 55 | self.task_data[split]['y'][i] = self.tok2label[y] 56 | 57 | def run(self, params, batcher): 58 | task_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | logging.info('Computing embeddings for train/dev/test') 61 | for key in self.task_data: 62 | # Sort to reduce padding 63 | sorted_data = sorted(zip(self.task_data[key]['X'], 64 | self.task_data[key]['y']), 65 | key=lambda z: (len(z[0]), z[1])) 66 | self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data)) 67 | 68 | task_embed[key]['X'] = [] 69 | for ii in range(0, len(self.task_data[key]['y']), bsize): 70 | batch = self.task_data[key]['X'][ii:ii + bsize] 71 | embeddings = batcher(params, batch) 72 | task_embed[key]['X'].append(embeddings) 73 | task_embed[key]['X'] = np.vstack(task_embed[key]['X']) 74 | task_embed[key]['y'] = np.array(self.task_data[key]['y']) 75 | logging.info('Computed embeddings') 76 | 77 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 78 | 'usepytorch': params.usepytorch, 79 | 'classifier': params.classifier} 80 | 81 | if self.task == "WordContent" and params.classifier['nhid'] > 0: 82 | config_classifier = copy.deepcopy(config_classifier) 83 | config_classifier['classifier']['nhid'] = 0 84 | print(params.classifier['nhid']) 85 | 86 | clf = SplitClassifier(X={'train': task_embed['train']['X'], 87 | 'valid': task_embed['dev']['X'], 88 | 'test': task_embed['test']['X']}, 89 | y={'train': task_embed['train']['y'], 90 | 'valid': task_embed['dev']['y'], 91 | 'test': task_embed['test']['y']}, 92 | config=config_classifier) 93 | 94 | devacc, testacc = clf.run() 95 | logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper())) 96 | 97 | return {'devacc': devacc, 'acc': testacc, 98 | 'ndev': len(task_embed['dev']['X']), 99 | 'ntest': len(task_embed['test']['X'])} 100 | 101 | """ 102 | Surface Information 103 | """ 104 | class LengthEval(PROBINGEval): 105 | def __init__(self, task_path, seed=1111): 106 | task_path = os.path.join(task_path, 'sentence_length.txt') 107 | # labels: bins 108 | PROBINGEval.__init__(self, 'Length', task_path, seed) 109 | 110 | class WordContentEval(PROBINGEval): 111 | def __init__(self, task_path, seed=1111): 112 | task_path = os.path.join(task_path, 'word_content.txt') 113 | # labels: 200 target words 114 | PROBINGEval.__init__(self, 'WordContent', task_path, seed) 115 | 116 | """ 117 | Latent Structural Information 118 | """ 119 | class DepthEval(PROBINGEval): 120 | def __init__(self, task_path, seed=1111): 121 | task_path = os.path.join(task_path, 'tree_depth.txt') 122 | # labels: bins 123 | PROBINGEval.__init__(self, 'Depth', task_path, seed) 124 | 125 | class TopConstituentsEval(PROBINGEval): 126 | def __init__(self, task_path, seed=1111): 127 | task_path = os.path.join(task_path, 'top_constituents.txt') 128 | # labels: 'PP_NP_VP_.' .. (20 classes) 129 | PROBINGEval.__init__(self, 'TopConstituents', task_path, seed) 130 | 131 | class BigramShiftEval(PROBINGEval): 132 | def __init__(self, task_path, seed=1111): 133 | task_path = os.path.join(task_path, 'bigram_shift.txt') 134 | # labels: 0 or 1 135 | PROBINGEval.__init__(self, 'BigramShift', task_path, seed) 136 | 137 | # TODO: Voice? 138 | 139 | """ 140 | Latent Semantic Information 141 | """ 142 | 143 | class TenseEval(PROBINGEval): 144 | def __init__(self, task_path, seed=1111): 145 | task_path = os.path.join(task_path, 'past_present.txt') 146 | # labels: 'PRES', 'PAST' 147 | PROBINGEval.__init__(self, 'Tense', task_path, seed) 148 | 149 | class SubjNumberEval(PROBINGEval): 150 | def __init__(self, task_path, seed=1111): 151 | task_path = os.path.join(task_path, 'subj_number.txt') 152 | # labels: 'NN', 'NNS' 153 | PROBINGEval.__init__(self, 'SubjNumber', task_path, seed) 154 | 155 | class ObjNumberEval(PROBINGEval): 156 | def __init__(self, task_path, seed=1111): 157 | task_path = os.path.join(task_path, 'obj_number.txt') 158 | # labels: 'NN', 'NNS' 159 | PROBINGEval.__init__(self, 'ObjNumber', task_path, seed) 160 | 161 | class OddManOutEval(PROBINGEval): 162 | def __init__(self, task_path, seed=1111): 163 | task_path = os.path.join(task_path, 'odd_man_out.txt') 164 | # labels: 'O', 'C' 165 | PROBINGEval.__init__(self, 'OddManOut', task_path, seed) 166 | 167 | class CoordinationInversionEval(PROBINGEval): 168 | def __init__(self, task_path, seed=1111): 169 | task_path = os.path.join(task_path, 'coordination_inversion.txt') 170 | # labels: 'O', 'I' 171 | PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed) 172 | -------------------------------------------------------------------------------- /senteval/tools/classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Pytorch Classifier class in the style of scikit-learn 10 | Classifiers include Logistic Regression and MLP 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import numpy as np 16 | import copy 17 | from senteval import utils 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | 23 | 24 | class PyTorchClassifier(object): 25 | def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111, 26 | cudaEfficient=False): 27 | # fix seed 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | 32 | self.inputdim = inputdim 33 | self.nclasses = nclasses 34 | self.l2reg = l2reg 35 | self.batch_size = batch_size 36 | self.cudaEfficient = cudaEfficient 37 | 38 | def prepare_split(self, X, y, validation_data=None, validation_split=None): 39 | # Preparing validation data 40 | assert validation_split or validation_data 41 | if validation_data is not None: 42 | trainX, trainy = X, y 43 | devX, devy = validation_data 44 | else: 45 | permutation = np.random.permutation(len(X)) 46 | trainidx = permutation[int(validation_split * len(X)):] 47 | devidx = permutation[0:int(validation_split * len(X))] 48 | trainX, trainy = X[trainidx], y[trainidx] 49 | devX, devy = X[devidx], y[devidx] 50 | 51 | device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda') 52 | 53 | trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32) 54 | trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64) 55 | devX = torch.from_numpy(devX).to(device, dtype=torch.float32) 56 | devy = torch.from_numpy(devy).to(device, dtype=torch.int64) 57 | 58 | return trainX, trainy, devX, devy 59 | 60 | def fit(self, X, y, validation_data=None, validation_split=None, 61 | early_stop=True): 62 | self.nepoch = 0 63 | bestaccuracy = -1 64 | stop_train = False 65 | early_stop_count = 0 66 | 67 | # Preparing validation data 68 | trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data, 69 | validation_split) 70 | 71 | # Training 72 | while not stop_train and self.nepoch <= self.max_epoch: 73 | self.trainepoch(trainX, trainy, epoch_size=self.epoch_size) 74 | accuracy = self.score(devX, devy) 75 | if accuracy > bestaccuracy: 76 | bestaccuracy = accuracy 77 | bestmodel = copy.deepcopy(self.model) 78 | elif early_stop: 79 | if early_stop_count >= self.tenacity: 80 | stop_train = True 81 | early_stop_count += 1 82 | self.model = bestmodel 83 | return bestaccuracy 84 | 85 | def trainepoch(self, X, y, epoch_size=1): 86 | self.model.train() 87 | for _ in range(self.nepoch, self.nepoch + epoch_size): 88 | permutation = np.random.permutation(len(X)) 89 | all_costs = [] 90 | for i in range(0, len(X), self.batch_size): 91 | # forward 92 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device) 93 | 94 | Xbatch = X[idx] 95 | ybatch = y[idx] 96 | 97 | if self.cudaEfficient: 98 | Xbatch = Xbatch.cuda() 99 | ybatch = ybatch.cuda() 100 | output = self.model(Xbatch) 101 | # loss 102 | loss = self.loss_fn(output, ybatch) 103 | all_costs.append(loss.data.item()) 104 | # backward 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | # Update parameters 108 | self.optimizer.step() 109 | self.nepoch += epoch_size 110 | 111 | def score(self, devX, devy): 112 | self.model.eval() 113 | correct = 0 114 | if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient: 115 | devX = torch.FloatTensor(devX).cuda() 116 | devy = torch.LongTensor(devy).cuda() 117 | with torch.no_grad(): 118 | for i in range(0, len(devX), self.batch_size): 119 | Xbatch = devX[i:i + self.batch_size] 120 | ybatch = devy[i:i + self.batch_size] 121 | if self.cudaEfficient: 122 | Xbatch = Xbatch.cuda() 123 | ybatch = ybatch.cuda() 124 | output = self.model(Xbatch) 125 | pred = output.data.max(1)[1] 126 | correct += pred.long().eq(ybatch.data.long()).sum().item() 127 | accuracy = 1.0 * correct / len(devX) 128 | return accuracy 129 | 130 | def predict(self, devX): 131 | self.model.eval() 132 | if not isinstance(devX, torch.cuda.FloatTensor): 133 | devX = torch.FloatTensor(devX).cuda() 134 | yhat = np.array([]) 135 | with torch.no_grad(): 136 | for i in range(0, len(devX), self.batch_size): 137 | Xbatch = devX[i:i + self.batch_size] 138 | output = self.model(Xbatch) 139 | yhat = np.append(yhat, 140 | output.data.max(1)[1].cpu().numpy()) 141 | yhat = np.vstack(yhat) 142 | return yhat 143 | 144 | def predict_proba(self, devX): 145 | self.model.eval() 146 | probas = [] 147 | with torch.no_grad(): 148 | for i in range(0, len(devX), self.batch_size): 149 | Xbatch = devX[i:i + self.batch_size] 150 | vals = F.softmax(self.model(Xbatch).data.cpu().numpy()) 151 | if not probas: 152 | probas = vals 153 | else: 154 | probas = np.concatenate(probas, vals, axis=0) 155 | return probas 156 | 157 | 158 | """ 159 | MLP with Pytorch (nhid=0 --> Logistic Regression) 160 | """ 161 | 162 | class MLP(PyTorchClassifier): 163 | def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64, 164 | seed=1111, cudaEfficient=False): 165 | super(self.__class__, self).__init__(inputdim, nclasses, l2reg, 166 | batch_size, seed, cudaEfficient) 167 | """ 168 | PARAMETERS: 169 | -nhid: number of hidden units (0: Logistic Regression) 170 | -optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 171 | -tenacity: how many times dev acc does not increase before stopping 172 | -epoch_size: each epoch corresponds to epoch_size pass on the train set 173 | -max_epoch: max number of epoches 174 | -dropout: dropout for MLP 175 | """ 176 | 177 | self.nhid = 0 if "nhid" not in params else params["nhid"] 178 | self.optim = "adam" if "optim" not in params else params["optim"] 179 | self.tenacity = 5 if "tenacity" not in params else params["tenacity"] 180 | self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"] 181 | self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"] 182 | self.dropout = 0. if "dropout" not in params else params["dropout"] 183 | self.batch_size = 64 if "batch_size" not in params else params["batch_size"] 184 | 185 | if params["nhid"] == 0: 186 | self.model = nn.Sequential( 187 | nn.Linear(self.inputdim, self.nclasses), 188 | ).cuda() 189 | else: 190 | self.model = nn.Sequential( 191 | nn.Linear(self.inputdim, params["nhid"]), 192 | nn.Dropout(p=self.dropout), 193 | nn.Sigmoid(), 194 | nn.Linear(params["nhid"], self.nclasses), 195 | ).cuda() 196 | 197 | self.loss_fn = nn.CrossEntropyLoss().cuda() 198 | self.loss_fn.size_average = False 199 | 200 | optim_fn, optim_params = utils.get_optimizer(self.optim) 201 | self.optimizer = optim_fn(self.model.parameters(), **optim_params) 202 | self.optimizer.param_groups[0]['weight_decay'] = self.l2reg 203 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io, os 3 | import numpy as np 4 | import logging 5 | import argparse 6 | from prettytable import PrettyTable 7 | import torch 8 | import transformers 9 | from transformers import AutoModel, AutoTokenizer 10 | 11 | # Set up logger 12 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 13 | 14 | # Set PATHs 15 | PATH_TO_SENTEVAL = './SentEval' 16 | PATH_TO_DATA = './SentEval/data' 17 | 18 | # Import SentEval 19 | sys.path.insert(0, PATH_TO_SENTEVAL) 20 | import senteval 21 | 22 | def print_table(task_names, scores): 23 | tb = PrettyTable() 24 | tb.field_names = task_names 25 | tb.add_row(scores) 26 | print(tb) 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--model_name_or_path", type=str, 31 | help="Transformers' model name or path") 32 | parser.add_argument("--pooler", type=str, 33 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 34 | default='cls', 35 | help="Which pooler to use") 36 | parser.add_argument("--mode", type=str, 37 | choices=['dev', 'test', 'fasttest'], 38 | default='test', 39 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 40 | parser.add_argument("--task_set", type=str, 41 | choices=['sts', 'transfer', 'full', 'na'], 42 | default='sts', 43 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 44 | parser.add_argument("--tasks", type=str, nargs='+', 45 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 46 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC', 47 | 'SICKRelatedness', 'STSBenchmark'], 48 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden") 49 | 50 | args = parser.parse_args() 51 | 52 | # Load transformers' model checkpoint 53 | model = AutoModel.from_pretrained(args.model_name_or_path) 54 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 56 | model = model.to(device) 57 | 58 | # Set up the tasks 59 | if args.task_set == 'sts': 60 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 61 | elif args.task_set == 'transfer': 62 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 63 | elif args.task_set == 'full': 64 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 65 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 66 | 67 | # Set params for SentEval 68 | if args.mode == 'dev' or args.mode == 'fasttest': 69 | # Fast mode 70 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 71 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 72 | 'tenacity': 3, 'epoch_size': 2} 73 | elif args.mode == 'test': 74 | # Full mode 75 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 76 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 77 | 'tenacity': 5, 'epoch_size': 4} 78 | else: 79 | raise NotImplementedError 80 | 81 | # SentEval prepare and batcher 82 | def prepare(params, samples): 83 | return 84 | 85 | def batcher(params, batch, max_length=None): 86 | # Handle rare token encoding issues in the dataset 87 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 88 | batch = [[word.decode('utf-8') for word in s] for s in batch] 89 | 90 | sentences = [' '.join(s) for s in batch] 91 | 92 | # Tokenization 93 | if max_length is not None: 94 | batch = tokenizer.batch_encode_plus( 95 | sentences, 96 | return_tensors='pt', 97 | padding=True, 98 | max_length=max_length, 99 | truncation=True 100 | ) 101 | else: 102 | batch = tokenizer.batch_encode_plus( 103 | sentences, 104 | return_tensors='pt', 105 | padding=True, 106 | ) 107 | 108 | # Move to the correct device 109 | for k in batch: 110 | batch[k] = batch[k].to(device) 111 | 112 | # Get raw embeddings 113 | with torch.no_grad(): 114 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 115 | last_hidden = outputs.last_hidden_state 116 | pooler_output = outputs.pooler_output 117 | hidden_states = outputs.hidden_states 118 | 119 | # Apply different poolers 120 | if args.pooler == 'cls': 121 | # There is a linear+activation layer after CLS representation 122 | return pooler_output.cpu() 123 | elif args.pooler == 'cls_before_pooler': 124 | return last_hidden[:, 0].cpu() 125 | elif args.pooler == "avg": 126 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu() 127 | elif args.pooler == "avg_first_last": 128 | first_hidden = hidden_states[0] 129 | last_hidden = hidden_states[-1] 130 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 131 | return pooled_result.cpu() 132 | elif args.pooler == "avg_top2": 133 | second_last_hidden = hidden_states[-2] 134 | last_hidden = hidden_states[-1] 135 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 136 | return pooled_result.cpu() 137 | else: 138 | raise NotImplementedError 139 | 140 | results = {} 141 | 142 | for task in args.tasks: 143 | se = senteval.engine.SE(params, batcher, prepare) 144 | result = se.eval(task) 145 | results[task] = result 146 | 147 | # Print evaluation results 148 | if args.mode == 'dev': 149 | print("------ %s ------" % (args.mode)) 150 | 151 | task_names = [] 152 | scores = [] 153 | for task in ['STSBenchmark', 'SICKRelatedness']: 154 | task_names.append(task) 155 | if task in results: 156 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100)) 157 | else: 158 | scores.append("0.00") 159 | print_table(task_names, scores) 160 | 161 | task_names = [] 162 | scores = [] 163 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 164 | task_names.append(task) 165 | if task in results: 166 | scores.append("%.2f" % (results[task]['devacc'])) 167 | else: 168 | scores.append("0.00") 169 | task_names.append("Avg.") 170 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 171 | print_table(task_names, scores) 172 | 173 | elif args.mode == 'test' or args.mode == 'fasttest': 174 | print("------ %s ------" % (args.mode)) 175 | 176 | task_names = [] 177 | scores = [] 178 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 179 | task_names.append(task) 180 | if task in results: 181 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 182 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 183 | else: 184 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 185 | else: 186 | scores.append("0.00") 187 | task_names.append("Avg.") 188 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 189 | print_table(task_names, scores) 190 | 191 | task_names = [] 192 | scores = [] 193 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 194 | task_names.append(task) 195 | if task in results: 196 | scores.append("%.2f" % (results[task]['devacc'])) 197 | else: 198 | scores.append("0.00") 199 | task_names.append("Avg.") 200 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 201 | print_table(task_names, scores) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /senteval/sick.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SICK Relatedness and Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import io 15 | import logging 16 | import numpy as np 17 | 18 | from sklearn.metrics import mean_squared_error 19 | from scipy.stats import pearsonr, spearmanr 20 | 21 | from senteval.tools.relatedness import RelatednessPytorch 22 | from senteval.tools.validation import SplitClassifier 23 | 24 | class SICKEval(object): 25 | def __init__(self, task_path, seed=1111): 26 | logging.debug('***** Transfer task : SICK-Relatedness*****\n\n') 27 | self.seed = seed 28 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 29 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 30 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 31 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | samples = self.sick_data['train']['X_A'] + \ 35 | self.sick_data['train']['X_B'] + \ 36 | self.sick_data['dev']['X_A'] + \ 37 | self.sick_data['dev']['X_B'] + \ 38 | self.sick_data['test']['X_A'] + self.sick_data['test']['X_B'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | skipFirstLine = True 43 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if skipFirstLine: 47 | skipFirstLine = False 48 | else: 49 | text = line.strip().split('\t') 50 | sick_data['X_A'].append(text[1].split()) 51 | sick_data['X_B'].append(text[2].split()) 52 | sick_data['y'].append(text[3]) 53 | 54 | sick_data['y'] = [float(s) for s in sick_data['y']] 55 | return sick_data 56 | 57 | def run(self, params, batcher): 58 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sick_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 65 | self.sick_data[key]['X_B'], 66 | self.sick_data[key]['y']), 67 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 68 | 69 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 70 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 71 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 72 | 73 | for txt_type in ['X_A', 'X_B']: 74 | sick_embed[key][txt_type] = [] 75 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 76 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 77 | embeddings = batcher(params, batch) 78 | sick_embed[key][txt_type].append(embeddings) 79 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 80 | sick_embed[key]['y'] = np.array(self.sick_data[key]['y']) 81 | logging.info('Computed {0} embeddings'.format(key)) 82 | 83 | # Train 84 | trainA = sick_embed['train']['X_A'] 85 | trainB = sick_embed['train']['X_B'] 86 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 87 | trainY = self.encode_labels(self.sick_data['train']['y']) 88 | 89 | # Dev 90 | devA = sick_embed['dev']['X_A'] 91 | devB = sick_embed['dev']['X_B'] 92 | devF = np.c_[np.abs(devA - devB), devA * devB] 93 | devY = self.encode_labels(self.sick_data['dev']['y']) 94 | 95 | # Test 96 | testA = sick_embed['test']['X_A'] 97 | testB = sick_embed['test']['X_B'] 98 | testF = np.c_[np.abs(testA - testB), testA * testB] 99 | testY = self.encode_labels(self.sick_data['test']['y']) 100 | 101 | config = {'seed': self.seed, 'nclasses': 5} 102 | clf = RelatednessPytorch(train={'X': trainF, 'y': trainY}, 103 | valid={'X': devF, 'y': devY}, 104 | test={'X': testF, 'y': testY}, 105 | devscores=self.sick_data['dev']['y'], 106 | config=config) 107 | 108 | devspr, yhat = clf.run() 109 | 110 | pr = pearsonr(yhat, self.sick_data['test']['y'])[0] 111 | sr = spearmanr(yhat, self.sick_data['test']['y'])[0] 112 | pr = 0 if pr != pr else pr 113 | sr = 0 if sr != sr else sr 114 | se = mean_squared_error(yhat, self.sick_data['test']['y']) 115 | logging.debug('Dev : Spearman {0}'.format(devspr)) 116 | logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \ 117 | for SICK Relatedness\n'.format(pr, sr, se)) 118 | 119 | return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se, 120 | 'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)} 121 | 122 | def encode_labels(self, labels, nclass=5): 123 | """ 124 | Label encoding from Tree LSTM paper (Tai, Socher, Manning) 125 | """ 126 | Y = np.zeros((len(labels), nclass)).astype('float32') 127 | for j, y in enumerate(labels): 128 | for i in range(nclass): 129 | if i+1 == np.floor(y) + 1: 130 | Y[j, i] = y - np.floor(y) 131 | if i+1 == np.floor(y): 132 | Y[j, i] = np.floor(y) - y + 1 133 | return Y 134 | 135 | 136 | class SICKEntailmentEval(SICKEval): 137 | def __init__(self, task_path, seed=1111): 138 | logging.debug('***** Transfer task : SICK-Entailment*****\n\n') 139 | self.seed = seed 140 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 141 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 142 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 143 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 144 | 145 | def loadFile(self, fpath): 146 | label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2} 147 | skipFirstLine = True 148 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 149 | with io.open(fpath, 'r', encoding='utf-8') as f: 150 | for line in f: 151 | if skipFirstLine: 152 | skipFirstLine = False 153 | else: 154 | text = line.strip().split('\t') 155 | sick_data['X_A'].append(text[1].split()) 156 | sick_data['X_B'].append(text[2].split()) 157 | sick_data['y'].append(text[4]) 158 | sick_data['y'] = [label2id[s] for s in sick_data['y']] 159 | return sick_data 160 | 161 | def run(self, params, batcher): 162 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 163 | bsize = params.batch_size 164 | 165 | for key in self.sick_data: 166 | logging.info('Computing embedding for {0}'.format(key)) 167 | # Sort to reduce padding 168 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 169 | self.sick_data[key]['X_B'], 170 | self.sick_data[key]['y']), 171 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 172 | 173 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 174 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 175 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 176 | 177 | for txt_type in ['X_A', 'X_B']: 178 | sick_embed[key][txt_type] = [] 179 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 180 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 181 | embeddings = batcher(params, batch) 182 | sick_embed[key][txt_type].append(embeddings) 183 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 184 | logging.info('Computed {0} embeddings'.format(key)) 185 | 186 | # Train 187 | trainA = sick_embed['train']['X_A'] 188 | trainB = sick_embed['train']['X_B'] 189 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 190 | trainY = np.array(self.sick_data['train']['y']) 191 | 192 | # Dev 193 | devA = sick_embed['dev']['X_A'] 194 | devB = sick_embed['dev']['X_B'] 195 | devF = np.c_[np.abs(devA - devB), devA * devB] 196 | devY = np.array(self.sick_data['dev']['y']) 197 | 198 | # Test 199 | testA = sick_embed['test']['X_A'] 200 | testB = sick_embed['test']['X_B'] 201 | testF = np.c_[np.abs(testA - testB), testA * testB] 202 | testY = np.array(self.sick_data['test']['y']) 203 | 204 | config = {'nclasses': 3, 'seed': self.seed, 205 | 'usepytorch': params.usepytorch, 206 | 'classifier': params.classifier, 207 | 'nhid': params.nhid} 208 | clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF}, 209 | y={'train': trainY, 'valid': devY, 'test': testY}, 210 | config=config) 211 | 212 | devacc, testacc = clf.run() 213 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 214 | SICK entailment\n'.format(devacc, testacc)) 215 | return {'devacc': devacc, 'acc': testacc, 216 | 'ndev': len(devA), 'ntest': len(testA)} 217 | -------------------------------------------------------------------------------- /senteval/tools/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Validation and classification 10 | (train) : inner-kfold classifier 11 | (train, test) : kfold classifier 12 | (train, dev, test) : split classifier 13 | 14 | """ 15 | from __future__ import absolute_import, division, unicode_literals 16 | 17 | import logging 18 | import numpy as np 19 | from senteval.tools.classifier import MLP 20 | 21 | import sklearn 22 | assert(sklearn.__version__ >= "0.18.0"), \ 23 | "need to update sklearn to version >= 0.18.0" 24 | from sklearn.linear_model import LogisticRegression 25 | from sklearn.model_selection import StratifiedKFold 26 | 27 | 28 | def get_classif_name(classifier_config, usepytorch): 29 | if not usepytorch: 30 | modelname = 'sklearn-LogReg' 31 | else: 32 | nhid = classifier_config['nhid'] 33 | optim = 'adam' if 'optim' not in classifier_config else classifier_config['optim'] 34 | bs = 64 if 'batch_size' not in classifier_config else classifier_config['batch_size'] 35 | modelname = 'pytorch-MLP-nhid%s-%s-bs%s' % (nhid, optim, bs) 36 | return modelname 37 | 38 | # Pytorch version 39 | class InnerKFoldClassifier(object): 40 | """ 41 | (train) split classifier : InnerKfold. 42 | """ 43 | def __init__(self, X, y, config): 44 | self.X = X 45 | self.y = y 46 | self.featdim = X.shape[1] 47 | self.nclasses = config['nclasses'] 48 | self.seed = config['seed'] 49 | self.devresults = [] 50 | self.testresults = [] 51 | self.usepytorch = config['usepytorch'] 52 | self.classifier_config = config['classifier'] 53 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 54 | 55 | self.k = 5 if 'kfold' not in config else config['kfold'] 56 | 57 | def run(self): 58 | logging.info('Training {0} with (inner) {1}-fold cross-validation' 59 | .format(self.modelname, self.k)) 60 | 61 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 62 | [2**t for t in range(-2, 4, 1)] 63 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=1111) 64 | innerskf = StratifiedKFold(n_splits=self.k, shuffle=True, 65 | random_state=1111) 66 | count = 0 67 | for train_idx, test_idx in skf.split(self.X, self.y): 68 | count += 1 69 | X_train, X_test = self.X[train_idx], self.X[test_idx] 70 | y_train, y_test = self.y[train_idx], self.y[test_idx] 71 | scores = [] 72 | for reg in regs: 73 | regscores = [] 74 | for inner_train_idx, inner_test_idx in innerskf.split(X_train, y_train): 75 | X_in_train, X_in_test = X_train[inner_train_idx], X_train[inner_test_idx] 76 | y_in_train, y_in_test = y_train[inner_train_idx], y_train[inner_test_idx] 77 | if self.usepytorch: 78 | clf = MLP(self.classifier_config, inputdim=self.featdim, 79 | nclasses=self.nclasses, l2reg=reg, 80 | seed=self.seed) 81 | clf.fit(X_in_train, y_in_train, 82 | validation_data=(X_in_test, y_in_test)) 83 | else: 84 | clf = LogisticRegression(C=reg, random_state=self.seed) 85 | clf.fit(X_in_train, y_in_train) 86 | regscores.append(clf.score(X_in_test, y_in_test)) 87 | scores.append(round(100*np.mean(regscores), 2)) 88 | optreg = regs[np.argmax(scores)] 89 | logging.info('Best param found at split {0}: l2reg = {1} \ 90 | with score {2}'.format(count, optreg, np.max(scores))) 91 | self.devresults.append(np.max(scores)) 92 | 93 | if self.usepytorch: 94 | clf = MLP(self.classifier_config, inputdim=self.featdim, 95 | nclasses=self.nclasses, l2reg=optreg, 96 | seed=self.seed) 97 | 98 | clf.fit(X_train, y_train, validation_split=0.05) 99 | else: 100 | clf = LogisticRegression(C=optreg, random_state=self.seed) 101 | clf.fit(X_train, y_train) 102 | 103 | self.testresults.append(round(100*clf.score(X_test, y_test), 2)) 104 | 105 | devaccuracy = round(np.mean(self.devresults), 2) 106 | testaccuracy = round(np.mean(self.testresults), 2) 107 | return devaccuracy, testaccuracy 108 | 109 | 110 | class KFoldClassifier(object): 111 | """ 112 | (train, test) split classifier : cross-validation on train. 113 | """ 114 | def __init__(self, train, test, config): 115 | self.train = train 116 | self.test = test 117 | self.featdim = self.train['X'].shape[1] 118 | self.nclasses = config['nclasses'] 119 | self.seed = config['seed'] 120 | self.usepytorch = config['usepytorch'] 121 | self.classifier_config = config['classifier'] 122 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 123 | 124 | self.k = 5 if 'kfold' not in config else config['kfold'] 125 | 126 | def run(self): 127 | # cross-validation 128 | logging.info('Training {0} with {1}-fold cross-validation' 129 | .format(self.modelname, self.k)) 130 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 131 | [2**t for t in range(-1, 6, 1)] 132 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, 133 | random_state=self.seed) 134 | scores = [] 135 | 136 | for reg in regs: 137 | scanscores = [] 138 | for train_idx, test_idx in skf.split(self.train['X'], 139 | self.train['y']): 140 | # Split data 141 | X_train, y_train = self.train['X'][train_idx], self.train['y'][train_idx] 142 | 143 | X_test, y_test = self.train['X'][test_idx], self.train['y'][test_idx] 144 | 145 | # Train classifier 146 | if self.usepytorch: 147 | clf = MLP(self.classifier_config, inputdim=self.featdim, 148 | nclasses=self.nclasses, l2reg=reg, 149 | seed=self.seed) 150 | clf.fit(X_train, y_train, validation_data=(X_test, y_test)) 151 | else: 152 | clf = LogisticRegression(C=reg, random_state=self.seed) 153 | clf.fit(X_train, y_train) 154 | score = clf.score(X_test, y_test) 155 | scanscores.append(score) 156 | # Append mean score 157 | scores.append(round(100*np.mean(scanscores), 2)) 158 | 159 | # evaluation 160 | logging.info([('reg:' + str(regs[idx]), scores[idx]) 161 | for idx in range(len(scores))]) 162 | optreg = regs[np.argmax(scores)] 163 | devaccuracy = np.max(scores) 164 | logging.info('Cross-validation : best param found is reg = {0} \ 165 | with score {1}'.format(optreg, devaccuracy)) 166 | 167 | logging.info('Evaluating...') 168 | if self.usepytorch: 169 | clf = MLP(self.classifier_config, inputdim=self.featdim, 170 | nclasses=self.nclasses, l2reg=optreg, 171 | seed=self.seed) 172 | clf.fit(self.train['X'], self.train['y'], validation_split=0.05) 173 | else: 174 | clf = LogisticRegression(C=optreg, random_state=self.seed) 175 | clf.fit(self.train['X'], self.train['y']) 176 | yhat = clf.predict(self.test['X']) 177 | 178 | testaccuracy = clf.score(self.test['X'], self.test['y']) 179 | testaccuracy = round(100*testaccuracy, 2) 180 | 181 | return devaccuracy, testaccuracy, yhat 182 | 183 | 184 | class SplitClassifier(object): 185 | """ 186 | (train, valid, test) split classifier. 187 | """ 188 | def __init__(self, X, y, config): 189 | self.X = X 190 | self.y = y 191 | self.nclasses = config['nclasses'] 192 | self.featdim = self.X['train'].shape[1] 193 | self.seed = config['seed'] 194 | self.usepytorch = config['usepytorch'] 195 | self.classifier_config = config['classifier'] 196 | self.cudaEfficient = False if 'cudaEfficient' not in config else \ 197 | config['cudaEfficient'] 198 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 199 | self.noreg = False if 'noreg' not in config else config['noreg'] 200 | self.config = config 201 | 202 | def run(self): 203 | logging.info('Training {0} with standard validation..' 204 | .format(self.modelname)) 205 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 206 | [2**t for t in range(-2, 4, 1)] 207 | if self.noreg: 208 | regs = [1e-9 if self.usepytorch else 1e9] 209 | scores = [] 210 | for reg in regs: 211 | if self.usepytorch: 212 | clf = MLP(self.classifier_config, inputdim=self.featdim, 213 | nclasses=self.nclasses, l2reg=reg, 214 | seed=self.seed, cudaEfficient=self.cudaEfficient) 215 | 216 | # TODO: Find a hack for reducing nb epoches in SNLI 217 | clf.fit(self.X['train'], self.y['train'], 218 | validation_data=(self.X['valid'], self.y['valid'])) 219 | else: 220 | clf = LogisticRegression(C=reg, random_state=self.seed) 221 | clf.fit(self.X['train'], self.y['train']) 222 | scores.append(round(100*clf.score(self.X['valid'], 223 | self.y['valid']), 2)) 224 | logging.info([('reg:'+str(regs[idx]), scores[idx]) 225 | for idx in range(len(scores))]) 226 | optreg = regs[np.argmax(scores)] 227 | devaccuracy = np.max(scores) 228 | logging.info('Validation : best param found is reg = {0} with score \ 229 | {1}'.format(optreg, devaccuracy)) 230 | clf = LogisticRegression(C=optreg, random_state=self.seed) 231 | logging.info('Evaluating...') 232 | if self.usepytorch: 233 | clf = MLP(self.classifier_config, inputdim=self.featdim, 234 | nclasses=self.nclasses, l2reg=optreg, 235 | seed=self.seed, cudaEfficient=self.cudaEfficient) 236 | 237 | # TODO: Find a hack for reducing nb epoches in SNLI 238 | clf.fit(self.X['train'], self.y['train'], 239 | validation_data=(self.X['valid'], self.y['valid'])) 240 | else: 241 | clf = LogisticRegression(C=optreg, random_state=self.seed) 242 | clf.fit(self.X['train'], self.y['train']) 243 | 244 | testaccuracy = clf.score(self.X['test'], self.y['test']) 245 | testaccuracy = round(100*testaccuracy, 2) 246 | return devaccuracy, testaccuracy 247 | -------------------------------------------------------------------------------- /simcse/tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SimCSE(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path))) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def search(self, queries: Union[str, List[str]], 181 | device: str = None, 182 | threshold: float = 0.6, 183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 184 | 185 | if not self.is_faiss_index: 186 | if isinstance(queries, list): 187 | combined_results = [] 188 | for query in queries: 189 | results = self.search(query, device) 190 | combined_results.append(results) 191 | return combined_results 192 | 193 | similarities = self.similarity(queries, self.index["index"]).tolist() 194 | id_and_score = [] 195 | for i, s in enumerate(similarities): 196 | if s >= threshold: 197 | id_and_score.append((i, s)) 198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 200 | return results 201 | else: 202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 203 | 204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 205 | 206 | def pack_single_result(dist, idx): 207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 208 | return results 209 | 210 | if isinstance(queries, list): 211 | combined_results = [] 212 | for i in range(len(queries)): 213 | results = pack_single_result(distance[i], idx[i]) 214 | combined_results.append(results) 215 | return combined_results 216 | else: 217 | return pack_single_result(distance[0], idx[0]) 218 | 219 | if __name__=="__main__": 220 | example_sentences = [ 221 | 'An animal is biting a persons finger.', 222 | 'A woman is reading.', 223 | 'A man is lifting weights in a garage.', 224 | 'A man plays the violin.', 225 | 'A man is eating food.', 226 | 'A man plays the piano.', 227 | 'A panda is climbing.', 228 | 'A man plays a guitar.', 229 | 'A woman is slicing a meat.', 230 | 'A woman is taking a picture.' 231 | ] 232 | example_queries = [ 233 | 'A man is playing music.', 234 | 'A woman is making a photo.' 235 | ] 236 | 237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 238 | simcse = SimCSE(model_name) 239 | 240 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 241 | similarities = simcse.similarity(example_queries, example_sentences) 242 | print(similarities) 243 | 244 | print("\n=========Naive brute force search============\n") 245 | simcse.build_index(example_sentences, use_faiss=False) 246 | results = simcse.search(example_queries) 247 | for i, result in enumerate(results): 248 | print("Retrieval results for query: {}".format(example_queries[i])) 249 | for sentence, score in result: 250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 251 | print("") 252 | 253 | print("\n=========Search with Faiss backend============\n") 254 | simcse.build_index(example_sentences, use_faiss=True) 255 | results = simcse.search(example_queries) 256 | for i, result in enumerate(results): 257 | print("Retrieval results for query: {}".format(example_queries[i])) 258 | for sentence, score in result: 259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 260 | print("") 261 | 262 | -------------------------------------------------------------------------------- /senteval/sts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and 10 | STS-benchmark (supervised) tasks 11 | ''' 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import os 16 | import io 17 | import numpy as np 18 | import logging 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | import torch 22 | 23 | from senteval.utils import cosine 24 | from senteval.sick import SICKEval 25 | 26 | 27 | class STSEval(object): 28 | def loadFile(self, fpath): 29 | self.data = {} 30 | self.samples = [] 31 | 32 | for dataset in self.datasets: 33 | sent1, sent2 = zip(*[l.split("\t") for l in 34 | io.open(fpath + '/STS.input.%s.txt' % dataset, 35 | encoding='utf8').read().splitlines()]) 36 | raw_scores = np.array([x for x in 37 | io.open(fpath + '/STS.gs.%s.txt' % dataset, 38 | encoding='utf8') 39 | .read().splitlines()]) 40 | not_empty_idx = raw_scores != '' 41 | 42 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]] 43 | sent1 = np.array([s.split() for s in sent1])[not_empty_idx] 44 | sent2 = np.array([s.split() for s in sent2])[not_empty_idx] 45 | # sort data by length to minimize padding in batcher 46 | sorted_data = sorted(zip(sent1, sent2, gs_scores), 47 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 48 | sent1, sent2, gs_scores = map(list, zip(*sorted_data)) 49 | 50 | self.data[dataset] = (sent1, sent2, gs_scores) 51 | self.samples += sent1 + sent2 52 | 53 | def do_prepare(self, params, prepare): 54 | if 'similarity' in params: 55 | self.similarity = params.similarity 56 | else: # Default similarity is cosine 57 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2))) 58 | return prepare(params, self.samples) 59 | 60 | def run(self, params, batcher): 61 | def align_loss(x, y, alpha=2): 62 | #x=x/torch.norm(x, p=2, dim=-1, keepdim=True) 63 | #y=y/torch.norm(y, p=2, dim=-1, keepdim=True) 64 | return (x - y).norm(p=2, dim=1).pow(alpha).mean() 65 | 66 | def uniform_loss(x, t=2): 67 | #x = x / torch.norm(x, p=2, dim=-1, keepdim=True) 68 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() 69 | 70 | results = {} 71 | all_sys_scores = [] 72 | all_gs_scores = [] 73 | ################# newly added 74 | all_loss_align = [] 75 | all_loss_uniform = [] 76 | ################# 77 | for dataset in self.datasets: 78 | sys_scores = [] 79 | input1, input2, gs_scores = self.data[dataset] 80 | for ii in range(0, len(gs_scores), params.batch_size): 81 | batch1 = input1[ii:ii + params.batch_size] 82 | batch2 = input2[ii:ii + params.batch_size] 83 | batch_gs_scores = gs_scores[ii:ii + params.batch_size] # newly added 84 | 85 | # we assume get_batch already throws out the faulty ones 86 | if len(batch1) == len(batch2) and len(batch1) > 0: 87 | enc1 = batcher(params, batch1) 88 | enc2 = batcher(params, batch2) 89 | 90 | ################# newly added 91 | pos_indices = [i for i in range(len(batch_gs_scores)) if batch_gs_scores[i] >= 4.0] 92 | enc1_norm = enc1/torch.norm(enc1, p=2, dim=-1, keepdim=True) 93 | enc2_norm = enc2/torch.norm(enc2, p=2, dim=-1, keepdim=True) 94 | enc1_pos = enc1_norm[pos_indices] 95 | enc2_pos = enc2_norm[pos_indices] 96 | loss_align = align_loss(enc1_pos, enc2_pos) 97 | loss_uniform = uniform_loss(torch.cat((enc1_norm, enc2_norm), dim=0)) 98 | all_loss_align.append(loss_align) 99 | all_loss_uniform.append(loss_uniform) 100 | ################# 101 | 102 | for kk in range(enc2.shape[0]): 103 | sys_score = self.similarity(enc1[kk], enc2[kk]) 104 | sys_scores.append(sys_score) 105 | all_sys_scores.extend(sys_scores) 106 | all_gs_scores.extend(gs_scores) 107 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores), 108 | 'spearman': spearmanr(sys_scores, gs_scores), 109 | 'nsamples': len(sys_scores), 110 | 'align_loss': float(np.mean(all_loss_align)), # newly added 111 | 'uniform_loss': float(np.mean(all_loss_uniform))} # newly added 112 | logging.debug('%s : pearson = %.4f, spearman = %.4f, align_loss = %.4f, uniform_loss = %.4f' % 113 | (dataset, results[dataset]['pearson'][0], 114 | results[dataset]['spearman'][0], results[dataset]['align_loss'], 115 | results[dataset]['uniform_loss'])) 116 | 117 | weights = [results[dset]['nsamples'] for dset in results.keys()] 118 | list_prs = np.array([results[dset]['pearson'][0] for 119 | dset in results.keys()]) 120 | list_spr = np.array([results[dset]['spearman'][0] for 121 | dset in results.keys()]) 122 | 123 | avg_pearson = np.average(list_prs) 124 | avg_spearman = np.average(list_spr) 125 | wavg_pearson = np.average(list_prs, weights=weights) 126 | wavg_spearman = np.average(list_spr, weights=weights) 127 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 128 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 129 | results['all'] = {'pearson': {'all': all_pearson[0], 130 | 'mean': avg_pearson, 131 | 'wmean': wavg_pearson}, 132 | 'spearman': {'all': all_spearman[0], 133 | 'mean': avg_spearman, 134 | 'wmean': wavg_spearman}} 135 | logging.debug('ALL : Pearson = %.4f, \ 136 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 137 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 138 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 139 | logging.debug('ALL (average) : Pearson = %.4f, \ 140 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 141 | 142 | return results 143 | 144 | 145 | class STS12Eval(STSEval): 146 | def __init__(self, taskpath, seed=1111): 147 | logging.debug('***** Transfer task : STS12 *****\n\n') 148 | self.seed = seed 149 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl', 150 | 'surprise.OnWN', 'surprise.SMTnews'] 151 | self.loadFile(taskpath) 152 | 153 | 154 | class STS13Eval(STSEval): 155 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue 156 | def __init__(self, taskpath, seed=1111): 157 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n') 158 | self.seed = seed 159 | self.datasets = ['FNWN', 'headlines', 'OnWN'] 160 | self.loadFile(taskpath) 161 | 162 | 163 | class STS14Eval(STSEval): 164 | def __init__(self, taskpath, seed=1111): 165 | logging.debug('***** Transfer task : STS14 *****\n\n') 166 | self.seed = seed 167 | self.datasets = ['deft-forum', 'deft-news', 'headlines', 168 | 'images', 'OnWN', 'tweet-news'] 169 | self.loadFile(taskpath) 170 | 171 | 172 | class STS15Eval(STSEval): 173 | def __init__(self, taskpath, seed=1111): 174 | logging.debug('***** Transfer task : STS15 *****\n\n') 175 | self.seed = seed 176 | self.datasets = ['answers-forums', 'answers-students', 177 | 'belief', 'headlines', 'images'] 178 | self.loadFile(taskpath) 179 | 180 | 181 | class STS16Eval(STSEval): 182 | def __init__(self, taskpath, seed=1111): 183 | logging.debug('***** Transfer task : STS16 *****\n\n') 184 | self.seed = seed 185 | self.datasets = ['answer-answer', 'headlines', 'plagiarism', 186 | 'postediting', 'question-question'] 187 | self.loadFile(taskpath) 188 | 189 | 190 | class STSBenchmarkEval(STSEval): 191 | def __init__(self, task_path, seed=1111): 192 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 193 | self.seed = seed 194 | self.samples = [] 195 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 196 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 197 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 198 | self.datasets = ['train', 'dev', 'test'] 199 | self.data = {'train': train, 'dev': dev, 'test': test} 200 | 201 | def loadFile(self, fpath): 202 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 203 | with io.open(fpath, 'r', encoding='utf-8') as f: 204 | for line in f: 205 | text = line.strip().split('\t') 206 | sick_data['X_A'].append(text[5].split()) 207 | sick_data['X_B'].append(text[6].split()) 208 | sick_data['y'].append(text[4]) 209 | 210 | sick_data['y'] = [float(s) for s in sick_data['y']] 211 | self.samples += sick_data['X_A'] + sick_data["X_B"] 212 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 213 | 214 | class STSBenchmarkFinetune(SICKEval): 215 | def __init__(self, task_path, seed=1111): 216 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 217 | self.seed = seed 218 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 219 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 220 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 221 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 222 | 223 | def loadFile(self, fpath): 224 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 225 | with io.open(fpath, 'r', encoding='utf-8') as f: 226 | for line in f: 227 | text = line.strip().split('\t') 228 | sick_data['X_A'].append(text[5].split()) 229 | sick_data['X_B'].append(text[6].split()) 230 | sick_data['y'].append(text[4]) 231 | 232 | sick_data['y'] = [float(s) for s in sick_data['y']] 233 | return sick_data 234 | 235 | class SICKRelatednessEval(STSEval): 236 | def __init__(self, task_path, seed=1111): 237 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n') 238 | self.seed = seed 239 | self.samples = [] 240 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 241 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 242 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 243 | self.datasets = ['train', 'dev', 'test'] 244 | self.data = {'train': train, 'dev': dev, 'test': test} 245 | 246 | def loadFile(self, fpath): 247 | skipFirstLine = True 248 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 249 | with io.open(fpath, 'r', encoding='utf-8') as f: 250 | for line in f: 251 | if skipFirstLine: 252 | skipFirstLine = False 253 | else: 254 | text = line.strip().split('\t') 255 | sick_data['X_A'].append(text[1].split()) 256 | sick_data['X_B'].append(text[2].split()) 257 | sick_data['y'].append(text[3]) 258 | 259 | sick_data['y'] = [float(s) for s in sick_data['y']] 260 | self.samples += sick_data['X_A'] + sick_data["X_B"] 261 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 262 | -------------------------------------------------------------------------------- /senteval/tools/ranking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Image Annotation/Search for COCO with Pytorch 10 | """ 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import logging 14 | import copy 15 | import numpy as np 16 | 17 | import torch 18 | from torch import nn 19 | from torch.autograd import Variable 20 | import torch.optim as optim 21 | 22 | 23 | class COCOProjNet(nn.Module): 24 | def __init__(self, config): 25 | super(COCOProjNet, self).__init__() 26 | self.imgdim = config['imgdim'] 27 | self.sentdim = config['sentdim'] 28 | self.projdim = config['projdim'] 29 | self.imgproj = nn.Sequential( 30 | nn.Linear(self.imgdim, self.projdim), 31 | ) 32 | self.sentproj = nn.Sequential( 33 | nn.Linear(self.sentdim, self.projdim), 34 | ) 35 | 36 | def forward(self, img, sent, imgc, sentc): 37 | # imgc : (bsize, ncontrast, imgdim) 38 | # sentc : (bsize, ncontrast, sentdim) 39 | # img : (bsize, imgdim) 40 | # sent : (bsize, sentdim) 41 | img = img.unsqueeze(1).expand_as(imgc).contiguous() 42 | img = img.view(-1, self.imgdim) 43 | imgc = imgc.view(-1, self.imgdim) 44 | sent = sent.unsqueeze(1).expand_as(sentc).contiguous() 45 | sent = sent.view(-1, self.sentdim) 46 | sentc = sentc.view(-1, self.sentdim) 47 | 48 | imgproj = self.imgproj(img) 49 | imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj) 50 | imgcproj = self.imgproj(imgc) 51 | imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj) 52 | sentproj = self.sentproj(sent) 53 | sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj) 54 | sentcproj = self.sentproj(sentc) 55 | sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj) 56 | # (bsize*ncontrast, projdim) 57 | 58 | anchor1 = torch.sum((imgproj*sentproj), 1) 59 | anchor2 = torch.sum((sentproj*imgproj), 1) 60 | img_sentc = torch.sum((imgproj*sentcproj), 1) 61 | sent_imgc = torch.sum((sentproj*imgcproj), 1) 62 | 63 | # (bsize*ncontrast) 64 | return anchor1, anchor2, img_sentc, sent_imgc 65 | 66 | def proj_sentence(self, sent): 67 | output = self.sentproj(sent) 68 | output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) 69 | return output # (bsize, projdim) 70 | 71 | def proj_image(self, img): 72 | output = self.imgproj(img) 73 | output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) 74 | return output # (bsize, projdim) 75 | 76 | 77 | class PairwiseRankingLoss(nn.Module): 78 | """ 79 | Pairwise ranking loss 80 | """ 81 | def __init__(self, margin): 82 | super(PairwiseRankingLoss, self).__init__() 83 | self.margin = margin 84 | 85 | def forward(self, anchor1, anchor2, img_sentc, sent_imgc): 86 | 87 | cost_sent = torch.clamp(self.margin - anchor1 + img_sentc, 88 | min=0.0).sum() 89 | cost_img = torch.clamp(self.margin - anchor2 + sent_imgc, 90 | min=0.0).sum() 91 | loss = cost_sent + cost_img 92 | return loss 93 | 94 | 95 | class ImageSentenceRankingPytorch(object): 96 | # Image Sentence Ranking on COCO with Pytorch 97 | def __init__(self, train, valid, test, config): 98 | # fix seed 99 | self.seed = config['seed'] 100 | np.random.seed(self.seed) 101 | torch.manual_seed(self.seed) 102 | torch.cuda.manual_seed(self.seed) 103 | 104 | self.train = train 105 | self.valid = valid 106 | self.test = test 107 | 108 | self.imgdim = len(train['imgfeat'][0]) 109 | self.sentdim = len(train['sentfeat'][0]) 110 | self.projdim = config['projdim'] 111 | self.margin = config['margin'] 112 | 113 | self.batch_size = 128 114 | self.ncontrast = 30 115 | self.maxepoch = 20 116 | self.early_stop = True 117 | 118 | config_model = {'imgdim': self.imgdim,'sentdim': self.sentdim, 119 | 'projdim': self.projdim} 120 | self.model = COCOProjNet(config_model).cuda() 121 | 122 | self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda() 123 | 124 | self.optimizer = optim.Adam(self.model.parameters()) 125 | 126 | def prepare_data(self, trainTxt, trainImg, devTxt, devImg, 127 | testTxt, testImg): 128 | trainTxt = torch.FloatTensor(trainTxt) 129 | trainImg = torch.FloatTensor(trainImg) 130 | devTxt = torch.FloatTensor(devTxt).cuda() 131 | devImg = torch.FloatTensor(devImg).cuda() 132 | testTxt = torch.FloatTensor(testTxt).cuda() 133 | testImg = torch.FloatTensor(testImg).cuda() 134 | 135 | return trainTxt, trainImg, devTxt, devImg, testTxt, testImg 136 | 137 | def run(self): 138 | self.nepoch = 0 139 | bestdevscore = -1 140 | early_stop_count = 0 141 | stop_train = False 142 | 143 | # Preparing data 144 | logging.info('prepare data') 145 | trainTxt, trainImg, devTxt, devImg, testTxt, testImg = \ 146 | self.prepare_data(self.train['sentfeat'], self.train['imgfeat'], 147 | self.valid['sentfeat'], self.valid['imgfeat'], 148 | self.test['sentfeat'], self.test['imgfeat']) 149 | 150 | # Training 151 | while not stop_train and self.nepoch <= self.maxepoch: 152 | logging.info('start epoch') 153 | self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1) 154 | logging.info('Epoch {0} finished'.format(self.nepoch)) 155 | 156 | results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 157 | 't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 158 | 'dev': bestdevscore} 159 | score = 0 160 | for i in range(5): 161 | devTxt_i = devTxt[i*5000:(i+1)*5000] 162 | devImg_i = devImg[i*5000:(i+1)*5000] 163 | # Compute dev ranks img2txt 164 | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i, 165 | devTxt_i) 166 | results['i2t']['r1'] += r1_i2t / 5 167 | results['i2t']['r5'] += r5_i2t / 5 168 | results['i2t']['r10'] += r10_i2t / 5 169 | results['i2t']['medr'] += medr_i2t / 5 170 | logging.info("Image to text: {0}, {1}, {2}, {3}" 171 | .format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 172 | # Compute dev ranks txt2img 173 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i, 174 | devTxt_i) 175 | results['t2i']['r1'] += r1_t2i / 5 176 | results['t2i']['r5'] += r5_t2i / 5 177 | results['t2i']['r10'] += r10_t2i / 5 178 | results['t2i']['medr'] += medr_t2i / 5 179 | logging.info("Text to Image: {0}, {1}, {2}, {3}" 180 | .format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 181 | score += (r1_i2t + r5_i2t + r10_i2t + 182 | r1_t2i + r5_t2i + r10_t2i) / 5 183 | 184 | logging.info("Dev mean Text to Image: {0}, {1}, {2}, {3}".format( 185 | results['t2i']['r1'], results['t2i']['r5'], 186 | results['t2i']['r10'], results['t2i']['medr'])) 187 | logging.info("Dev mean Image to text: {0}, {1}, {2}, {3}".format( 188 | results['i2t']['r1'], results['i2t']['r5'], 189 | results['i2t']['r10'], results['i2t']['medr'])) 190 | 191 | # early stop on Pearson 192 | if score > bestdevscore: 193 | bestdevscore = score 194 | bestmodel = copy.deepcopy(self.model) 195 | elif self.early_stop: 196 | if early_stop_count >= 3: 197 | stop_train = True 198 | early_stop_count += 1 199 | self.model = bestmodel 200 | 201 | # Compute test for the 5 splits 202 | results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 203 | 't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 204 | 'dev': bestdevscore} 205 | for i in range(5): 206 | testTxt_i = testTxt[i*5000:(i+1)*5000] 207 | testImg_i = testImg[i*5000:(i+1)*5000] 208 | # Compute test ranks img2txt 209 | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i) 210 | results['i2t']['r1'] += r1_i2t / 5 211 | results['i2t']['r5'] += r5_i2t / 5 212 | results['i2t']['r10'] += r10_i2t / 5 213 | results['i2t']['medr'] += medr_i2t / 5 214 | # Compute test ranks txt2img 215 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i) 216 | results['t2i']['r1'] += r1_t2i / 5 217 | results['t2i']['r5'] += r5_t2i / 5 218 | results['t2i']['r10'] += r10_t2i / 5 219 | results['t2i']['medr'] += medr_t2i / 5 220 | 221 | return bestdevscore, results['i2t']['r1'], results['i2t']['r5'], \ 222 | results['i2t']['r10'], results['i2t']['medr'], \ 223 | results['t2i']['r1'], results['t2i']['r5'], \ 224 | results['t2i']['r10'], results['t2i']['medr'] 225 | 226 | def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1): 227 | self.model.train() 228 | for _ in range(self.nepoch, self.nepoch + nepoches): 229 | permutation = list(np.random.permutation(len(trainTxt))) 230 | all_costs = [] 231 | for i in range(0, len(trainTxt), self.batch_size): 232 | # forward 233 | if i % (self.batch_size*500) == 0 and i > 0: 234 | logging.info('samples : {0}'.format(i)) 235 | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg, 236 | devTxt) 237 | logging.info("Image to text: {0}, {1}, {2}, {3}".format( 238 | r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 239 | # Compute test ranks txt2img 240 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg, 241 | devTxt) 242 | logging.info("Text to Image: {0}, {1}, {2}, {3}".format( 243 | r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 244 | idx = torch.LongTensor(permutation[i:i + self.batch_size]) 245 | imgbatch = Variable(trainImg.index_select(0, idx)).cuda() 246 | sentbatch = Variable(trainTxt.index_select(0, idx)).cuda() 247 | 248 | idximgc = np.random.choice(permutation[:i] + 249 | permutation[i + self.batch_size:], 250 | self.ncontrast*idx.size(0)) 251 | idxsentc = np.random.choice(permutation[:i] + 252 | permutation[i + self.batch_size:], 253 | self.ncontrast*idx.size(0)) 254 | idximgc = torch.LongTensor(idximgc) 255 | idxsentc = torch.LongTensor(idxsentc) 256 | # Get indexes for contrastive images and sentences 257 | imgcbatch = Variable(trainImg.index_select(0, idximgc)).view( 258 | -1, self.ncontrast, self.imgdim).cuda() 259 | sentcbatch = Variable(trainTxt.index_select(0, idxsentc)).view( 260 | -1, self.ncontrast, self.sentdim).cuda() 261 | 262 | anchor1, anchor2, img_sentc, sent_imgc = self.model( 263 | imgbatch, sentbatch, imgcbatch, sentcbatch) 264 | # loss 265 | loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc) 266 | all_costs.append(loss.data.item()) 267 | # backward 268 | self.optimizer.zero_grad() 269 | loss.backward() 270 | # Update parameters 271 | self.optimizer.step() 272 | self.nepoch += nepoches 273 | 274 | def t2i(self, images, captions): 275 | """ 276 | Images: (5N, imgdim) matrix of images 277 | Captions: (5N, sentdim) matrix of captions 278 | """ 279 | with torch.no_grad(): 280 | # Project images and captions 281 | img_embed, sent_embed = [], [] 282 | for i in range(0, len(images), self.batch_size): 283 | img_embed.append(self.model.proj_image( 284 | Variable(images[i:i + self.batch_size]))) 285 | sent_embed.append(self.model.proj_sentence( 286 | Variable(captions[i:i + self.batch_size]))) 287 | img_embed = torch.cat(img_embed, 0).data 288 | sent_embed = torch.cat(sent_embed, 0).data 289 | 290 | npts = int(img_embed.size(0) / 5) 291 | idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5)) 292 | ims = img_embed.index_select(0, idxs) 293 | 294 | ranks = np.zeros(5 * npts) 295 | for index in range(npts): 296 | 297 | # Get query captions 298 | queries = sent_embed[5*index: 5*index + 5] 299 | 300 | # Compute scores 301 | scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy() 302 | inds = np.zeros(scores.shape) 303 | for i in range(len(inds)): 304 | inds[i] = np.argsort(scores[i])[::-1] 305 | ranks[5 * index + i] = np.where(inds[i] == index)[0][0] 306 | 307 | # Compute metrics 308 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 309 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 310 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 311 | medr = np.floor(np.median(ranks)) + 1 312 | return (r1, r5, r10, medr) 313 | 314 | def i2t(self, images, captions): 315 | """ 316 | Images: (5N, imgdim) matrix of images 317 | Captions: (5N, sentdim) matrix of captions 318 | """ 319 | with torch.no_grad(): 320 | # Project images and captions 321 | img_embed, sent_embed = [], [] 322 | for i in range(0, len(images), self.batch_size): 323 | img_embed.append(self.model.proj_image( 324 | Variable(images[i:i + self.batch_size]))) 325 | sent_embed.append(self.model.proj_sentence( 326 | Variable(captions[i:i + self.batch_size]))) 327 | img_embed = torch.cat(img_embed, 0).data 328 | sent_embed = torch.cat(sent_embed, 0).data 329 | 330 | npts = int(img_embed.size(0) / 5) 331 | index_list = [] 332 | 333 | ranks = np.zeros(npts) 334 | for index in range(npts): 335 | 336 | # Get query image 337 | query_img = img_embed[5 * index] 338 | 339 | # Compute scores 340 | scores = torch.mm(query_img.view(1, -1), 341 | sent_embed.transpose(0, 1)).view(-1) 342 | scores = scores.cpu().numpy() 343 | inds = np.argsort(scores)[::-1] 344 | index_list.append(inds[0]) 345 | 346 | # Score 347 | rank = 1e20 348 | for i in range(5*index, 5*index + 5, 1): 349 | tmp = np.where(inds == i)[0][0] 350 | if tmp < rank: 351 | rank = tmp 352 | ranks[index] = rank 353 | 354 | # Compute metrics 355 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 356 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 357 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 358 | medr = np.floor(np.median(ranks)) + 1 359 | return (r1, r5, r10, medr) 360 | -------------------------------------------------------------------------------- /simcse/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Optional, Union, List, Dict, Tuple 7 | import torch 8 | import collections 9 | import random 10 | 11 | from datasets import load_dataset 12 | 13 | import transformers 14 | from transformers import ( 15 | CONFIG_MAPPING, 16 | MODEL_FOR_MASKED_LM_MAPPING, 17 | AutoConfig, 18 | AutoModelForMaskedLM, 19 | AutoModelForSequenceClassification, 20 | AutoTokenizer, 21 | DataCollatorForLanguageModeling, 22 | DataCollatorWithPadding, 23 | HfArgumentParser, 24 | Trainer, 25 | TrainingArguments, 26 | default_data_collator, 27 | set_seed, 28 | EvalPrediction, 29 | BertModel, 30 | BertForPreTraining, 31 | RobertaModel 32 | ) 33 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase 34 | from transformers.trainer_utils import is_main_process 35 | from transformers.data.data_collator import DataCollatorForLanguageModeling 36 | from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available 37 | from simcse.models import RobertaForCL, BertForCL 38 | from simcse.trainers import CLTrainer 39 | 40 | logger = logging.getLogger(__name__) 41 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 42 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 43 | 44 | @dataclass 45 | class ModelArguments: 46 | """ 47 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 48 | """ 49 | 50 | # Huggingface's original arguments 51 | model_name_or_path: Optional[str] = field( 52 | default=None, 53 | metadata={ 54 | "help": "The model checkpoint for weights initialization." 55 | "Don't set if you want to train a model from scratch." 56 | }, 57 | ) 58 | model_type: Optional[str] = field( 59 | default=None, 60 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 61 | ) 62 | config_name: Optional[str] = field( 63 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 64 | ) 65 | tokenizer_name: Optional[str] = field( 66 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 67 | ) 68 | cache_dir: Optional[str] = field( 69 | default=None, 70 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 71 | ) 72 | use_fast_tokenizer: bool = field( 73 | default=True, 74 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 75 | ) 76 | model_revision: str = field( 77 | default="main", 78 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 79 | ) 80 | use_auth_token: bool = field( 81 | default=False, 82 | metadata={ 83 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 84 | "with private models)." 85 | }, 86 | ) 87 | 88 | # SimCSE's arguments 89 | temp: float = field( 90 | default=0.05, 91 | metadata={ 92 | "help": "Temperature for softmax." 93 | } 94 | ) 95 | pooler_type: str = field( 96 | default="cls", 97 | metadata={ 98 | "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)." 99 | } 100 | ) 101 | hard_negative_weight: float = field( 102 | default=0, 103 | metadata={ 104 | "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)." 105 | } 106 | ) 107 | do_mlm: bool = field( 108 | default=False, 109 | metadata={ 110 | "help": "Whether to use MLM auxiliary objective." 111 | } 112 | ) 113 | mlm_weight: float = field( 114 | default=0.1, 115 | metadata={ 116 | "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)." 117 | } 118 | ) 119 | mlp_only_train: bool = field( 120 | default=False, 121 | metadata={ 122 | "help": "Use MLP only during training" 123 | } 124 | ) 125 | 126 | 127 | @dataclass 128 | class DataTrainingArguments: 129 | """ 130 | Arguments pertaining to what data we are going to input our model for training and eval. 131 | """ 132 | 133 | # Huggingface's original arguments. 134 | dataset_name: Optional[str] = field( 135 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 136 | ) 137 | dataset_config_name: Optional[str] = field( 138 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 139 | ) 140 | overwrite_cache: bool = field( 141 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 142 | ) 143 | validation_split_percentage: Optional[int] = field( 144 | default=5, 145 | metadata={ 146 | "help": "The percentage of the train set used as validation set in case there's no validation split" 147 | }, 148 | ) 149 | preprocessing_num_workers: Optional[int] = field( 150 | default=None, 151 | metadata={"help": "The number of processes to use for the preprocessing."}, 152 | ) 153 | 154 | # SimCSE's arguments 155 | train_file: Optional[str] = field( 156 | default=None, 157 | metadata={"help": "The training data file (.txt or .csv)."} 158 | ) 159 | max_seq_length: Optional[int] = field( 160 | default=32, 161 | metadata={ 162 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 163 | "than this will be truncated." 164 | }, 165 | ) 166 | pad_to_max_length: bool = field( 167 | default=False, 168 | metadata={ 169 | "help": "Whether to pad all samples to `max_seq_length`. " 170 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 171 | }, 172 | ) 173 | mlm_probability: float = field( 174 | default=0.15, 175 | metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"} 176 | ) 177 | 178 | def __post_init__(self): 179 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 180 | raise ValueError("Need either a dataset name or a training/validation file.") 181 | else: 182 | if self.train_file is not None: 183 | extension = self.train_file.split(".")[-1] 184 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 185 | 186 | 187 | @dataclass 188 | class OurTrainingArguments(TrainingArguments): 189 | # Evaluation 190 | ## By default, we evaluate STS (dev) during training (for selecting best checkpoints) and evaluate 191 | ## both STS and transfer tasks (dev) at the end of training. Using --eval_transfer will allow evaluating 192 | ## both STS and transfer tasks (dev) during training. 193 | eval_transfer: bool = field( 194 | default=False, 195 | metadata={"help": "Evaluate transfer task dev sets (in validation)."} 196 | ) 197 | 198 | @cached_property 199 | @torch_required 200 | def _setup_devices(self) -> "torch.device": 201 | logger.info("PyTorch: setting up devices") 202 | if self.no_cuda: 203 | device = torch.device("cpu") 204 | self._n_gpu = 0 205 | elif is_torch_tpu_available(): 206 | device = xm.xla_device() 207 | self._n_gpu = 0 208 | elif self.local_rank == -1: 209 | # if n_gpu is > 1 we'll use nn.DataParallel. 210 | # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` 211 | # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will 212 | # trigger an error that a device index is missing. Index 0 takes into account the 213 | # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` 214 | # will use the first GPU in that env, i.e. GPU#1 215 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 216 | # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at 217 | # the default value. 218 | self._n_gpu = torch.cuda.device_count() 219 | else: 220 | # Here, we'll use torch.distributed. 221 | # Initializes the distributed backend which will take care of synchronizing nodes/GPUs 222 | # 223 | # deepspeed performs its own DDP internally, and requires the program to be started with: 224 | # deepspeed ./program.py 225 | # rather than: 226 | # python -m torch.distributed.launch --nproc_per_node=2 ./program.py 227 | if self.deepspeed: 228 | from .integrations import is_deepspeed_available 229 | 230 | if not is_deepspeed_available(): 231 | raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") 232 | import deepspeed 233 | 234 | deepspeed.init_distributed() 235 | else: 236 | torch.distributed.init_process_group(backend="nccl") 237 | device = torch.device("cuda", self.local_rank) 238 | self._n_gpu = 1 239 | 240 | if device.type == "cuda": 241 | torch.cuda.set_device(device) 242 | 243 | return device 244 | 245 | 246 | def main(): 247 | # See all possible arguments in src/transformers/training_args.py 248 | # or by passing the --help flag to this script. 249 | # We now keep distinct sets of args, for a cleaner separation of concerns. 250 | 251 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments)) 252 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 253 | # If we pass only one argument to the script and it's the path to a json file, 254 | # let's parse it to get our arguments. 255 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 256 | else: 257 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 258 | 259 | if ( 260 | os.path.exists(training_args.output_dir) 261 | and os.listdir(training_args.output_dir) 262 | and training_args.do_train 263 | and not training_args.overwrite_output_dir 264 | ): 265 | raise ValueError( 266 | f"Output directory ({training_args.output_dir}) already exists and is not empty." 267 | "Use --overwrite_output_dir to overcome." 268 | ) 269 | 270 | # Setup logging 271 | logging.basicConfig( 272 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 273 | datefmt="%m/%d/%Y %H:%M:%S", 274 | level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, 275 | ) 276 | 277 | # Log on each process the small summary: 278 | logger.warning( 279 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 280 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 281 | ) 282 | # Set the verbosity to info of the Transformers logger (on main process only): 283 | if is_main_process(training_args.local_rank): 284 | transformers.utils.logging.set_verbosity_info() 285 | transformers.utils.logging.enable_default_handler() 286 | transformers.utils.logging.enable_explicit_format() 287 | logger.info("Training/evaluation parameters %s", training_args) 288 | 289 | # Set seed before initializing model. 290 | set_seed(training_args.seed) 291 | 292 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 293 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 294 | # (the dataset will be downloaded automatically from the datasets Hub 295 | # 296 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this 297 | # behavior (see below) 298 | # 299 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 300 | # download the dataset. 301 | data_files = {} 302 | if data_args.train_file is not None: 303 | data_files["train"] = data_args.train_file 304 | extension = data_args.train_file.split(".")[-1] 305 | if extension == "txt": 306 | extension = "text" 307 | if extension == "csv": 308 | datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/", delimiter="\t" if "tsv" in data_args.train_file else ",") 309 | else: 310 | datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/") 311 | 312 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 313 | # https://huggingface.co/docs/datasets/loading_datasets.html. 314 | 315 | # Load pretrained model and tokenizer 316 | # 317 | # Distributed training: 318 | # The .from_pretrained methods guarantee that only one local process can concurrently 319 | # download model & vocab. 320 | config_kwargs = { 321 | "cache_dir": model_args.cache_dir, 322 | "revision": model_args.model_revision, 323 | "use_auth_token": True if model_args.use_auth_token else None, 324 | } 325 | if model_args.config_name: 326 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 327 | elif model_args.model_name_or_path: 328 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 329 | else: 330 | config = CONFIG_MAPPING[model_args.model_type]() 331 | logger.warning("You are instantiating a new config instance from scratch.") 332 | 333 | tokenizer_kwargs = { 334 | "cache_dir": model_args.cache_dir, 335 | "use_fast": model_args.use_fast_tokenizer, 336 | "revision": model_args.model_revision, 337 | "use_auth_token": True if model_args.use_auth_token else None, 338 | } 339 | if model_args.tokenizer_name: 340 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 341 | elif model_args.model_name_or_path: 342 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 343 | else: 344 | raise ValueError( 345 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 346 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 347 | ) 348 | 349 | if model_args.model_name_or_path: 350 | if 'roberta' in model_args.model_name_or_path: 351 | model = RobertaForCL.from_pretrained( 352 | model_args.model_name_or_path, 353 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 354 | config=config, 355 | cache_dir=model_args.cache_dir, 356 | revision=model_args.model_revision, 357 | use_auth_token=True if model_args.use_auth_token else None, 358 | model_args=model_args 359 | ) 360 | elif 'bert' in model_args.model_name_or_path: 361 | model = BertForCL.from_pretrained( 362 | model_args.model_name_or_path, 363 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 364 | config=config, 365 | cache_dir=model_args.cache_dir, 366 | revision=model_args.model_revision, 367 | use_auth_token=True if model_args.use_auth_token else None, 368 | model_args=model_args 369 | ) 370 | if model_args.do_mlm: 371 | pretrained_model = BertForPreTraining.from_pretrained(model_args.model_name_or_path) 372 | model.lm_head.load_state_dict(pretrained_model.cls.predictions.state_dict()) 373 | else: 374 | raise NotImplementedError 375 | else: 376 | raise NotImplementedError 377 | logger.info("Training new model from scratch") 378 | model = AutoModelForMaskedLM.from_config(config) 379 | 380 | model.resize_token_embeddings(len(tokenizer)) 381 | 382 | # Prepare features 383 | column_names = datasets["train"].column_names 384 | sent2_cname = None 385 | if len(column_names) == 2: 386 | # Pair datasets 387 | sent0_cname = column_names[0] 388 | sent1_cname = column_names[1] 389 | elif len(column_names) == 3: 390 | # Pair datasets with hard negatives 391 | sent0_cname = column_names[0] 392 | sent1_cname = column_names[1] 393 | sent2_cname = column_names[2] 394 | elif len(column_names) == 1: 395 | # Unsupervised datasets 396 | sent0_cname = column_names[0] 397 | sent1_cname = column_names[0] 398 | else: 399 | raise NotImplementedError 400 | 401 | def prepare_features(examples): 402 | # padding = longest (default) 403 | # If no sentence in the batch exceed the max length, then use 404 | # the max sentence length in the batch, otherwise use the 405 | # max sentence length in the argument and truncate those that 406 | # exceed the max length. 407 | # padding = max_length (when pad_to_max_length, for pressure test) 408 | # All sentences are padded/truncated to data_args.max_seq_length. 409 | total = len(examples[sent0_cname]) 410 | 411 | # Avoid "None" fields 412 | for idx in range(total): 413 | if examples[sent0_cname][idx] is None: 414 | examples[sent0_cname][idx] = " " 415 | if examples[sent1_cname][idx] is None: 416 | examples[sent1_cname][idx] = " " 417 | 418 | sentences = examples[sent0_cname] + examples[sent1_cname] 419 | 420 | # If hard negative exists 421 | if sent2_cname is not None: 422 | for idx in range(total): 423 | if examples[sent2_cname][idx] is None: 424 | examples[sent2_cname][idx] = " " 425 | sentences += examples[sent2_cname] 426 | 427 | sent_features = tokenizer( 428 | sentences, 429 | max_length=data_args.max_seq_length, 430 | truncation=True, 431 | padding="max_length" if data_args.pad_to_max_length else False, 432 | ) 433 | 434 | features = {} 435 | if sent2_cname is not None: 436 | for key in sent_features: 437 | features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)] 438 | else: 439 | for key in sent_features: 440 | features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)] 441 | 442 | return features 443 | 444 | if training_args.do_train: 445 | train_dataset = datasets["train"].map( 446 | prepare_features, 447 | batched=True, 448 | num_proc=data_args.preprocessing_num_workers, 449 | remove_columns=column_names, 450 | load_from_cache_file=not data_args.overwrite_cache, 451 | ) 452 | 453 | # Data collator 454 | @dataclass 455 | class OurDataCollatorWithPadding: 456 | 457 | tokenizer: PreTrainedTokenizerBase 458 | padding: Union[bool, str, PaddingStrategy] = True 459 | max_length: Optional[int] = None 460 | pad_to_multiple_of: Optional[int] = None 461 | mlm: bool = True 462 | mlm_probability: float = data_args.mlm_probability 463 | 464 | def __call__(self, features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 465 | special_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'mlm_input_ids', 'mlm_labels'] 466 | bs = len(features) 467 | if bs > 0: 468 | num_sent = len(features[0]['input_ids']) 469 | else: 470 | return 471 | flat_features = [] 472 | for feature in features: 473 | for i in range(num_sent): 474 | flat_features.append({k: feature[k][i] if k in special_keys else feature[k] for k in feature}) 475 | 476 | batch = self.tokenizer.pad( 477 | flat_features, 478 | padding=self.padding, 479 | max_length=self.max_length, 480 | pad_to_multiple_of=self.pad_to_multiple_of, 481 | return_tensors="pt", 482 | ) 483 | if model_args.do_mlm: 484 | batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"]) 485 | 486 | batch = {k: batch[k].view(bs, num_sent, -1) if k in special_keys else batch[k].view(bs, num_sent, -1)[:, 0] for k in batch} 487 | 488 | if "label" in batch: 489 | batch["labels"] = batch["label"] 490 | del batch["label"] 491 | if "label_ids" in batch: 492 | batch["labels"] = batch["label_ids"] 493 | del batch["label_ids"] 494 | 495 | return batch 496 | 497 | def mask_tokens( 498 | self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None 499 | ) -> Tuple[torch.Tensor, torch.Tensor]: 500 | """ 501 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 502 | """ 503 | labels = inputs.clone() 504 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 505 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 506 | if special_tokens_mask is None: 507 | special_tokens_mask = [ 508 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 509 | ] 510 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 511 | else: 512 | special_tokens_mask = special_tokens_mask.bool() 513 | 514 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 515 | masked_indices = torch.bernoulli(probability_matrix).bool() 516 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 517 | 518 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 519 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 520 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 521 | 522 | # 10% of the time, we replace masked input tokens with random word 523 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 524 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 525 | inputs[indices_random] = random_words[indices_random] 526 | 527 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 528 | return inputs, labels 529 | 530 | data_collator = default_data_collator if data_args.pad_to_max_length else OurDataCollatorWithPadding(tokenizer) 531 | 532 | trainer = CLTrainer( 533 | model=model, 534 | args=training_args, 535 | train_dataset=train_dataset if training_args.do_train else None, 536 | tokenizer=tokenizer, 537 | data_collator=data_collator, 538 | ) 539 | trainer.model_args = model_args 540 | 541 | # Training 542 | if training_args.do_train: 543 | model_path = ( 544 | model_args.model_name_or_path 545 | if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) 546 | else None 547 | ) 548 | train_result = trainer.train(model_path=model_path) 549 | trainer.save_model() # Saves the tokenizer too for easy upload 550 | 551 | output_train_file = os.path.join(training_args.output_dir, "train_results.txt") 552 | if trainer.is_world_process_zero(): 553 | with open(output_train_file, "w") as writer: 554 | logger.info("***** Train results *****") 555 | for key, value in sorted(train_result.metrics.items()): 556 | logger.info(f" {key} = {value}") 557 | writer.write(f"{key} = {value}\n") 558 | 559 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model 560 | trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) 561 | 562 | # Evaluation 563 | results = {} 564 | if training_args.do_eval: 565 | logger.info("*** Evaluate ***") 566 | results = trainer.evaluate(eval_senteval_transfer=True) 567 | 568 | output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") 569 | if trainer.is_world_process_zero(): 570 | with open(output_eval_file, "w") as writer: 571 | logger.info("***** Eval results *****") 572 | for key, value in sorted(results.items()): 573 | logger.info(f" {key} = {value}") 574 | writer.write(f"{key} = {value}\n") 575 | 576 | return results 577 | 578 | def _mp_fn(index): 579 | # For xla_spawn (TPUs) 580 | main() 581 | 582 | 583 | if __name__ == "__main__": 584 | main() 585 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | os.environ["WANDB_DISABLED"] = "true" 5 | import sys 6 | from dataclasses import dataclass, field 7 | from typing import Optional, Union, List, Dict, Tuple 8 | import torch 9 | import collections 10 | import random 11 | 12 | from datasets import load_dataset 13 | 14 | import transformers 15 | from transformers import ( 16 | CONFIG_MAPPING, 17 | MODEL_FOR_MASKED_LM_MAPPING, 18 | AutoConfig, 19 | AutoModelForMaskedLM, 20 | AutoModelForSequenceClassification, 21 | AutoTokenizer, 22 | DataCollatorForLanguageModeling, 23 | DataCollatorWithPadding, 24 | HfArgumentParser, 25 | Trainer, 26 | TrainingArguments, 27 | default_data_collator, 28 | set_seed, 29 | EvalPrediction, 30 | BertModel, 31 | BertForPreTraining, 32 | RobertaModel 33 | ) 34 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase 35 | from transformers.trainer_utils import is_main_process 36 | from transformers.data.data_collator import DataCollatorForLanguageModeling 37 | from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available 38 | from simcse.models import RobertaForCL, BertForCL 39 | from simcse.trainers import CLTrainer 40 | 41 | logger = logging.getLogger(__name__) 42 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 43 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 44 | 45 | @dataclass 46 | class ModelArguments: 47 | """ 48 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 49 | """ 50 | 51 | # Huggingface's original arguments 52 | model_name_or_path: Optional[str] = field( 53 | default=None, 54 | metadata={ 55 | "help": "The model checkpoint for weights initialization." 56 | "Don't set if you want to train a model from scratch." 57 | }, 58 | ) 59 | c_model_name_or_path: Optional[str] = field( 60 | default=None, 61 | metadata={ 62 | "help": "The model checkpoint for weights initialization." 63 | "Don't set if you want to train a model from scratch." 64 | }, 65 | ) 66 | model_type: Optional[str] = field( 67 | default=None, 68 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 69 | ) 70 | config_name: Optional[str] = field( 71 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 72 | ) 73 | tokenizer_name: Optional[str] = field( 74 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 75 | ) 76 | cache_dir: Optional[str] = field( 77 | default=None, 78 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 79 | ) 80 | use_fast_tokenizer: bool = field( 81 | default=True, 82 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 83 | ) 84 | model_revision: str = field( 85 | default="main", 86 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 87 | ) 88 | use_auth_token: bool = field( 89 | default=False, 90 | metadata={ 91 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 92 | "with private models)." 93 | }, 94 | ) 95 | 96 | # SimCSE's arguments 97 | temp: float = field( 98 | default=0.05, 99 | metadata={ 100 | "help": "Temperature for softmax." 101 | } 102 | ) 103 | pooler_type: str = field( 104 | default="cls", 105 | metadata={ 106 | "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)." 107 | } 108 | ) 109 | hard_negative_weight: float = field( 110 | default=0, 111 | metadata={ 112 | "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)." 113 | } 114 | ) 115 | do_mlm: bool = field( 116 | default=False, 117 | metadata={ 118 | "help": "Whether to use MLM auxiliary objective." 119 | } 120 | ) 121 | mlm_weight: float = field( 122 | default=0.1, 123 | metadata={ 124 | "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)." 125 | } 126 | ) 127 | mlp_only_train: bool = field( 128 | default=False, 129 | metadata={ 130 | "help": "Use MLP only during training" 131 | } 132 | ) 133 | phi: float = field( 134 | default=0.85, 135 | metadata={ 136 | "help": "Weight for instance weighting." 137 | } 138 | ) 139 | noise_times: float = field( 140 | default=1, 141 | metadata={ 142 | "help": "Weight for noise-based negatives number." 143 | } 144 | ) 145 | pgd: int = field( 146 | default=4, 147 | metadata={ 148 | "help": "Weight for PGD turns." 149 | } 150 | ) 151 | is_base: bool = field( 152 | default=True, 153 | metadata={ 154 | "help": "Is the base model? or large ones?" 155 | } 156 | ) 157 | 158 | 159 | @dataclass 160 | class DataTrainingArguments: 161 | """ 162 | Arguments pertaining to what data we are going to input our model for training and eval. 163 | """ 164 | 165 | # Huggingface's original arguments. 166 | dataset_name: Optional[str] = field( 167 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 168 | ) 169 | dataset_config_name: Optional[str] = field( 170 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 171 | ) 172 | overwrite_cache: bool = field( 173 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 174 | ) 175 | validation_split_percentage: Optional[int] = field( 176 | default=5, 177 | metadata={ 178 | "help": "The percentage of the train set used as validation set in case there's no validation split" 179 | }, 180 | ) 181 | preprocessing_num_workers: Optional[int] = field( 182 | default=None, 183 | metadata={"help": "The number of processes to use for the preprocessing."}, 184 | ) 185 | 186 | # SimCSE's arguments 187 | train_file: Optional[str] = field( 188 | default=None, 189 | metadata={"help": "The training data file (.txt or .csv)."} 190 | ) 191 | max_seq_length: Optional[int] = field( 192 | default=32, 193 | metadata={ 194 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 195 | "than this will be truncated." 196 | }, 197 | ) 198 | pad_to_max_length: bool = field( 199 | default=False, 200 | metadata={ 201 | "help": "Whether to pad all samples to `max_seq_length`. " 202 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 203 | }, 204 | ) 205 | mlm_probability: float = field( 206 | default=0.15, 207 | metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"} 208 | ) 209 | 210 | def __post_init__(self): 211 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 212 | raise ValueError("Need either a dataset name or a training/validation file.") 213 | else: 214 | if self.train_file is not None: 215 | extension = self.train_file.split(".")[-1] 216 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 217 | 218 | 219 | @dataclass 220 | class OurTrainingArguments(TrainingArguments): 221 | # Evaluation 222 | ## By default, we evaluate STS (dev) during training (for selecting best checkpoints) and evaluate 223 | ## both STS and transfer tasks (dev) at the end of training. Using --eval_transfer will allow evaluating 224 | ## both STS and transfer tasks (dev) during training. 225 | eval_transfer: bool = field( 226 | default=False, 227 | metadata={"help": "Evaluate transfer task dev sets (in validation)."} 228 | ) 229 | 230 | gradient_accumulation_steps: int = field( 231 | default=1, 232 | metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, 233 | ) 234 | 235 | @cached_property 236 | @torch_required 237 | def _setup_devices(self) -> "torch.device": 238 | logger.info("PyTorch: setting up devices") 239 | if self.no_cuda: 240 | device = torch.device("cpu") 241 | self._n_gpu = 0 242 | elif is_torch_tpu_available(): 243 | device = xm.xla_device() 244 | self._n_gpu = 0 245 | elif self.local_rank == -1: 246 | # if n_gpu is > 1 we'll use nn.DataParallel. 247 | # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` 248 | # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will 249 | # trigger an error that a device index is missing. Index 0 takes into account the 250 | # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` 251 | # will use the first GPU in that env, i.e. GPU#1 252 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 253 | # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at 254 | # the default value. 255 | self._n_gpu = torch.cuda.device_count() 256 | else: 257 | # Here, we'll use torch.distributed. 258 | # Initializes the distributed backend which will take care of synchronizing nodes/GPUs 259 | # 260 | # deepspeed performs its own DDP internally, and requires the program to be started with: 261 | # deepspeed ./program.py 262 | # rather than: 263 | # python -m torch.distributed.launch --nproc_per_node=2 ./program.py 264 | if self.deepspeed: 265 | from .integrations import is_deepspeed_available 266 | 267 | if not is_deepspeed_available(): 268 | raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") 269 | import deepspeed 270 | 271 | deepspeed.init_distributed() 272 | else: 273 | torch.distributed.init_process_group(backend="nccl") 274 | device = torch.device("cuda", self.local_rank) 275 | self._n_gpu = 1 276 | 277 | if device.type == "cuda": 278 | torch.cuda.set_device(device) 279 | 280 | return device 281 | 282 | 283 | def main(): 284 | # See all possible arguments in src/transformers/training_args.py 285 | # or by passing the --help flag to this script. 286 | # We now keep distinct sets of args, for a cleaner separation of concerns. 287 | 288 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments)) 289 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 290 | # If we pass only one argument to the script and it's the path to a json file, 291 | # let's parse it to get our arguments. 292 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 293 | else: 294 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 295 | 296 | if ( 297 | os.path.exists(training_args.output_dir) 298 | and os.listdir(training_args.output_dir) 299 | and training_args.do_train 300 | and not training_args.overwrite_output_dir 301 | ): 302 | raise ValueError( 303 | f"Output directory ({training_args.output_dir}) already exists and is not empty." 304 | "Use --overwrite_output_dir to overcome." 305 | ) 306 | 307 | # Setup logging 308 | logging.basicConfig( 309 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 310 | datefmt="%m/%d/%Y %H:%M:%S", 311 | level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, 312 | ) 313 | 314 | # Log on each process the small summary: 315 | logger.warning( 316 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 317 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 318 | ) 319 | # Set the verbosity to info of the Transformers logger (on main process only): 320 | if is_main_process(training_args.local_rank): 321 | transformers.utils.logging.set_verbosity_info() 322 | transformers.utils.logging.enable_default_handler() 323 | transformers.utils.logging.enable_explicit_format() 324 | logger.info("Training/evaluation parameters %s", training_args) 325 | 326 | # Set seed before initializing model. 327 | set_seed(training_args.seed) 328 | 329 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 330 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 331 | # (the dataset will be downloaded automatically from the datasets Hub 332 | # 333 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this 334 | # behavior (see below) 335 | # 336 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 337 | # download the dataset. 338 | data_files = {} 339 | if data_args.train_file is not None: 340 | data_files["train"] = data_args.train_file 341 | extension = data_args.train_file.split(".")[-1] 342 | if extension == "txt": 343 | extension = "text" 344 | if extension == "csv": 345 | datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/", delimiter="\t" if "tsv" in data_args.train_file else ",") 346 | else: 347 | datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/") 348 | 349 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 350 | # https://huggingface.co/docs/datasets/loading_datasets.html. 351 | 352 | # Load pretrained model and tokenizer 353 | # 354 | # Distributed training: 355 | # The .from_pretrained methods guarantee that only one local process can concurrently 356 | # download model & vocab. 357 | config_kwargs = { 358 | "cache_dir": model_args.cache_dir, 359 | "revision": model_args.model_revision, 360 | "use_auth_token": True if model_args.use_auth_token else None, 361 | } 362 | if model_args.config_name: 363 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 364 | elif model_args.model_name_or_path: 365 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 366 | else: 367 | config = CONFIG_MAPPING[model_args.model_type]() 368 | logger.warning("You are instantiating a new config instance from scratch.") 369 | 370 | tokenizer_kwargs = { 371 | "cache_dir": model_args.cache_dir, 372 | "use_fast": model_args.use_fast_tokenizer, 373 | "revision": model_args.model_revision, 374 | "use_auth_token": True if model_args.use_auth_token else None, 375 | } 376 | if model_args.tokenizer_name: 377 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 378 | elif model_args.model_name_or_path: 379 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 380 | else: 381 | raise ValueError( 382 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 383 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 384 | ) 385 | 386 | if model_args.model_name_or_path: 387 | if 'roberta' in model_args.model_name_or_path: 388 | model = RobertaForCL.from_pretrained( 389 | model_args.model_name_or_path, 390 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 391 | config=config, 392 | cache_dir=model_args.cache_dir, 393 | revision=model_args.model_revision, 394 | use_auth_token=True if model_args.use_auth_token else None, 395 | model_args=model_args 396 | ) 397 | fix_bert = RobertaModel.from_pretrained(model_args.c_model_name_or_path) 398 | model.fix_bert.load_state_dict(fix_bert.state_dict()) 399 | elif 'bert' in model_args.model_name_or_path: 400 | model = BertForCL.from_pretrained( 401 | model_args.model_name_or_path, 402 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 403 | config=config, 404 | cache_dir=model_args.cache_dir, 405 | revision=model_args.model_revision, 406 | use_auth_token=True if model_args.use_auth_token else None, 407 | model_args=model_args 408 | ) 409 | fix_bert = BertModel.from_pretrained(model_args.c_model_name_or_path) 410 | model.fix_bert.load_state_dict(fix_bert.state_dict()) 411 | if model_args.do_mlm: 412 | pretrained_model = BertForPreTraining.from_pretrained(model_args.model_name_or_path) 413 | model.lm_head.load_state_dict(pretrained_model.cls.predictions.state_dict()) 414 | else: 415 | raise NotImplementedError 416 | else: 417 | raise NotImplementedError 418 | logger.info("Training new model from scratch") 419 | model = AutoModelForMaskedLM.from_config(config) 420 | 421 | model.resize_token_embeddings(len(tokenizer)) 422 | 423 | # Prepare features 424 | column_names = datasets["train"].column_names 425 | sent2_cname = None 426 | if len(column_names) == 2: 427 | # Pair datasets 428 | sent0_cname = column_names[0] 429 | sent1_cname = column_names[1] 430 | elif len(column_names) == 3: 431 | # Pair datasets with hard negatives 432 | sent0_cname = column_names[0] 433 | sent1_cname = column_names[1] 434 | sent2_cname = column_names[2] 435 | elif len(column_names) == 1: 436 | # Unsupervised datasets 437 | sent0_cname = column_names[0] 438 | sent1_cname = column_names[0] 439 | else: 440 | raise NotImplementedError 441 | 442 | def prepare_features(examples): 443 | # padding = longest (default) 444 | # If no sentence in the batch exceed the max length, then use 445 | # the max sentence length in the batch, otherwise use the 446 | # max sentence length in the argument and truncate those that 447 | # exceed the max length. 448 | # padding = max_length (when pad_to_max_length, for pressure test) 449 | # All sentences are padded/truncated to data_args.max_seq_length. 450 | total = len(examples[sent0_cname]) 451 | 452 | # Avoid "None" fields 453 | for idx in range(total): 454 | if examples[sent0_cname][idx] is None: 455 | examples[sent0_cname][idx] = " " 456 | if examples[sent1_cname][idx] is None: 457 | examples[sent1_cname][idx] = " " 458 | 459 | sentences = examples[sent0_cname] + examples[sent1_cname] 460 | 461 | # If hard negative exists 462 | if sent2_cname is not None: 463 | for idx in range(total): 464 | if examples[sent2_cname][idx] is None: 465 | examples[sent2_cname][idx] = " " 466 | sentences += examples[sent2_cname] 467 | 468 | sent_features = tokenizer( 469 | sentences, 470 | max_length=data_args.max_seq_length, 471 | truncation=True, 472 | padding="max_length" if data_args.pad_to_max_length else False, 473 | ) 474 | 475 | features = {} 476 | if sent2_cname is not None: 477 | for key in sent_features: 478 | features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)] 479 | else: 480 | for key in sent_features: 481 | features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)] 482 | 483 | return features 484 | 485 | if training_args.do_train: 486 | train_dataset = datasets["train"].map( 487 | prepare_features, 488 | batched=True, 489 | num_proc=data_args.preprocessing_num_workers, 490 | remove_columns=column_names, 491 | load_from_cache_file=not data_args.overwrite_cache, 492 | ) 493 | 494 | # Data collator 495 | @dataclass 496 | class OurDataCollatorWithPadding: 497 | 498 | tokenizer: PreTrainedTokenizerBase 499 | padding: Union[bool, str, PaddingStrategy] = True 500 | max_length: Optional[int] = None 501 | pad_to_multiple_of: Optional[int] = None 502 | mlm: bool = True 503 | mlm_probability: float = data_args.mlm_probability 504 | 505 | def __call__(self, features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 506 | special_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'mlm_input_ids', 'mlm_labels'] 507 | bs = len(features) 508 | if bs > 0: 509 | num_sent = len(features[0]['input_ids']) 510 | else: 511 | return 512 | flat_features = [] 513 | for feature in features: 514 | for i in range(num_sent): 515 | flat_features.append({k: feature[k][i] if k in special_keys else feature[k] for k in feature}) 516 | 517 | batch = self.tokenizer.pad( 518 | flat_features, 519 | padding=self.padding, 520 | max_length=self.max_length, 521 | pad_to_multiple_of=self.pad_to_multiple_of, 522 | return_tensors="pt", 523 | ) 524 | if model_args.do_mlm: 525 | batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"]) 526 | 527 | batch = {k: batch[k].view(bs, num_sent, -1) if k in special_keys else batch[k].view(bs, num_sent, -1)[:, 0] for k in batch} 528 | 529 | if "label" in batch: 530 | batch["labels"] = batch["label"] 531 | del batch["label"] 532 | if "label_ids" in batch: 533 | batch["labels"] = batch["label_ids"] 534 | del batch["label_ids"] 535 | 536 | return batch 537 | 538 | def mask_tokens( 539 | self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None 540 | ) -> Tuple[torch.Tensor, torch.Tensor]: 541 | """ 542 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 543 | """ 544 | labels = inputs.clone() 545 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 546 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 547 | if special_tokens_mask is None: 548 | special_tokens_mask = [ 549 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 550 | ] 551 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 552 | else: 553 | special_tokens_mask = special_tokens_mask.bool() 554 | 555 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 556 | masked_indices = torch.bernoulli(probability_matrix).bool() 557 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 558 | 559 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 560 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 561 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 562 | 563 | # 10% of the time, we replace masked input tokens with random word 564 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 565 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 566 | inputs[indices_random] = random_words[indices_random] 567 | 568 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 569 | return inputs, labels 570 | 571 | data_collator = default_data_collator if data_args.pad_to_max_length else OurDataCollatorWithPadding(tokenizer) 572 | 573 | trainer = CLTrainer( 574 | model=model, 575 | args=training_args, 576 | train_dataset=train_dataset if training_args.do_train else None, 577 | tokenizer=tokenizer, 578 | data_collator=data_collator, 579 | ) 580 | trainer.model_args = model_args 581 | 582 | # Training 583 | if training_args.do_train: 584 | model_path = ( 585 | model_args.model_name_or_path 586 | if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) 587 | else None 588 | ) 589 | train_result = trainer.train(model_path=model_path) 590 | trainer.save_model() # Saves the tokenizer too for easy upload 591 | 592 | output_train_file = os.path.join(training_args.output_dir, "train_results.txt") 593 | if trainer.is_world_process_zero(): 594 | with open(output_train_file, "w") as writer: 595 | logger.info("***** Train results *****") 596 | for key, value in sorted(train_result.metrics.items()): 597 | logger.info(f" {key} = {value}") 598 | writer.write(f"{key} = {value}\n") 599 | 600 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model 601 | trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) 602 | 603 | # Evaluation 604 | results = {} 605 | if training_args.do_eval: 606 | logger.info("*** Evaluate ***") 607 | results = trainer.evaluate(eval_senteval_transfer=True) 608 | 609 | output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") 610 | if trainer.is_world_process_zero(): 611 | with open(output_eval_file, "w") as writer: 612 | logger.info("***** Eval results *****") 613 | for key, value in sorted(results.items()): 614 | logger.info(f" {key} = {value}") 615 | writer.write(f"{key} = {value}\n") 616 | 617 | return results 618 | 619 | def _mp_fn(index): 620 | # For xla_spawn (TPUs) 621 | main() 622 | 623 | 624 | if __name__ == "__main__": 625 | main() 626 | -------------------------------------------------------------------------------- /simcse/trainers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import inspect 3 | import math 4 | import sys 5 | import os 6 | import re 7 | import json 8 | import shutil 9 | import time 10 | import warnings 11 | from pathlib import Path 12 | import importlib.util 13 | from packaging import version 14 | from transformers import Trainer 15 | from transformers.modeling_utils import PreTrainedModel 16 | from transformers.training_args import ParallelMode, TrainingArguments 17 | from transformers.utils import logging 18 | from transformers.trainer_utils import ( 19 | PREFIX_CHECKPOINT_DIR, 20 | BestRun, 21 | EvalPrediction, 22 | HPSearchBackend, 23 | PredictionOutput, 24 | TrainOutput, 25 | default_compute_objective, 26 | default_hp_space, 27 | set_seed, 28 | speed_metrics, 29 | ShardedDDPOption 30 | ) 31 | from transformers.file_utils import ( 32 | WEIGHTS_NAME, 33 | is_apex_available, 34 | is_datasets_available, 35 | is_in_notebook, 36 | is_torch_tpu_available, 37 | is_sagemaker_mp_enabled, 38 | ) 39 | from transformers.trainer_callback import ( 40 | CallbackHandler, 41 | DefaultFlowCallback, 42 | PrinterCallback, 43 | ProgressCallback, 44 | TrainerCallback, 45 | TrainerControl, 46 | TrainerState, 47 | ) 48 | from transformers.trainer_pt_utils import ( 49 | reissue_pt_warnings, get_parameter_names 50 | ) 51 | 52 | from transformers.utils import logging 53 | from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator 54 | import torch 55 | import torch.nn as nn 56 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 57 | from torch.utils.data.dataloader import DataLoader 58 | from torch.utils.data.dataset import Dataset 59 | from torch.utils.data.distributed import DistributedSampler 60 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 61 | 62 | if is_torch_tpu_available(): 63 | import torch_xla.core.xla_model as xm 64 | import torch_xla.debug.metrics as met 65 | import torch_xla.distributed.parallel_loader as pl 66 | 67 | if is_apex_available(): 68 | from apex import amp 69 | 70 | if version.parse(torch.__version__) >= version.parse("1.6"): 71 | _is_native_amp_available = True 72 | from torch.cuda.amp import autocast 73 | 74 | if is_datasets_available(): 75 | import datasets 76 | 77 | from transformers.optimization import Adafactor, AdamW, get_scheduler 78 | import copy 79 | # Set path to SentEval 80 | PATH_TO_SENTEVAL = './SentEval' 81 | PATH_TO_DATA = './SentEval/data' 82 | 83 | # Import SentEval 84 | sys.path.insert(0, PATH_TO_SENTEVAL) 85 | import senteval 86 | import numpy as np 87 | from datetime import datetime 88 | from filelock import FileLock 89 | 90 | logger = logging.get_logger(__name__) 91 | 92 | class CLTrainer(Trainer): 93 | 94 | def evaluate( 95 | self, 96 | eval_dataset: Optional[Dataset] = None, 97 | ignore_keys: Optional[List[str]] = None, 98 | metric_key_prefix: str = "eval", 99 | eval_senteval_transfer: bool = False, 100 | ) -> Dict[str, float]: 101 | 102 | # SentEval prepare and batcher 103 | def prepare(params, samples): 104 | return 105 | 106 | def batcher(params, batch): 107 | sentences = [' '.join(s) for s in batch] 108 | batch = self.tokenizer.batch_encode_plus( 109 | sentences, 110 | return_tensors='pt', 111 | padding=True, 112 | ) 113 | for k in batch: 114 | batch[k] = batch[k].to(self.args.device) 115 | with torch.no_grad(): 116 | outputs = self.model(**batch, output_hidden_states=True, return_dict=True, sent_emb=True) 117 | pooler_output = outputs.pooler_output 118 | return pooler_output.cpu() 119 | 120 | # Set params for SentEval (fastmode) 121 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 122 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 123 | 'tenacity': 3, 'epoch_size': 2} 124 | 125 | se = senteval.engine.SE(params, batcher, prepare) 126 | tasks = ['STSBenchmark', 'SICKRelatedness'] 127 | if eval_senteval_transfer or self.args.eval_transfer: 128 | tasks = ['STSBenchmark', 'SICKRelatedness', 'MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC'] 129 | self.model.eval() 130 | results = se.eval(tasks) 131 | 132 | stsb_spearman = results['STSBenchmark']['dev']['spearman'][0] 133 | stsb_align_loss = results['STSBenchmark']['dev']['align_loss'] 134 | stsb_uniform_loss = results['STSBenchmark']['dev']['uniform_loss'] 135 | sickr_spearman = results['SICKRelatedness']['dev']['spearman'][0] 136 | 137 | metrics = {"eval_stsb_spearman": stsb_spearman, "eval_sickr_spearman": sickr_spearman, 138 | "eval_avg_sts": (stsb_spearman + sickr_spearman) / 2, "eval_align_loss": stsb_align_loss, 139 | "eval_uniform_loss": stsb_uniform_loss} 140 | 141 | if eval_senteval_transfer or self.args.eval_transfer: 142 | avg_transfer = 0 143 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 144 | avg_transfer += results[task]['devacc'] 145 | metrics['eval_{}'.format(task)] = results[task]['devacc'] 146 | avg_transfer /= 7 147 | metrics['eval_avg_transfer'] = avg_transfer 148 | 149 | self.log(metrics) 150 | return metrics 151 | 152 | def _save_checkpoint(self, model, trial, metrics=None): 153 | """ 154 | Compared to original implementation, we change the saving policy to 155 | only save the best-validation checkpoints. 156 | """ 157 | 158 | # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we 159 | # want to save. 160 | 161 | # Determine the new best metric / best model checkpoint 162 | if metrics is not None and self.args.metric_for_best_model is not None: 163 | metric_to_check = self.args.metric_for_best_model 164 | if not metric_to_check.startswith("eval_"): 165 | metric_to_check = f"eval_{metric_to_check}" 166 | metric_value = metrics[metric_to_check] 167 | 168 | operator = np.greater if self.args.greater_is_better else np.less 169 | if ( 170 | self.state.best_metric is None 171 | or self.state.best_model_checkpoint is None 172 | or operator(metric_value, self.state.best_metric) 173 | ): 174 | output_dir = self.args.output_dir 175 | self.state.best_metric = metric_value 176 | self.state.best_model_checkpoint = output_dir 177 | 178 | # Only save model when it is the best one 179 | self.save_model(output_dir) 180 | if self.deepspeed: 181 | self.deepspeed.save_checkpoint(output_dir) 182 | 183 | if is_torch_tpu_available(): 184 | xm.rendezvous("saving_optimizer_states") 185 | xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 186 | with warnings.catch_warnings(record=True) as caught_warnings: 187 | xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 188 | reissue_pt_warnings(caught_warnings) 189 | elif self.is_world_process_zero() and not self.deepspeed: 190 | # deepspeed.save_checkpoint above saves model/optim/sched 191 | torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 192 | with warnings.catch_warnings(record=True) as caught_warnings: 193 | torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 194 | reissue_pt_warnings(caught_warnings) 195 | 196 | # Save the Trainer state 197 | #if self.is_world_process_zero(): 198 | # self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) 199 | else: 200 | # Save model checkpoint 201 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 202 | 203 | if self.hp_search_backend is not None and trial is not None: 204 | if self.hp_search_backend == HPSearchBackend.OPTUNA: 205 | run_id = trial.number 206 | else: 207 | from ray import tune 208 | 209 | run_id = tune.get_trial_id() 210 | run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" 211 | output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder) 212 | else: 213 | output_dir = os.path.join(self.args.output_dir, checkpoint_folder) 214 | 215 | self.store_flos() 216 | 217 | self.save_model(output_dir) 218 | if self.deepspeed: 219 | self.deepspeed.save_checkpoint(output_dir) 220 | 221 | if is_torch_tpu_available(): 222 | xm.rendezvous("saving_optimizer_states") 223 | xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 224 | with warnings.catch_warnings(record=True) as caught_warnings: 225 | xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 226 | reissue_pt_warnings(caught_warnings) 227 | elif self.is_world_process_zero() and not self.deepspeed: 228 | # deepspeed.save_checkpoint above saves model/optim/sched 229 | torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 230 | with warnings.catch_warnings(record=True) as caught_warnings: 231 | torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 232 | reissue_pt_warnings(caught_warnings) 233 | 234 | 235 | # Save the Trainer state 236 | if self.is_world_process_zero(): 237 | self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) 238 | 239 | # Maybe delete some older checkpoints. 240 | if self.is_world_process_zero(): 241 | self._rotate_checkpoints(use_mtime=True) 242 | 243 | def create_optimizer(self): 244 | """ 245 | Setup the optimizer. 246 | 247 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 248 | Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. 249 | """ 250 | if self.optimizer is None: 251 | decay_parameters = get_parameter_names(self.model, [nn.LayerNorm]) 252 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 253 | optimizer_grouped_parameters = [ 254 | { 255 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters and 'fix' not in n], 256 | "weight_decay": self.args.weight_decay, 257 | }, 258 | { 259 | "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters and 'fix' not in n], 260 | "weight_decay": 0.0, 261 | }, 262 | ] 263 | optimizer_cls = Adafactor if self.args.adafactor else AdamW 264 | if self.args.adafactor: 265 | optimizer_cls = Adafactor 266 | optimizer_kwargs = {"scale_parameter": False, "relative_step": False} 267 | else: 268 | optimizer_cls = AdamW 269 | optimizer_kwargs = { 270 | "betas": (self.args.adam_beta1, self.args.adam_beta2), 271 | "eps": self.args.adam_epsilon, 272 | } 273 | optimizer_kwargs["lr"] = self.args.learning_rate 274 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 275 | self.optimizer = OSS( 276 | params=optimizer_grouped_parameters, 277 | optim=optimizer_cls, 278 | **optimizer_kwargs, 279 | ) 280 | else: 281 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 282 | 283 | if is_sagemaker_mp_enabled(): 284 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 285 | 286 | return self.optimizer 287 | 288 | def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): 289 | """ 290 | Main training entry point. 291 | 292 | Args: 293 | model_path (:obj:`str`, `optional`): 294 | Local path to the model if the model to train has been instantiated from a local path. If present, 295 | training will resume from the optimizer/scheduler states loaded here. 296 | trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): 297 | The trial run or the hyperparameter dictionary for hyperparameter search. 298 | 299 | The main difference between ours and Huggingface's original implementation is that we 300 | also load model_args when reloading best checkpoints for evaluation. 301 | """ 302 | # This might change the seed so needs to run first. 303 | self._hp_search_setup(trial) 304 | 305 | # Model re-init 306 | if self.model_init is not None: 307 | # Seed must be set before instantiating the model when using model_init. 308 | set_seed(self.args.seed) 309 | 310 | model = self.call_model_init(trial) 311 | if not self.is_model_parallel: 312 | model = model.to(self.args.device) 313 | 314 | self.model = model 315 | self.model_wrapped = model 316 | 317 | # Reinitializes optimizer and scheduler 318 | self.optimizer, self.lr_scheduler = None, None 319 | 320 | # Keeping track whether we can can len() on the dataset or not 321 | train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) 322 | 323 | # Data loader and number of training steps 324 | train_dataloader = self.get_train_dataloader() 325 | 326 | # Setting up training control variables: 327 | # number of training epochs: num_train_epochs 328 | # number of training steps per epoch: num_update_steps_per_epoch 329 | # total number of training steps to execute: max_steps 330 | if train_dataset_is_sized: 331 | num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps 332 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 333 | if self.args.max_steps > 0: 334 | max_steps = self.args.max_steps 335 | num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( 336 | self.args.max_steps % num_update_steps_per_epoch > 0 337 | ) 338 | else: 339 | max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch) 340 | num_train_epochs = math.ceil(self.args.num_train_epochs) 341 | else: 342 | # see __init__. max_steps is set when the dataset has no __len__ 343 | max_steps = self.args.max_steps 344 | num_train_epochs = 1 345 | num_update_steps_per_epoch = max_steps 346 | 347 | if self.args.deepspeed: 348 | model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) 349 | self.model = model.module 350 | self.model_wrapped = model # will get further wrapped in DDP 351 | self.deepspeed = model # DeepSpeedEngine object 352 | self.optimizer = optimizer 353 | self.lr_scheduler = lr_scheduler 354 | else: 355 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 356 | 357 | self.state = TrainerState() 358 | self.state.is_hyper_param_search = trial is not None 359 | 360 | # Check if saved optimizer or scheduler states exist 361 | self._load_optimizer_and_scheduler(model_path) 362 | 363 | model = self.model_wrapped 364 | 365 | # Mixed precision training with apex (torch < 1.6) 366 | if self.use_apex: 367 | model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) 368 | 369 | # Multi-gpu training (should be after apex fp16 initialization) 370 | if self.args.n_gpu > 1: 371 | model = torch.nn.DataParallel(model) 372 | 373 | elif self.args.local_rank != -1: 374 | model = torch.nn.parallel.DistributedDataParallel( 375 | model, 376 | device_ids=[self.args.local_rank], 377 | output_device=self.args.local_rank, 378 | find_unused_parameters=( 379 | not getattr(model.config, "gradient_checkpointing", False) 380 | if isinstance(model, PreTrainedModel) 381 | else True 382 | ), 383 | ) 384 | # find_unused_parameters breaks checkpointing as per 385 | # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 386 | 387 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 388 | if model is not self.model: 389 | self.model_wrapped = model 390 | 391 | # important: at this point: 392 | # self.model is the Transformers Model 393 | # self.model_wrapped is DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc. 394 | 395 | # Train! 396 | if is_torch_tpu_available(): 397 | total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() 398 | else: 399 | total_train_batch_size = ( 400 | self.args.train_batch_size 401 | * self.args.gradient_accumulation_steps 402 | * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) 403 | ) 404 | 405 | num_examples = ( 406 | self.num_examples(train_dataloader) 407 | if train_dataset_is_sized 408 | else total_train_batch_size * self.args.max_steps 409 | ) 410 | 411 | logger.info("***** Running training *****") 412 | logger.info(f" Num examples = {num_examples}") 413 | logger.info(f" Num Epochs = {num_train_epochs}") 414 | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") 415 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 416 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 417 | logger.info(f" Total optimization steps = {max_steps}") 418 | 419 | self.state.epoch = 0 420 | start_time = time.time() 421 | epochs_trained = 0 422 | steps_trained_in_current_epoch = 0 423 | 424 | # Check if continuing training from a checkpoint 425 | if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): 426 | self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) 427 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 428 | if not self.args.ignore_data_skip: 429 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 430 | steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps 431 | else: 432 | steps_trained_in_current_epoch = 0 433 | 434 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 435 | logger.info(f" Continuing training from epoch {epochs_trained}") 436 | logger.info(f" Continuing training from global step {self.state.global_step}") 437 | if not self.args.ignore_data_skip: 438 | logger.info( 439 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 440 | "batches in the first epoch." 441 | ) 442 | 443 | # Update the references 444 | self.callback_handler.model = self.model 445 | self.callback_handler.optimizer = self.optimizer 446 | self.callback_handler.lr_scheduler = self.lr_scheduler 447 | self.callback_handler.train_dataloader = train_dataloader 448 | self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None 449 | self.state.trial_params = hp_params(trial) if trial is not None else None 450 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 451 | # to set this after the load. 452 | self.state.max_steps = max_steps 453 | self.state.num_train_epochs = num_train_epochs 454 | self.state.is_local_process_zero = self.is_local_process_zero() 455 | self.state.is_world_process_zero = self.is_world_process_zero() 456 | 457 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 458 | tr_loss = torch.tensor(0.0).to(self.args.device) 459 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 460 | self._total_loss_scalar = 0.0 461 | self._globalstep_last_logged = 0 462 | self._total_flos = self.state.total_flos 463 | model.zero_grad() 464 | 465 | self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) 466 | 467 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 468 | if not self.args.ignore_data_skip: 469 | for epoch in range(epochs_trained): 470 | # We just need to begin an iteration to create the randomization of the sampler. 471 | for _ in train_dataloader: 472 | break 473 | for epoch in range(epochs_trained, num_train_epochs): 474 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 475 | train_dataloader.sampler.set_epoch(epoch) 476 | epoch_iterator = train_dataloader 477 | 478 | # Reset the past mems state at the beginning of each epoch if necessary. 479 | if self.args.past_index >= 0: 480 | self._past = None 481 | 482 | steps_in_epoch = len(train_dataloader) if train_dataset_is_sized else self.args.max_steps 483 | self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) 484 | 485 | assert train_dataset_is_sized, "currently we only support sized dataloader!" 486 | 487 | inputs = None 488 | last_inputs = None 489 | for step, inputs in enumerate(epoch_iterator): 490 | # Skip past any already trained steps if resuming training 491 | if steps_trained_in_current_epoch > 0: 492 | steps_trained_in_current_epoch -= 1 493 | continue 494 | 495 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 496 | self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) 497 | 498 | if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1: 499 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 500 | with model.no_sync(): 501 | tr_loss += self.training_step(model, inputs) 502 | else: 503 | tr_loss += self.training_step(model, inputs) 504 | self._total_flos += self.floating_point_ops(inputs) 505 | 506 | if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( 507 | # last step in epoch but step is always smaller than gradient_accumulation_steps 508 | steps_in_epoch <= self.args.gradient_accumulation_steps 509 | and (step + 1) == steps_in_epoch 510 | ): 511 | # Gradient clipping 512 | if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed: 513 | # deepspeed does its own clipping 514 | 515 | if self.use_amp: 516 | # AMP: gradients need unscaling 517 | self.scaler.unscale_(self.optimizer) 518 | 519 | if hasattr(self.optimizer, "clip_grad_norm"): 520 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 521 | self.optimizer.clip_grad_norm(self.args.max_grad_norm) 522 | else: 523 | # Revert to normal clipping otherwise, handling Apex or full precision 524 | torch.nn.utils.clip_grad_norm_( 525 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(), 526 | self.args.max_grad_norm, 527 | ) 528 | 529 | # Optimizer step 530 | if is_torch_tpu_available(): 531 | xm.optimizer_step(self.optimizer) 532 | elif self.use_amp: 533 | self.scaler.step(self.optimizer) 534 | self.scaler.update() 535 | else: 536 | self.optimizer.step() 537 | 538 | self.lr_scheduler.step() 539 | 540 | model.zero_grad() 541 | 542 | self.state.global_step += 1 543 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 544 | self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) 545 | 546 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 547 | 548 | if self.control.should_epoch_stop or self.control.should_training_stop: 549 | break 550 | 551 | self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) 552 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 553 | 554 | if self.args.tpu_metrics_debug or self.args.debug: 555 | if is_torch_tpu_available(): 556 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 557 | xm.master_print(met.metrics_report()) 558 | else: 559 | logger.warning( 560 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " 561 | "configured. Check your training configuration if this is unexpected." 562 | ) 563 | if self.control.should_training_stop: 564 | break 565 | 566 | if self.args.past_index and hasattr(self, "_past"): 567 | # Clean the state at the end of training 568 | delattr(self, "_past") 569 | 570 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") 571 | if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: 572 | logger.info( 573 | f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." 574 | ) 575 | if isinstance(self.model, PreTrainedModel): 576 | self.model = self.model.from_pretrained(self.state.best_model_checkpoint, model_args=self.model_args) 577 | if not self.is_model_parallel: 578 | self.model = self.model.to(self.args.device) 579 | else: 580 | state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) 581 | self.model.load_state_dict(state_dict) 582 | 583 | if self.deepspeed: 584 | self.deepspeed.load_checkpoint( 585 | self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False 586 | ) 587 | 588 | metrics = speed_metrics("train", start_time, self.state.max_steps) 589 | if self._total_flos is not None: 590 | self.store_flos() 591 | metrics["total_flos"] = self.state.total_flos 592 | self.log(metrics) 593 | 594 | self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) 595 | # add remaining tr_loss 596 | self._total_loss_scalar += tr_loss.item() 597 | 598 | return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) 599 | --------------------------------------------------------------------------------