├── README.md ├── code ├── PromptBERT_RankEncoder │ ├── README.md │ ├── evaluation.py │ ├── get_embedding.py │ ├── get_prompt_bert_embedding.sh │ ├── get_prompt_bert_rank_encoder_embedding.sh │ ├── prompt_bert │ │ ├── __init__.py │ │ ├── models.py │ │ └── trainers.py │ ├── prompt_bert_rank_encoder │ │ ├── __init__.py │ │ ├── models.py │ │ └── trainers.py │ ├── prompt_bert_rank_encoder_inference.py │ ├── prompt_bert_rank_encoder_inference.sh │ ├── promptbert_module.py │ ├── run_prompt_bert.sh │ ├── run_prompt_bert_rank_encoder.sh │ ├── train_prompt_bert.py │ └── train_prompt_bert_rank_encoder.py ├── SNCSE_RankEncoder │ ├── README.md │ ├── evaluation.py │ ├── evaluation_sncse.sh │ ├── evaluation_sncse_rank_encoder.sh │ ├── generate_soft_negative_samples.py │ ├── generate_soft_negative_samples.sh │ ├── get_embedding.py │ ├── get_sncse_embedding.sh │ ├── get_sncse_rank_encoder_embedding.sh │ ├── sncse_rank_encoder │ │ ├── simcse │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── tool.py │ │ │ └── trainers.py │ │ ├── sncse_rank_encoder.sh │ │ └── train_SNCSE.py │ ├── sncse_rank_encoder_inference.py │ └── sncse_rank_encoder_inference.sh ├── SimCSE_RankEncoder │ ├── README.md │ ├── evaluation.py │ ├── evaluation.sh │ ├── get_embedding.py │ ├── get_rank_encoder_embedding.sh │ ├── get_simcse_embedding.sh │ ├── run_simcse.sh │ ├── run_simcse_rank_encoder.sh │ ├── simcse │ │ ├── __init__.py │ │ ├── models.py │ │ ├── tool.py │ │ └── trainers.py │ ├── simcse_rank_encoder │ │ ├── __init__.py │ │ ├── models.py │ │ ├── tool.py │ │ └── trainers.py │ ├── simcse_rank_encoder_inference.py │ ├── simcse_rank_encoder_inference.sh │ ├── train_simcse.py │ └── train_simcse_rank_encoder.py └── file_utils │ ├── random_sampling_sentences.py │ └── random_sampling_sentences.sh └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | ## RANKING-ENHANCED UNSUPERVISED SENTENCE REPRESENTATION LEARNING 2 | **[Yeon Seonwoo](https://yeonsw.github.io/), Guoyin Wang, Changmin Seo, Sajal Choudhary, Jiwei Li, Xiang Li, Puyang Xu, Sunghyun Park, Alice Oh** | ACL 2023 | [Paper](https://arxiv.org/abs/2209.04333) 3 | 4 | KAIST, Amazon, Zhejiang Univ. 5 | 6 | ## Abstract 7 | Unsupervised sentence representation learning has progressed through contrastive learning and data augmentation methods such as dropout masking. Despite this progress, sentence encoders are still limited to using only an input sentence when predicting its semantic vector. In this work, we show that the semantic meaning of a sentence is also determined by nearest-neighbor sentences that are similar to the input sentence. Based on this finding, we propose a novel unsupervised sentence encoder, RankEncoder. RankEncoder predicts the semantic vector of an input sentence by leveraging its relationship with other sentences in an external corpus, as well as the input sentence itself. We evaluate RankEncoder on semantic textual benchmark datasets. From the experimental results, we verify that 1) RankEncoder achieves 80.07% Spearman's correlation, a 1.1% absolute improvement compared to the previous state-of-the-art performance, 2) RankEncoder is universally applicable to existing unsupervised sentence embedding methods, and 3) RankEncoder is specifically effective for predicting the similarity scores of similar sentence pairs. 8 | 9 | ## Getting Started 10 | ### SimCSE-RankEncoder 11 | Please see [README.md](https://github.com/yeonsw/RankEncoder/tree/main/code/SimCSE_RankEncoder) at code/SimCSE\_RankEncoder 12 | 13 | ### PromptBERT-RankEncoder 14 | Please see [README.md](https://github.com/yeonsw/RankEncoder/tree/main/code/PromptBERT_RankEncoder) at code/PromptBERT\_RankEncoder 15 | 16 | ### SNCSE-RankEncoder 17 | Please see [README.md](https://github.com/yeonsw/RankEncoder/tree/main/code/SNCSE_RankEncoder) at code/SNCSE\_RankEncoder 18 | 19 | ## Code Reference 20 | https://github.com/princeton-nlp/SimCSE 21 | 22 | https://github.com/kongds/Prompt-BERT 23 | 24 | https://github.com/Sense-GVT/SNCSE 25 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ### Requirements 4 | You need at least two GPUs to proceed the following instructions. 5 | 6 | ### Setting 7 | 1. Set the project directory 8 | ```bash 9 | export PROJECT_DIR=/path/to/this/project/folder 10 | ``` 11 | Note that there is no "/" at the end, e.g., /home/RankEncoder. 12 | 13 | 2. Download the SentEval folder at https://github.com/princeton-nlp/SimCSE and locate the file at code/PromptBERT\_RankEncoder/ 14 | 3. Go to SentEval/data/downstream and execute the following command 15 | ```bash 16 | bash download_dataset.sh 17 | ``` 18 | 4. Download the Wiki1m dataset with the following command and locate this dataset at data/corpus/corpus.txt 19 | ```bash 20 | bash wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt 21 | mv wiki1m_for_simcse.txt ../../data/corpus/corpus.txt 22 | ``` 23 | 5. Go to code/file\_utils/ and execute the following command. 24 | ```bash 25 | bash random_sampling_sentences.sh 26 | ``` 27 | 28 | ### Training the base encoder ([PromptBERT](https://arxiv.org/abs/2201.04337)) 29 | Go to code/PromptBERT\_RankEncoder and execute the following command 30 | ```bash 31 | bash run_prompt_bert.sh 32 | ``` 33 | 34 | ### Training RankEncoder 35 | 1. Get sentence vectors with the base encoder 36 | ```bash 37 | bash get_prompt_bert_embedding.sh 38 | ``` 39 | 2. Train RankEncoder 40 | ```bash 41 | bash run_prompt_bert_rank_encoder.sh 42 | ``` 43 | 44 | The above command will give you the performance of RankEncoder (without Eq.7 in [our paper](https://arxiv.org/pdf/2209.04333.pdf)) 45 | 46 | We provide the checkpoint of the trained RankeEncoder-PromptBERT [here](https://drive.google.com/file/d/1ixIt_TNx2c1fSfpzMPeuVMQVuqDAYMZo/view?usp=sharing) 47 | 48 | ### Evaluation 49 | 50 | The following command compute the performance of RankEncoder (with Eq.7 in [our paper](https://arxiv.org/pdf/2209.04333.pdf)) 51 | 52 | ```bash 53 | bash get_prompt_bert_rank_encoder_embedding.sh 54 | bash prompt_bert_rank_encoder_inference.sh 55 | ``` 56 | Note that we only sample 10,000 sentences for computational efficiency. Please use 100,000 sentences to replicate our experimental results. 57 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import io, os 4 | import torch 5 | import numpy as np 6 | import logging 7 | import tqdm 8 | import argparse 9 | from prettytable import PrettyTable 10 | import torch 11 | import transformers 12 | from transformers import AutoModel, AutoTokenizer 13 | 14 | # Set up logger 15 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 16 | 17 | # Set PATHs 18 | PATH_TO_SENTEVAL = './SentEval' 19 | PATH_TO_DATA = './SentEval/data' 20 | 21 | # Import SentEval 22 | sys.path.insert(0, PATH_TO_SENTEVAL) 23 | import senteval 24 | 25 | def cal_avg_cosine(k, n=100000): 26 | cos = torch.nn.CosineSimilarity(dim=-1) 27 | s = torch.tensor(k[:100000]).cuda() 28 | kk = [] 29 | pbar = tqdm.tqdm(total=n) 30 | with torch.no_grad(): 31 | for i in range(n): 32 | kk.append(cos(s[i:i+1], s).mean().item()) 33 | pbar.set_postfix({'cosine': sum(kk)/len(kk)}) 34 | pbar.update(1) 35 | return sum(kk) /len(kk) 36 | 37 | def s_eval(args): 38 | se, task = args[0], args[1] 39 | return se.eval(task) 40 | 41 | def print_table(task_names, scores): 42 | tb = PrettyTable() 43 | tb.field_names = task_names 44 | tb.add_row(scores) 45 | print(tb) 46 | 47 | def get_delta(model, template, tokenizer, device, args): 48 | model.eval() 49 | 50 | template = template.replace('*mask*', tokenizer.mask_token)\ 51 | .replace('*sep+*', '')\ 52 | .replace('*cls*', '').replace('*sent_0*', ' ') 53 | # strip for roberta tokenizer 54 | bs_length = len(tokenizer.encode(template.split(' ')[0].replace('_', ' ').strip())) - 2 + 1 55 | # replace for roberta tokenizer 56 | batch = tokenizer([template.replace('_', ' ').strip().replace(' ', ' ')], return_tensors='pt') 57 | batch['position_ids'] = torch.arange(batch['input_ids'].shape[1]).to(device).unsqueeze(0) 58 | for k in batch: 59 | batch[k] = batch[k].repeat(128, 1).to(device) 60 | m_mask = batch['input_ids'] == tokenizer.mask_token_id 61 | 62 | with torch.no_grad(): 63 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 64 | last_hidden = outputs.hidden_states[-1] 65 | delta = last_hidden[m_mask] 66 | delta.requires_grad = False 67 | #import pdb;pdb.set_trace() 68 | template_len = batch['input_ids'].shape[1] 69 | return delta, template_len 70 | 71 | def main(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--embedding_only", action='store_true') 74 | parser.add_argument('--mlm_head_predict', action='store_true') 75 | parser.add_argument('--remove_continue_word', action='store_true') 76 | parser.add_argument('--mask_embedding_sentence', action='store_true') 77 | parser.add_argument('--mask_embedding_sentence_use_org_pooler', action='store_true') 78 | parser.add_argument('--mask_embedding_sentence_template', type=str, default=None) 79 | parser.add_argument('--mask_embedding_sentence_delta', action='store_true') 80 | parser.add_argument('--mask_embedding_sentence_use_pooler', action='store_true') 81 | parser.add_argument('--mask_embedding_sentence_autoprompt', action='store_true') 82 | parser.add_argument('--mask_embedding_sentence_org_mlp', action='store_true') 83 | parser.add_argument("--tokenizer_name", type=str, default='') 84 | parser.add_argument("--model_name_or_path", type=str, 85 | help="Transformers' model name or path") 86 | parser.add_argument("--pooler", type=str, 87 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_first_last'], 88 | default='cls', 89 | help="Which pooler to use") 90 | parser.add_argument("--mode", type=str, 91 | choices=['dev', 'test', 'fasttest'], 92 | default='test', 93 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 94 | parser.add_argument("--task_set", type=str, 95 | choices=['sts', 'transfer', 'full', 'na'], 96 | default='sts', 97 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 98 | parser.add_argument('--calc_anisotropy', action='store_true') 99 | 100 | args = parser.parse_args() 101 | 102 | # Load transformers' model checkpoint 103 | if args.mask_embedding_sentence_org_mlp: 104 | #only for bert-base 105 | from transformers import BertForMaskedLM, BertConfig 106 | config = BertConfig.from_pretrained("bert-base-uncased") 107 | mlp = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).cls.predictions.transform 108 | if 'result' in args.model_name_or_path: 109 | state_dict = torch.load(args.model_name_or_path+'/pytorch_model.bin') 110 | new_state_dict = {} 111 | for key, param in state_dict.items(): 112 | # Replace "mlp" to "pooler" 113 | if 'pooler' in key: 114 | key = key.replace("pooler.", "") 115 | new_state_dict[key] = param 116 | mlp.load_state_dict(new_state_dict) 117 | model = AutoModel.from_pretrained(args.model_name_or_path) 118 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) 119 | 120 | if args.mask_embedding_sentence_autoprompt: 121 | state_dict = torch.load(args.model_name_or_path+'/pytorch_model.bin') 122 | p_mbv = state_dict['p_mbv'] 123 | template = args.mask_embedding_sentence_template 124 | template = template.replace('*mask*', tokenizer.mask_token)\ 125 | .replace('*sep+*', '')\ 126 | .replace('*cls*', '').replace('*sent_0*', ' ').replace('_', ' ') 127 | mask_embedding_template = tokenizer.encode(template) 128 | mask_index = mask_embedding_template.index(tokenizer.mask_token_id) 129 | index_mbv = mask_embedding_template[1:mask_index] + mask_embedding_template[mask_index+1:-1] 130 | #mask_embedding_template = [ 101, 2023, 6251, 1997, 1000, 1000, 2965, 103, 1012, 102] 131 | #index_mbv = mask_embedding_template[1:7] + mask_embedding_template[8:9] 132 | 133 | dict_mbv = index_mbv 134 | fl_mbv = [i <= 3 for i, k in enumerate(index_mbv)] 135 | 136 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 137 | 138 | 139 | #device = torch.device("cpu") 140 | model = model.to(device) 141 | if args.mask_embedding_sentence_org_mlp: 142 | mlp = mlp.to(device) 143 | 144 | if args.mask_embedding_sentence_delta: 145 | delta, template_len = get_delta(model, args.mask_embedding_sentence_template, tokenizer, device, args) 146 | 147 | # Set up the tasks 148 | #args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 149 | #args.tasks = ['MR'] 150 | if args.task_set == 'sts': 151 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 152 | elif args.task_set == 'transfer': 153 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 154 | elif args.task_set == 'full': 155 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 156 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 157 | 158 | # Set params for SentEval 159 | if args.mode == 'dev' or args.mode == 'fasttest': 160 | # Fast mode 161 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 162 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 163 | 'tenacity': 3, 'epoch_size': 2} 164 | elif args.mode == 'test': 165 | # Full mode 166 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 167 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 168 | 'tenacity': 5, 'epoch_size': 4} 169 | else: 170 | raise NotImplementedError 171 | 172 | # SentEval prepare and batcher 173 | def prepare(params, samples): 174 | return 175 | 176 | if args.remove_continue_word: 177 | pun_remove_set = {'?', '*', '#', '´', '’', '=', '…', '|', '~', '/', '‚', '¿', '–', '»', '-', '€', '‘', '"', '(', '•', '`', '$', ':', '[', '”', '%', '£', '<', '[UNK]', ';', '“', '@', '_', '{', '^', ',', '.', '!', '™', '&', ']', '>', '\\', "'", ')', '+', '—'} 178 | if args.model_name_or_path == 'roberta-base': 179 | remove_set = {'Ġ.', 'Ġa', 'Ġthe', 'Ġin', 'a', 'Ġ, ', 'Ġis', 'Ġto', 'Ġof', 'Ġand', 'Ġon', 'Ġ\'', 's', '.', 'the', 'Ġman', '-', 'Ġwith', 'Ġfor', 'Ġat', 'Ġwoman', 'Ġare', 'Ġ"', 'Ġthat', 'Ġit', 'Ġdog', 'Ġsaid', 'Ġplaying', 'Ġwas', 'Ġas', 'Ġfrom', 'Ġ:', 'Ġyou', 'Ġan', 'i', 'Ġby'} 180 | else: 181 | remove_set = {".", "a", "the", "in", ",", "is", "to", "of", "and", "'", "on", "man", "-", "s", "with", "for", "\"", "at", "##s", "woman", "are", "it", "two", "that", "you", "dog", "said", "playing", "i", "an", "as", "was", "from", ":", "by", "white"} 182 | 183 | vocab = tokenizer.get_vocab() 184 | 185 | 186 | def batcher(params, batch, max_length=None): 187 | # Handle rare token encoding issues in the dataset 188 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 189 | batch = [[word.decode('utf-8') for word in s] for s in batch] 190 | 191 | sentences = [' '.join(s) for s in batch] 192 | if args.mask_embedding_sentence and args.mask_embedding_sentence_template is not None: 193 | # *cls*_This_sentence_of_"*sent_0*"_means*mask*.*sep+* 194 | template = args.mask_embedding_sentence_template 195 | template = template.replace('*mask*', tokenizer.mask_token )\ 196 | .replace('_', ' ').replace('*sep+*', '')\ 197 | .replace('*cls*', '') 198 | 199 | for i, s in enumerate(sentences): 200 | if len(s) > 0 and s[-1] not in '.?"\'': s += '.' 201 | sentences[i] = template.replace('*sent 0*', s).strip() 202 | elif args.remove_continue_word: 203 | for i, s in enumerate(sentences): 204 | sentences[i] = ' ' if args.model_name_or_path == 'roberta-base' else '' 205 | es = tokenizer.encode(' ' + s, add_special_tokens=False) 206 | for iw, w in enumerate(tokenizer.convert_ids_to_tokens(es)): 207 | if args.model_name_or_path == 'roberta-base': 208 | # roberta base 209 | if 'Ġ' not in w or w in remove_set: 210 | pass 211 | else: 212 | if re.search('[a-zA-Z0-9]', w) is not None: 213 | sentences[i] += w.replace('Ġ', '').lower() + ' ' 214 | elif w not in remove_set and w not in pun_remove_set and '##' not in w: 215 | # bert base 216 | sentences[i] += w.lower() + ' ' 217 | if len(sentences[i]) == 0: sentences[i] = '[PAD]' 218 | 219 | if max_length is not None: 220 | batch = tokenizer.batch_encode_plus( 221 | sentences, 222 | return_tensors='pt', 223 | padding=True, 224 | max_length=max_length, 225 | truncation=True 226 | ) 227 | else: 228 | batch = tokenizer.batch_encode_plus( 229 | sentences, 230 | return_tensors='pt', 231 | padding=True, 232 | ) 233 | 234 | # Move to the correct device 235 | for k in batch: 236 | batch[k] = batch[k].to(device) if batch[k] is not None else None 237 | 238 | # Get raw embeddings 239 | with torch.no_grad(): 240 | if args.embedding_only: 241 | hidden_states = None 242 | pooler_output = None 243 | last_hidden = model.embeddings.word_embeddings(batch['input_ids']) 244 | position_ids = model.embeddings.position_ids[:, 0 : last_hidden.shape[1]] 245 | token_type_ids = torch.zeros(batch['input_ids'].shape, dtype=torch.long, 246 | device=model.embeddings.position_ids.device) 247 | 248 | position_embeddings = model.embeddings.position_embeddings(position_ids) 249 | token_type_embeddings = model.embeddings.token_type_embeddings(token_type_ids) 250 | 251 | if args.remove_continue_word: 252 | batch['attention_mask'][batch['input_ids'] == tokenizer.cls_token_id] = 0 253 | batch['attention_mask'][batch['input_ids'] == tokenizer.sep_token_id] = 0 254 | elif args.mask_embedding_sentence_autoprompt: 255 | input_ids = batch['input_ids'] 256 | inputs_embeds = model.embeddings.word_embeddings(input_ids) 257 | p = torch.arange(input_ids.shape[1]).to(input_ids.device).view(1, -1) 258 | b = torch.arange(input_ids.shape[0]).to(input_ids.device) 259 | for i, k in enumerate(dict_mbv): 260 | if fl_mbv[i]: 261 | index = ((input_ids == k) * p).max(-1)[1] 262 | else: 263 | index = ((input_ids == k) * -p).min(-1)[1] 264 | inputs_embeds[b, index] = p_mbv[i] 265 | batch['input_ids'], batch['inputs_embeds'] = None, inputs_embeds 266 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 267 | batch['input_ids'] = input_ids 268 | 269 | last_hidden = outputs.last_hidden_state 270 | pooler_output = last_hidden[input_ids == tokenizer.mask_token_id] 271 | 272 | if args.mask_embedding_sentence_org_mlp: 273 | pooler_output = mlp(pooler_output) 274 | if args.mask_embedding_sentence_delta: 275 | blen = batch['attention_mask'].sum(-1) - template_len 276 | if args.mask_embedding_sentence_org_mlp: 277 | pooler_output -= mlp(delta[blen]) 278 | else: 279 | pooler_output -= delta[blen] 280 | if args.mask_embedding_sentence_use_pooler: 281 | pooler_output = model.pooler.dense(pooler_output) 282 | pooler_output = model.pooler.activation(pooler_output) 283 | 284 | else: 285 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 286 | 287 | try: 288 | pooler_output = outputs.pooler_output 289 | except AttributeError: 290 | pooler_output = outputs['last_hidden_state'][:, 0, :] 291 | if args.mask_embedding_sentence: 292 | last_hidden = outputs.last_hidden_state 293 | pooler_output = last_hidden[batch['input_ids'] == tokenizer.mask_token_id] 294 | if args.mask_embedding_sentence_org_mlp: 295 | pooler_output = mlp(pooler_output) 296 | if args.mask_embedding_sentence_delta: 297 | blen = batch['attention_mask'].sum(-1) - template_len 298 | if args.mask_embedding_sentence_org_mlp: 299 | pooler_output -= mlp(delta[blen]) 300 | else: 301 | pooler_output -= delta[blen] 302 | if args.mask_embedding_sentence_use_org_pooler: 303 | pooler_output = mlp(pooler_output) 304 | if args.mask_embedding_sentence_use_pooler: 305 | pooler_output = model.pooler.dense(pooler_output) 306 | pooler_output = model.pooler.activation(pooler_output) 307 | else: 308 | last_hidden = outputs.last_hidden_state 309 | hidden_states = outputs.hidden_states 310 | 311 | 312 | # Apply different pooler 313 | if args.mask_embedding_sentence: 314 | return pooler_output.view(batch['input_ids'].shape[0], -1).cpu() 315 | elif args.pooler == 'cls': 316 | # There is a linear+activation layer after CLS representation 317 | return pooler_output.cpu() 318 | elif args.pooler == 'cls_before_pooler': 319 | batch['input_ids'][(batch['input_ids'] == 0) | (batch['input_ids'] == 101) | (batch['input_ids'] == 102)] = batch['input_ids'].max() 320 | index = batch['input_ids'].topk(3, dim=-1, largest=False)[1] 321 | index2 = torch.arange(batch['input_ids'].shape[0]).to(index.device) 322 | r = last_hidden[index2, index[:, 0], :] 323 | for i in range(1, 3): 324 | r += last_hidden[index2, index[:, i], :] 325 | return (r/3).cpu() 326 | elif args.pooler == "avg": 327 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu() 328 | elif args.pooler == "avg_first_last": 329 | first_hidden = hidden_states[0] 330 | last_hidden = hidden_states[-1] 331 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 332 | return pooled_result.cpu() 333 | elif args.pooler == "avg_top2": 334 | second_last_hidden = hidden_states[-2] 335 | last_hidden = hidden_states[-1] 336 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 337 | return pooled_result.cpu() 338 | else: 339 | raise NotImplementedError 340 | 341 | if args.calc_anisotropy: 342 | with open('./data/wiki1m_for_simcse.txt') as f: 343 | lines = f.readlines()[:100000] 344 | batch, embeds = [], [] 345 | print('Get Sentence Embeddings....') 346 | for line in tqdm.tqdm(lines): 347 | batch.append(line.replace('\n', '').lower().split()[:32]) 348 | if len(batch) >= 128: 349 | embeds.append(batcher(None, batch).detach().numpy()) 350 | batch = [] 351 | embeds.append(batcher(None, batch).detach().numpy()) 352 | print('Calculate anisotropy....') 353 | embeds = np.concatenate(embeds, axis=0) 354 | cosine = cal_avg_cosine(embeds) 355 | print('Avg. Cos:', cosine) 356 | exit(0) 357 | 358 | results = {} 359 | 360 | for task in args.tasks: 361 | se = senteval.engine.SE(params, batcher, prepare) 362 | result = se.eval(task) 363 | results[task] = result 364 | 365 | # Print evaluation results 366 | if args.mode == 'dev': 367 | print("------ %s ------" % (args.mode)) 368 | 369 | task_names = [] 370 | scores = [] 371 | for task in ['STSBenchmark', 'SICKRelatedness']: 372 | task_names.append(task) 373 | if task in results: 374 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100)) 375 | else: 376 | scores.append("0.00") 377 | print_table(task_names, scores) 378 | 379 | task_names = [] 380 | scores = [] 381 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 382 | task_names.append(task) 383 | if task in results: 384 | scores.append("%.2f" % (results[task]['devacc'])) 385 | else: 386 | scores.append("0.00") 387 | task_names.append("Avg.") 388 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 389 | print_table(task_names, scores) 390 | 391 | elif args.mode == 'test' or args.mode == 'fasttest': 392 | print("------ %s ------" % (args.mode)) 393 | 394 | task_names = [] 395 | scores = [] 396 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 397 | task_names.append(task) 398 | if task in results: 399 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 400 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 401 | else: 402 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 403 | else: 404 | scores.append("0.00") 405 | task_names.append("Avg.") 406 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 407 | print_table(task_names, scores) 408 | 409 | task_names = [] 410 | scores = [] 411 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 412 | task_names.append(task) 413 | if task in results: 414 | scores.append("%.2f" % (results[task]['acc'])) 415 | else: 416 | scores.append("0.00") 417 | task_names.append("Avg.") 418 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 419 | print_table(task_names, scores) 420 | 421 | if __name__ == "__main__": 422 | main() 423 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/get_embedding.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import io, os 4 | import math 5 | import torch 6 | import numpy as np 7 | import logging 8 | import tqdm 9 | import argparse 10 | from prettytable import PrettyTable 11 | import torch 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer 14 | from tqdm import tqdm 15 | 16 | # Set up logger 17 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 18 | 19 | def get_delta(model, template, tokenizer, device, args): 20 | model.eval() 21 | 22 | template = template.replace('*mask*', tokenizer.mask_token)\ 23 | .replace('*sep+*', '')\ 24 | .replace('*cls*', '').replace('*sent_0*', ' ') 25 | # strip for roberta tokenizer 26 | bs_length = len(tokenizer.encode(template.split(' ')[0].replace('_', ' ').strip())) - 2 + 1 27 | # replace for roberta tokenizer 28 | batch = tokenizer([template.replace('_', ' ').strip().replace(' ', ' ')], return_tensors='pt') 29 | batch['position_ids'] = torch.arange(batch['input_ids'].shape[1]).to(device).unsqueeze(0) 30 | for k in batch: 31 | batch[k] = batch[k].repeat(512, 1).to(device) 32 | m_mask = batch['input_ids'] == tokenizer.mask_token_id 33 | 34 | with torch.no_grad(): 35 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 36 | last_hidden = outputs.hidden_states[-1] 37 | delta = last_hidden[m_mask] 38 | delta.requires_grad = False 39 | #import pdb;pdb.set_trace() 40 | template_len = batch['input_ids'].shape[1] 41 | return delta, template_len 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--sentence_file", type=str, required=True) 46 | parser.add_argument("--vector_file", type=str, required=True) 47 | parser.add_argument("--batch_size", type=int, default=32) 48 | parser.add_argument("--embedding_only", action='store_true') 49 | parser.add_argument('--mlm_head_predict', action='store_true') 50 | parser.add_argument('--remove_continue_word', action='store_true') 51 | parser.add_argument('--mask_embedding_sentence', action='store_true') 52 | parser.add_argument('--mask_embedding_sentence_use_org_pooler', action='store_true') 53 | parser.add_argument('--mask_embedding_sentence_template', type=str, default=None) 54 | parser.add_argument('--mask_embedding_sentence_delta', action='store_true') 55 | parser.add_argument('--mask_embedding_sentence_use_pooler', action='store_true') 56 | parser.add_argument('--mask_embedding_sentence_autoprompt', action='store_true') 57 | parser.add_argument('--mask_embedding_sentence_org_mlp', action='store_true') 58 | parser.add_argument("--tokenizer_name", type=str, default='') 59 | parser.add_argument("--model_name_or_path", type=str, 60 | help="Transformers' model name or path") 61 | parser.add_argument("--pooler", type=str, 62 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_first_last'], 63 | default='cls', 64 | help="Which pooler to use") 65 | 66 | args = parser.parse_args() 67 | 68 | # Load transformers' model checkpoint 69 | if args.mask_embedding_sentence_org_mlp: 70 | #only for bert-base 71 | from transformers import BertForMaskedLM, BertConfig 72 | config = BertConfig.from_pretrained("bert-base-uncased") 73 | mlp = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).cls.predictions.transform 74 | if 'result' in args.model_name_or_path: 75 | state_dict = torch.load(args.model_name_or_path+'/pytorch_model.bin') 76 | new_state_dict = {} 77 | for key, param in state_dict.items(): 78 | # Replace "mlp" to "pooler" 79 | if 'pooler' in key: 80 | key = key.replace("pooler.", "") 81 | new_state_dict[key] = param 82 | mlp.load_state_dict(new_state_dict) 83 | model = AutoModel.from_pretrained(args.model_name_or_path) 84 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) 85 | 86 | if args.mask_embedding_sentence_autoprompt: 87 | state_dict = torch.load(args.model_name_or_path+'/pytorch_model.bin') 88 | p_mbv = state_dict['p_mbv'] 89 | template = args.mask_embedding_sentence_template 90 | template = template.replace('*mask*', tokenizer.mask_token)\ 91 | .replace('*sep+*', '')\ 92 | .replace('*cls*', '').replace('*sent_0*', ' ').replace('_', ' ') 93 | mask_embedding_template = tokenizer.encode(template) 94 | mask_index = mask_embedding_template.index(tokenizer.mask_token_id) 95 | index_mbv = mask_embedding_template[1:mask_index] + mask_embedding_template[mask_index+1:-1] 96 | #mask_embedding_template = [ 101, 2023, 6251, 1997, 1000, 1000, 2965, 103, 1012, 102] 97 | #index_mbv = mask_embedding_template[1:7] + mask_embedding_template[8:9] 98 | 99 | dict_mbv = index_mbv 100 | fl_mbv = [i <= 3 for i, k in enumerate(index_mbv)] 101 | 102 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 103 | device = torch.device("cpu") 104 | if torch.cuda.is_available(): 105 | device = torch.device("cuda") 106 | n_gpu = torch.cuda.device_count() 107 | encoder = model 108 | 109 | if n_gpu > 1: 110 | model = torch.nn.DataParallel(model) 111 | encoder = model.module 112 | model.eval() 113 | 114 | #device = torch.device("cpu") 115 | model = model.to(device) 116 | if args.mask_embedding_sentence_org_mlp: 117 | mlp = mlp.to(device) 118 | 119 | if args.mask_embedding_sentence_delta: 120 | delta, template_len = get_delta(model, args.mask_embedding_sentence_template, tokenizer, device, args) 121 | 122 | if args.remove_continue_word: 123 | pun_remove_set = {'?', '*', '#', '´', '’', '=', '…', '|', '~', '/', '‚', '¿', '–', '»', '-', '€', '‘', '"', '(', '•', '`', '$', ':', '[', '”', '%', '£', '<', '[UNK]', ';', '“', '@', '_', '{', '^', ',', '.', '!', '™', '&', ']', '>', '\\', "'", ')', '+', '—'} 124 | if args.model_name_or_path == 'roberta-base': 125 | remove_set = {'Ġ.', 'Ġa', 'Ġthe', 'Ġin', 'a', 'Ġ, ', 'Ġis', 'Ġto', 'Ġof', 'Ġand', 'Ġon', 'Ġ\'', 's', '.', 'the', 'Ġman', '-', 'Ġwith', 'Ġfor', 'Ġat', 'Ġwoman', 'Ġare', 'Ġ"', 'Ġthat', 'Ġit', 'Ġdog', 'Ġsaid', 'Ġplaying', 'Ġwas', 'Ġas', 'Ġfrom', 'Ġ:', 'Ġyou', 'Ġan', 'i', 'Ġby'} 126 | else: 127 | remove_set = {".", "a", "the", "in", ",", "is", "to", "of", "and", "'", "on", "man", "-", "s", "with", "for", "\"", "at", "##s", "woman", "are", "it", "two", "that", "you", "dog", "said", "playing", "i", "an", "as", "was", "from", ":", "by", "white"} 128 | 129 | vocab = tokenizer.get_vocab() 130 | 131 | 132 | def batcher(batch, max_length=None): 133 | # Handle rare token encoding issues in the dataset 134 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 135 | batch = [[word.decode('utf-8') for word in s] for s in batch] 136 | 137 | sentences = [' '.join(s) for s in batch] 138 | if args.mask_embedding_sentence and args.mask_embedding_sentence_template is not None: 139 | # *cls*_This_sentence_of_"*sent_0*"_means*mask*.*sep+* 140 | template = args.mask_embedding_sentence_template 141 | template = template.replace('*mask*', tokenizer.mask_token )\ 142 | .replace('_', ' ').replace('*sep+*', '')\ 143 | .replace('*cls*', '') 144 | template_tokens = tokenizer.encode(template.replace('*sent 0*', ""), add_special_tokens=False) 145 | len_template_tokens = len(template_tokens) 146 | 147 | for i, s in enumerate(sentences): 148 | if len(s) > 0 and s[-1] not in '.?"\'': s += '.' 149 | s_tokens = tokenizer.encode(s, add_special_tokens=False) 150 | limit = 512 - 2 - len_template_tokens 151 | if limit < len(s_tokens): 152 | s_tokens = s_tokens[:limit] 153 | s = tokenizer.decode(s_tokens) 154 | sentences[i] = template.replace('*sent 0*', s).strip() 155 | elif args.remove_continue_word: 156 | for i, s in enumerate(sentences): 157 | sentences[i] = ' ' if args.model_name_or_path == 'roberta-base' else '' 158 | es = tokenizer.encode(' ' + s, add_special_tokens=False) 159 | for iw, w in enumerate(tokenizer.convert_ids_to_tokens(es)): 160 | if args.model_name_or_path == 'roberta-base': 161 | # roberta base 162 | if 'Ġ' not in w or w in remove_set: 163 | pass 164 | else: 165 | if re.search('[a-zA-Z0-9]', w) is not None: 166 | sentences[i] += w.replace('Ġ', '').lower() + ' ' 167 | elif w not in remove_set and w not in pun_remove_set and '##' not in w: 168 | # bert base 169 | sentences[i] += w.lower() + ' ' 170 | if len(sentences[i]) == 0: sentences[i] = '[PAD]' 171 | 172 | if max_length is not None: 173 | batch = tokenizer.batch_encode_plus( 174 | sentences, 175 | return_tensors='pt', 176 | padding=True, 177 | max_length=max_length, 178 | truncation=True 179 | ) 180 | else: 181 | batch = tokenizer.batch_encode_plus( 182 | sentences, 183 | return_tensors='pt', 184 | padding=True, 185 | truncation=True 186 | ) 187 | 188 | # Move to the correct device 189 | for k in batch: 190 | batch[k] = batch[k].to(device) if batch[k] is not None else None 191 | 192 | # Get raw embeddings 193 | with torch.no_grad(): 194 | if args.embedding_only: 195 | hidden_states = None 196 | pooler_output = None 197 | last_hidden = encoder.embeddings.word_embeddings(batch['input_ids']) 198 | position_ids = encoder.embeddings.position_ids[:, 0 : last_hidden.shape[1]] 199 | token_type_ids = torch.zeros(batch['input_ids'].shape, dtype=torch.long, 200 | device=encoder.embeddings.position_ids.device) 201 | 202 | position_embeddings = encoder.embeddings.position_embeddings(position_ids) 203 | token_type_embeddings = encoder.embeddings.token_type_embeddings(token_type_ids) 204 | 205 | if args.remove_continue_word: 206 | batch['attention_mask'][batch['input_ids'] == tokenizer.cls_token_id] = 0 207 | batch['attention_mask'][batch['input_ids'] == tokenizer.sep_token_id] = 0 208 | elif args.mask_embedding_sentence_autoprompt: 209 | input_ids = batch['input_ids'] 210 | inputs_embeds = encoder.embeddings.word_embeddings(input_ids) 211 | p = torch.arange(input_ids.shape[1]).to(input_ids.device).view(1, -1) 212 | b = torch.arange(input_ids.shape[0]).to(input_ids.device) 213 | for i, k in enumerate(dict_mbv): 214 | if fl_mbv[i]: 215 | index = ((input_ids == k) * p).max(-1)[1] 216 | else: 217 | index = ((input_ids == k) * -p).min(-1)[1] 218 | inputs_embeds[b, index] = p_mbv[i] 219 | batch['input_ids'], batch['inputs_embeds'] = None, inputs_embeds 220 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 221 | batch['input_ids'] = input_ids 222 | 223 | last_hidden = outputs.last_hidden_state 224 | pooler_output = last_hidden[input_ids == tokenizer.mask_token_id] 225 | 226 | if args.mask_embedding_sentence_org_mlp: 227 | pooler_output = mlp(pooler_output) 228 | if args.mask_embedding_sentence_delta: 229 | blen = batch['attention_mask'].sum(-1) - template_len 230 | if args.mask_embedding_sentence_org_mlp: 231 | pooler_output -= mlp(delta[blen]) 232 | else: 233 | pooler_output -= delta[blen] 234 | if args.mask_embedding_sentence_use_pooler: 235 | pooler_output = encoder.pooler.dense(pooler_output) 236 | pooler_output = encoder.pooler.activation(pooler_output) 237 | 238 | else: 239 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 240 | 241 | try: 242 | pooler_output = outputs.pooler_output 243 | except AttributeError: 244 | pooler_output = outputs['last_hidden_state'][:, 0, :] 245 | if args.mask_embedding_sentence: 246 | last_hidden = outputs.last_hidden_state 247 | pooler_output = last_hidden[batch['input_ids'] == tokenizer.mask_token_id] 248 | if args.mask_embedding_sentence_org_mlp: 249 | pooler_output = mlp(pooler_output) 250 | if args.mask_embedding_sentence_delta: 251 | blen = batch['attention_mask'].sum(-1) - template_len 252 | if args.mask_embedding_sentence_org_mlp: 253 | pooler_output -= mlp(delta[blen]) 254 | else: 255 | pooler_output -= delta[blen] 256 | if args.mask_embedding_sentence_use_org_pooler: 257 | pooler_output = mlp(pooler_output) 258 | if args.mask_embedding_sentence_use_pooler: 259 | pooler_output = encoder.pooler.dense(pooler_output) 260 | pooler_output = encoder.pooler.activation(pooler_output) 261 | else: 262 | last_hidden = outputs.last_hidden_state 263 | hidden_states = outputs.hidden_states 264 | 265 | # Apply different pooler 266 | if args.mask_embedding_sentence: 267 | return pooler_output.view(batch['input_ids'].shape[0], -1).cpu() 268 | elif args.pooler == 'cls': 269 | # There is a linear+activation layer after CLS representation 270 | return pooler_output.cpu() 271 | elif args.pooler == 'cls_before_pooler': 272 | batch['input_ids'][(batch['input_ids'] == 0) | (batch['input_ids'] == 101) | (batch['input_ids'] == 102)] = batch['input_ids'].max() 273 | index = batch['input_ids'].topk(3, dim=-1, largest=False)[1] 274 | index2 = torch.arange(batch['input_ids'].shape[0]).to(index.device) 275 | r = last_hidden[index2, index[:, 0], :] 276 | for i in range(1, 3): 277 | r += last_hidden[index2, index[:, i], :] 278 | return (r/3).cpu() 279 | elif args.pooler == "avg": 280 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu() 281 | elif args.pooler == "avg_first_last": 282 | first_hidden = hidden_states[0] 283 | last_hidden = hidden_states[-1] 284 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 285 | return pooled_result.cpu() 286 | elif args.pooler == "avg_top2": 287 | second_last_hidden = hidden_states[-2] 288 | last_hidden = hidden_states[-1] 289 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 290 | return pooled_result.cpu() 291 | else: 292 | raise NotImplementedError 293 | 294 | sentences = None 295 | with open(args.sentence_file, "r") as f: 296 | sentences = f.readlines() 297 | sentences = [s.strip().split() for s in tqdm(sentences, desc="Preprocessing")] 298 | 299 | sentences = sorted([(i, s) for i, s in enumerate(sentences)], key=lambda x: len(x[1]), reverse=True) 300 | inds, sentences = map(list, zip(*sentences)) 301 | sort_inds = sorted([(i, j) for i, j in enumerate(inds)], key=lambda x: x[1]) 302 | sort_inds, _ = map(list, zip(*sort_inds)) 303 | 304 | sentence_vectors = [] 305 | n_batch = math.ceil(len(sentences) / args.batch_size) 306 | for i in tqdm(range(0, len(sentences), args.batch_size), desc="Embedding sentences..."): 307 | batch = sentences[i:i + args.batch_size] 308 | vectors = batcher(batch).numpy() 309 | assert vectors.shape[1] == encoder.config.hidden_size 310 | sentence_vectors.append(vectors) 311 | sentence_vectors = np.concatenate(sentence_vectors) 312 | sentence_vectors = sentence_vectors[sort_inds] 313 | 314 | print("Saving...") 315 | os.makedirs(os.path.dirname(args.vector_file), exist_ok=True) 316 | with open(args.vector_file, "wb") as f: 317 | np.save(f, sentence_vectors) 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/get_prompt_bert_embedding.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CHECKPOINT=$OUTPUT_DIR/promptbert/checkpoints/prompt_bert_seed_$SEED 5 | CUDA_VISIBLE_DEVICES=0,1 python get_embedding.py \ 6 | --model_name_or_path $CHECKPOINT \ 7 | --pooler avg \ 8 | --mask_embedding_sentence \ 9 | --mask_embedding_sentence_template "*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*" \ 10 | --sentence_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 11 | --vector_file $OUTPUT_DIR/promptbert/index_vecs/corpus_0.01_prompt_bert_seed_$SEED.npy \ 12 | --batch_size 64 13 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/get_prompt_bert_rank_encoder_embedding.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CHECKPOINT=$OUTPUT_DIR/promptbert/checkpoints/prompt_bert_rank_encoder_seed_$SEED 5 | CUDA_VISIBLE_DEVICES=0,1 python get_embedding.py \ 6 | --model_name_or_path $CHECKPOINT \ 7 | --pooler avg \ 8 | --mask_embedding_sentence \ 9 | --mask_embedding_sentence_template "*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*" \ 10 | --sentence_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 11 | --vector_file $OUTPUT_DIR/promptbert/index_vecs/corpus_0.01_prompt_bert_rank_encoder_seed_$SEED.npy \ 12 | --batch_size 64 13 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/prompt_bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeonsw/RankEncoder/fbf762b252afe5d9683c713fccf7475ef526c8b6/code/PromptBERT_RankEncoder/prompt_bert/__init__.py -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/prompt_bert_rank_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeonsw/RankEncoder/fbf762b252afe5d9683c713fccf7475ef526c8b6/code/PromptBERT_RankEncoder/prompt_bert_rank_encoder/__init__.py -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/prompt_bert_rank_encoder_inference.py: -------------------------------------------------------------------------------- 1 | import apex 2 | import re 3 | import sys 4 | import io, os 5 | import faiss 6 | import math 7 | import json 8 | import torch 9 | import numpy as np 10 | import logging 11 | import tqdm 12 | import time 13 | import argparse 14 | from prettytable import PrettyTable 15 | from scipy.stats import spearmanr, pearsonr 16 | from scipy.special import softmax 17 | from scipy.stats import rankdata 18 | import string 19 | import torch 20 | import transformers 21 | from transformers import AutoModel, AutoTokenizer 22 | from transformers import BertTokenizer, BertModel 23 | from tqdm import tqdm 24 | # Set up logger 25 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 26 | 27 | # Set PATHs 28 | PATH_TO_SENTEVAL = './SentEval' 29 | 30 | # Import SentEval 31 | sys.path.insert(0, PATH_TO_SENTEVAL) 32 | 33 | import senteval 34 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval 35 | from senteval.sts import SICKRelatednessEval 36 | 37 | def normalize(vecs): 38 | eps = 1e-8 39 | return vecs / (np.sqrt(np.sum(np.square(vecs), axis=1)) + eps)[:,None] 40 | 41 | def print_table(task_names, scores): 42 | tb = PrettyTable() 43 | tb.field_names = task_names 44 | tb.add_row(scores) 45 | print(tb) 46 | 47 | def read_benchmark_data(senteval_path, task): 48 | task2class = { \ 49 | 'STS12': STS12Eval, 50 | 'STS13': STS13Eval, 51 | 'STS14': STS14Eval, 52 | 'STS15': STS15Eval, 53 | 'STS16': STS16Eval, 54 | 'STSBenchmark': STSBenchmarkEval, 55 | 'SICKRelatedness': SICKRelatednessEval 56 | } 57 | dataset_path = None 58 | print("SentEval path: {}".format(senteval_path)) 59 | if task in ["STS12", "STS13", "STS14", "STS15", "STS16"]: 60 | dataset_path = os.path.join(senteval_path, "downstream/STS/", "{}{}".format(task, "-en-test")) 61 | elif task == "STSBenchmark": 62 | dataset_path = os.path.join(senteval_path, "downstream/STS/", "{}".format(task)) 63 | elif task == "SICKRelatedness": 64 | dataset_path = os.path.join(senteval_path, "downstream/SICK") 65 | print(dataset_path) 66 | data = {} 67 | task_data = task2class[task](dataset_path) 68 | for dset in task_data.datasets: 69 | input1, input2, gs_scores = task_data.data[dset] 70 | data[dset] = (input1, input2, gs_scores) 71 | return data 72 | 73 | def compute_similarity(q0, q0_sim, q1, q1_sim, lmb=0.0): 74 | normalized_q0 = normalize(np.reshape(q0, (1, -1))) 75 | normalized_q1 = normalize(np.reshape(q1, (1, -1))) 76 | add_score, _ = spearmanr(q0_sim, q1_sim) 77 | score = np.sum(np.matmul(normalized_q0, normalized_q1.T)) 78 | score = lmb * score + (1.0 - lmb) * add_score 79 | return score 80 | 81 | def evaluate_retrieval_augmented_promptbert(args, data, batcher, sentence_vecs): 82 | results = {} 83 | all_sys_scores = [] 84 | all_gs_scores = [] 85 | for dset in data: 86 | sys_scores = [] 87 | input1, input2, gs_scores = data[dset] 88 | for ii in range(0, len(gs_scores), args.batch_size): 89 | batch1 = input1[ii:ii + args.batch_size] 90 | batch2 = input2[ii:ii + args.batch_size] 91 | 92 | # we assume get_batch already throws out the faulty ones 93 | if len(batch1) == len(batch2) and len(batch1) > 0: 94 | enc1 = batcher(batch1) 95 | enc2 = batcher(batch2) 96 | sim1 = np.matmul( \ 97 | enc1, sentence_vecs.T \ 98 | ) 99 | sim2 = np.matmul( \ 100 | enc2, sentence_vecs.T \ 101 | ) 102 | 103 | for kk in range(enc1.shape[0]): 104 | sys_score = compute_similarity( \ 105 | enc1[kk], sim1[kk], \ 106 | enc2[kk], sim2[kk], \ 107 | args.lmb \ 108 | ) 109 | sys_scores.append(sys_score) 110 | all_sys_scores.extend(sys_scores) 111 | all_gs_scores.extend(gs_scores) 112 | results[dset] = { 113 | 'pearson': pearsonr(sys_scores, gs_scores), 114 | 'spearman': spearmanr(sys_scores, gs_scores), 115 | 'nsamples': len(sys_scores) 116 | } 117 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 118 | (dset, results[dset]['pearson'][0], 119 | results[dset]['spearman'][0])) 120 | 121 | weights = [results[dset]['nsamples'] for dset in results.keys()] 122 | list_prs = np.array([results[dset]['pearson'][0] for 123 | dset in results.keys()]) 124 | list_spr = np.array([results[dset]['spearman'][0] for 125 | dset in results.keys()]) 126 | 127 | avg_pearson = np.average(list_prs) 128 | avg_spearman = np.average(list_spr) 129 | wavg_pearson = np.average(list_prs, weights=weights) 130 | wavg_spearman = np.average(list_spr, weights=weights) 131 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 132 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 133 | results['all'] = {'pearson': {'all': all_pearson[0], 134 | 'mean': avg_pearson, 135 | 'wmean': wavg_pearson}, 136 | 'spearman': {'all': all_spearman[0], 137 | 'mean': avg_spearman, 138 | 'wmean': wavg_spearman}} 139 | logging.debug('ALL : Pearson = %.4f, \ 140 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 141 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 142 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 143 | logging.debug('ALL (average) : Pearson = %.4f, \ 144 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 145 | results["pred_scores"] = all_sys_scores 146 | results["gs_scores"] = all_gs_scores 147 | return results 148 | 149 | def parse_args(): 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("--sentence_vecs", type=str, required=True) 152 | parser.add_argument("--senteval_path", type=str, default="SentEval/data") 153 | parser.add_argument("--batch_size", type=int, default=32) 154 | parser.add_argument("--lmb", type=float, default=1.0) 155 | 156 | # PromptBERT args 157 | parser.add_argument('--mask_embedding_sentence', action='store_true') 158 | parser.add_argument('--mask_embedding_sentence_template', type=str, default=None) 159 | parser.add_argument("--model_name_or_path", type=str, 160 | help="Transformers' model name or path") 161 | args = parser.parse_args() 162 | return args 163 | 164 | def main(args): 165 | # Load transformers' model checkpoint 166 | model = AutoModel.from_pretrained(args.model_name_or_path) 167 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) 168 | 169 | device = torch.device("cpu") 170 | if torch.cuda.is_available(): 171 | device = torch.device("cuda") 172 | n_gpu = torch.cuda.device_count() 173 | 174 | model = model.to(device) 175 | encoder = model 176 | if n_gpu > 1: 177 | model = torch.nn.DataParallel(model) 178 | encoder = model.module 179 | model.eval() 180 | 181 | def batcher(batch, max_length=None): 182 | # Handle rare token encoding issues in the dataset 183 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 184 | batch = [[word.decode('utf-8') for word in s] for s in batch] 185 | 186 | sentences = [' '.join(s) for s in batch] 187 | if args.mask_embedding_sentence and args.mask_embedding_sentence_template is not None: 188 | # *cls*_This_sentence_of_"*sent_0*"_means*mask*.*sep+* 189 | template = args.mask_embedding_sentence_template 190 | template = template.replace('*mask*', tokenizer.mask_token )\ 191 | .replace('_', ' ').replace('*sep+*', '')\ 192 | .replace('*cls*', '') 193 | 194 | for i, s in enumerate(sentences): 195 | if len(s) > 0 and s[-1] not in '.?"\'': s += '.' 196 | sentences[i] = template.replace('*sent 0*', s).strip() 197 | 198 | if max_length is not None: 199 | batch = tokenizer.batch_encode_plus( 200 | sentences, 201 | return_tensors='pt', 202 | padding=True, 203 | max_length=max_length, 204 | truncation=True 205 | ) 206 | else: 207 | batch = tokenizer.batch_encode_plus( 208 | sentences, 209 | return_tensors='pt', 210 | padding=True, 211 | truncation=True 212 | ) 213 | 214 | # Move to the correct device 215 | for k in batch: 216 | batch[k] = batch[k].to(device) if batch[k] is not None else None 217 | 218 | # Get raw embeddings 219 | with torch.no_grad(): 220 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 221 | 222 | try: 223 | pooler_output = outputs.pooler_output 224 | except AttributeError: 225 | pooler_output = outputs['last_hidden_state'][:, 0, :] 226 | if args.mask_embedding_sentence: 227 | last_hidden = outputs.last_hidden_state 228 | pooler_output = last_hidden[batch['input_ids'] == tokenizer.mask_token_id] 229 | else: 230 | last_hidden = outputs.last_hidden_state 231 | hidden_states = outputs.hidden_states 232 | 233 | sentence_embedding = None 234 | if args.mask_embedding_sentence: 235 | sentence_embedding = pooler_output.view(batch['input_ids'].shape[0], -1).cpu() 236 | else: 237 | raise NotImplementedError 238 | 239 | sentence_embedding = normalize(sentence_embedding.numpy()) 240 | 241 | return sentence_embedding 242 | 243 | print("Loading {}".format(args.sentence_vecs)) 244 | sentence_vecs = np.load(args.sentence_vecs) 245 | 246 | # Load benchmark datasets 247 | target_tasks = [ \ 248 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', \ 249 | 'STSBenchmark', \ 250 | 'SICKRelatedness' \ 251 | ] 252 | # Reference: https://github.com/facebookresearch/SentEval/blob/main/senteval/sts.py 253 | results = {} 254 | for task in target_tasks: 255 | data = read_benchmark_data(args.senteval_path, task) 256 | result = evaluate_retrieval_augmented_promptbert(args, data, batcher, sentence_vecs) 257 | results[task] = result 258 | 259 | task_names = [] 260 | scores = [] 261 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 262 | task_names.append(task) 263 | if task in results: 264 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 265 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 266 | else: 267 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 268 | else: 269 | scores.append("0.00") 270 | task_names.append("Avg.") 271 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 272 | print_table(task_names, scores) 273 | 274 | return 0 275 | 276 | if __name__ == "__main__": 277 | args = parse_args() 278 | _ = main(args) 279 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/prompt_bert_rank_encoder_inference.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CUDA_VISIBLE_DEVICES=0,1 python prompt_bert_rank_encoder_inference.py \ 5 | --sentence_vecs $OUTPUT_DIR/promptbert/index_vecs/corpus_0.01_prompt_bert_rank_encoder_seed_$SEED.npy \ 6 | --senteval_path SentEval/data \ 7 | --batch_size 256 \ 8 | --model_name_or_path $OUTPUT_DIR/promptbert/checkpoints/prompt_bert_rank_encoder_seed_$SEED \ 9 | --mask_embedding_sentence \ 10 | --mask_embedding_sentence_template "*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*" \ 11 | --lmb 0.9 12 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/promptbert_module.py: -------------------------------------------------------------------------------- 1 | import apex 2 | import re 3 | import sys 4 | import io, os 5 | import faiss 6 | import csv 7 | import math 8 | import json 9 | import torch 10 | import numpy as np 11 | import logging 12 | import tqdm 13 | import time 14 | import argparse 15 | from prettytable import PrettyTable 16 | import torch 17 | import transformers 18 | from transformers import AutoModel, AutoTokenizer 19 | from torch.nn.functional import normalize 20 | from tqdm import tqdm 21 | # Set up logger 22 | 23 | class PromptBERTEncoder: 24 | def __init__(self, args, gpu_ids, device): 25 | self.args = args 26 | 27 | # Load transformers' model checkpoint 28 | if self.args.mask_embedding_sentence_org_mlp: 29 | #only for bert-base 30 | from transformers import BertForMaskedLM, BertConfig 31 | self.config = BertConfig.from_pretrained("bert-base-uncased") 32 | self.mlp = BertForMaskedLM.from_pretrained('bert-base-uncased', config=self.config).cls.predictions.transform 33 | if 'result' in self.args.model_name_or_path: 34 | state_dict = torch.load(self.args.model_name_or_path+'/pytorch_model.bin') 35 | new_state_dict = {} 36 | for key, param in state_dict.items(): 37 | # Replace "mlp" to "pooler" 38 | if 'pooler' in key: 39 | key = key.replace("pooler.", "") 40 | new_state_dict[key] = param 41 | self.mlp.load_state_dict(new_state_dict) 42 | self.model = AutoModel.from_pretrained(self.args.model_name_or_path) 43 | self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path, use_fast=True) 44 | 45 | if self.args.mask_embedding_sentence_autoprompt: 46 | state_dict = torch.load(self.args.model_name_or_path+'/pytorch_model.bin') 47 | self.p_mbv = state_dict['p_mbv'] 48 | template = self.args.mask_embedding_sentence_template 49 | template = template.replace('*mask*', self.tokenizer.mask_token)\ 50 | .replace('*sep+*', '')\ 51 | .replace('*cls*', '').replace('*sent_0*', ' ').replace('_', ' ') 52 | mask_embedding_template = self.tokenizer.encode(template) 53 | mask_index = mask_embedding_template.index(self.tokenizer.mask_token_id) 54 | index_mbv = mask_embedding_template[1:mask_index] + mask_embedding_template[mask_index+1:-1] 55 | #mask_embedding_template = [ 101, 2023, 6251, 1997, 1000, 1000, 2965, 103, 1012, 102] 56 | #index_mbv = mask_embedding_template[1:7] + mask_embedding_template[8:9] 57 | 58 | self.dict_mbv = index_mbv 59 | self.fl_mbv = [i <= 3 for i, k in enumerate(index_mbv)] 60 | 61 | self.device = torch.device(device) 62 | n_gpu = len(gpu_ids) 63 | self.encoder = self.model 64 | 65 | if n_gpu > 1: 66 | self.model = torch.nn.DataParallel(self.model, device_ids=gpu_ids) 67 | self.encoder = self.model.module 68 | self.model.eval() 69 | #self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 70 | 71 | #device = torch.device("cpu") 72 | self.model = self.model.to(self.device) 73 | if self.args.mask_embedding_sentence_org_mlp: 74 | self.mlp = self.mlp.to(self.device) 75 | 76 | if self.args.mask_embedding_sentence_delta: 77 | self.delta, self.template_len = self.get_delta(self.model, self.args.mask_embedding_sentence_template, self.tokenizer, self.device, self.args) 78 | 79 | if self.args.remove_continue_word: 80 | self.pun_remove_set = {'?', '*', '#', '´', '’', '=', '…', '|', '~', '/', '‚', '¿', '–', '»', '-', '€', '‘', '"', '(', '•', '`', '$', ':', '[', '”', '%', '£', '<', '[UNK]', ';', '“', '@', '_', '{', '^', ',', '.', '!', '™', '&', ']', '>', '\\', "'", ')', '+', '—'} 81 | if self.args.model_name_or_path == 'roberta-base': 82 | self.remove_set = {'Ġ.', 'Ġa', 'Ġthe', 'Ġin', 'a', 'Ġ, ', 'Ġis', 'Ġto', 'Ġof', 'Ġand', 'Ġon', 'Ġ\'', 's', '.', 'the', 'Ġman', '-', 'Ġwith', 'Ġfor', 'Ġat', 'Ġwoman', 'Ġare', 'Ġ"', 'Ġthat', 'Ġit', 'Ġdog', 'Ġsaid', 'Ġplaying', 'Ġwas', 'Ġas', 'Ġfrom', 'Ġ:', 'Ġyou', 'Ġan', 'i', 'Ġby'} 83 | else: 84 | self.remove_set = {".", "a", "the", "in", ",", "is", "to", "of", "and", "'", "on", "man", "-", "s", "with", "for", "\"", "at", "##s", "woman", "are", "it", "two", "that", "you", "dog", "said", "playing", "i", "an", "as", "was", "from", ":", "by", "white"} 85 | 86 | self.vocab = self.tokenizer.get_vocab() 87 | 88 | def get_delta(self, model, template, tokenizer, device, args): 89 | model.eval() 90 | 91 | template = template.replace('*mask*', tokenizer.mask_token)\ 92 | .replace('*sep+*', '')\ 93 | .replace('*cls*', '').replace('*sent_0*', ' ') 94 | # strip for roberta tokenizer 95 | bs_length = len(tokenizer.encode(template.split(' ')[0].replace('_', ' ').strip())) - 2 + 1 96 | # replace for roberta tokenizer 97 | batch = tokenizer([template.replace('_', ' ').strip().replace(' ', ' ')], return_tensors='pt') 98 | batch['position_ids'] = torch.arange(batch['input_ids'].shape[1]).to(self.device).unsqueeze(0) 99 | for k in batch: 100 | batch[k] = batch[k].repeat(512, 1).to(device) 101 | m_mask = batch['input_ids'] == tokenizer.mask_token_id 102 | 103 | with torch.no_grad(): 104 | outputs = self.model(**batch, output_hidden_states=True, return_dict=True) 105 | last_hidden = outputs.hidden_states[-1] 106 | delta = last_hidden[m_mask] 107 | delta.requires_grad = False 108 | #import pdb;pdb.set_trace() 109 | template_len = batch['input_ids'].shape[1] 110 | return delta, template_len 111 | 112 | def batcher(self, batch, max_length=None): 113 | sentences = batch 114 | if self.args.mask_embedding_sentence and self.args.mask_embedding_sentence_template is not None: 115 | # *cls*_This_sentence_of_"*sent_0*"_means*mask*.*sep+* 116 | template = self.args.mask_embedding_sentence_template 117 | template = template.replace('*mask*', self.tokenizer.mask_token )\ 118 | .replace('_', ' ').replace('*sep+*', '')\ 119 | .replace('*cls*', '') 120 | template_tokens = self.tokenizer.encode(template.replace('*sent 0*', ""), add_special_tokens=False) 121 | len_template_tokens = len(template_tokens) 122 | 123 | for i, s in enumerate(sentences): 124 | if len(s) > 0 and s[-1] not in '.?"\'': s += '.' 125 | s_tokens = self.tokenizer.encode(s, add_special_tokens=False) 126 | limit = 512 - 2 - len_template_tokens 127 | if limit < len(s_tokens): 128 | s_tokens = s_tokens[:limit] 129 | s = self.tokenizer.decode(s_tokens) 130 | sentences[i] = template.replace('*sent 0*', s).strip() 131 | elif self.args.remove_continue_word: 132 | for i, s in enumerate(sentences): 133 | sentences[i] = ' ' if self.args.model_name_or_path == 'roberta-base' else '' 134 | es = self.tokenizer.encode(' ' + s, add_special_tokens=False) 135 | for iw, w in enumerate(self.tokenizer.convert_ids_to_tokens(es)): 136 | if self.args.model_name_or_path == 'roberta-base': 137 | # roberta base 138 | if 'Ġ' not in w or w in self.remove_set: 139 | pass 140 | else: 141 | if re.search('[a-zA-Z0-9]', w) is not None: 142 | sentences[i] += w.replace('Ġ', '').lower() + ' ' 143 | elif w not in self.remove_set and w not in self.pun_remove_set and '##' not in w: 144 | # bert base 145 | sentences[i] += w.lower() + ' ' 146 | if len(sentences[i]) == 0: sentences[i] = '[PAD]' 147 | 148 | if max_length is not None: 149 | batch = self.tokenizer.batch_encode_plus( 150 | sentences, 151 | return_tensors='pt', 152 | padding=True, 153 | max_length=max_length, 154 | truncation=True 155 | ) 156 | else: 157 | batch = self.tokenizer.batch_encode_plus( 158 | sentences, 159 | return_tensors='pt', 160 | padding=True, 161 | truncation=True 162 | ) 163 | 164 | # Move to the correct device 165 | for k in batch: 166 | batch[k] = batch[k].to(self.device) if batch[k] is not None else None 167 | 168 | # Get raw embeddings 169 | with torch.no_grad(): 170 | if self.args.embedding_only: 171 | hidden_states = None 172 | pooler_output = None 173 | last_hidden = self.encoder.embeddings.word_embeddings(batch['input_ids']) 174 | position_ids = self.encoder.embeddings.position_ids[:, 0 : last_hidden.shape[1]] 175 | token_type_ids = torch.zeros(batch['input_ids'].shape, dtype=torch.long, 176 | device=self.encoder.embeddings.position_ids.device) 177 | 178 | position_embeddings = self.encoder.embeddings.position_embeddings(position_ids) 179 | token_type_embeddings = self.encoder.embeddings.token_type_embeddings(token_type_ids) 180 | 181 | if self.args.remove_continue_word: 182 | batch['attention_mask'][batch['input_ids'] == self.tokenizer.cls_token_id] = 0 183 | batch['attention_mask'][batch['input_ids'] == self.tokenizer.sep_token_id] = 0 184 | elif self.args.mask_embedding_sentence_autoprompt: 185 | input_ids = batch['input_ids'] 186 | inputs_embeds = self.encoder.embeddings.word_embeddings(input_ids) 187 | p = torch.arange(input_ids.shape[1]).to(input_ids.device).view(1, -1) 188 | b = torch.arange(input_ids.shape[0]).to(input_ids.device) 189 | for i, k in enumerate(self.dict_mbv): 190 | if self.fl_mbv[i]: 191 | index = ((input_ids == k) * p).max(-1)[1] 192 | else: 193 | index = ((input_ids == k) * -p).min(-1)[1] 194 | inputs_embeds[b, index] = self.p_mbv[i] 195 | batch['input_ids'], batch['inputs_embeds'] = None, inputs_embeds 196 | outputs = self.model(**batch, output_hidden_states=True, return_dict=True) 197 | batch['input_ids'] = input_ids 198 | 199 | last_hidden = outputs.last_hidden_state 200 | pooler_output = last_hidden[input_ids == self.tokenizer.mask_token_id] 201 | 202 | if self.args.mask_embedding_sentence_org_mlp: 203 | pooler_output = self.mlp(pooler_output) 204 | if self.args.mask_embedding_sentence_delta: 205 | blen = batch['attention_mask'].sum(-1) - self.template_len 206 | if self.args.mask_embedding_sentence_org_mlp: 207 | pooler_output -= self.mlp(self.delta[blen]) 208 | else: 209 | pooler_output -= self.delta[blen] 210 | if self.args.mask_embedding_sentence_use_pooler: 211 | pooler_output = self.encoder.pooler.dense(pooler_output) 212 | pooler_output = self.encoder.pooler.activation(pooler_output) 213 | 214 | else: 215 | outputs = self.model(**batch, output_hidden_states=True, return_dict=True) 216 | 217 | try: 218 | pooler_output = outputs.pooler_output 219 | except AttributeError: 220 | pooler_output = outputs['last_hidden_state'][:, 0, :] 221 | if self.args.mask_embedding_sentence: 222 | last_hidden = outputs.last_hidden_state 223 | pooler_output = last_hidden[batch['input_ids'] == self.tokenizer.mask_token_id] 224 | if self.args.mask_embedding_sentence_org_mlp: 225 | pooler_output = self.mlp(pooler_output) 226 | if self.args.mask_embedding_sentence_delta: 227 | blen = batch['attention_mask'].sum(-1) - self.template_len 228 | if self.args.mask_embedding_sentence_org_mlp: 229 | pooler_output -= self.mlp(self.delta[blen]) 230 | else: 231 | pooler_output -= self.delta[blen] 232 | if self.args.mask_embedding_sentence_use_org_pooler: 233 | pooler_output = self.mlp(pooler_output) 234 | if self.args.mask_embedding_sentence_use_pooler: 235 | pooler_output = self.encoder.pooler.dense(pooler_output) 236 | pooler_output = self.encoder.pooler.activation(pooler_output) 237 | else: 238 | last_hidden = outputs.last_hidden_state 239 | hidden_states = outputs.hidden_states 240 | 241 | sentence_embedding = None 242 | # Apply different pooler 243 | if self.args.mask_embedding_sentence: 244 | sentence_embedding = pooler_output.view(batch['input_ids'].shape[0], -1) 245 | elif self.args.pooler == 'cls': 246 | # There is a linear+activation layer after CLS representation 247 | sentence_embedding = pooler_output 248 | elif self.args.pooler == 'cls_before_pooler': 249 | batch['input_ids'][(batch['input_ids'] == 0) | (batch['input_ids'] == 101) | (batch['input_ids'] == 102)] = batch['input_ids'].max() 250 | index = batch['input_ids'].topk(3, dim=-1, largest=False)[1] 251 | index2 = torch.arange(batch['input_ids'].shape[0]).to(index.device) 252 | r = last_hidden[index2, index[:, 0], :] 253 | for i in range(1, 3): 254 | r += last_hidden[index2, index[:, i], :] 255 | sentence_embedding = (r/3) 256 | elif self.args.pooler == "avg": 257 | sentence_embedding = ( \ 258 | (last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) \ 259 | ) 260 | elif self.args.pooler == "avg_first_last": 261 | first_hidden = hidden_states[0] 262 | last_hidden = hidden_states[-1] 263 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 264 | sentence_embedding = pooled_result 265 | elif self.args.pooler == "avg_top2": 266 | second_last_hidden = hidden_states[-2] 267 | last_hidden = hidden_states[-1] 268 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 269 | sentence_embedding = pooled_result 270 | else: 271 | raise NotImplementedError 272 | 273 | # Retriever encoding 274 | #retriever_embedding = normalize(sentence_embedding) 275 | 276 | return sentence_embedding 277 | 278 | def embed(self, sentences): 279 | sentence_vectors = [] 280 | for i in range(0, len(sentences), self.args.batch_size): 281 | batch = sentences[i:i+self.args.batch_size] 282 | vectors = self.batcher(batch) 283 | assert vectors.size()[1] == self.encoder.config.hidden_size 284 | sentence_vectors.append(vectors) 285 | sentence_vectors = torch.cat(sentence_vectors, dim=0) 286 | return sentence_vectors 287 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/run_prompt_bert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=$PROJECT_DIR/data 4 | OUTPUT_DIR=$PROJECT_DIR/outputs 5 | SEED=61507 6 | 7 | GPU=0 8 | ES=125 # --eval_steps 9 | BMETRIC=stsb_spearman # --metric_for_best_model 10 | TRAIN_FILE=$DATA_DIR/corpus/corpus.txt 11 | 12 | args=() # flags for training 13 | eargs=() # flags for evaluation 14 | 15 | BC=(python train_prompt_bert.py) 16 | BATCH=256 17 | EPOCH=1 18 | LR=1e-5 19 | TEMPLATE="*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*" 20 | TEMPLATE2="*cls*_This_sentence_:_\"*sent_0*\"_means*mask*.*sep+*" 21 | MODEL=bert-base-uncased 22 | args=(--mlp_only_train --mask_embedding_sentence\ 23 | --mask_embedding_sentence_template $TEMPLATE\ 24 | --mask_embedding_sentence_different_template $TEMPLATE2\ 25 | --mask_embedding_sentence_delta\ 26 | --seed $SEED ) 27 | eargs=(--mask_embedding_sentence \ 28 | --mask_embedding_sentence_template $TEMPLATE ) 29 | 30 | CHECKPOINT=$OUTPUT_DIR/promptbert/checkpoints/prompt_bert_seed_$SEED 31 | CUDA_VISIBLE_DEVICES=$GPU ${BC[@]}\ 32 | --model_name_or_path $MODEL\ 33 | --train_file $TRAIN_FILE\ 34 | --output_dir $CHECKPOINT\ 35 | --num_train_epochs $EPOCH\ 36 | --per_device_train_batch_size $BATCH \ 37 | --learning_rate $LR \ 38 | --max_seq_length 32\ 39 | --evaluation_strategy steps\ 40 | --metric_for_best_model $BMETRIC\ 41 | --load_best_model_at_end\ 42 | --eval_steps $ES\ 43 | --overwrite_output_dir\ 44 | --temp 0.05\ 45 | --do_train\ 46 | --fp16\ 47 | --preprocessing_num_workers 10\ 48 | ${args[@]} 49 | 50 | CUDA_VISIBLE_DEVICES=$GPU python evaluation.py \ 51 | --model_name_or_path $CHECKPOINT \ 52 | --pooler avg\ 53 | --mode test\ 54 | ${eargs[@]} 55 | -------------------------------------------------------------------------------- /code/PromptBERT_RankEncoder/run_prompt_bert_rank_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_DIR=$PROJECT_DIR/data 3 | OUTPUT_DIR=$PROJECT_DIR/outputs 4 | 5 | SEED=61507 6 | GPU=0,1 7 | ES=125 # --eval_steps 8 | BMETRIC=avg_sts # --metric_for_best_model 9 | TRAIN_FILE=$DATA_DIR/corpus/corpus.txt 10 | 11 | args=() # flags for training 12 | eargs=() # flags for evaluation 13 | 14 | BC=(python train_prompt_bert_rank_encoder.py) 15 | BATCH=256 16 | EPOCH=1 17 | LR=1e-5 18 | TEMPLATE="*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*" 19 | TEMPLATE2="*cls*_This_sentence_:_\"*sent_0*\"_means*mask*.*sep+*" 20 | MODEL=bert-base-uncased 21 | args=(--mlp_only_train --mask_embedding_sentence\ 22 | --mask_embedding_sentence_template $TEMPLATE\ 23 | --mask_embedding_sentence_different_template $TEMPLATE2\ 24 | --mask_embedding_sentence_delta\ 25 | --seed $SEED \ 26 | --baseE_model_name_or_checkpoint $OUTPUT_DIR/promptbert/checkpoints/prompt_bert_seed_$SEED \ 27 | --corpus_vecs $OUTPUT_DIR/promptbert/index_vecs/corpus_0.01_prompt_bert_seed_$SEED.npy \ 28 | --baseE_lmb 0.05 \ 29 | --baseE_sim_thresh_low 0.5 \ 30 | --baseE_sim_thresh_upp 0.8 \ 31 | --loss_type hinge \ 32 | --mask_embedding_sentence_delta_no_delta_eval ) 33 | eargs=(--mask_embedding_sentence \ 34 | --mask_embedding_sentence_template $TEMPLATE ) 35 | 36 | CHECKPOINT=$OUTPUT_DIR/promptbert/checkpoints/prompt_bert_rank_encoder_seed_$SEED 37 | CUDA_VISIBLE_DEVICES=$GPU ${BC[@]}\ 38 | --model_name_or_path $MODEL\ 39 | --train_file $TRAIN_FILE\ 40 | --output_dir $CHECKPOINT\ 41 | --num_train_epochs $EPOCH\ 42 | --per_device_train_batch_size $BATCH \ 43 | --learning_rate $LR \ 44 | --max_seq_length 32\ 45 | --evaluation_strategy steps\ 46 | --metric_for_best_model $BMETRIC\ 47 | --load_best_model_at_end\ 48 | --eval_steps $ES \ 49 | --overwrite_output_dir \ 50 | --temp 0.05\ 51 | --do_train\ 52 | --fp16\ 53 | --preprocessing_num_workers 10\ 54 | ${args[@]} 55 | 56 | CUDA_VISIBLE_DEVICES=$GPU python evaluation.py \ 57 | --model_name_or_path $CHECKPOINT \ 58 | --pooler avg\ 59 | --mode test\ 60 | ${eargs[@]} 61 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ### Requirements 4 | You need at least two GPUs to proceed the following instructions. 5 | 6 | ### Setting 7 | 1. Set the project directory 8 | ```bash 9 | export PROJECT_DIR=/path/to/this/project/folder 10 | ``` 11 | Note that there is no "/" at the end, e.g., /home/RankEncoder. 12 | 13 | 2. Download the SentEval folder at https://github.com/princeton-nlp/SimCSE and locate the file at code/SNCSE\_RankEncoder/ 14 | 3. Go to SentEval/data/downstream and execute the following command 15 | ```bash 16 | bash download_dataset.sh 17 | ``` 18 | 4. Download the Wiki1m dataset with the following command and locate this dataset at data/corpus/corpus.txt 19 | ```bash 20 | bash wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt 21 | mv wiki1m_for_simcse.txt ../../data/corpus/corpus.txt 22 | ``` 23 | 5. Go to code/file\_utils/ and execute the following command. 24 | ```bash 25 | bash random_sampling_sentences.sh 26 | ``` 27 | 28 | ### Training the base encoder ([SNCSE](https://arxiv.org/abs/2201.05979)) 29 | We recommend to use the SNCSE checkpoint provided by the authors. Please follow the instruction below. 30 | 1. Download SNCSE-bert-base-uncased.zip file at [this link](https://drive.google.com/drive/folders/1w2srzbtTMLlaxUx-7ETV9vQWdw_6lVuN?usp=sharing) 31 | 2. Unzip this file and locate the SNCSE-bert-base-uncased folder in $PROJECT\_DIR/outputs/sncse/checkpoints/ 32 | 33 | Please see details at [this link](https://github.com/Sense-GVT/SNCSE) if you want to train SNCSE from scratch. 34 | 35 | ### Training RankEncoder 36 | 1. Download the following model 37 | ```bash 38 | python -m spacy download en_core_web_sm 39 | ``` 40 | 2. Generate soft negative samples (please see details in the [SNCSE paper](https://arxiv.org/abs/2201.05979)) 41 | ```bash 42 | bash generate_soft_negative_samples.sh 43 | ``` 44 | 3. Get sentence vectors with the base encoder 45 | ```bash 46 | bash get_sncse_embedding.sh 47 | ``` 48 | 4. Train RankEncoder 49 | ```bash 50 | cd ./sncse_rank_encoder/ 51 | bash sncse_rank_encoder.sh 52 | ``` 53 | We provide the checkpoint of the trained RankeEncoder-SNCSE [here](https://drive.google.com/file/d/1YSxcTl6bVXqkC2oARq9MyH_9RdBT9G-d/view?usp=share_link) 54 | 55 | ### Evaluation 56 | The following command will give you the performance of RankEncoder (without Eq.7 in [our paper](https://arxiv.org/pdf/2209.04333.pdf)) 57 | ```bash 58 | bash evaluation_sncse_rank_encoder.sh 59 | ``` 60 | 61 | The following command compute the performance of RankEncoder (with Eq.7 in [our paper](https://arxiv.org/pdf/2209.04333.pdf)) 62 | 63 | ```bash 64 | bash get_sncse_rank_encoder_embedding.sh 65 | bash sncse_rank_encoder_inference.sh 66 | ``` 67 | Note that we only sample 10,000 sentences for computational efficiency. Please use 100,000 sentences to replicate our experimental results. 68 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io, os 3 | import numpy as np 4 | import logging 5 | import argparse 6 | from prettytable import PrettyTable 7 | import string 8 | import torch 9 | import transformers 10 | from transformers import AutoModel, AutoTokenizer 11 | from transformers import BertTokenizer, BertModel 12 | 13 | # Set up logger 14 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 15 | 16 | # Set PATHs 17 | PATH_TO_SENTEVAL = './SentEval' 18 | PATH_TO_DATA = './SentEval/data' 19 | 20 | # Import SentEval 21 | sys.path.insert(0, PATH_TO_SENTEVAL) 22 | import senteval 23 | 24 | PUNCTUATION = list(string.punctuation) 25 | 26 | def print_table(task_names, scores): 27 | tb = PrettyTable() 28 | tb.field_names = task_names 29 | tb.add_row(scores) 30 | print(tb) 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--model_name_or_path", type=str, 35 | help="Transformers' model name or path") 36 | parser.add_argument("--mode", type=str, 37 | choices=['dev', 'test', 'fasttest'], 38 | default='test', 39 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 40 | parser.add_argument("--task_set", type=str, 41 | choices=['sts', 'transfer', 'full', 'na'], 42 | default='sts', 43 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 44 | parser.add_argument("--tasks", type=str, nargs='+', 45 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 46 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC', 47 | 'SICKRelatedness', 'STSBenchmark'], 48 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden") 49 | 50 | args = parser.parse_args() 51 | 52 | # Load transformers' model checkpoint 53 | model = BertModel.from_pretrained(args.model_name_or_path) 54 | tokenizer = BertTokenizer(vocab_file=os.path.join(args.model_name_or_path, "vocab.txt")) 55 | temp = {"mask_token": tokenizer.mask_token} 56 | tokenizer.add_special_tokens(temp) 57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 58 | model = model.to(device) 59 | 60 | # Set up the tasks 61 | if args.task_set == 'sts': 62 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 63 | elif args.task_set == 'transfer': 64 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 65 | elif args.task_set == 'full': 66 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 67 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 68 | 69 | # Set params for SentEval 70 | if args.mode == 'dev' or args.mode == 'fasttest': 71 | # Fast mode 72 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 73 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 74 | 'tenacity': 3, 'epoch_size': 2} 75 | elif args.mode == 'test': 76 | # Full mode 77 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 78 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 79 | 'tenacity': 5, 'epoch_size': 4} 80 | else: 81 | raise NotImplementedError 82 | 83 | # SentEval prepare and batcher 84 | def prepare(params, samples): 85 | return 86 | 87 | def batcher(params, batch, max_length=None): 88 | # Handle rare token encoding issues in the dataset 89 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 90 | batch = [[word.decode('utf-8') for word in s] for s in batch] 91 | 92 | sentences = [' '.join(s) for s in batch] 93 | sentences = [ \ 94 | s + " ." if s.strip()[-1] not in PUNCTUATION else s \ 95 | for s in sentences \ 96 | ] 97 | sentences = [ \ 98 | '''This sentence : " ''' + s + ''' " means [MASK] .''' \ 99 | for s in sentences \ 100 | ] 101 | 102 | # Tokenization 103 | if max_length is not None: 104 | batch = tokenizer.batch_encode_plus( 105 | sentences, 106 | return_tensors='pt', 107 | padding=True, 108 | max_length=max_length, 109 | truncation=True 110 | ) 111 | else: 112 | batch = tokenizer.batch_encode_plus( 113 | sentences, 114 | return_tensors='pt', 115 | padding=True, 116 | ) 117 | 118 | # Move to the correct device 119 | for k in batch: 120 | batch[k] = batch[k].to(device) 121 | 122 | # Get raw embeddings 123 | with torch.no_grad(): 124 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 125 | last_hidden = outputs.last_hidden_state 126 | sent_vecs = last_hidden[batch["input_ids"] == tokenizer.mask_token_id].cpu() 127 | return sent_vecs 128 | 129 | results = {} 130 | 131 | for task in args.tasks: 132 | se = senteval.engine.SE(params, batcher, prepare) 133 | result = se.eval(task) 134 | results[task] = result 135 | 136 | # Print evaluation results 137 | if args.mode == 'dev': 138 | print("------ %s ------" % (args.mode)) 139 | 140 | task_names = [] 141 | scores = [] 142 | for task in ['STSBenchmark', 'SICKRelatedness']: 143 | task_names.append(task) 144 | if task in results: 145 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100)) 146 | else: 147 | scores.append("0.00") 148 | print_table(task_names, scores) 149 | 150 | task_names = [] 151 | scores = [] 152 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 153 | task_names.append(task) 154 | if task in results: 155 | scores.append("%.2f" % (results[task]['devacc'])) 156 | else: 157 | scores.append("0.00") 158 | task_names.append("Avg.") 159 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 160 | print_table(task_names, scores) 161 | 162 | elif args.mode == 'test' or args.mode == 'fasttest': 163 | print("------ %s ------" % (args.mode)) 164 | 165 | task_names = [] 166 | scores = [] 167 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 168 | task_names.append(task) 169 | if task in results: 170 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 171 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 172 | else: 173 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 174 | else: 175 | scores.append("0.00") 176 | task_names.append("Avg.") 177 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 178 | print_table(task_names, scores) 179 | 180 | task_names = [] 181 | scores = [] 182 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 183 | task_names.append(task) 184 | if task in results: 185 | scores.append("%.2f" % (results[task]['acc'])) 186 | else: 187 | scores.append("0.00") 188 | task_names.append("Avg.") 189 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 190 | print_table(task_names, scores) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/evaluation_sncse.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | python evaluation.py \ 5 | --model_name_or_path $OUTPUT_DIR/sncse/checkpoints/SNCSE-bert-base-uncased \ 6 | --task_set sts \ 7 | --mode test 8 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/evaluation_sncse_rank_encoder.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | python evaluation.py \ 5 | --model_name_or_path $OUTPUT_DIR/sncse/checkpoints/sncse_rank_encoder_seed_$SEED \ 6 | --task_set sts \ 7 | --mode test 8 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/generate_soft_negative_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import en_core_web_sm 4 | 5 | 6 | special_words = ["am", "is", "was", "are", "were", "can", "could", "will", 7 | "would", "shall", "should", "may", "must", "might"] 8 | 9 | 10 | def convert_to_negation(parser, sentence): 11 | 12 | parsered_sentence = parser(sentence) 13 | tokens = [str(_) for _ in parsered_sentence] 14 | deps = [_.dep_ for _ in parsered_sentence] 15 | tags = [_.tag_ for _ in parsered_sentence] 16 | lemmas = [_.lemma_ for _ in parsered_sentence] 17 | 18 | if "not" in tokens: 19 | index = tokens.index("not") 20 | del tokens[index] 21 | sentence_negation = " ".join(tokens) 22 | return sentence_negation 23 | 24 | flag = 0 25 | for dep in deps: 26 | if dep == "aux" or dep == "auxpass": 27 | flag = 1 28 | break 29 | if dep == "ROOT": 30 | flag = 2 31 | break 32 | 33 | if flag == 1: 34 | for i, dep in enumerate(deps): 35 | if dep == "aux" or dep == "auxpass": 36 | tokens[i] += " not" 37 | break 38 | elif flag == 2: 39 | index = deps.index("ROOT") 40 | if tokens[index].lower() in special_words: 41 | tokens[index] += " not" 42 | elif tags[index] == "VBP": 43 | tokens[index] = "do not " + lemmas[index] 44 | elif tags[index] == "VBZ": 45 | tokens[index] = "does not " + lemmas[index] 46 | elif tags[index] == "VBD": 47 | tokens[index] = "did not " + lemmas[index] 48 | else: 49 | tokens.insert(0, "Not") 50 | else: 51 | tokens.insert(0, "Not") 52 | 53 | sentence_negation = " ".join(tokens) 54 | 55 | return sentence_negation 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--in_file", type=str) 60 | parser.add_argument("--out_file", type=str) 61 | args = parser.parse_args() 62 | return args 63 | 64 | if __name__ == "__main__": 65 | args = parse_args() 66 | parser = en_core_web_sm.load() 67 | 68 | in_file = args.in_file 69 | 70 | out_file = args.out_file 71 | 72 | f = open(in_file) 73 | 74 | f1 = open(out_file, "w") 75 | 76 | for line in f: 77 | sentence = line.strip() 78 | negation = convert_to_negation(parser=parser, sentence=sentence) 79 | temp = [sentence, negation] 80 | f1.write(json.dumps(temp) + "\n") 81 | 82 | f.close() 83 | f1.close() 84 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/generate_soft_negative_samples.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=$PROJECT_DIR/data 2 | OUTPUT_DIR=$PROJECT_DIR/outputs 3 | 4 | python generate_soft_negative_samples.py \ 5 | --in_file $DATA_DIR/corpus/corpus.txt \ 6 | --out_file $OUTPUT_DIR/sncse/soft_negative_samples.txt 7 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/get_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import numpy as np 5 | import torch 6 | from scipy.spatial.distance import cosine 7 | from transformers import BertTokenizer, BertModel 8 | from tqdm import tqdm 9 | from scipy.stats import spearmanr 10 | import torch.nn as nn 11 | import string 12 | from tqdm import tqdm 13 | 14 | PUNCTUATION = list(string.punctuation) 15 | 16 | def calculate_vectors(tokenizer, model, texts): 17 | 18 | inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") 19 | 20 | for _ in inputs: 21 | inputs[_] = inputs[_].cuda() 22 | 23 | temp = inputs["input_ids"] 24 | 25 | # Get the embeddings 26 | with torch.no_grad(): 27 | embeddings = model(**inputs, output_hidden_states=True, return_dict=True).last_hidden_state.cpu() 28 | 29 | embeddings = embeddings[temp == tokenizer.mask_token_id] 30 | 31 | embeddings = embeddings.numpy() 32 | 33 | return embeddings 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--checkpoint", type=str, required=True) 38 | parser.add_argument("--corpus_file", type=str, required=True) 39 | parser.add_argument("--sentence_vectors_np_file", type=str, required=True) 40 | args = parser.parse_args() 41 | return args 42 | 43 | if __name__ == "__main__": 44 | args = parse_args() 45 | tokenizer = BertTokenizer(vocab_file=os.path.join(args.checkpoint, "vocab.txt")) 46 | 47 | temp = {"mask_token": tokenizer.mask_token} 48 | tokenizer.add_special_tokens(temp) 49 | 50 | model = BertModel.from_pretrained(args.checkpoint).cuda() 51 | device = torch.device("cpu") 52 | if torch.cuda.is_available(): 53 | device = torch.device("cuda") 54 | n_gpu = torch.cuda.device_count() 55 | 56 | if n_gpu > 1: 57 | model = torch.nn.DataParallel(model) 58 | model.eval() 59 | 60 | #device = torch.device("cpu") 61 | model = model.to(device) 62 | 63 | batch_size = 128 64 | 65 | with open(args.corpus_file, "r") as f: 66 | sentences = f.readlines() 67 | 68 | outputs = [] 69 | for i in tqdm(range(0, len(sentences), batch_size), desc="Computing..."): 70 | batch_sentences = sentences[i:i+batch_size] 71 | batch = [] 72 | for line in batch_sentences: 73 | text = line.strip() 74 | text = text + " ." if text.strip()[-1] not in PUNCTUATION else text 75 | text = '''This sentence : " ''' + text + ''' " means [MASK] .''' 76 | batch.append(text) 77 | vectors = calculate_vectors(tokenizer=tokenizer, model=model, texts=batch) 78 | outputs.append(vectors) 79 | outputs = np.concatenate(outputs, axis=0) 80 | 81 | os.makedirs(os.path.dirname(args.sentence_vectors_np_file), exist_ok=True) 82 | with open(args.sentence_vectors_np_file, "wb") as f: 83 | np.save(f, outputs) 84 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/get_sncse_embedding.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CUDA_VISIBLE_DEVICES=0,1 python get_embedding.py \ 5 | --checkpoint $OUTPUT_DIR/sncse/checkpoints/SNCSE-bert-base-uncased \ 6 | --corpus_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 7 | --sentence_vectors_np_file $OUTPUT_DIR/sncse/index_vecs/corpus_0.01_sncse.npy 8 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/get_sncse_rank_encoder_embedding.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CUDA_VISIBLE_DEVICES=0,1 python get_embedding.py \ 5 | --checkpoint $OUTPUT_DIR/sncse/checkpoints/sncse_rank_encoder_seed_$SEED \ 6 | --corpus_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 7 | --sentence_vectors_np_file $OUTPUT_DIR/sncse/index_vecs/corpus_0.01_sncse_rank_encoder_seed_$SEED.npy 8 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/sncse_rank_encoder/simcse/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import SimCSE 2 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/sncse_rank_encoder/simcse/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 4 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 5 | 6 | from transformers.modeling_outputs import SequenceClassifierOutput 7 | 8 | 9 | class MLPLayer(nn.Module): 10 | """ 11 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 12 | """ 13 | 14 | def __init__(self, config): 15 | super().__init__() 16 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 17 | self.activation = nn.Tanh() 18 | 19 | def forward(self, features, **kwargs): 20 | x = self.dense(features) 21 | x = self.activation(x) 22 | 23 | return x 24 | 25 | 26 | class Similarity(nn.Module): 27 | """ 28 | Dot product or cosine similarity 29 | """ 30 | 31 | def __init__(self, temp): 32 | super().__init__() 33 | self.temp = temp 34 | self.cos = nn.CosineSimilarity(dim=-1) 35 | 36 | def forward(self, x, y): 37 | return self.cos(x, y) / self.temp 38 | 39 | class Pooler(nn.Module): 40 | """ 41 | Parameter-free poolers to get the sentence embedding 42 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 43 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 44 | 'avg': average of the last layers' hidden states at each token. 45 | 'avg_top2': average of the last two layers. 46 | 'avg_first_last': average of the first and the last layers. 47 | """ 48 | def __init__(self, pooler_type): 49 | super().__init__() 50 | self.pooler_type = pooler_type 51 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 52 | 53 | def forward(self, attention_mask, outputs): 54 | last_hidden = outputs.last_hidden_state 55 | pooler_output = outputs.pooler_output 56 | hidden_states = outputs.hidden_states 57 | 58 | if self.pooler_type in ['cls_before_pooler', 'cls']: 59 | return last_hidden[:, 0] 60 | elif self.pooler_type == "avg": 61 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 62 | elif self.pooler_type == "avg_first_last": 63 | first_hidden = hidden_states[0] 64 | last_hidden = hidden_states[-1] 65 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 66 | return pooled_result 67 | elif self.pooler_type == "avg_top2": 68 | second_last_hidden = hidden_states[-2] 69 | last_hidden = hidden_states[-1] 70 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 71 | return pooled_result 72 | else: 73 | raise NotImplementedError 74 | 75 | def _get_ranks(x: torch.Tensor) -> torch.Tensor: 76 | x_rank = x.argsort(dim=1) 77 | ranks = torch.zeros_like(x_rank, dtype=torch.float) 78 | n, d = x_rank.size() 79 | 80 | for i in range(n): 81 | ranks[i][x_rank[i]] = torch.arange(d, dtype=torch.float).to(ranks.device) 82 | return ranks 83 | 84 | def cal_spr_corr(x: torch.Tensor, y: torch.Tensor): 85 | x_rank = _get_ranks(x) 86 | y_rank = _get_ranks(y) 87 | x_rank_mean = torch.mean(x_rank, dim=1).unsqueeze(1) 88 | y_rank_mean = torch.mean(y_rank, dim=1).unsqueeze(1) 89 | xn = x_rank - x_rank_mean 90 | yn = y_rank - y_rank_mean 91 | x_var = torch.sqrt(torch.sum(torch.square(xn), dim=1).unsqueeze(1)) 92 | y_var = torch.sqrt(torch.sum(torch.square(yn), dim=1).unsqueeze(1)) 93 | xn = xn / x_var 94 | yn = yn / y_var 95 | 96 | return torch.mm(xn, torch.transpose(yn, 0, 1)) 97 | 98 | def cl_init(cls, config, model_loss=None): 99 | """ 100 | Contrastive learning class init function. 101 | """ 102 | cls.pooler_type = cls.model_args.pooler_type 103 | cls.pooler = Pooler(cls.model_args.pooler_type) 104 | if cls.model_args.pooler_type == "cls": 105 | cls.mlp = MLPLayer(config) 106 | cls.sim = Similarity(temp=cls.model_args.temp) 107 | cls.init_weights() 108 | 109 | 110 | def cl_forward(cls, 111 | encoder, 112 | input_ids=None, 113 | attention_mask=None, 114 | token_type_ids=None, 115 | position_ids=None, 116 | head_mask=None, 117 | inputs_embeds=None, 118 | labels=None, 119 | output_attentions=None, 120 | output_hidden_states=None, 121 | return_dict=None, 122 | mlm_input_ids=None, 123 | mlm_labels=None, 124 | distances1=None, 125 | distances2=None, 126 | baseE_vecs1=None, 127 | baseE_vecs2=None, 128 | ): 129 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 130 | 131 | batch_size = input_ids.size(0) 132 | 133 | # Number of sentences in one instance 134 | num_sent = input_ids.size(1) 135 | 136 | # Flatten input for encoding 137 | input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len) 138 | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len) 139 | if token_type_ids is not None: 140 | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len) 141 | 142 | torch.cuda.empty_cache() 143 | 144 | # Get raw embeddings 145 | outputs = encoder( 146 | input_ids, 147 | attention_mask=attention_mask, 148 | token_type_ids=token_type_ids, 149 | position_ids=position_ids, 150 | head_mask=head_mask, 151 | inputs_embeds=inputs_embeds, 152 | output_attentions=True, 153 | output_hidden_states=True, 154 | return_dict=True, 155 | ) 156 | 157 | # Obtain sentence embeddings from [MASK] token 158 | index = input_ids == cls.mask_token_id 159 | last_hidden_state = outputs.last_hidden_state[index] 160 | assert last_hidden_state.size() == torch.Size([batch_size * num_sent, last_hidden_state.size(-1)]) 161 | 162 | # During training, add an extra MLP layer with activation function 163 | if cls.pooler_type == "cls": 164 | pooler_output = cls.mlp(last_hidden_state) 165 | else: 166 | pooler_output = last_hidden_state 167 | 168 | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) 169 | return pooler_output, last_hidden_state 170 | 171 | class BertForCL(BertPreTrainedModel): 172 | _keys_to_ignore_on_load_missing = [r"position_ids"] 173 | 174 | def __init__(self, config, mask_token_id=None, alpha=None, beta=None, lambda_=None, **model_kargs): 175 | super().__init__(config) 176 | self.model_args = model_kargs["model_args"] 177 | self.bert = BertModel(config, add_pooling_layer=False) 178 | self.mask_token_id = mask_token_id 179 | 180 | if self.model_args.do_mlm: 181 | self.lm_head = BertLMPredictionHead(config) 182 | 183 | cl_init(self, config) 184 | 185 | self.alpha = alpha # 0.1 186 | self.beta = beta # 0.3 187 | self.lambda_ = lambda_ # 1e-3 188 | 189 | def forward(self, 190 | input_ids=None, 191 | attention_mask=None, 192 | token_type_ids=None, 193 | position_ids=None, 194 | head_mask=None, 195 | inputs_embeds=None, 196 | labels=None, 197 | output_attentions=None, 198 | output_hidden_states=None, 199 | return_dict=None, 200 | sent_emb=False, 201 | mlm_input_ids=None, 202 | mlm_labels=None, 203 | distances1=None, 204 | distances2=None, 205 | baseE_vecs1=None, 206 | baseE_vecs2=None, 207 | ): 208 | return cl_forward(self, 209 | self.bert, 210 | input_ids=input_ids, 211 | attention_mask=attention_mask, 212 | token_type_ids=token_type_ids, 213 | position_ids=position_ids, 214 | head_mask=head_mask, 215 | inputs_embeds=inputs_embeds, 216 | labels=labels, 217 | output_attentions=output_attentions, 218 | output_hidden_states=output_hidden_states, 219 | return_dict=return_dict, 220 | mlm_input_ids=mlm_input_ids, 221 | mlm_labels=mlm_labels, 222 | distances1=distances1, 223 | distances2=distances2, 224 | baseE_vecs1=baseE_vecs1, 225 | baseE_vecs2=baseE_vecs2, 226 | ) 227 | 228 | 229 | class RobertaForCL(RobertaPreTrainedModel): 230 | _keys_to_ignore_on_load_missing = [r"position_ids"] 231 | 232 | def __init__(self, config, mask_token_id, alpha=None, beta=None, lambda_=None, **model_kargs): 233 | super().__init__(config) 234 | self.model_args = model_kargs["model_args"] 235 | self.roberta = RobertaModel(config, add_pooling_layer=False) 236 | self.mask_token_id = mask_token_id 237 | 238 | if self.model_args.do_mlm: 239 | self.lm_head = RobertaLMHead(config) 240 | 241 | cl_init(self, config) 242 | 243 | self.alpha = alpha # 0.1 244 | self.beta = beta # 0.3 245 | self.lambda_ = lambda_ # 5e-4 246 | 247 | def forward(self, 248 | input_ids=None, 249 | attention_mask=None, 250 | token_type_ids=None, 251 | position_ids=None, 252 | head_mask=None, 253 | inputs_embeds=None, 254 | labels=None, 255 | output_attentions=None, 256 | output_hidden_states=None, 257 | return_dict=None, 258 | sent_emb=False, 259 | mlm_input_ids=None, 260 | mlm_labels=None, 261 | distances1=None, 262 | distances2=None, 263 | baseE_vecs1=None, 264 | baseE_vecs2=None, 265 | ): 266 | 267 | return cl_forward(self, self.roberta, 268 | input_ids=input_ids, 269 | attention_mask=attention_mask, 270 | token_type_ids=token_type_ids, 271 | position_ids=position_ids, 272 | head_mask=head_mask, 273 | inputs_embeds=inputs_embeds, 274 | labels=labels, 275 | output_attentions=output_attentions, 276 | output_hidden_states=output_hidden_states, 277 | return_dict=return_dict, 278 | mlm_input_ids=mlm_input_ids, 279 | mlm_labels=mlm_labels, 280 | distances1=distances1, 281 | distances2=distances2, 282 | baseE_vecs1=baseE_vecs1, 283 | baseE_vecs2=baseE_vecs2, 284 | ) 285 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/sncse_rank_encoder/simcse/tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SimCSE(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path))) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def search(self, queries: Union[str, List[str]], 181 | device: str = None, 182 | threshold: float = 0.6, 183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 184 | 185 | if not self.is_faiss_index: 186 | if isinstance(queries, list): 187 | combined_results = [] 188 | for query in queries: 189 | results = self.search(query, device) 190 | combined_results.append(results) 191 | return combined_results 192 | 193 | similarities = self.similarity(queries, self.index["index"]).tolist() 194 | id_and_score = [] 195 | for i, s in enumerate(similarities): 196 | if s >= threshold: 197 | id_and_score.append((i, s)) 198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 200 | return results 201 | else: 202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 203 | 204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 205 | 206 | def pack_single_result(dist, idx): 207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 208 | return results 209 | 210 | if isinstance(queries, list): 211 | combined_results = [] 212 | for i in range(len(queries)): 213 | results = pack_single_result(distance[i], idx[i]) 214 | combined_results.append(results) 215 | return combined_results 216 | else: 217 | return pack_single_result(distance[0], idx[0]) 218 | 219 | if __name__=="__main__": 220 | example_sentences = [ 221 | 'An animal is biting a persons finger.', 222 | 'A woman is reading.', 223 | 'A man is lifting weights in a garage.', 224 | 'A man plays the violin.', 225 | 'A man is eating food.', 226 | 'A man plays the piano.', 227 | 'A panda is climbing.', 228 | 'A man plays a guitar.', 229 | 'A woman is slicing a meat.', 230 | 'A woman is taking a picture.' 231 | ] 232 | example_queries = [ 233 | 'A man is playing music.', 234 | 'A woman is making a photo.' 235 | ] 236 | 237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 238 | simcse = SimCSE(model_name) 239 | 240 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 241 | similarities = simcse.similarity(example_queries, example_sentences) 242 | print(similarities) 243 | 244 | print("\n=========Naive brute force search============\n") 245 | simcse.build_index(example_sentences, use_faiss=False) 246 | results = simcse.search(example_queries) 247 | for i, result in enumerate(results): 248 | print("Retrieval results for query: {}".format(example_queries[i])) 249 | for sentence, score in result: 250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 251 | print("") 252 | 253 | print("\n=========Search with Faiss backend============\n") 254 | simcse.build_index(example_sentences, use_faiss=True) 255 | results = simcse.search(example_queries) 256 | for i, result in enumerate(results): 257 | print("Retrieval results for query: {}".format(example_queries[i])) 258 | for sentence, score in result: 259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 260 | print("") 261 | 262 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/sncse_rank_encoder/sncse_rank_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | OUTPUT_DIR=$PROJECT_DIR/outputs 4 | DATA_DIR=$PROJECT_DIR/data 5 | 6 | SEED=61507 7 | CUDA_VISIBLE_DEVICES=0,1 python train_SNCSE.py \ 8 | --baseE_sim_thresh_upp 0.9999 \ 9 | --baseE_sim_thresh_low 0.5 \ 10 | --baseE_lmb 0.05 \ 11 | --baseE_model_name_or_path $OUTPUT_DIR/sncse/checkpoints/SNCSE-bert-base-uncased \ 12 | --corpus_vecs $OUTPUT_DIR/sncse/index_vecs/corpus_0.01_sncse.npy \ 13 | --model_name_or_path bert-base-uncased \ 14 | --train_file $DATA_DIR/corpus/corpus.txt \ 15 | --output_dir $OUTPUT_DIR/sncse/checkpoints/sncse_rank_encoder_seed_$SEED \ 16 | --num_train_epoch 1 \ 17 | --per_device_train_batch_size 128 \ 18 | --learning_rate 1e-5 \ 19 | --max_seq_length 32 \ 20 | --evaluation_strategy steps \ 21 | --metric_for_best_model stsb_spearman \ 22 | --load_best_model_at_end \ 23 | --eval_step 125 \ 24 | --pooler_type cls \ 25 | --mlp_only_train \ 26 | --overwrite_output_dir \ 27 | --temp 0.05 \ 28 | --do_train \ 29 | --do_eval \ 30 | --preprocessing_num_workers 10 \ 31 | --seed $SEED \ 32 | --soft_negative_file $OUTPUT_DIR/sncse/soft_negative_samples.txt 33 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/sncse_rank_encoder_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import apex 3 | import re 4 | import sys 5 | import io, os 6 | import faiss 7 | import math 8 | import json 9 | import numpy as np 10 | import logging 11 | import tqdm 12 | import torch 13 | import time 14 | from prettytable import PrettyTable 15 | from scipy.stats import spearmanr, pearsonr 16 | from scipy.special import softmax 17 | from scipy.stats import rankdata 18 | import string 19 | import torch 20 | import transformers 21 | from transformers import AutoModel, AutoTokenizer 22 | from transformers import BertTokenizer, BertModel 23 | from tqdm import tqdm 24 | # Set up logger 25 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 26 | 27 | # Set PATHs 28 | PATH_TO_SENTEVAL = './SentEval' 29 | 30 | # Import SentEval 31 | sys.path.insert(0, PATH_TO_SENTEVAL) 32 | 33 | import senteval 34 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval 35 | from senteval.sts import SICKRelatednessEval 36 | from senteval.utils import cosine 37 | 38 | PUNCTUATION = list(string.punctuation) 39 | 40 | def normalize(vecs): 41 | eps = 1e-8 42 | return vecs / (np.sqrt(np.sum(np.square(vecs), axis=1)) + eps)[:,None] 43 | 44 | def print_table(task_names, scores): 45 | tb = PrettyTable() 46 | tb.field_names = task_names 47 | tb.add_row(scores) 48 | print(tb) 49 | 50 | def read_benchmark_data(senteval_path, task): 51 | task2class = { \ 52 | 'STS12': STS12Eval, 53 | 'STS13': STS13Eval, 54 | 'STS14': STS14Eval, 55 | 'STS15': STS15Eval, 56 | 'STS16': STS16Eval, 57 | 'STSBenchmark': STSBenchmarkEval, 58 | 'SICKRelatedness': SICKRelatednessEval 59 | } 60 | dataset_path = None 61 | print("SentEval path: {}".format(senteval_path)) 62 | if task in ["STS12", "STS13", "STS14", "STS15", "STS16"]: 63 | dataset_path = os.path.join(senteval_path, "downstream/STS/", "{}{}".format(task, "-en-test")) 64 | elif task == "STSBenchmark": 65 | dataset_path = os.path.join(senteval_path, "downstream/STS/", "{}".format(task)) 66 | elif task == "SICKRelatedness": 67 | dataset_path = os.path.join(senteval_path, "downstream/SICK") 68 | print(dataset_path) 69 | data = {} 70 | task_data = task2class[task](dataset_path) 71 | for dset in task_data.datasets: 72 | input1, input2, gs_scores = task_data.data[dset] 73 | data[dset] = (input1, input2, gs_scores) 74 | return data 75 | 76 | def compute_similarity(q0, q0_sim, q1, q1_sim, lmb=0.0): 77 | normalized_q0 = normalize(np.reshape(q0, (1, -1))) 78 | normalized_q1 = normalize(np.reshape(q1, (1, -1))) 79 | add_score, _ = spearmanr(q0_sim, q1_sim) 80 | score = np.sum(np.matmul(normalized_q0, normalized_q1.T)) 81 | score = lmb * score + (1.0 - lmb) * add_score 82 | return score 83 | 84 | def evaluate_retrieval_augmented_promptbert(args, data, batcher, sentence_vecs): 85 | results = {} 86 | all_sys_scores = [] 87 | all_gs_scores = [] 88 | for dset in data: 89 | sys_scores = [] 90 | input1, input2, gs_scores = data[dset] 91 | for ii in range(0, len(gs_scores), args.batch_size): 92 | batch1 = input1[ii:ii + args.batch_size] 93 | batch2 = input2[ii:ii + args.batch_size] 94 | 95 | # we assume get_batch already throws out the faulty ones 96 | if len(batch1) == len(batch2) and len(batch1) > 0: 97 | enc1 = batcher(batch1) 98 | enc2 = batcher(batch2) 99 | sim1 = np.matmul(enc1, sentence_vecs.T) 100 | sim2 = np.matmul(enc2, sentence_vecs.T) 101 | 102 | for kk in range(enc1.shape[0]): 103 | sys_score = compute_similarity( \ 104 | enc1[kk], sim1[kk], \ 105 | enc2[kk], sim2[kk], \ 106 | args.lmb \ 107 | ) 108 | sys_scores.append(sys_score) 109 | all_sys_scores.extend(sys_scores) 110 | all_gs_scores.extend(gs_scores) 111 | results[dset] = { 112 | 'pearson': pearsonr(sys_scores, gs_scores), 113 | 'spearman': spearmanr(sys_scores, gs_scores), 114 | 'nsamples': len(sys_scores) 115 | } 116 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 117 | (dset, results[dset]['pearson'][0], 118 | results[dset]['spearman'][0])) 119 | 120 | weights = [results[dset]['nsamples'] for dset in results.keys()] 121 | list_prs = np.array([results[dset]['pearson'][0] for 122 | dset in results.keys()]) 123 | list_spr = np.array([results[dset]['spearman'][0] for 124 | dset in results.keys()]) 125 | 126 | avg_pearson = np.average(list_prs) 127 | avg_spearman = np.average(list_spr) 128 | wavg_pearson = np.average(list_prs, weights=weights) 129 | wavg_spearman = np.average(list_spr, weights=weights) 130 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 131 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 132 | results['all'] = {'pearson': {'all': all_pearson[0], 133 | 'mean': avg_pearson, 134 | 'wmean': wavg_pearson}, 135 | 'spearman': {'all': all_spearman[0], 136 | 'mean': avg_spearman, 137 | 'wmean': wavg_spearman}} 138 | logging.debug('ALL : Pearson = %.4f, \ 139 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 140 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 141 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 142 | logging.debug('ALL (average) : Pearson = %.4f, \ 143 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 144 | results["pred_scores"] = all_sys_scores 145 | results["gs_scores"] = all_gs_scores 146 | return results 147 | 148 | def parse_args(): 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--sentence_vecs", type=str, required=True) 151 | parser.add_argument("--senteval_path", type=str, default="SentEval/data") 152 | parser.add_argument("--batch_size", type=int, default=32) 153 | parser.add_argument("--lmb", type=float, default=1.0) 154 | parser.add_argument("--model_name_or_path", type=str) 155 | args = parser.parse_args() 156 | return args 157 | 158 | def main(args): 159 | device = torch.device("cpu") 160 | if torch.cuda.is_available(): 161 | device = torch.device("cuda") 162 | n_gpu = torch.cuda.device_count() 163 | 164 | model = BertModel.from_pretrained(args.model_name_or_path) 165 | model.to(device) 166 | if n_gpu > 1: 167 | model = torch.nn.DataParallel(model) 168 | model.eval() 169 | 170 | tokenizer = BertTokenizer(vocab_file=os.path.join(args.model_name_or_path, "vocab.txt")) 171 | temp = {"mask_token": tokenizer.mask_token} 172 | tokenizer.add_special_tokens(temp) 173 | 174 | def batcher(batch, max_length=None): 175 | # Handle rare token encoding issues in the dataset 176 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 177 | batch = [[word.decode('utf-8') for word in s] for s in batch] 178 | 179 | sentences = [' '.join(s) for s in batch] 180 | sentences = [ \ 181 | s + " ." if s.strip()[-1] not in PUNCTUATION else s \ 182 | for s in sentences \ 183 | ] 184 | sentences = [ \ 185 | '''This sentence : " ''' + s + ''' " means [MASK] .''' \ 186 | for s in sentences \ 187 | ] 188 | 189 | # Tokenization 190 | if max_length is not None: 191 | batch = tokenizer.batch_encode_plus( 192 | sentences, 193 | return_tensors='pt', 194 | padding=True, 195 | max_length=max_length, 196 | truncation=True 197 | ) 198 | else: 199 | batch = tokenizer.batch_encode_plus( 200 | sentences, 201 | return_tensors='pt', 202 | padding=True, 203 | ) 204 | 205 | # Move to the correct device 206 | for k in batch: 207 | batch[k] = batch[k].to(device) 208 | 209 | # Get raw embeddings 210 | with torch.no_grad(): 211 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 212 | last_hidden = outputs.last_hidden_state 213 | sent_vecs = last_hidden[batch["input_ids"] == tokenizer.mask_token_id].cpu().numpy() 214 | 215 | sent_vecs = normalize(sent_vecs) 216 | return sent_vecs 217 | 218 | print("Loading {}".format(args.sentence_vecs)) 219 | sentence_vecs = np.load(args.sentence_vecs) 220 | 221 | # Load benchmark datasets 222 | target_tasks = [ \ 223 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', \ 224 | 'STSBenchmark', \ 225 | 'SICKRelatedness' \ 226 | ] 227 | # Reference: https://github.com/facebookresearch/SentEval/blob/main/senteval/sts.py 228 | results = {} 229 | for task in target_tasks: 230 | data = read_benchmark_data(args.senteval_path, task) 231 | result = evaluate_retrieval_augmented_promptbert(args, data, batcher, sentence_vecs) 232 | results[task] = result 233 | 234 | task_names = [] 235 | scores = [] 236 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 237 | task_names.append(task) 238 | if task in results: 239 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 240 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 241 | else: 242 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 243 | else: 244 | scores.append("0.00") 245 | task_names.append("Avg.") 246 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 247 | print_table(task_names, scores) 248 | 249 | return 0 250 | 251 | if __name__ == "__main__": 252 | args = parse_args() 253 | _ = main(args) 254 | -------------------------------------------------------------------------------- /code/SNCSE_RankEncoder/sncse_rank_encoder_inference.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CUDA_VISIBLE_DEVICES=0,1 python sncse_rank_encoder_inference.py \ 5 | --sentence_vecs $OUTPUT_DIR/sncse/index_vecs/corpus_0.01_sncse_rank_encoder_seed_$SEED.npy \ 6 | --senteval_path SentEval/data \ 7 | --batch_size 256 \ 8 | --model_name_or_path $OUTPUT_DIR/sncse/checkpoints/sncse_rank_encoder_seed_$SEED \ 9 | --lmb 0.9 10 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ### Requirements 4 | You need at least two GPUs to proceed the following instructions. 5 | 6 | ### Setting 7 | 1. Set the project directory 8 | ```bash 9 | export PROJECT_DIR=/path/to/this/project/folder 10 | ``` 11 | Note that there is no "/" at the end, e.g., /home/RankEncoder. 12 | 13 | 2. Download the SentEval folder at https://github.com/princeton-nlp/SimCSE and locate the file at code/SimCSE\_RankEncoder/ 14 | 3. Go to SentEval/data/downstream and execute the following command 15 | ```bash 16 | bash download_dataset.sh 17 | ``` 18 | 4. Download the Wiki1m dataset with the following command and locate this dataset at data/corpus/corpus.txt 19 | ```bash 20 | bash wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt 21 | mv wiki1m_for_simcse.txt ../../data/corpus/corpus.txt 22 | ``` 23 | 5. Go to code/file\_utils/ and execute the following command. 24 | ```bash 25 | bash random_sampling_sentences.sh 26 | ``` 27 | 28 | ### Training the base encoder ([SimCSE](https://aclanthology.org/2021.emnlp-main.552/)) 29 | Go to code/SimCSE\_RankEncoder and execute the following command 30 | ```bash 31 | bash run_simcse.sh 32 | ``` 33 | 34 | ### Training RankEncoder 35 | 1. Get sentence vectors with the base encoder 36 | ```bash 37 | bash get_simcse_embedding.sh 38 | ``` 39 | 2. Train RankEncoder 40 | ```bash 41 | bash run_simcse_rank_encoder.sh 42 | ``` 43 | 44 | We provide the checkpoint of the trained RankeEncoder-SimCSE [here](https://drive.google.com/file/d/15BvamHk4zuCSU1slOWb37bnncJ2GGIwX/view?usp=sharing) 45 | 46 | ### Evaluation 47 | 1. The following command compute the performance of RankEncoder (without Eq.7 in [our paper](https://arxiv.org/pdf/2209.04333.pdf)) 48 | ```bash 49 | bash evaluation.sh 50 | ``` 51 | 52 | 2. The following command compute the performance of RankEncoder (Eq.7 in [our paper](https://arxiv.org/pdf/2209.04333.pdf)) 53 | ```bash 54 | bash get_rank_encoder_embedding.sh 55 | bash simcse_rank_encoder_inference.sh 56 | ``` 57 | Note that we only sample 10,000 sentences for computational efficiency. Please use 100,000 sentences to replicate our experimental results. 58 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io, os 3 | import numpy as np 4 | import logging 5 | import argparse 6 | from prettytable import PrettyTable 7 | import torch 8 | import transformers 9 | from transformers import AutoModel, AutoTokenizer 10 | 11 | # Set up logger 12 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 13 | 14 | # Set PATHs 15 | PATH_TO_SENTEVAL = './SentEval' 16 | PATH_TO_DATA = './SentEval/data' 17 | 18 | # Import SentEval 19 | sys.path.insert(0, PATH_TO_SENTEVAL) 20 | import senteval 21 | 22 | def print_table(task_names, scores): 23 | tb = PrettyTable() 24 | tb.field_names = task_names 25 | tb.add_row(scores) 26 | print(tb) 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--model_name_or_path", type=str, 31 | help="Transformers' model name or path") 32 | parser.add_argument("--pooler", type=str, 33 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 34 | default='cls', 35 | help="Which pooler to use") 36 | parser.add_argument("--mode", type=str, 37 | choices=['dev', 'test', 'fasttest'], 38 | default='test', 39 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 40 | parser.add_argument("--task_set", type=str, 41 | choices=['sts', 'transfer', 'full', 'na'], 42 | default='sts', 43 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 44 | parser.add_argument("--tasks", type=str, nargs='+', 45 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 46 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC', 47 | 'SICKRelatedness', 'STSBenchmark'], 48 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden") 49 | 50 | args = parser.parse_args() 51 | 52 | # Load transformers' model checkpoint 53 | model = AutoModel.from_pretrained(args.model_name_or_path) 54 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 56 | model = model.to(device) 57 | 58 | # Set up the tasks 59 | if args.task_set == 'sts': 60 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 61 | elif args.task_set == 'transfer': 62 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 63 | elif args.task_set == 'full': 64 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 65 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 66 | 67 | # Set params for SentEval 68 | if args.mode == 'dev' or args.mode == 'fasttest': 69 | # Fast mode 70 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 71 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 72 | 'tenacity': 3, 'epoch_size': 2} 73 | elif args.mode == 'test': 74 | # Full mode 75 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 76 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 77 | 'tenacity': 5, 'epoch_size': 4} 78 | else: 79 | raise NotImplementedError 80 | 81 | # SentEval prepare and batcher 82 | def prepare(params, samples): 83 | return 84 | 85 | def batcher(params, batch, max_length=None): 86 | # Handle rare token encoding issues in the dataset 87 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 88 | batch = [[word.decode('utf-8') for word in s] for s in batch] 89 | 90 | sentences = [' '.join(s) for s in batch] 91 | 92 | # Tokenization 93 | if max_length is not None: 94 | batch = tokenizer.batch_encode_plus( 95 | sentences, 96 | return_tensors='pt', 97 | padding=True, 98 | max_length=max_length, 99 | truncation=True 100 | ) 101 | else: 102 | batch = tokenizer.batch_encode_plus( 103 | sentences, 104 | return_tensors='pt', 105 | padding=True, 106 | ) 107 | 108 | # Move to the correct device 109 | for k in batch: 110 | batch[k] = batch[k].to(device) 111 | 112 | # Get raw embeddings 113 | with torch.no_grad(): 114 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 115 | last_hidden = outputs.last_hidden_state 116 | pooler_output = outputs.pooler_output 117 | hidden_states = outputs.hidden_states 118 | 119 | # Apply different poolers 120 | if args.pooler == 'cls': 121 | # There is a linear+activation layer after CLS representation 122 | return pooler_output.cpu() 123 | elif args.pooler == 'cls_before_pooler': 124 | return last_hidden[:, 0].cpu() 125 | elif args.pooler == "avg": 126 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu() 127 | elif args.pooler == "avg_first_last": 128 | first_hidden = hidden_states[0] 129 | last_hidden = hidden_states[-1] 130 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 131 | return pooled_result.cpu() 132 | elif args.pooler == "avg_top2": 133 | second_last_hidden = hidden_states[-2] 134 | last_hidden = hidden_states[-1] 135 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 136 | return pooled_result.cpu() 137 | else: 138 | raise NotImplementedError 139 | 140 | results = {} 141 | 142 | for task in args.tasks: 143 | se = senteval.engine.SE(params, batcher, prepare) 144 | result = se.eval(task) 145 | results[task] = result 146 | 147 | # Print evaluation results 148 | if args.mode == 'dev': 149 | print("------ %s ------" % (args.mode)) 150 | 151 | task_names = [] 152 | scores = [] 153 | for task in ['STSBenchmark', 'SICKRelatedness']: 154 | task_names.append(task) 155 | if task in results: 156 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100)) 157 | else: 158 | scores.append("0.00") 159 | print_table(task_names, scores) 160 | 161 | task_names = [] 162 | scores = [] 163 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 164 | task_names.append(task) 165 | if task in results: 166 | scores.append("%.2f" % (results[task]['devacc'])) 167 | else: 168 | scores.append("0.00") 169 | task_names.append("Avg.") 170 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 171 | print_table(task_names, scores) 172 | 173 | elif args.mode == 'test' or args.mode == 'fasttest': 174 | print("------ %s ------" % (args.mode)) 175 | 176 | task_names = [] 177 | scores = [] 178 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 179 | task_names.append(task) 180 | if task in results: 181 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 182 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 183 | else: 184 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 185 | else: 186 | scores.append("0.00") 187 | task_names.append("Avg.") 188 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 189 | print_table(task_names, scores) 190 | 191 | task_names = [] 192 | scores = [] 193 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 194 | task_names.append(task) 195 | if task in results: 196 | scores.append("%.2f" % (results[task]['acc'])) 197 | else: 198 | scores.append("0.00") 199 | task_names.append("Avg.") 200 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 201 | print_table(task_names, scores) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/evaluation.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | python evaluation.py \ 5 | --model_name_or_path $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_rank_encoder_seed_$SEED \ 6 | --pooler cls_before_pooler \ 7 | --task_set sts \ 8 | --mode test 9 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/get_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io, os 3 | import numpy as np 4 | import logging 5 | import argparse 6 | from prettytable import PrettyTable 7 | import torch 8 | import transformers 9 | from transformers import AutoModel, AutoTokenizer 10 | from tqdm import tqdm 11 | 12 | # Set up logger 13 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 14 | 15 | # Set PATHs 16 | PATH_TO_SENTEVAL = './SentEval' 17 | PATH_TO_DATA = './SentEval/data' 18 | 19 | # Import SentEval 20 | sys.path.insert(0, PATH_TO_SENTEVAL) 21 | import senteval 22 | 23 | def print_table(task_names, scores): 24 | tb = PrettyTable() 25 | tb.field_names = task_names 26 | tb.add_row(scores) 27 | print(tb) 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--corpus_file", type=str, required=True) 32 | parser.add_argument("--vector_file", type=str, required=True) 33 | parser.add_argument("--batch_size", type=int, default=128) 34 | 35 | parser.add_argument("--model_name_or_path", type=str, 36 | help="Transformers' model name or path") 37 | parser.add_argument("--pooler", type=str, 38 | choices=['cls', 'cls_before_pooler', 'avg', 'avg_top2', 'avg_first_last'], 39 | default='cls', 40 | help="Which pooler to use") 41 | parser.add_argument("--mode", type=str, 42 | choices=['dev', 'test', 'fasttest'], 43 | default='test', 44 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 45 | parser.add_argument("--task_set", type=str, 46 | choices=['sts', 'transfer', 'full', 'na'], 47 | default='sts', 48 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 49 | parser.add_argument("--tasks", type=str, nargs='+', 50 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 51 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC', 52 | 'SICKRelatedness', 'STSBenchmark'], 53 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden") 54 | 55 | args = parser.parse_args() 56 | 57 | # Load transformers' model checkpoint 58 | model = AutoModel.from_pretrained(args.model_name_or_path) 59 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 60 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 61 | model = model.to(device) 62 | 63 | # Set up the tasks 64 | if args.task_set == 'sts': 65 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 66 | elif args.task_set == 'transfer': 67 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 68 | elif args.task_set == 'full': 69 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 70 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 71 | 72 | # Set params for SentEval 73 | if args.mode == 'dev' or args.mode == 'fasttest': 74 | # Fast mode 75 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 76 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 77 | 'tenacity': 3, 'epoch_size': 2} 78 | elif args.mode == 'test': 79 | # Full mode 80 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 81 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 82 | 'tenacity': 5, 'epoch_size': 4} 83 | else: 84 | raise NotImplementedError 85 | 86 | # SentEval prepare and batcher 87 | def prepare(params, samples): 88 | return 89 | 90 | def batcher(batch, max_length=None): 91 | # Handle rare token encoding issues in the dataset 92 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 93 | batch = [[word.decode('utf-8') for word in s] for s in batch] 94 | 95 | sentences = [' '.join(s) for s in batch] 96 | 97 | # Tokenization 98 | if max_length is not None: 99 | batch = tokenizer.batch_encode_plus( 100 | sentences, 101 | return_tensors='pt', 102 | padding=True, 103 | max_length=max_length, 104 | truncation=True 105 | ) 106 | else: 107 | batch = tokenizer.batch_encode_plus( 108 | sentences, 109 | return_tensors='pt', 110 | padding=True, 111 | ) 112 | 113 | # Move to the correct device 114 | for k in batch: 115 | batch[k] = batch[k].to(device) 116 | 117 | # Get raw embeddings 118 | with torch.no_grad(): 119 | outputs = model(**batch, output_hidden_states=True, return_dict=True) 120 | last_hidden = outputs.last_hidden_state 121 | pooler_output = outputs.pooler_output 122 | hidden_states = outputs.hidden_states 123 | 124 | # Apply different poolers 125 | if args.pooler == 'cls': 126 | # There is a linear+activation layer after CLS representation 127 | return pooler_output.cpu() 128 | elif args.pooler == 'cls_before_pooler': 129 | return last_hidden[:, 0].cpu() 130 | elif args.pooler == "avg": 131 | return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1)).cpu() 132 | elif args.pooler == "avg_first_last": 133 | first_hidden = hidden_states[0] 134 | last_hidden = hidden_states[-1] 135 | pooled_result = ((first_hidden + last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 136 | return pooled_result.cpu() 137 | elif args.pooler == "avg_top2": 138 | second_last_hidden = hidden_states[-2] 139 | last_hidden = hidden_states[-1] 140 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch['attention_mask'].unsqueeze(-1)).sum(1) / batch['attention_mask'].sum(-1).unsqueeze(-1) 141 | return pooled_result.cpu() 142 | else: 143 | raise NotImplementedError 144 | 145 | with open(args.corpus_file, "r") as f: 146 | sentences = f.readlines() 147 | sentences = [s.strip().split() for s in sentences] 148 | 149 | sentences = sorted([(i, s) for i, s in enumerate(sentences)], key=lambda x: len(x[1]), reverse=True) 150 | inds, sentences = map(list, zip(*sentences)) 151 | sort_inds = sorted([(i, j) for i, j in enumerate(inds)], key=lambda x: x[1]) 152 | sort_inds, _ = map(list, zip(*sort_inds)) 153 | 154 | sentence_vectors = [] 155 | for i in tqdm(range(0, len(sentences), args.batch_size), desc="Computing sentence vectors"): 156 | batch = sentences[i:i+args.batch_size] 157 | vectors = batcher(batch).numpy() 158 | assert vectors.shape[1] == model.config.hidden_size 159 | sentence_vectors.append(vectors) 160 | sentence_vectors = np.concatenate(sentence_vectors, axis=0) 161 | sentence_vectors = sentence_vectors[sort_inds] 162 | 163 | print("Saving...") 164 | os.makedirs(os.path.dirname(args.vector_file), exist_ok=True) 165 | with open(args.vector_file, "wb") as f: 166 | np.save(f, sentence_vectors) 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/get_rank_encoder_embedding.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CUDA_VISIBLE_DEVICES=1,2 python get_embedding.py \ 5 | --corpus_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 6 | --vector_file $OUTPUT_DIR/simcse/index_vecs/corpus_0.01_rank_encoder_seed_$SEED.npy \ 7 | --batch_size 128 \ 8 | --model_name_or_path $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_rank_encoder_seed_$SEED \ 9 | --pooler cls_before_pooler \ 10 | --task_set sts \ 11 | --mode test 12 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/get_simcse_embedding.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | SEED=61507 4 | CUDA_VISIBLE_DEVICES=1,2 python get_embedding.py \ 5 | --corpus_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 6 | --vector_file $OUTPUT_DIR/simcse/index_vecs/corpus_0.01_seed_$SEED.npy \ 7 | --batch_size 128 \ 8 | --model_name_or_path $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_seed_$SEED \ 9 | --pooler cls_before_pooler \ 10 | --task_set sts \ 11 | --mode test 12 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/run_simcse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # In this example, we show how to train SimCSE on unsupervised Wikipedia data. 4 | # If you want to train it with multiple GPU cards, see "run_sup_example.sh" 5 | # about how to use PyTorch's distributed data parallel. 6 | DATA_DIR=$PROJECT_DIR/data 7 | OUTPUT_DIR=$PROJECT_DIR/outputs 8 | SEED=61507 9 | CUDA_VISIBLE_DEVICES=0 python train_simcse.py \ 10 | --model_name_or_path bert-base-uncased \ 11 | --train_file $DATA_DIR/corpus/corpus.txt \ 12 | --output_dir $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_seed_$SEED \ 13 | --num_train_epochs 1 \ 14 | --per_device_train_batch_size 64 \ 15 | --learning_rate 3e-5 \ 16 | --max_seq_length 32 \ 17 | --evaluation_strategy steps \ 18 | --metric_for_best_model stsb_spearman \ 19 | --eval_steps 125 \ 20 | --pooler_type cls \ 21 | --mlp_only_train \ 22 | --overwrite_output_dir \ 23 | --temp 0.05 \ 24 | --do_train \ 25 | --do_eval \ 26 | --fp16 \ 27 | --seed $SEED \ 28 | --load_best_model_at_end \ 29 | "$@" 30 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/run_simcse_rank_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUTPUT_DIR=$PROJECT_DIR/outputs 4 | DATA_DIR=$PROJECT_DIR/data 5 | 6 | SEED=61507 7 | VECTOR_FILE=$OUTPUT_DIR/simcse/index_vecs/corpus_0.01_seed_$SEED.npy 8 | CUDA_VISIBLE_DEVICES=0,1 python train_simcse_rank_encoder.py \ 9 | --baseE_sim_thresh_upp 0.9999 \ 10 | --baseE_sim_thresh_low 0.5 \ 11 | --baseE_lmb 0.05 \ 12 | --simf Spearmanr \ 13 | --loss_type weighted_sum \ 14 | --baseE_model_name_or_path $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_seed_$SEED \ 15 | --corpus_vecs $VECTOR_FILE \ 16 | --model_name_or_path bert-base-uncased \ 17 | --train_file $DATA_DIR/corpus/corpus.txt \ 18 | --output_dir $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_rank_encoder_seed_$SEED \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 64 \ 21 | --learning_rate 3e-5 \ 22 | --max_seq_length 32 \ 23 | --evaluation_strategy steps \ 24 | --metric_for_best_model avg_sts \ 25 | --load_best_model_at_end \ 26 | --eval_steps 125 \ 27 | --pooler_type cls \ 28 | --mlp_only_train \ 29 | --overwrite_output_dir \ 30 | --temp 0.05 \ 31 | --do_train \ 32 | --do_eval \ 33 | --fp16 \ 34 | --seed $SEED \ 35 | "$@" 36 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import SimCSE 2 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | 6 | import transformers 7 | from transformers import RobertaTokenizer 8 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 9 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 10 | from transformers.activations import gelu 11 | from transformers.file_utils import ( 12 | add_code_sample_docstrings, 13 | add_start_docstrings, 14 | add_start_docstrings_to_model_forward, 15 | replace_return_docstrings, 16 | ) 17 | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions 18 | 19 | class MLPLayer(nn.Module): 20 | """ 21 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 22 | """ 23 | 24 | def __init__(self, config): 25 | super().__init__() 26 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 27 | self.activation = nn.Tanh() 28 | 29 | def forward(self, features, **kwargs): 30 | x = self.dense(features) 31 | x = self.activation(x) 32 | 33 | return x 34 | 35 | class Similarity(nn.Module): 36 | """ 37 | Dot product or cosine similarity 38 | """ 39 | 40 | def __init__(self, temp): 41 | super().__init__() 42 | self.temp = temp 43 | self.cos = nn.CosineSimilarity(dim=-1) 44 | 45 | def forward(self, x, y): 46 | return self.cos(x, y) / self.temp 47 | 48 | 49 | class Pooler(nn.Module): 50 | """ 51 | Parameter-free poolers to get the sentence embedding 52 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 53 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 54 | 'avg': average of the last layers' hidden states at each token. 55 | 'avg_top2': average of the last two layers. 56 | 'avg_first_last': average of the first and the last layers. 57 | """ 58 | def __init__(self, pooler_type): 59 | super().__init__() 60 | self.pooler_type = pooler_type 61 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 62 | 63 | def forward(self, attention_mask, outputs): 64 | last_hidden = outputs.last_hidden_state 65 | pooler_output = outputs.pooler_output 66 | hidden_states = outputs.hidden_states 67 | 68 | if self.pooler_type in ['cls_before_pooler', 'cls']: 69 | return last_hidden[:, 0] 70 | elif self.pooler_type == "avg": 71 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 72 | elif self.pooler_type == "avg_first_last": 73 | first_hidden = hidden_states[0] 74 | last_hidden = hidden_states[-1] 75 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 76 | return pooled_result 77 | elif self.pooler_type == "avg_top2": 78 | second_last_hidden = hidden_states[-2] 79 | last_hidden = hidden_states[-1] 80 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 81 | return pooled_result 82 | else: 83 | raise NotImplementedError 84 | 85 | 86 | def cl_init(cls, config): 87 | """ 88 | Contrastive learning class init function. 89 | """ 90 | cls.pooler_type = cls.model_args.pooler_type 91 | cls.pooler = Pooler(cls.model_args.pooler_type) 92 | if cls.model_args.pooler_type == "cls": 93 | cls.mlp = MLPLayer(config) 94 | cls.sim = Similarity(temp=cls.model_args.temp) 95 | cls.init_weights() 96 | 97 | def cl_forward(cls, 98 | encoder, 99 | input_ids=None, 100 | attention_mask=None, 101 | token_type_ids=None, 102 | position_ids=None, 103 | head_mask=None, 104 | inputs_embeds=None, 105 | labels=None, 106 | output_attentions=None, 107 | output_hidden_states=None, 108 | return_dict=None, 109 | mlm_input_ids=None, 110 | mlm_labels=None, 111 | ): 112 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 113 | ori_input_ids = input_ids 114 | batch_size = input_ids.size(0) 115 | # Number of sentences in one instance 116 | # 2: pair instance; 3: pair instance with a hard negative 117 | num_sent = input_ids.size(1) 118 | 119 | mlm_outputs = None 120 | # Flatten input for encoding 121 | input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len) 122 | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len) 123 | if token_type_ids is not None: 124 | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len) 125 | 126 | # Get raw embeddings 127 | outputs = encoder( 128 | input_ids, 129 | attention_mask=attention_mask, 130 | token_type_ids=token_type_ids, 131 | position_ids=position_ids, 132 | head_mask=head_mask, 133 | inputs_embeds=inputs_embeds, 134 | output_attentions=output_attentions, 135 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 136 | return_dict=True, 137 | ) 138 | 139 | # MLM auxiliary objective 140 | if mlm_input_ids is not None: 141 | mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) 142 | mlm_outputs = encoder( 143 | mlm_input_ids, 144 | attention_mask=attention_mask, 145 | token_type_ids=token_type_ids, 146 | position_ids=position_ids, 147 | head_mask=head_mask, 148 | inputs_embeds=inputs_embeds, 149 | output_attentions=output_attentions, 150 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 151 | return_dict=True, 152 | ) 153 | 154 | # Pooling 155 | pooler_output = cls.pooler(attention_mask, outputs) 156 | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden) 157 | 158 | # If using "cls", we add an extra MLP layer 159 | # (same as BERT's original implementation) over the representation. 160 | if cls.pooler_type == "cls": 161 | pooler_output = cls.mlp(pooler_output) 162 | 163 | # Separate representation 164 | z1, z2 = pooler_output[:,0], pooler_output[:,1] 165 | 166 | # Hard negative 167 | if num_sent == 3: 168 | z3 = pooler_output[:, 2] 169 | 170 | # Gather all embeddings if using distributed training 171 | if dist.is_initialized() and cls.training: 172 | # Gather hard negative 173 | if num_sent >= 3: 174 | z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())] 175 | dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous()) 176 | z3_list[dist.get_rank()] = z3 177 | z3 = torch.cat(z3_list, 0) 178 | 179 | # Dummy vectors for allgather 180 | z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())] 181 | z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())] 182 | # Allgather 183 | dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous()) 184 | dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous()) 185 | 186 | # Since allgather results do not have gradients, we replace the 187 | # current process's corresponding embeddings with original tensors 188 | z1_list[dist.get_rank()] = z1 189 | z2_list[dist.get_rank()] = z2 190 | # Get full batch embeddings: (bs x N, hidden) 191 | z1 = torch.cat(z1_list, 0) 192 | z2 = torch.cat(z2_list, 0) 193 | 194 | cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0)) 195 | # Hard negative 196 | if num_sent >= 3: 197 | z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0)) 198 | cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) 199 | 200 | labels = torch.arange(cos_sim.size(0)).long().to(cls.device) 201 | loss_fct = nn.CrossEntropyLoss() 202 | 203 | # Calculate loss with hard negatives 204 | if num_sent == 3: 205 | # Note that weights are actually logits of weights 206 | z3_weight = cls.model_args.hard_negative_weight 207 | weights = torch.tensor( 208 | [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))] 209 | ).to(cls.device) 210 | cos_sim = cos_sim + weights 211 | 212 | loss = loss_fct(cos_sim, labels) 213 | 214 | # Calculate loss for MLM 215 | if mlm_outputs is not None and mlm_labels is not None: 216 | mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) 217 | prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state) 218 | masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1)) 219 | loss = loss + cls.model_args.mlm_weight * masked_lm_loss 220 | 221 | if not return_dict: 222 | output = (cos_sim,) + outputs[2:] 223 | return ((loss,) + output) if loss is not None else output 224 | return SequenceClassifierOutput( 225 | loss=loss, 226 | logits=cos_sim, 227 | hidden_states=outputs.hidden_states, 228 | attentions=outputs.attentions, 229 | ) 230 | 231 | 232 | def sentemb_forward( 233 | cls, 234 | encoder, 235 | input_ids=None, 236 | attention_mask=None, 237 | token_type_ids=None, 238 | position_ids=None, 239 | head_mask=None, 240 | inputs_embeds=None, 241 | labels=None, 242 | output_attentions=None, 243 | output_hidden_states=None, 244 | return_dict=None, 245 | ): 246 | 247 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 248 | 249 | outputs = encoder( 250 | input_ids, 251 | attention_mask=attention_mask, 252 | token_type_ids=token_type_ids, 253 | position_ids=position_ids, 254 | head_mask=head_mask, 255 | inputs_embeds=inputs_embeds, 256 | output_attentions=output_attentions, 257 | output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False, 258 | return_dict=True, 259 | ) 260 | 261 | pooler_output = cls.pooler(attention_mask, outputs) 262 | if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train: 263 | pooler_output = cls.mlp(pooler_output) 264 | 265 | if not return_dict: 266 | return (outputs[0], pooler_output) + outputs[2:] 267 | 268 | return BaseModelOutputWithPoolingAndCrossAttentions( 269 | pooler_output=pooler_output, 270 | last_hidden_state=outputs.last_hidden_state, 271 | hidden_states=outputs.hidden_states, 272 | ) 273 | 274 | 275 | class BertForCL(BertPreTrainedModel): 276 | _keys_to_ignore_on_load_missing = [r"position_ids"] 277 | 278 | def __init__(self, config, *model_args, **model_kargs): 279 | super().__init__(config) 280 | self.model_args = model_kargs["model_args"] 281 | self.bert = BertModel(config, add_pooling_layer=False) 282 | 283 | if self.model_args.do_mlm: 284 | self.lm_head = BertLMPredictionHead(config) 285 | 286 | cl_init(self, config) 287 | 288 | def forward(self, 289 | input_ids=None, 290 | attention_mask=None, 291 | token_type_ids=None, 292 | position_ids=None, 293 | head_mask=None, 294 | inputs_embeds=None, 295 | labels=None, 296 | output_attentions=None, 297 | output_hidden_states=None, 298 | return_dict=None, 299 | sent_emb=False, 300 | mlm_input_ids=None, 301 | mlm_labels=None, 302 | ): 303 | if sent_emb: 304 | return sentemb_forward(self, self.bert, 305 | input_ids=input_ids, 306 | attention_mask=attention_mask, 307 | token_type_ids=token_type_ids, 308 | position_ids=position_ids, 309 | head_mask=head_mask, 310 | inputs_embeds=inputs_embeds, 311 | labels=labels, 312 | output_attentions=output_attentions, 313 | output_hidden_states=output_hidden_states, 314 | return_dict=return_dict, 315 | ) 316 | else: 317 | return cl_forward(self, self.bert, 318 | input_ids=input_ids, 319 | attention_mask=attention_mask, 320 | token_type_ids=token_type_ids, 321 | position_ids=position_ids, 322 | head_mask=head_mask, 323 | inputs_embeds=inputs_embeds, 324 | labels=labels, 325 | output_attentions=output_attentions, 326 | output_hidden_states=output_hidden_states, 327 | return_dict=return_dict, 328 | mlm_input_ids=mlm_input_ids, 329 | mlm_labels=mlm_labels, 330 | ) 331 | 332 | 333 | 334 | class RobertaForCL(RobertaPreTrainedModel): 335 | _keys_to_ignore_on_load_missing = [r"position_ids"] 336 | 337 | def __init__(self, config, *model_args, **model_kargs): 338 | super().__init__(config) 339 | self.model_args = model_kargs["model_args"] 340 | self.roberta = RobertaModel(config, add_pooling_layer=False) 341 | 342 | if self.model_args.do_mlm: 343 | self.lm_head = RobertaLMHead(config) 344 | 345 | cl_init(self, config) 346 | 347 | def forward(self, 348 | input_ids=None, 349 | attention_mask=None, 350 | token_type_ids=None, 351 | position_ids=None, 352 | head_mask=None, 353 | inputs_embeds=None, 354 | labels=None, 355 | output_attentions=None, 356 | output_hidden_states=None, 357 | return_dict=None, 358 | sent_emb=False, 359 | mlm_input_ids=None, 360 | mlm_labels=None, 361 | ): 362 | if sent_emb: 363 | return sentemb_forward(self, self.roberta, 364 | input_ids=input_ids, 365 | attention_mask=attention_mask, 366 | token_type_ids=token_type_ids, 367 | position_ids=position_ids, 368 | head_mask=head_mask, 369 | inputs_embeds=inputs_embeds, 370 | labels=labels, 371 | output_attentions=output_attentions, 372 | output_hidden_states=output_hidden_states, 373 | return_dict=return_dict, 374 | ) 375 | else: 376 | return cl_forward(self, self.roberta, 377 | input_ids=input_ids, 378 | attention_mask=attention_mask, 379 | token_type_ids=token_type_ids, 380 | position_ids=position_ids, 381 | head_mask=head_mask, 382 | inputs_embeds=inputs_embeds, 383 | labels=labels, 384 | output_attentions=output_attentions, 385 | output_hidden_states=output_hidden_states, 386 | return_dict=return_dict, 387 | mlm_input_ids=mlm_input_ids, 388 | mlm_labels=mlm_labels, 389 | ) 390 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse/tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SimCSE(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path))) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def search(self, queries: Union[str, List[str]], 181 | device: str = None, 182 | threshold: float = 0.6, 183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 184 | 185 | if not self.is_faiss_index: 186 | if isinstance(queries, list): 187 | combined_results = [] 188 | for query in queries: 189 | results = self.search(query, device) 190 | combined_results.append(results) 191 | return combined_results 192 | 193 | similarities = self.similarity(queries, self.index["index"]).tolist() 194 | id_and_score = [] 195 | for i, s in enumerate(similarities): 196 | if s >= threshold: 197 | id_and_score.append((i, s)) 198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 200 | return results 201 | else: 202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 203 | 204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 205 | 206 | def pack_single_result(dist, idx): 207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 208 | return results 209 | 210 | if isinstance(queries, list): 211 | combined_results = [] 212 | for i in range(len(queries)): 213 | results = pack_single_result(distance[i], idx[i]) 214 | combined_results.append(results) 215 | return combined_results 216 | else: 217 | return pack_single_result(distance[0], idx[0]) 218 | 219 | if __name__=="__main__": 220 | example_sentences = [ 221 | 'An animal is biting a persons finger.', 222 | 'A woman is reading.', 223 | 'A man is lifting weights in a garage.', 224 | 'A man plays the violin.', 225 | 'A man is eating food.', 226 | 'A man plays the piano.', 227 | 'A panda is climbing.', 228 | 'A man plays a guitar.', 229 | 'A woman is slicing a meat.', 230 | 'A woman is taking a picture.' 231 | ] 232 | example_queries = [ 233 | 'A man is playing music.', 234 | 'A woman is making a photo.' 235 | ] 236 | 237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 238 | simcse = SimCSE(model_name) 239 | 240 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 241 | similarities = simcse.similarity(example_queries, example_sentences) 242 | print(similarities) 243 | 244 | print("\n=========Naive brute force search============\n") 245 | simcse.build_index(example_sentences, use_faiss=False) 246 | results = simcse.search(example_queries) 247 | for i, result in enumerate(results): 248 | print("Retrieval results for query: {}".format(example_queries[i])) 249 | for sentence, score in result: 250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 251 | print("") 252 | 253 | print("\n=========Search with Faiss backend============\n") 254 | simcse.build_index(example_sentences, use_faiss=True) 255 | results = simcse.search(example_queries) 256 | for i, result in enumerate(results): 257 | print("Retrieval results for query: {}".format(example_queries[i])) 258 | for sentence, score in result: 259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 260 | print("") 261 | 262 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse_rank_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import SimCSE 2 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse_rank_encoder/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | 6 | import transformers 7 | from transformers import RobertaTokenizer 8 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 9 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 10 | from transformers.activations import gelu 11 | from transformers.file_utils import ( 12 | add_code_sample_docstrings, 13 | add_start_docstrings, 14 | add_start_docstrings_to_model_forward, 15 | replace_return_docstrings, 16 | ) 17 | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions 18 | 19 | class MLPLayer(nn.Module): 20 | """ 21 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 22 | """ 23 | 24 | def __init__(self, config): 25 | super().__init__() 26 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 27 | self.activation = nn.Tanh() 28 | 29 | def forward(self, features, **kwargs): 30 | x = self.dense(features) 31 | x = self.activation(x) 32 | 33 | return x 34 | 35 | class Similarity(nn.Module): 36 | """ 37 | Dot product or cosine similarity 38 | """ 39 | 40 | def __init__(self, temp): 41 | super().__init__() 42 | self.temp = temp 43 | self.cos = nn.CosineSimilarity(dim=-1) 44 | 45 | def forward(self, x, y): 46 | return self.cos(x, y) / self.temp 47 | 48 | 49 | class Pooler(nn.Module): 50 | """ 51 | Parameter-free poolers to get the sentence embedding 52 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 53 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 54 | 'avg': average of the last layers' hidden states at each token. 55 | 'avg_top2': average of the last two layers. 56 | 'avg_first_last': average of the first and the last layers. 57 | """ 58 | def __init__(self, pooler_type): 59 | super().__init__() 60 | self.pooler_type = pooler_type 61 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 62 | 63 | def forward(self, attention_mask, outputs): 64 | last_hidden = outputs.last_hidden_state 65 | pooler_output = outputs.pooler_output 66 | hidden_states = outputs.hidden_states 67 | 68 | if self.pooler_type in ['cls_before_pooler', 'cls']: 69 | return last_hidden[:, 0] 70 | elif self.pooler_type == "avg": 71 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 72 | elif self.pooler_type == "avg_first_last": 73 | first_hidden = hidden_states[0] 74 | last_hidden = hidden_states[-1] 75 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 76 | return pooled_result 77 | elif self.pooler_type == "avg_top2": 78 | second_last_hidden = hidden_states[-2] 79 | last_hidden = hidden_states[-1] 80 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 81 | return pooled_result 82 | else: 83 | raise NotImplementedError 84 | 85 | def _get_ranks(x: torch.Tensor) -> torch.Tensor: 86 | x_rank = x.argsort(dim=1) 87 | ranks = torch.zeros_like(x_rank, dtype=torch.float) 88 | n, d = x_rank.size() 89 | 90 | for i in range(n): 91 | ranks[i][x_rank[i]] = torch.arange(d, dtype=torch.float).to(ranks.device) 92 | return ranks 93 | 94 | def cal_spr_corr(x: torch.Tensor, y: torch.Tensor): 95 | x_rank = _get_ranks(x) 96 | y_rank = _get_ranks(y) 97 | x_rank_mean = torch.mean(x_rank, dim=1).unsqueeze(1) 98 | y_rank_mean = torch.mean(y_rank, dim=1).unsqueeze(1) 99 | xn = x_rank - x_rank_mean 100 | yn = y_rank - y_rank_mean 101 | x_var = torch.sqrt(torch.sum(torch.square(xn), dim=1).unsqueeze(1)) 102 | y_var = torch.sqrt(torch.sum(torch.square(yn), dim=1).unsqueeze(1)) 103 | xn = xn / x_var 104 | yn = yn / y_var 105 | 106 | return torch.mm(xn, torch.transpose(yn, 0, 1)) 107 | 108 | def cl_init(cls, config): 109 | """ 110 | Contrastive learning class init function. 111 | """ 112 | cls.pooler_type = cls.model_args.pooler_type 113 | cls.pooler = Pooler(cls.model_args.pooler_type) 114 | if cls.model_args.pooler_type == "cls": 115 | cls.mlp = MLPLayer(config) 116 | cls.sim = Similarity(temp=cls.model_args.temp) 117 | cls.init_weights() 118 | 119 | def cl_forward(cls, 120 | encoder, 121 | input_ids=None, 122 | attention_mask=None, 123 | token_type_ids=None, 124 | position_ids=None, 125 | head_mask=None, 126 | inputs_embeds=None, 127 | labels=None, 128 | output_attentions=None, 129 | output_hidden_states=None, 130 | return_dict=None, 131 | mlm_input_ids=None, 132 | mlm_labels=None, 133 | distances1=None, 134 | distances2=None, 135 | baseE_vecs1=None, 136 | baseE_vecs2=None, 137 | ): 138 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 139 | ori_input_ids = input_ids 140 | batch_size = input_ids.size(0) 141 | # Number of sentences in one instance 142 | # 2: pair instance; 3: pair instance with a hard negative 143 | num_sent = input_ids.size(1) 144 | 145 | mlm_outputs = None 146 | # Flatten input for encoding 147 | input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len) 148 | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len) 149 | if token_type_ids is not None: 150 | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len) 151 | 152 | # Get raw embeddings 153 | outputs = encoder( 154 | input_ids, 155 | attention_mask=attention_mask, 156 | token_type_ids=token_type_ids, 157 | position_ids=position_ids, 158 | head_mask=head_mask, 159 | inputs_embeds=inputs_embeds, 160 | output_attentions=output_attentions, 161 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 162 | return_dict=True, 163 | ) 164 | 165 | # MLM auxiliary objective 166 | if mlm_input_ids is not None: 167 | mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) 168 | mlm_outputs = encoder( 169 | mlm_input_ids, 170 | attention_mask=attention_mask, 171 | token_type_ids=token_type_ids, 172 | position_ids=position_ids, 173 | head_mask=head_mask, 174 | inputs_embeds=inputs_embeds, 175 | output_attentions=output_attentions, 176 | output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False, 177 | return_dict=True, 178 | ) 179 | 180 | # Pooling 181 | pooler_output = cls.pooler(attention_mask, outputs) 182 | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden) 183 | 184 | # If using "cls", we add an extra MLP layer 185 | # (same as BERT's original implementation) over the representation. 186 | if cls.pooler_type == "cls": 187 | pooler_output = cls.mlp(pooler_output) 188 | 189 | # Separate representation 190 | z1, z2 = pooler_output[:,0], pooler_output[:,1] 191 | 192 | # Hard negative 193 | if num_sent == 3: 194 | z3 = pooler_output[:, 2] 195 | 196 | # Gather all embeddings if using distributed training 197 | if dist.is_initialized() and cls.training: 198 | # Gather hard negative 199 | if num_sent >= 3: 200 | z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())] 201 | dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous()) 202 | z3_list[dist.get_rank()] = z3 203 | z3 = torch.cat(z3_list, 0) 204 | 205 | # Dummy vectors for allgather 206 | z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())] 207 | z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())] 208 | # Allgather 209 | dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous()) 210 | dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous()) 211 | 212 | # Since allgather results do not have gradients, we replace the 213 | # current process's corresponding embeddings with original tensors 214 | z1_list[dist.get_rank()] = z1 215 | z2_list[dist.get_rank()] = z2 216 | # Get full batch embeddings: (bs x N, hidden) 217 | z1 = torch.cat(z1_list, 0) 218 | z2 = torch.cat(z2_list, 0) 219 | 220 | cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0)) 221 | # Hard negative 222 | if num_sent >= 3: 223 | z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0)) 224 | cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) 225 | 226 | labels = torch.arange(cos_sim.size(0)).long().to(cls.device) 227 | loss_fct = nn.CrossEntropyLoss() 228 | 229 | # Calculate loss with hard negatives 230 | if num_sent == 3: 231 | # Note that weights are actually logits of weights 232 | z3_weight = cls.model_args.hard_negative_weight 233 | weights = torch.tensor( 234 | [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))] 235 | ).to(cls.device) 236 | cos_sim = cos_sim + weights 237 | 238 | # Spearmanr 239 | if num_sent == 3: 240 | raise NotImplementedError 241 | 242 | cos_sim_baseE = None 243 | if cls.model_args.simf == "Spearmanr": 244 | cos_sim_baseE = cal_spr_corr( \ 245 | distances1, distances1 \ 246 | ) 247 | else: 248 | raise NotImplementedError 249 | 250 | cos_sim_baseE = cos_sim_baseE.to(cos_sim.device) 251 | loss_fct_baseE = nn.MSELoss(reduction="none") 252 | cos_sim_baseE_bound = torch.logical_and( \ 253 | cos_sim_baseE <= cls.model_args.baseE_sim_thresh_upp, \ 254 | cos_sim_baseE >= cls.model_args.baseE_sim_thresh_low \ 255 | ).type(torch.float).to(cos_sim.device) 256 | mse = loss_fct_baseE(cos_sim * cls.model_args.temp, cos_sim_baseE) 257 | loss_baseE = torch.sum(mse * cos_sim_baseE_bound) / (torch.sum(cos_sim_baseE_bound) + 1e-8) 258 | 259 | 260 | loss_o = loss_fct(cos_sim, labels) 261 | # Spearmanr 262 | if cls.model_args.loss_type == "hinge": 263 | loss = torch.max(loss_o, cls.model_args.baseE_lmb * loss_baseE) 264 | elif cls.model_args.loss_type == "weighted_sum": 265 | loss = loss_o + cls.model_args.baseE_lmb * loss_baseE 266 | else: 267 | raise NotImplementedError 268 | 269 | # Calculate loss for MLM 270 | if mlm_outputs is not None and mlm_labels is not None: 271 | mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) 272 | prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state) 273 | masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1)) 274 | loss = loss + cls.model_args.mlm_weight * masked_lm_loss 275 | 276 | if not return_dict: 277 | output = (cos_sim,) + outputs[2:] 278 | return ((loss,) + output) if loss is not None else output 279 | return SequenceClassifierOutput( 280 | loss=loss, 281 | logits=cos_sim, 282 | hidden_states=outputs.hidden_states, 283 | attentions=outputs.attentions, 284 | ) 285 | 286 | 287 | def sentemb_forward( 288 | cls, 289 | encoder, 290 | input_ids=None, 291 | attention_mask=None, 292 | token_type_ids=None, 293 | position_ids=None, 294 | head_mask=None, 295 | inputs_embeds=None, 296 | labels=None, 297 | output_attentions=None, 298 | output_hidden_states=None, 299 | return_dict=None, 300 | ): 301 | 302 | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict 303 | 304 | outputs = encoder( 305 | input_ids, 306 | attention_mask=attention_mask, 307 | token_type_ids=token_type_ids, 308 | position_ids=position_ids, 309 | head_mask=head_mask, 310 | inputs_embeds=inputs_embeds, 311 | output_attentions=output_attentions, 312 | output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False, 313 | return_dict=True, 314 | ) 315 | 316 | pooler_output = cls.pooler(attention_mask, outputs) 317 | if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train: 318 | pooler_output = cls.mlp(pooler_output) 319 | 320 | if not return_dict: 321 | return (outputs[0], pooler_output) + outputs[2:] 322 | 323 | return BaseModelOutputWithPoolingAndCrossAttentions( 324 | pooler_output=pooler_output, 325 | last_hidden_state=outputs.last_hidden_state, 326 | hidden_states=outputs.hidden_states, 327 | ) 328 | 329 | 330 | class BertForCL(BertPreTrainedModel): 331 | _keys_to_ignore_on_load_missing = [r"position_ids"] 332 | 333 | def __init__(self, config, *model_args, **model_kargs): 334 | super().__init__(config) 335 | self.model_args = model_kargs["model_args"] 336 | self.bert = BertModel(config, add_pooling_layer=False) 337 | 338 | if self.model_args.do_mlm: 339 | self.lm_head = BertLMPredictionHead(config) 340 | 341 | cl_init(self, config) 342 | 343 | def forward(self, 344 | input_ids=None, 345 | attention_mask=None, 346 | token_type_ids=None, 347 | position_ids=None, 348 | head_mask=None, 349 | inputs_embeds=None, 350 | labels=None, 351 | output_attentions=None, 352 | output_hidden_states=None, 353 | return_dict=None, 354 | sent_emb=False, 355 | mlm_input_ids=None, 356 | mlm_labels=None, 357 | distances1=None, 358 | distances2=None, 359 | baseE_vecs1=None, 360 | baseE_vecs2=None, 361 | ): 362 | if sent_emb: 363 | return sentemb_forward(self, self.bert, 364 | input_ids=input_ids, 365 | attention_mask=attention_mask, 366 | token_type_ids=token_type_ids, 367 | position_ids=position_ids, 368 | head_mask=head_mask, 369 | inputs_embeds=inputs_embeds, 370 | labels=labels, 371 | output_attentions=output_attentions, 372 | output_hidden_states=output_hidden_states, 373 | return_dict=return_dict, 374 | ) 375 | else: 376 | return cl_forward(self, self.bert, 377 | input_ids=input_ids, 378 | attention_mask=attention_mask, 379 | token_type_ids=token_type_ids, 380 | position_ids=position_ids, 381 | head_mask=head_mask, 382 | inputs_embeds=inputs_embeds, 383 | labels=labels, 384 | output_attentions=output_attentions, 385 | output_hidden_states=output_hidden_states, 386 | return_dict=return_dict, 387 | mlm_input_ids=mlm_input_ids, 388 | mlm_labels=mlm_labels, 389 | distances1=distances1, 390 | distances2=distances2, 391 | baseE_vecs1=baseE_vecs1, 392 | baseE_vecs2=baseE_vecs2, 393 | ) 394 | 395 | 396 | 397 | class RobertaForCL(RobertaPreTrainedModel): 398 | _keys_to_ignore_on_load_missing = [r"position_ids"] 399 | 400 | def __init__(self, config, *model_args, **model_kargs): 401 | super().__init__(config) 402 | self.model_args = model_kargs["model_args"] 403 | self.roberta = RobertaModel(config, add_pooling_layer=False) 404 | 405 | if self.model_args.do_mlm: 406 | self.lm_head = RobertaLMHead(config) 407 | 408 | cl_init(self, config) 409 | 410 | def forward(self, 411 | input_ids=None, 412 | attention_mask=None, 413 | token_type_ids=None, 414 | position_ids=None, 415 | head_mask=None, 416 | inputs_embeds=None, 417 | labels=None, 418 | output_attentions=None, 419 | output_hidden_states=None, 420 | return_dict=None, 421 | sent_emb=False, 422 | mlm_input_ids=None, 423 | mlm_labels=None, 424 | distances1=None, 425 | distances2=None, 426 | baseE_vecs1=None, 427 | baseE_vecs2=None, 428 | ): 429 | if sent_emb: 430 | return sentemb_forward(self, self.roberta, 431 | input_ids=input_ids, 432 | attention_mask=attention_mask, 433 | token_type_ids=token_type_ids, 434 | position_ids=position_ids, 435 | head_mask=head_mask, 436 | inputs_embeds=inputs_embeds, 437 | labels=labels, 438 | output_attentions=output_attentions, 439 | output_hidden_states=output_hidden_states, 440 | return_dict=return_dict, 441 | ) 442 | else: 443 | return cl_forward(self, self.roberta, 444 | input_ids=input_ids, 445 | attention_mask=attention_mask, 446 | token_type_ids=token_type_ids, 447 | position_ids=position_ids, 448 | head_mask=head_mask, 449 | inputs_embeds=inputs_embeds, 450 | labels=labels, 451 | output_attentions=output_attentions, 452 | output_hidden_states=output_hidden_states, 453 | return_dict=return_dict, 454 | mlm_input_ids=mlm_input_ids, 455 | mlm_labels=mlm_labels, 456 | distances1=distances1, 457 | distances2=distances2, 458 | baseE_vecs1=baseE_vecs1, 459 | baseE_vecs2=baseE_vecs2, 460 | ) 461 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse_rank_encoder/tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | import numpy as np 4 | from numpy import ndarray 5 | import torch 6 | from torch import Tensor, device 7 | import transformers 8 | from transformers import AutoModel, AutoTokenizer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.preprocessing import normalize 11 | from typing import List, Dict, Tuple, Type, Union 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class SimCSE(object): 18 | """ 19 | A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE. 20 | """ 21 | def __init__(self, model_name_or_path: str, 22 | device: str = None, 23 | num_cells: int = 100, 24 | num_cells_in_search: int = 10, 25 | pooler = None): 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 28 | self.model = AutoModel.from_pretrained(model_name_or_path) 29 | if device is None: 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.device = device 32 | 33 | self.index = None 34 | self.is_faiss_index = False 35 | self.num_cells = num_cells 36 | self.num_cells_in_search = num_cells_in_search 37 | 38 | if pooler is not None: 39 | self.pooler = pooler 40 | elif "unsup" in model_name_or_path: 41 | logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") 42 | self.pooler = "cls_before_pooler" 43 | else: 44 | self.pooler = "cls" 45 | 46 | def encode(self, sentence: Union[str, List[str]], 47 | device: str = None, 48 | return_numpy: bool = False, 49 | normalize_to_unit: bool = True, 50 | keepdim: bool = False, 51 | batch_size: int = 64, 52 | max_length: int = 128) -> Union[ndarray, Tensor]: 53 | 54 | target_device = self.device if device is None else device 55 | self.model = self.model.to(target_device) 56 | 57 | single_sentence = False 58 | if isinstance(sentence, str): 59 | sentence = [sentence] 60 | single_sentence = True 61 | 62 | embedding_list = [] 63 | with torch.no_grad(): 64 | total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) 65 | for batch_id in tqdm(range(total_batch)): 66 | inputs = self.tokenizer( 67 | sentence[batch_id*batch_size:(batch_id+1)*batch_size], 68 | padding=True, 69 | truncation=True, 70 | max_length=max_length, 71 | return_tensors="pt" 72 | ) 73 | inputs = {k: v.to(target_device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs, return_dict=True) 75 | if self.pooler == "cls": 76 | embeddings = outputs.pooler_output 77 | elif self.pooler == "cls_before_pooler": 78 | embeddings = outputs.last_hidden_state[:, 0] 79 | else: 80 | raise NotImplementedError 81 | if normalize_to_unit: 82 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 83 | embedding_list.append(embeddings.cpu()) 84 | embeddings = torch.cat(embedding_list, 0) 85 | 86 | if single_sentence and not keepdim: 87 | embeddings = embeddings[0] 88 | 89 | if return_numpy and not isinstance(embeddings, ndarray): 90 | return embeddings.numpy() 91 | return embeddings 92 | 93 | def similarity(self, queries: Union[str, List[str]], 94 | keys: Union[str, List[str], ndarray], 95 | device: str = None) -> Union[float, ndarray]: 96 | 97 | query_vecs = self.encode(queries, device=device, return_numpy=True) # suppose N queries 98 | 99 | if not isinstance(keys, ndarray): 100 | key_vecs = self.encode(keys, device=device, return_numpy=True) # suppose M keys 101 | else: 102 | key_vecs = keys 103 | 104 | # check whether N == 1 or M == 1 105 | single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1 106 | if single_query: 107 | query_vecs = query_vecs.reshape(1, -1) 108 | if single_key: 109 | key_vecs = key_vecs.reshape(1, -1) 110 | 111 | # returns an N*M similarity array 112 | similarities = cosine_similarity(query_vecs, key_vecs) 113 | 114 | if single_query: 115 | similarities = similarities[0] 116 | if single_key: 117 | similarities = float(similarities[0]) 118 | 119 | return similarities 120 | 121 | def build_index(self, sentences_or_file_path: Union[str, List[str]], 122 | use_faiss: bool = None, 123 | faiss_fast: bool = False, 124 | device: str = None, 125 | batch_size: int = 64): 126 | 127 | if use_faiss is None or use_faiss: 128 | try: 129 | import faiss 130 | assert hasattr(faiss, "IndexFlatIP") 131 | use_faiss = True 132 | except: 133 | logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.") 134 | use_faiss = False 135 | 136 | # if the input sentence is a string, we assume it's the path of file that stores various sentences 137 | if isinstance(sentences_or_file_path, str): 138 | sentences = [] 139 | with open(sentences_or_file_path, "r") as f: 140 | logging.info("Loading sentences from %s ..." % (sentences_or_file_path)) 141 | for line in tqdm(f): 142 | sentences.append(line.rstrip()) 143 | sentences_or_file_path = sentences 144 | 145 | logger.info("Encoding embeddings for sentences...") 146 | embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True) 147 | 148 | logger.info("Building index...") 149 | self.index = {"sentences": sentences_or_file_path} 150 | 151 | if use_faiss: 152 | quantizer = faiss.IndexFlatIP(embeddings.shape[1]) 153 | if faiss_fast: 154 | index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path))) 155 | else: 156 | index = quantizer 157 | 158 | if (self.device == "cuda" and device != "cpu") or device == "cuda": 159 | if hasattr(faiss, "StandardGpuResources"): 160 | logger.info("Use GPU-version faiss") 161 | res = faiss.StandardGpuResources() 162 | res.setTempMemory(20 * 1024 * 1024 * 1024) 163 | index = faiss.index_cpu_to_gpu(res, 0, index) 164 | else: 165 | logger.info("Use CPU-version faiss") 166 | else: 167 | logger.info("Use CPU-version faiss") 168 | 169 | if faiss_fast: 170 | index.train(embeddings.astype(np.float32)) 171 | index.add(embeddings.astype(np.float32)) 172 | index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path)) 173 | self.is_faiss_index = True 174 | else: 175 | index = embeddings 176 | self.is_faiss_index = False 177 | self.index["index"] = index 178 | logger.info("Finished") 179 | 180 | def search(self, queries: Union[str, List[str]], 181 | device: str = None, 182 | threshold: float = 0.6, 183 | top_k: int = 5) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: 184 | 185 | if not self.is_faiss_index: 186 | if isinstance(queries, list): 187 | combined_results = [] 188 | for query in queries: 189 | results = self.search(query, device) 190 | combined_results.append(results) 191 | return combined_results 192 | 193 | similarities = self.similarity(queries, self.index["index"]).tolist() 194 | id_and_score = [] 195 | for i, s in enumerate(similarities): 196 | if s >= threshold: 197 | id_and_score.append((i, s)) 198 | id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k] 199 | results = [(self.index["sentences"][idx], score) for idx, score in id_and_score] 200 | return results 201 | else: 202 | query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True) 203 | 204 | distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k) 205 | 206 | def pack_single_result(dist, idx): 207 | results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold] 208 | return results 209 | 210 | if isinstance(queries, list): 211 | combined_results = [] 212 | for i in range(len(queries)): 213 | results = pack_single_result(distance[i], idx[i]) 214 | combined_results.append(results) 215 | return combined_results 216 | else: 217 | return pack_single_result(distance[0], idx[0]) 218 | 219 | if __name__=="__main__": 220 | example_sentences = [ 221 | 'An animal is biting a persons finger.', 222 | 'A woman is reading.', 223 | 'A man is lifting weights in a garage.', 224 | 'A man plays the violin.', 225 | 'A man is eating food.', 226 | 'A man plays the piano.', 227 | 'A panda is climbing.', 228 | 'A man plays a guitar.', 229 | 'A woman is slicing a meat.', 230 | 'A woman is taking a picture.' 231 | ] 232 | example_queries = [ 233 | 'A man is playing music.', 234 | 'A woman is making a photo.' 235 | ] 236 | 237 | model_name = "princeton-nlp/sup-simcse-bert-base-uncased" 238 | simcse = SimCSE(model_name) 239 | 240 | print("\n=========Calculate cosine similarities between queries and sentences============\n") 241 | similarities = simcse.similarity(example_queries, example_sentences) 242 | print(similarities) 243 | 244 | print("\n=========Naive brute force search============\n") 245 | simcse.build_index(example_sentences, use_faiss=False) 246 | results = simcse.search(example_queries) 247 | for i, result in enumerate(results): 248 | print("Retrieval results for query: {}".format(example_queries[i])) 249 | for sentence, score in result: 250 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 251 | print("") 252 | 253 | print("\n=========Search with Faiss backend============\n") 254 | simcse.build_index(example_sentences, use_faiss=True) 255 | results = simcse.search(example_queries) 256 | for i, result in enumerate(results): 257 | print("Retrieval results for query: {}".format(example_queries[i])) 258 | for sentence, score in result: 259 | print(" {} (cosine similarity: {:.4f})".format(sentence, score)) 260 | print("") 261 | 262 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse_rank_encoder_inference.py: -------------------------------------------------------------------------------- 1 | import apex 2 | import re 3 | import sys 4 | import io, os 5 | import faiss 6 | import math 7 | import json 8 | import torch 9 | import numpy as np 10 | import logging 11 | import tqdm 12 | import time 13 | import argparse 14 | from prettytable import PrettyTable 15 | from scipy.stats import spearmanr, pearsonr 16 | from scipy.special import softmax 17 | from scipy.stats import rankdata 18 | import string 19 | import torch 20 | import transformers 21 | from transformers import AutoModel, AutoTokenizer 22 | from transformers import BertTokenizer, BertModel 23 | from tqdm import tqdm 24 | # Set up logger 25 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 26 | 27 | # Set PATHs 28 | PATH_TO_SENTEVAL = './SentEval' 29 | 30 | # Import SentEval 31 | sys.path.insert(0, PATH_TO_SENTEVAL) 32 | 33 | import senteval 34 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval 35 | from senteval.sts import SICKRelatednessEval 36 | 37 | def normalize(vecs): 38 | eps = 1e-8 39 | return vecs / (np.sqrt(np.sum(np.square(vecs), axis=1)) + eps)[:,None] 40 | 41 | def print_table(task_names, scores): 42 | tb = PrettyTable() 43 | tb.field_names = task_names 44 | tb.add_row(scores) 45 | print(tb) 46 | 47 | def read_benchmark_data(senteval_path, task): 48 | task2class = { \ 49 | 'STS12': STS12Eval, 50 | 'STS13': STS13Eval, 51 | 'STS14': STS14Eval, 52 | 'STS15': STS15Eval, 53 | 'STS16': STS16Eval, 54 | 'STSBenchmark': STSBenchmarkEval, 55 | 'SICKRelatedness': SICKRelatednessEval 56 | } 57 | dataset_path = None 58 | print("SentEval path: {}".format(senteval_path)) 59 | if task in ["STS12", "STS13", "STS14", "STS15", "STS16"]: 60 | dataset_path = os.path.join(senteval_path, "downstream/STS/", "{}{}".format(task, "-en-test")) 61 | elif task == "STSBenchmark": 62 | dataset_path = os.path.join(senteval_path, "downstream/STS/", "{}".format(task)) 63 | elif task == "SICKRelatedness": 64 | dataset_path = os.path.join(senteval_path, "downstream/SICK") 65 | print(dataset_path) 66 | data = {} 67 | task_data = task2class[task](dataset_path) 68 | for dset in task_data.datasets: 69 | input1, input2, gs_scores = task_data.data[dset] 70 | data[dset] = (input1, input2, gs_scores) 71 | return data 72 | 73 | def compute_similarity(q0, q0_sim, q1, q1_sim, lmb=0.0): 74 | normalized_q0 = normalize(np.reshape(q0, (1, -1))) 75 | normalized_q1 = normalize(np.reshape(q1, (1, -1))) 76 | add_score, _ = spearmanr(q0_sim, q1_sim) 77 | score = np.sum(np.matmul(normalized_q0, normalized_q1.T)) 78 | score = lmb * score + (1.0 - lmb) * add_score 79 | return score 80 | 81 | def evaluate_retrieval_augmented_promptbert(args, data, batcher, sentence_vecs): 82 | results = {} 83 | all_sys_scores = [] 84 | all_gs_scores = [] 85 | for dset in data: 86 | sys_scores = [] 87 | input1, input2, gs_scores = data[dset] 88 | for ii in range(0, len(gs_scores), args.batch_size): 89 | batch1 = input1[ii:ii + args.batch_size] 90 | batch2 = input2[ii:ii + args.batch_size] 91 | 92 | # we assume get_batch already throws out the faulty ones 93 | if len(batch1) == len(batch2) \ 94 | and len(batch1) > 0: 95 | enc1 = batcher(batch1) 96 | enc2 = batcher(batch2) 97 | sim1 = np.matmul( \ 98 | enc1, sentence_vecs.T \ 99 | ) 100 | sim2 = np.matmul( \ 101 | enc2, sentence_vecs.T \ 102 | ) 103 | 104 | for kk in range(enc1.shape[0]): 105 | sys_score = compute_similarity( \ 106 | enc1[kk], sim1[kk], \ 107 | enc2[kk], sim2[kk], \ 108 | args.lmb \ 109 | ) 110 | sys_scores.append(sys_score) 111 | all_sys_scores.extend(sys_scores) 112 | all_gs_scores.extend(gs_scores) 113 | results[dset] = { 114 | 'pearson': pearsonr(sys_scores, gs_scores), 115 | 'spearman': spearmanr(sys_scores, gs_scores), 116 | 'nsamples': len(sys_scores) 117 | } 118 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 119 | (dset, results[dset]['pearson'][0], 120 | results[dset]['spearman'][0])) 121 | 122 | weights = [results[dset]['nsamples'] for dset in results.keys()] 123 | list_prs = np.array([results[dset]['pearson'][0] for 124 | dset in results.keys()]) 125 | list_spr = np.array([results[dset]['spearman'][0] for 126 | dset in results.keys()]) 127 | 128 | avg_pearson = np.average(list_prs) 129 | avg_spearman = np.average(list_spr) 130 | wavg_pearson = np.average(list_prs, weights=weights) 131 | wavg_spearman = np.average(list_spr, weights=weights) 132 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 133 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 134 | results['all'] = {'pearson': {'all': all_pearson[0], 135 | 'mean': avg_pearson, 136 | 'wmean': wavg_pearson}, 137 | 'spearman': {'all': all_spearman[0], 138 | 'mean': avg_spearman, 139 | 'wmean': wavg_spearman}} 140 | logging.debug('ALL : Pearson = %.4f, \ 141 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 142 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 143 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 144 | logging.debug('ALL (average) : Pearson = %.4f, \ 145 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 146 | results["pred_scores"] = all_sys_scores 147 | results["gs_scores"] = all_gs_scores 148 | return results 149 | 150 | def parse_args(): 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--sentence_vecs", type=str, required=True) 153 | parser.add_argument("--senteval_path", type=str, default="SentEval/data") 154 | parser.add_argument("--batch_size", type=int, default=32) 155 | parser.add_argument("--lmb", type=float, default=1.0) 156 | parser.add_argument("--model_name_or_path", type=str, 157 | help="Transformers' model name or path") 158 | args = parser.parse_args() 159 | return args 160 | 161 | def main(args): 162 | device = torch.device("cpu") 163 | if torch.cuda.is_available(): 164 | device = torch.device("cuda") 165 | n_gpu = torch.cuda.device_count() 166 | 167 | model = AutoModel.from_pretrained(args.model_name_or_path) 168 | model.to(device) 169 | if n_gpu > 1: 170 | model = torch.nn.DataParallel(model) 171 | model.eval() 172 | 173 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 174 | 175 | def batcher(batch, max_length=None): 176 | # Handle rare token encoding issues in the dataset 177 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 178 | batch = [[word.decode('utf-8') for word in s] for s in batch] 179 | 180 | sentences = [' '.join(s) for s in batch] 181 | simcse_batch = tokenizer.batch_encode_plus( \ 182 | sentences, \ 183 | return_tensors="pt", \ 184 | padding=True, \ 185 | truncation=True, \ 186 | max_length=max_length \ 187 | ) 188 | gpu_batch = {} 189 | for key in simcse_batch: 190 | gpu_batch[key] = simcse_batch[key].to(device) 191 | with torch.no_grad(): 192 | sentence_embedding = model(**gpu_batch) 193 | sentence_embedding = sentence_embedding.last_hidden_state[:,0].cpu().numpy() 194 | sentence_embedding = normalize(sentence_embedding) 195 | return sentence_embedding 196 | 197 | print("Loading {}".format(args.sentence_vecs)) 198 | sentence_vecs = np.load(args.sentence_vecs) 199 | 200 | # Load benchmark datasets 201 | target_tasks = [ \ 202 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', \ 203 | 'STSBenchmark', \ 204 | 'SICKRelatedness' \ 205 | ] 206 | # Reference: https://github.com/facebookresearch/SentEval/blob/main/senteval/sts.py 207 | results = {} 208 | for task in target_tasks: 209 | data = read_benchmark_data(args.senteval_path, task) 210 | result = evaluate_retrieval_augmented_promptbert(args, data, batcher, sentence_vecs) 211 | results[task] = result 212 | 213 | task_names = [] 214 | scores = [] 215 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 216 | task_names.append(task) 217 | if task in results: 218 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 219 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 220 | else: 221 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 222 | else: 223 | scores.append("0.00") 224 | task_names.append("Avg.") 225 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 226 | print_table(task_names, scores) 227 | 228 | return 0 229 | 230 | if __name__ == "__main__": 231 | args = parse_args() 232 | _ = main(args) 233 | -------------------------------------------------------------------------------- /code/SimCSE_RankEncoder/simcse_rank_encoder_inference.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=$PROJECT_DIR/outputs 2 | 3 | LMB=0.9 4 | SEED=61507 5 | CUDA_VISIBLE_DEVICES=0,1 python simcse_rank_encoder_inference.py \ 6 | --sentence_vecs $OUTPUT_DIR/simcse/index_vecs/corpus_0.01_rank_encoder_seed_$SEED.npy \ 7 | --batch_size 256 \ 8 | --model_name_or_path $OUTPUT_DIR/simcse/checkpoints/simcse_unsup_rank_encoder_seed_$SEED \ 9 | --lmb $LMB 10 | -------------------------------------------------------------------------------- /code/file_utils/random_sampling_sentences.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | import random 5 | from tqdm import tqdm 6 | 7 | def write_corpus_file(sentences, fname): 8 | os.makedirs(os.path.dirname(fname), exist_ok=True) 9 | with open(fname, "w") as f: 10 | for s in tqdm(sentences, desc="Writing"): 11 | f.write("{}\n".format(s)) 12 | return 0 13 | 14 | def read_wiki1m(fname): 15 | print("Reading {}".format(fname)) 16 | with open(fname, "r") as f: 17 | lines = f.readlines() 18 | lines = [l.strip() for l in lines] 19 | return lines 20 | 21 | def sampling_sentences(sentences, ratio, n_sentences): 22 | n = None 23 | if args.ratio != None: 24 | n = int(len(sentences) * ratio) 25 | elif args.n_sentences != None: 26 | n = args.n_sentences 27 | #assert n_sentences <= len(sentences) 28 | n = min(n, len(sentences)) 29 | samples = random.sample(sentences, n) 30 | return samples 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--sentence_file", type=str, required=True) 35 | parser.add_argument("--output_file", type=str, required=True) 36 | parser.add_argument("--ratio", type=float, default=None) 37 | parser.add_argument("--n_sentences", type=int, default=None) 38 | parser.add_argument("--seed", type=int, default=42) 39 | args = parser.parse_args() 40 | return args 41 | 42 | def main(args): 43 | assert args.ratio != None or args.n_sentences != None 44 | 45 | random.seed(args.seed) 46 | texts = read_wiki1m(args.sentence_file) 47 | samples = sampling_sentences(texts, args.ratio, args.n_sentences) 48 | _ = write_corpus_file(samples, args.output_file) 49 | return 0 50 | 51 | if __name__ == "__main__": 52 | args = parse_args() 53 | _ = main(args) 54 | -------------------------------------------------------------------------------- /code/file_utils/random_sampling_sentences.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=$PROJECT_DIR/data 2 | OUTPUT_DIR=$PROJECT_DIR/outputs 3 | 4 | SEED=42 5 | python random_sampling_sentences.py \ 6 | --sentence_file $DATA_DIR/corpus/corpus.txt \ 7 | --output_file $OUTPUT_DIR/corpus/corpus_0.01.txt \ 8 | --seed $SEED \ 9 | --n_sentences 10000 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apex==0.1 2 | datasets==1.2.1 3 | deepspeed==0.6.5 4 | faiss==1.7.2 5 | filelock==3.4.0 6 | matplotlib==3.5.0 7 | nltk==3.6.5 8 | numpy==1.21.2 9 | packaging==21.3 10 | prettytable==2.1.0 11 | ray==1.13.0 12 | scikit_learn==1.1.1 13 | scipy==1.5.4 14 | setuptools==49.3.0 15 | skipthoughts==0.0.1 16 | tensorflow==2.9.1 17 | tensorflow_hub==0.12.0 18 | torch==1.10.0 19 | tqdm==4.49.0 20 | transformers==4.2.1 21 | spacy==3.4.1 22 | --------------------------------------------------------------------------------