├── stsbenchmark.tsv.gz ├── requirements.txt ├── README.md ├── training.py ├── loss.py └── modules.py /stsbenchmark.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galsang/SG-BERT/HEAD/stsbenchmark.tsv.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2022.5.18.1 2 | charset-normalizer==2.0.12 3 | click==8.1.3 4 | dataclasses==0.6 5 | dill==0.3.5.1 6 | filelock==3.7.1 7 | future==0.18.2 8 | idna==3.3 9 | importlib-metadata==4.11.4 10 | joblib==1.1.0 11 | nltk==3.7 12 | numpy==1.21.6 13 | packaging==21.3 14 | protobuf==3.20.0 15 | pyparsing==3.0.9 16 | regex==2022.6.2 17 | requests==2.27.1 18 | sacremoses==0.0.53 19 | scikit-learn==1.0.2 20 | scipy==1.7.3 21 | sentence-transformers==0.3.9 22 | sentencepiece==0.1.91 23 | six==1.16.0 24 | threadpoolctl==3.1.0 25 | tokenizers==0.9.3 26 | torch==1.7.0 27 | tqdm==4.64.0 28 | transformers==3.5.1 29 | typing_extensions==4.2.0 30 | urllib3==1.26.9 31 | zipp==3.8.0 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SG-BERT 2 | 3 | This repository contains the implementation of **Self-Gudied Contrastive Learning for BERT Sentence Representations (ACL 2021)**. 4 | (Disclaimer: the code is a little bit cluttered as this is not a cleaned version.) 5 | 6 | When using this code for the following work, please cite our paper with the BibTex below. 7 | 8 | @inproceedings{kim-etal-2021-self, 9 | title = "Self-Guided Contrastive Learning for {BERT} Sentence Representations", 10 | author = "Kim, Taeuk and 11 | Yoo, Kang Min and 12 | Lee, Sang-goo", 13 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 14 | month = aug, 15 | year = "2021", 16 | address = "Online", 17 | publisher = "Association for Computational Linguistics", 18 | url = "https://aclanthology.org/2021.acl-long.197", 19 | doi = "10.18653/v1/2021.acl-long.197", 20 | pages = "2528--2540", 21 | abstract = "Although BERT and its variants have reshaped the NLP landscape, it still remains unclear how best to derive sentence embeddings from such pre-trained Transformers. In this work, we propose a contrastive learning method that utilizes self-guidance for improving the quality of BERT sentence representations. Our method fine-tunes BERT in a self-supervised fashion, does not rely on data augmentation, and enables the usual [CLS] token embeddings to function as sentence vectors. Moreover, we redesign the contrastive learning objective (NT-Xent) and apply it to sentence representation learning. We demonstrate with extensive experiments that our approach is more effective than competitive baselines on diverse sentence-related tasks. We also show it is efficient at inference and robust to domain shifts.",} 22 | 23 | 24 | 25 | ## Pre-requisite Python Libraries 26 | 27 | Please install the following libraries specified in the **requirements.txt** first before running our code. 28 | 29 | certifi==2022.5.18.1 30 | charset-normalizer==2.0.12 31 | click==8.1.3 32 | dataclasses==0.6 33 | dill==0.3.5.1 34 | filelock==3.7.1 35 | future==0.18.2 36 | idna==3.3 37 | importlib-metadata==4.11.4 38 | joblib==1.1.0 39 | nltk==3.7 40 | numpy==1.21.6 41 | packaging==21.3 42 | protobuf==3.20.0 43 | pyparsing==3.0.9 44 | regex==2022.6.2 45 | requests==2.27.1 46 | sacremoses==0.0.53 47 | scikit-learn==1.0.2 48 | scipy==1.7.3 49 | sentence-transformers==0.3.9 50 | sentencepiece==0.1.91 51 | six==1.16.0 52 | threadpoolctl==3.1.0 53 | tokenizers==0.9.3 54 | torch==1.7.0 55 | tqdm==4.64.0 56 | transformers==3.5.1 57 | typing_extensions==4.2.0 58 | urllib3==1.26.9 59 | zipp==3.8.0 60 | 61 | 62 | ## How to Run Code 63 | 64 | > python training.py 65 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import gzip 4 | import logging 5 | import math 6 | import random 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | from sentence_transformers import LoggingHandler, InputExample 12 | from sentence_transformers import models 13 | from sentence_transformers.models import Transformer 14 | from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction 15 | from torch.utils.data import DataLoader 16 | 17 | from loss import Loss 18 | from modules import SentencesDataset, SentenceTransformer 19 | 20 | start_time = time.time() 21 | 22 | PRETRAINED_MODELS = ['bert-base-nli-cls-token', 23 | 'bert-base-nli-mean-tokens', 24 | 'bert-large-nli-cls-token', 25 | 'bert-large-nli-mean-tokens', 26 | 'roberta-base-nli-cls-token', 27 | 'roberta-base-nli-mean-tokens', 28 | 'roberta-large-nli-cls-token', 29 | 'roberta-large-nli-mean-tokens'] 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--seed', default=1, type=int) 33 | parser.add_argument('--model_name', default='bert-base-uncased', type=str) 34 | parser.add_argument('--pooling', default='cls', type=str) 35 | parser.add_argument('--pooling2', default='mean', type=str) 36 | parser.add_argument('--eval_step', default=50, type=int) 37 | parser.add_argument('--num_epochs', default=1, type=int) 38 | parser.add_argument('--lr', default=5e-5, type=float) 39 | parser.add_argument('--batch_size', default=16, type=int) 40 | parser.add_argument('--T', default=1e-2, type=float) 41 | parser.add_argument('--eps', default=0.1, type=float) 42 | parser.add_argument('--lmin', default=0, type=int) 43 | parser.add_argument('--lmax', default=-1, type=int) 44 | parser.add_argument('--lamb', default=0.1, type=float) 45 | parser.add_argument('--es', default=10, type=int) 46 | parser.add_argument('--weight_decay', default=0, type=float) 47 | parser.add_argument('--training', default=True, action='store_true') 48 | parser.add_argument('--freeze', default=True, action='store_true') 49 | parser.add_argument('--clone', default=True, action='store_true') 50 | parser.add_argument('--disable_tqdm', default=True, action='store_true') 51 | parser.add_argument('--obj', default='SG-OPT', type=str) 52 | parser.add_argument('--device', default='cuda:0', type=str) 53 | 54 | args = parser.parse_args() 55 | for a in args.__dict__: 56 | print(f'{a}: {args.__dict__[a]}') 57 | 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.cuda.manual_seed_all(args.seed) 62 | torch.random.manual_seed(args.seed) 63 | 64 | logging.basicConfig(format='%(asctime)s - %(message)s', 65 | datefmt='%Y-%m-%d %H:%M:%S', 66 | level=logging.INFO, 67 | handlers=[LoggingHandler()]) 68 | 69 | sts_dataset_path = 'stsbenchmark.tsv.gz' 70 | 71 | args_string = args.model_name + '-' + str(args.seed) + '-' + str(args.eps) + '-' + args.pooling + '-' + str(args.lmin) + '-' + str(args.lmax) 72 | logging.info(f'args_string: {args_string}') 73 | model_save_path = f'output/{args_string}' 74 | 75 | if args.model_name in PRETRAINED_MODELS: 76 | logging.info('Loading from SBERT') 77 | pretrained = SentenceTransformer(args.model_name) 78 | word_embedding_model = pretrained._first_module() 79 | else: 80 | model_args = {'output_hidden_states': True, 'output_attentions': True} 81 | word_embedding_model = Transformer(args.model_name, model_args=model_args) 82 | 83 | pooling_model = models.Pooling( 84 | word_embedding_model.get_word_embedding_dimension(), 85 | pooling_mode_mean_tokens=args.pooling == 'mean' or args.pooling not in ['cls', 'max'], 86 | pooling_mode_cls_token=args.pooling == 'cls', 87 | pooling_mode_max_tokens=args.pooling == 'max') 88 | 89 | modules = [word_embedding_model, pooling_model] 90 | model = SentenceTransformer(modules=modules, name=args.model_name, device=args.device) 91 | 92 | train_samples = [] 93 | with gzip.open(sts_dataset_path, 'rt', encoding='utf-8') as fIn: 94 | reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE) 95 | for row in reader: 96 | train_samples.append(InputExample(texts=[row['sentence1']])) 97 | train_samples.append(InputExample(texts=[row['sentence2']])) 98 | 99 | train_dataset = SentencesDataset(train_samples, model=model) 100 | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) 101 | train_loss = Loss(model, args) 102 | 103 | logging.info(f"Read eval dataset") 104 | dev_samples = [] 105 | test_samples = [] 106 | 107 | with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn: 108 | reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE) 109 | label2int = {"contradiction": 0, "entailment": 1, "neutral": 2} 110 | for row in reader: 111 | if row['split'] == 'dev': 112 | score = float(row['score']) / 5.0 #Normalize score to range 0 ... 1 113 | dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score)) 114 | elif row['split'] == 'test': 115 | score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1 116 | test_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score)) 117 | 118 | dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, batch_size=args.batch_size, name=f'stsb-dev', main_similarity=SimilarityFunction.COSINE) 119 | test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, batch_size=args.batch_size, name=f'stsb-test', main_similarity=SimilarityFunction.COSINE) 120 | 121 | warmup_steps = math.ceil(len(train_dataset) * args.num_epochs / args.batch_size * 0.1) #10% of train data for warm-up 122 | logging.info("Warmup-steps: {}".format(warmup_steps)) 123 | 124 | model.fit(train_objectives=[(train_dataloader, train_loss)], 125 | dev_evaluator=dev_evaluator, 126 | test_evaluator= None, 127 | epochs=args.num_epochs, 128 | optimizer_params={'lr': args.lr, 'correct_bias': True, 'weight_decay': args.weight_decay, 'betas': (0.9, 0.9)}, 129 | evaluation_steps=args.eval_step, 130 | warmup_steps=warmup_steps, 131 | output_path=model_save_path, 132 | early_stopping_limit=args.es, 133 | disable_tqdm=args.disable_tqdm) 134 | 135 | logging.info('Training finished.') 136 | 137 | dev_score = dev_evaluator(model, output_path=model_save_path) 138 | test_score = test_evaluator(model, output_path=model_save_path) 139 | print(dev_score) 140 | print(test_score) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, Tensor 4 | from typing import Union, Tuple, List, Iterable, Dict 5 | 6 | import copy 7 | import logging 8 | 9 | import numpy as np 10 | import scipy.stats as stats 11 | 12 | from sentence_transformers import models 13 | from modules import Transformer, SentenceTransformer 14 | 15 | def compute_entropy(probs): 16 | eps = torch.finfo(probs.dtype).eps 17 | ps_clamped = probs.clamp(min=eps, max=1 - eps) 18 | logits = torch.log(ps_clamped) 19 | min_real = torch.finfo(logits.dtype).min 20 | logits = torch.clamp(logits, min=min_real) 21 | p_log_p = logits * probs 22 | return -p_log_p.sum(-1) 23 | 24 | class NTXentLossOriginal(torch.nn.Module): 25 | 26 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 27 | super(NTXentLossOriginal, self).__init__() 28 | self.batch_size = batch_size 29 | self.temperature = temperature 30 | self.device = device 31 | self.softmax = torch.nn.Softmax(dim=-1) 32 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 33 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 34 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 35 | 36 | def _get_similarity_function(self, use_cosine_similarity): 37 | if use_cosine_similarity: 38 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 39 | return self._cosine_simililarity 40 | else: 41 | return self._dot_simililarity 42 | 43 | def _get_correlated_mask(self): 44 | diag = np.eye(2 * self.batch_size) 45 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) 46 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) 47 | mask = torch.from_numpy((diag + l1 + l2)) 48 | mask = (1 - mask).type(torch.bool) 49 | return mask.to(self.device) 50 | 51 | @staticmethod 52 | def _dot_simililarity(x, y): 53 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 54 | # x shape: (N, 1, C) 55 | # y shape: (1, C, 2N) 56 | # v shape: (N, 2N) 57 | return v 58 | 59 | def _cosine_simililarity(self, x, y): 60 | # x shape: (N, 1, C) 61 | # y shape: (1, 2N, C) 62 | # v shape: (N, 2N) 63 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 64 | return v 65 | 66 | def forward(self, zis, zjs): 67 | self.batch_size = zis.size(0) 68 | representations = torch.cat([zjs, zis], dim=0) 69 | 70 | similarity_matrix = self.similarity_function(representations, representations) 71 | 72 | l_pos = torch.diag(similarity_matrix, self.batch_size) 73 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 74 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 75 | 76 | negatives = similarity_matrix[self._get_correlated_mask().type(torch.bool)].view(2 * self.batch_size, -1) 77 | 78 | logits = torch.cat((positives, negatives), dim=1) 79 | logits /= self.temperature 80 | 81 | labels = torch.zeros(self.batch_size).to(self.device).long() 82 | loss = self.criterion(logits, labels) 83 | 84 | return loss / self.batch_size 85 | 86 | class NTXentLossOpt1(NTXentLossOriginal): 87 | 88 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 89 | super(NTXentLossOpt1, self).__init__(device, batch_size, temperature, use_cosine_similarity) 90 | 91 | def forward(self, cls, pooled): 92 | self.batch_size = cls.size(0) 93 | representations = torch.cat([cls, pooled], dim=0) 94 | 95 | similarity_matrix = self.similarity_function(representations, representations) 96 | 97 | pos = torch.diag(similarity_matrix, self.batch_size) 98 | positives = pos.view(self.batch_size, 1) 99 | 100 | negatives = similarity_matrix[self._get_correlated_mask().type(torch.bool)].view(2 * self.batch_size, -1)[:self.batch_size] 101 | 102 | logits = torch.cat((positives, negatives), dim=1) 103 | logits /= self.temperature 104 | 105 | labels = torch.zeros(self.batch_size).to(self.device).long() 106 | loss = self.criterion(logits, labels) 107 | 108 | return loss / self.batch_size 109 | 110 | 111 | class NTXentLossOpt2(NTXentLossOriginal): 112 | 113 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 114 | super(NTXentLossOpt2, self).__init__(device, batch_size, temperature, use_cosine_similarity) 115 | 116 | def forward(self, cls, pooled): 117 | self.batch_size = cls.size(0) 118 | representations = torch.cat([cls, pooled], dim=0) 119 | 120 | similarity_matrix = self.similarity_function(representations, representations) 121 | 122 | pos = torch.diag(similarity_matrix, self.batch_size) 123 | positives = pos.view(self.batch_size, 1) 124 | 125 | negatives = similarity_matrix[self._get_correlated_mask().type(torch.bool)].view(2 * self.batch_size, -1)[:self.batch_size, self.batch_size:] 126 | 127 | logits = torch.cat((positives, negatives), dim=1) 128 | logits /= self.temperature 129 | 130 | labels = torch.zeros(self.batch_size).to(self.device).long() 131 | loss = self.criterion(logits, labels) 132 | 133 | return loss / self.batch_size 134 | 135 | 136 | class NTXentLoss(torch.nn.Module): 137 | 138 | def __init__(self, device, batch_size, temperature=1, use_cosine_similarity=True): 139 | super(NTXentLoss, self).__init__() 140 | self.batch_size = batch_size 141 | self.temperature = temperature 142 | self.device = device 143 | self.softmax = torch.nn.Softmax(dim=-1) 144 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 145 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 146 | 147 | def _get_similarity_function(self, use_cosine_similarity): 148 | if use_cosine_similarity: 149 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 150 | return self._cosine_simililarity 151 | else: 152 | return self._dot_simililarity 153 | 154 | def _get_correlated_mask(self, batch_size): 155 | diag = np.eye(2 * batch_size) 156 | l1 = np.eye((2 * batch_size), 2 * batch_size, k=-batch_size) 157 | l2 = np.eye((2 * batch_size), 2 * batch_size, k=batch_size) 158 | mask = torch.from_numpy((diag + l1 + l2)) 159 | mask = (1 - mask).type(torch.bool) 160 | return mask.to(self.device) 161 | 162 | @staticmethod 163 | def _dot_simililarity(x, y): 164 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 165 | # x shape: (N, 1, C) 166 | # y shape: (1, C, 2N) 167 | # v shape: (N, 2N) 168 | return v 169 | 170 | def _cosine_simililarity(self, x, y): 171 | # x shape: (N, 1, C) 172 | # y shape: (1, 2N, C) 173 | # v shape: (N, 2N) 174 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 175 | return v 176 | 177 | def euclidean(self, x, y): 178 | return ((x.unsqueeze(1) - y.unsqueeze(0)) ** 2).sum(dim=-1).sqrt() 179 | 180 | def forward(self, cls, cont): 181 | """ 182 | :param cls: (batch_size, hidden_size) 183 | :param cont: (batch_size, num_layers, hidden_size) 184 | :return: 185 | """ 186 | batch_size = cls.size(0) 187 | num_layers = cont.size(1) 188 | 189 | positives, negatives = [], [] 190 | for i in range(num_layers): 191 | # (batch_size, hidden_size) X (batch_size, hidden_size) -> (batch_size, batch_size) 192 | similarity_matrix = self.similarity_function(cls, cont[:, i]) 193 | # add (batch_size, 1) 194 | positives.append(torch.diag(similarity_matrix)) 195 | # (batch_size, batch_size - 1) 196 | neg_idx = (1 - torch.eye(batch_size)).bool() 197 | negatives.append(similarity_matrix[neg_idx].view(batch_size, -1)) 198 | # (batch_size * num_layers, 1) 199 | positives = torch.cat(positives).view(-1, 1) 200 | 201 | # add other cls embeddings to negative samples 202 | # similarity_matrix = self.similarity_function(cls, cls) 203 | # (batch_size, batch_size - 1) 204 | # cls_negatives = similarity_matrix[(1 - torch.eye(batch_size)).bool()].view(batch_size, -1) 205 | # (batch_size * num_layers, batch_size - 1) 206 | # cls_negatives = torch.cat([cls_negatives] * num_layers, dim=0) 207 | 208 | # (batch_size, (batch_size - 1) * (num_layers (+ 1))) 209 | negatives = torch.cat(negatives, dim=1) 210 | # (batch_size * num_layers, (batch_size - 1) * (num_layers (+ 1))) 211 | negatives = torch.cat([negatives] * num_layers, dim=0) 212 | 213 | # (batch_size * num_layers, 1 + (batch_size - 1) * (num_layers (+ 1))) 214 | logits = torch.cat((positives, negatives), dim=1) 215 | logits /= self.temperature 216 | 217 | labels = torch.zeros(batch_size * num_layers).to(self.device).long() 218 | loss = self.criterion(logits, labels) 219 | 220 | return loss / (batch_size * num_layers) 221 | 222 | 223 | class Loss(nn.Module): 224 | def __init__(self, model, args): 225 | super(Loss, self).__init__() 226 | self.args = args 227 | config = model._first_module().auto_model.config 228 | self.config = config 229 | self.vocab_size = config.vocab_size 230 | 231 | if self.args.lmax == -1: 232 | self.args.lmax = config.num_hidden_layers + 1 233 | 234 | # class: SentenceTransformer 235 | self.model = model 236 | 237 | self.original = copy.deepcopy(model) 238 | self.original[0].eval() 239 | self.original_params = dict(self.original[0].named_parameters()) 240 | for n, p in self.original_params.items(): 241 | p.requires_grad = False 242 | 243 | if args.freeze: 244 | for n, p in self.model._first_module().auto_model.embeddings.named_parameters(): 245 | p.requires_grad = False 246 | 247 | ph_hidden_size = 4096 248 | starting_hidden_size = config.hidden_size 249 | self.projection_head = nn.Sequential( 250 | nn.Linear(starting_hidden_size, ph_hidden_size), 251 | nn.GELU(), 252 | nn.Linear(ph_hidden_size, ph_hidden_size), 253 | nn.GELU()) 254 | 255 | self.projection_head[0].weight.data.normal_(mean=0.0, std=config.initializer_range) 256 | self.projection_head[0].bias.data.zero_() 257 | self.projection_head[2].weight.data.normal_(mean=0.0, std=config.initializer_range) 258 | self.projection_head[2].bias.data.zero_() 259 | 260 | if self.args.obj == 'SG-OPT': 261 | self.loss = NTXentLoss 262 | elif self.args.obj == 'OPT1': 263 | self.loss = NTXentLossOpt1 264 | elif self.args.obj == 'OPT2': 265 | self.loss = NTXentLossOpt2 266 | else: 267 | self.loss = NTXentLossOriginal 268 | 269 | self.loss = self.loss( 270 | device=torch.device(self.args.device), 271 | batch_size=args.batch_size, 272 | temperature=args.T, 273 | use_cosine_similarity=True) 274 | 275 | self.sample_cnt = torch.zeros(config.num_hidden_layers + 1, dtype=torch.int) 276 | 277 | def compute_diff(self): 278 | diff = 0.0 279 | for n,p in self.model[0].named_parameters(): 280 | diff += torch.norm(self.original_params[n] - p, p=2) ** 2 281 | return diff 282 | 283 | def mean_pooling(self, t, mask): 284 | return self.sum_pooling(t, mask) / mask.sum(2) 285 | 286 | def sum_pooling(self, t, mask): 287 | t = t * mask 288 | return t.sum(2) 289 | 290 | def max_pooling(self, t, mask): 291 | t[mask == 0] = -1e9 292 | return t.max(dim=2)[0] 293 | 294 | def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels): 295 | reps = [] 296 | for i, sf in enumerate(sentence_features): 297 | if self.args.clone: 298 | ori = self.original(copy.deepcopy(sentence_features[i])) 299 | else: 300 | ori = self.model(copy.deepcopy(sentence_features[i])) 301 | target = self.model(copy.deepcopy(sentence_features[i])) 302 | 303 | sent_emb = target['sentence_embedding'] 304 | batch_size = sent_emb.size(0) 305 | # (batch, n_layers, seq_len, hidden_size) 306 | intermediate = torch.stack([l for l in ori['all_layer_embeddings'][self.args.lmin:self.args.lmax]], dim=1) 307 | mask = ori['attention_mask'].unsqueeze(1).unsqueeze(-1).expand(intermediate.size()).float() 308 | # (batch, n_layers, hidden_size) 309 | pooled = getattr(self, f'{self.args.pooling2}_pooling')(intermediate, mask) 310 | reps.append({'sent_emb': sent_emb, 'pooled': pooled}) 311 | 312 | sent_emb = reps[0]['sent_emb'] 313 | if len(sentence_features) > 1 and self.args.obj == 'BT': 314 | pooled = reps[1]['sent_emb'] 315 | elif self.args.obj in ['SG', 'OPT1', 'OPT2']: 316 | idx = torch.randint(self.args.lmin, self.args.lmax, (batch_size,)) 317 | pooled = pooled[torch.arange(batch_size), idx] 318 | else: 319 | # pooled = torch.cat([reps[0]['pooled'], reps[1]['pooled']], dim=1) 320 | pooled = reps[0]['pooled'] 321 | 322 | sent_emb = self.projection_head(sent_emb) 323 | if self.args.pooling == 'test': 324 | pooled = self.pre_projection_head(pooled) 325 | pooled = self.projection_head(pooled) 326 | loss1 = self.loss(sent_emb, pooled) 327 | if self.args.lamb > 0 : 328 | loss2 = self.compute_diff() 329 | return loss1 + self.args.lamb * loss2 330 | else: 331 | return loss1 332 | 333 | 334 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | from collections import OrderedDict 6 | from typing import List, Dict, Optional 7 | from zipfile import ZipFile 8 | 9 | import numpy as np 10 | import requests 11 | import sentence_transformers 12 | import sentence_transformers.models as models 13 | import torch 14 | import torch.nn as nn 15 | import transformers 16 | from sentence_transformers import __DOWNLOAD_SERVER__ 17 | from sentence_transformers import __version__ 18 | from sentence_transformers.datasets.EncodeDataset import EncodeDataset 19 | from sentence_transformers.models import Pooling 20 | from sentence_transformers.readers import InputExample 21 | from sentence_transformers.util import import_from_string, batch_to_device, http_get 22 | from torch.utils.data import DataLoader 23 | from torch.utils.data import Dataset 24 | from tqdm.autonotebook import trange 25 | 26 | class Transformer(models.Transformer): 27 | 28 | __module__ = 'sbert_modules.Transformer' 29 | 30 | def __init__(self, model_name_or_path: str, max_seq_length: int = 128, 31 | model_args: Dict = {}, cache_dir: Optional[str] = None, 32 | tokenizer_args: Dict = {}, do_lower_case: Optional[bool] = None): 33 | super(Transformer, self).__init__(model_name_or_path, max_seq_length, model_args, cache_dir, tokenizer_args, do_lower_case) 34 | 35 | def forward(self, features): 36 | """Returns token_embeddings, cls_token""" 37 | output_states = self.auto_model(**features) 38 | output_tokens = output_states[0] 39 | 40 | cls_tokens = output_tokens[:, 0, :] # CLS token is first token 41 | features.update({'token_embeddings': output_tokens, 42 | 'cls_token_embeddings': cls_tokens, 43 | 'attention_mask': features['attention_mask']}) 44 | 45 | if self.auto_model.config.output_hidden_states: 46 | all_layer_idx = 2 47 | if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states 48 | all_layer_idx = 1 49 | 50 | hidden_states = output_states[all_layer_idx] 51 | features.update({'all_layer_embeddings': hidden_states}) 52 | if self.auto_model.config.output_attentions: 53 | attentions = output_states[all_layer_idx+1] 54 | features.update({'attentions': attentions}) 55 | 56 | return features 57 | 58 | @staticmethod 59 | def load(input_path: str): 60 | # Old classes used other config names than 'sentence_bert_config.json' 61 | for config_name in ['sentence_bert_config.json', 62 | 'sentence_roberta_config.json', 63 | 'sentence_distilbert_config.json', 64 | 'sentence_camembert_config.json', 65 | 'sentence_albert_config.json', 66 | 'sentence_xlm-roberta_config.json', 67 | 'sentence_xlnet_config.json']: 68 | sbert_config_path = os.path.join(input_path, config_name) 69 | if os.path.exists(sbert_config_path): 70 | break 71 | 72 | with open(sbert_config_path) as fIn: 73 | config = json.load(fIn) 74 | config['model_args'] = {'output_hidden_states': True, 'output_attentions': True} 75 | 76 | return Transformer(model_name_or_path=input_path, **config) 77 | 78 | 79 | class SentenceTransformer(sentence_transformers.SentenceTransformer): 80 | 81 | __module__ = 'sbert_modules.SentenceTransformer' 82 | 83 | def __init__(self, model_name_or_path=None, modules=None, device=None, name=None): 84 | self.encoder_name = name 85 | 86 | if model_name_or_path is not None and model_name_or_path != "": 87 | logging.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path)) 88 | model_path = model_name_or_path 89 | 90 | if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'): 91 | logging.info("Did not find folder {}".format(model_path)) 92 | 93 | if '\\' in model_path or model_path.count('/') > 1: 94 | raise AttributeError("Path {} not found".format(model_path)) 95 | 96 | model_path = __DOWNLOAD_SERVER__ + model_path + '.zip' 97 | logging.info("Try to download model from server: {}".format(model_path)) 98 | 99 | if model_path.startswith('http://') or model_path.startswith('https://'): 100 | model_url = model_path 101 | folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250].rstrip('.zip') 102 | 103 | try: 104 | from torch.hub import _get_torch_home 105 | torch_cache_home = _get_torch_home() 106 | except ImportError: 107 | torch_cache_home = os.path.expanduser( 108 | os.getenv('TORCH_HOME', os.path.join( 109 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 110 | default_cache_path = os.path.join(torch_cache_home, 'sentence_transformers') 111 | model_path = os.path.join(default_cache_path, folder_name) 112 | 113 | if not os.path.exists(model_path) or not os.listdir(model_path): 114 | if model_url[-1] == "/": 115 | model_url = model_url[:-1] 116 | logging.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path)) 117 | 118 | model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part" 119 | try: 120 | zip_save_path = os.path.join(model_path_tmp, 'model.zip') 121 | http_get(model_url, zip_save_path) 122 | with ZipFile(zip_save_path, 'r') as zip: 123 | zip.extractall(model_path_tmp) 124 | os.remove(zip_save_path) 125 | os.rename(model_path_tmp, model_path) 126 | except requests.exceptions.HTTPError as e: 127 | shutil.rmtree(model_path_tmp) 128 | if e.response.status_code == 404: 129 | logging.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url)) 130 | logging.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path)) 131 | 132 | model_path = None 133 | transformer_model = Transformer(model_name_or_path) 134 | pooling_model = Pooling(transformer_model.get_word_embedding_dimension()) 135 | modules = [transformer_model, pooling_model] 136 | 137 | else: 138 | raise e 139 | except Exception as e: 140 | shutil.rmtree(model_path) 141 | raise e 142 | 143 | #### Load from disk 144 | if model_path is not None: 145 | logging.info("Load SentenceTransformer from folder: {}".format(model_path)) 146 | 147 | if os.path.exists(os.path.join(model_path, 'config.json')): 148 | with open(os.path.join(model_path, 'config.json')) as fIn: 149 | config = json.load(fIn) 150 | if config['__version__'] > __version__: 151 | logging.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__)) 152 | 153 | with open(os.path.join(model_path, 'modules.json')) as fIn: 154 | contained_modules = json.load(fIn) 155 | 156 | modules = OrderedDict() 157 | for module_config in contained_modules: 158 | if module_config['type'] in ['sentence_transformers.models.Transformer', 'sentence_transformers.models.BERT']: 159 | module_config['type'] = 'sbert_modules.Transformer' 160 | module_class = import_from_string(module_config['type']) 161 | module = module_class.load(os.path.join(model_path, module_config['path'])) 162 | modules[module_config['name']] = module 163 | 164 | if modules is not None and not isinstance(modules, OrderedDict): 165 | modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)]) 166 | 167 | nn.Sequential.__init__(self, modules) 168 | if device is None: 169 | device = "cuda" if torch.cuda.is_available() else "cpu" 170 | logging.info("Use pytorch device: {}".format(device)) 171 | 172 | self._target_device = torch.device(device) 173 | 174 | def forward(self, input): 175 | for i, module in enumerate(self): 176 | if i > 0 and isinstance(module, models.Transformer): 177 | pass 178 | else: 179 | input = module(input) 180 | return input 181 | 182 | def fit(self, train_objectives, dev_evaluator, test_evaluator, epochs=1, 183 | steps_per_epoch=None, scheduler='WarmupLinear', warmup_steps=10000, 184 | optimizer_class=transformers.AdamW, optimizer_params={}, 185 | weight_decay=0.01, evaluation_steps=0, output_path=None, 186 | save_best_model=True, max_grad_norm=1, use_amp=False, callback=None, 187 | output_path_ignore_not_empty=False, early_stopping_limit=5, disable_tqdm=False): 188 | if use_amp: 189 | from torch.cuda.amp import autocast 190 | scaler = torch.cuda.amp.GradScaler() 191 | 192 | self.to(self._target_device) 193 | 194 | if output_path is not None: 195 | os.makedirs(output_path, exist_ok=True) 196 | 197 | dataloaders = [dataloader for dataloader, _ in train_objectives] 198 | 199 | # Use smart batching 200 | for dataloader in dataloaders: 201 | dataloader.collate_fn = self.smart_batching_collate 202 | 203 | loss_models = [loss for _, loss in train_objectives] 204 | for loss_model in loss_models: 205 | loss_model.to(self._target_device) 206 | 207 | self.best_score = -9999999 208 | 209 | if steps_per_epoch is None or steps_per_epoch == 0: 210 | steps_per_epoch = min([len(dataloader) for dataloader in dataloaders]) 211 | 212 | num_train_steps = int(steps_per_epoch * epochs) 213 | 214 | # Prepare optimizers 215 | optimizers = [] 216 | schedulers = [] 217 | for loss_model in loss_models: 218 | param_optimizer = [(n,p) for n, p in list(loss_model.named_parameters()) if p.requires_grad] 219 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 220 | optimizer_grouped_parameters = [ 221 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, 222 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 223 | ] 224 | 225 | optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params) 226 | scheduler_obj = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps) 227 | 228 | optimizers.append(optimizer) 229 | schedulers.append(scheduler_obj) 230 | 231 | global_step = 0 232 | data_iterators = [iter(dataloader) for dataloader in dataloaders] 233 | 234 | num_train_objectives = len(train_objectives) 235 | 236 | dev_score = self._eval_during_training(dev_evaluator, output_path, False, 0, 0, callback) 237 | 238 | skip_scheduler = False 239 | early_stopping_cnt = 0 240 | last_score = 0 241 | 242 | range_epoch = range(epochs) if disable_tqdm else trange(epochs, desc='Epoch') 243 | range_iter = range(steps_per_epoch) if disable_tqdm else trange(steps_per_epoch, desc="Iteration", smoothing=0.05) 244 | 245 | for epoch in range_epoch: 246 | training_steps = 0 247 | 248 | for loss_model in loss_models: 249 | loss_model.zero_grad() 250 | loss_model.train() 251 | 252 | for _ in range_iter: 253 | for train_idx in range(num_train_objectives): 254 | loss_model = loss_models[train_idx] 255 | optimizer = optimizers[train_idx] 256 | scheduler = schedulers[train_idx] 257 | data_iterator = data_iterators[train_idx] 258 | 259 | try: 260 | data = next(data_iterator) 261 | except StopIteration: 262 | #logging.info("Restart data_iterator") 263 | data_iterator = iter(dataloaders[train_idx]) 264 | data_iterators[train_idx] = data_iterator 265 | data = next(data_iterator) 266 | 267 | features, labels = batch_to_device(data, self._target_device) 268 | 269 | if use_amp: 270 | with autocast(): 271 | loss_value = loss_model(features, labels) 272 | 273 | scale_before_step = scaler.get_scale() 274 | scaler.scale(loss_value).backward() 275 | scaler.unscale_(optimizer) 276 | torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm) 277 | scaler.step(optimizer) 278 | scaler.update() 279 | 280 | skip_scheduler = scaler.get_scale() != scale_before_step 281 | else: 282 | loss_value = loss_model(features, labels) 283 | loss_value.backward() 284 | torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm) 285 | optimizer.step() 286 | 287 | optimizer.zero_grad() 288 | 289 | if not skip_scheduler: 290 | scheduler.step() 291 | 292 | training_steps += 1 293 | global_step += 1 294 | 295 | if evaluation_steps > 0 and training_steps % evaluation_steps == 0: 296 | dev_score = self._eval_during_training(dev_evaluator, output_path, save_best_model, epoch, training_steps, callback) 297 | 298 | for loss_model in loss_models: 299 | loss_model.zero_grad() 300 | loss_model.train() 301 | 302 | if dev_score < last_score: 303 | early_stopping_cnt += 1 304 | if early_stopping_cnt >= early_stopping_limit: 305 | logging.info('Early stopping!') 306 | return 307 | last_score = dev_score 308 | 309 | self._eval_during_training(dev_evaluator, output_path, save_best_model, epoch, training_steps, callback) 310 | if test_evaluator is not None: 311 | self._eval_during_training(test_evaluator, output_path, save_best_model, epoch, training_steps, callback) 312 | 313 | def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback): 314 | """Runs evaluation during the training""" 315 | if evaluator is not None: 316 | score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps) 317 | if callback is not None: 318 | callback(score, epoch, steps) 319 | if score > self.best_score: 320 | self.best_score = score 321 | if save_best_model: 322 | self.save(output_path) 323 | 324 | return score 325 | 326 | 327 | class SentencesDataset(Dataset): 328 | def __init__(self, examples: List[InputExample], model): 329 | self.model = model 330 | self.examples = examples 331 | self.n = 0 332 | for m in model: 333 | if isinstance(m, models.Transformer): 334 | self.n += 1 335 | self.label_type = torch.long if isinstance(self.examples[0].label, int) else torch.float 336 | 337 | def __getitem__(self, item): 338 | label = torch.tensor(self.examples[item].label, dtype=self.label_type) 339 | if self.examples[item].texts_tokenized is None: 340 | if self.n > 1: 341 | text = self.examples[item].texts[0] 342 | self.examples[item].texts_tokenized = [self.model[i].tokenize(text) for i in range(self.n)] 343 | else: 344 | self.examples[item].texts_tokenized = [self.model.tokenize(text) for text in self.examples[item].texts] 345 | return self.examples[item].texts_tokenized, label 346 | 347 | def __len__(self): 348 | return len(self.examples) 349 | --------------------------------------------------------------------------------