├── .gitignore ├── README.md ├── environment.yml ├── log.py ├── modeling.py ├── ota.py ├── utils.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # One-Token Approximation 2 | 3 | While traditional word embedding algorithms (e.g., Word2Vec, Glove) assign a **single** embedding to each word, pretrained language models (e.g., BERT, RoBERTa, XLNet, T5) typically represent words as sequences of subword tokens. For example, BERT represents the word `strawberries` as two tokens `straw` and `##berries`. Transferring ideas and algorithms from traditional embeddings to contextualized embeddings may therefore raise questions like the following: 4 | 5 | > How would the embedding of "strawberries" (or any other multi-token word) look like in BERT's embedding space if the word was represented by a single token? 6 | 7 | *One-Token Approximation* (OTA) can be used to answer this question. More details can be found [here](https://arxiv.org/abs/1904.06707). 8 | 9 | ## Dependencies 10 | 11 | All dependencies can be found in `environment.yml`. If you use conda, simply type 12 | ``` 13 | conda env create -f environment.yml 14 | ``` 15 | to create a new environment with all required packages installed. 16 | 17 | ## Usage 18 | 19 | To obtain One-Token Approximations for multi-token words, run the following command: 20 | 21 | ``` 22 | python3 ota.py --words WORDS --output_file OUTPUT_FILE --model_cls MODEL_CLS --model MODEL --iterations ITERATIONS 23 | ``` 24 | where 25 | - `WORDS` is the path to a file containing all words for which one-token approximations should be computed (with each line containing exactly one word); 26 | - `OUTPUT_FILE` is the path to a file where all one-token approximations are saved (in the format ` `); 27 | - `MODEL_CLS` is either `bert` or `roberta` (the script currently does not support other pretrained language models); 28 | - `MODEL` is either the name of a pretrained model from the [Hugging Face Transformers Library](https://github.com/huggingface/transformers) (e.g., `bert-base-uncased`) or the path to a finetuned model; 29 | - `ITERATIONS` is the number of iterations for which to perform OTA. For BERT, 4000 iterations generally give good results; for RoBERTa, we found that much better results can be obtained by increasing the number of iterations to 8000. 30 | 31 | For additional parameters, check the source code of `ota.py` or run `python3 ota.py --help`. 32 | 33 | ## Citation 34 | 35 | If you make use of One-Token Approximation, please cite the following paper: 36 | 37 | ``` 38 | @inproceedings{schick2020rare, 39 | title={Rare words: A major problem for contextualized representation and how to fix it by attentive mimicking}, 40 | author={Schick, Timo and Sch{\"u}tze, Hinrich}, 41 | url="https://arxiv.org/abs/1904.06707", 42 | booktitle={Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence}, 43 | year={2020} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ota 2 | 3 | channels: 4 | - pytorch 5 | - defaults 6 | 7 | dependencies: 8 | - python=3.7 9 | - numpy 10 | - pytorch 11 | - torchvision 12 | - scipy 13 | - pip: 14 | - transformers==2.1 -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | names = set() 4 | 5 | 6 | def __setup_custom_logger(name: str) -> logging.Logger: 7 | formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s') 8 | 9 | names.add(name) 10 | 11 | handler = logging.StreamHandler() 12 | handler.setFormatter(formatter) 13 | 14 | logger = logging.getLogger(name) 15 | logger.setLevel(logging.INFO) 16 | logger.addHandler(handler) 17 | return logger 18 | 19 | 20 | def get_logger(name: str) -> logging.Logger: 21 | if name in names: 22 | return logging.getLogger(name) 23 | else: 24 | return __setup_custom_logger(name) 25 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from typing import Callable, List, Tuple 4 | 5 | from torch import Tensor 6 | from torch.nn import Module, Embedding 7 | 8 | from transformers import PreTrainedTokenizer, BertTokenizer, GPT2Tokenizer 9 | 10 | import log 11 | 12 | logger = log.get_logger('root') 13 | 14 | 15 | def default_filter(x: str) -> bool: 16 | return not x.startswith('[') and not x.startswith('#') 17 | 18 | 19 | class OverwriteableEmbedding(Module): 20 | 21 | def __init__(self, embedding: Embedding, overwrite_fct): 22 | super().__init__() 23 | self.embedding = embedding 24 | self.overwrite_fct = overwrite_fct 25 | 26 | def forward(self, input: Tensor): 27 | embds = self.embedding(input) 28 | if self.overwrite_fct is not None: 29 | embds = self.overwrite_fct(embds) 30 | return embds 31 | 32 | 33 | class OTAInput: 34 | def __init__(self, tokens, segments=None, mask=None): 35 | self.tokens = tokens 36 | self.segments = segments if segments is not None else torch.zeros(self.tokens.shape, dtype=torch.long) 37 | self.mask = mask if mask is not None else torch.ones(self.tokens.shape, dtype=torch.long) 38 | 39 | def get_length(self) -> int: 40 | return self.tokens.shape[0] 41 | 42 | @staticmethod 43 | def stack(inputs: List['OTAInput']) -> 'OTAInput': 44 | max_seq_length = max(x.get_length() for x in inputs) 45 | for inp in inputs: 46 | # zero-pad up to the sequence length 47 | padding = torch.tensor([0] * (max_seq_length - inp.get_length()), dtype=torch.long) 48 | inp.tokens = torch.cat((inp.tokens, padding), dim=0) 49 | inp.segments = torch.cat((inp.segments, padding), dim=0) 50 | inp.mask = torch.cat((inp.mask, padding), dim=0) 51 | 52 | stacked_tokens = torch.stack([x.tokens for x in inputs]) 53 | stacked_segments = torch.stack([x.segments for x in inputs]) 54 | stacked_masks = torch.stack([x.mask for x in inputs]) 55 | return OTAInput(stacked_tokens, stacked_segments, stacked_masks) 56 | 57 | 58 | class InputPreparator: 59 | def __init__(self, tokenizer: PreTrainedTokenizer, filter_callable: Callable[[str], bool] = default_filter, 60 | prefix: str = '', suffix: str = ' .', pmin: int = 0, pmax: int = 0, smin: int = 0, smax: int = 0, 61 | seed: int = 1234, eval_sentence: str = None, **_): 62 | self.tokenizer = tokenizer 63 | 64 | if isinstance(tokenizer, BertTokenizer): 65 | vocab = tokenizer.vocab.keys() 66 | elif isinstance(tokenizer, GPT2Tokenizer): 67 | vocab = tokenizer.encoder.keys() 68 | else: 69 | raise ValueError('Access to vocab is currently only implemented for BertTokenizer and GPT2Tokenizer') 70 | 71 | self.words = [x for x in vocab if not filter_callable or filter_callable(x)] 72 | self.prefix = tokenizer.tokenize(prefix) 73 | self.suffix = tokenizer.tokenize(suffix) 74 | self.pmin = pmin 75 | self.pmax = pmax 76 | self.smin = smin 77 | self.smax = smax 78 | self.eval_sentence = eval_sentence 79 | 80 | if seed: 81 | random.seed(seed) 82 | 83 | def generate_random_word(self) -> str: 84 | return random.choice(self.words) 85 | 86 | def prepare_batch(self, batch: List[str]) -> Tuple[OTAInput, OTAInput, int]: 87 | 88 | prefix, suffix = self._create_infixes() 89 | index_to_optimize = len(prefix) + 1 90 | 91 | inputs_gold = [] 92 | inputs_inference = [] 93 | 94 | for word in batch: 95 | if isinstance(self.tokenizer, GPT2Tokenizer): 96 | word_toks = self.tokenizer.tokenize(word, add_prefix_space=True) 97 | else: 98 | word_toks = self.tokenizer.tokenize(word) 99 | inputs_gold.append(self._prepare_input(prefix, word_toks, suffix)) 100 | inputs_inference.append(self._prepare_input(prefix, [self.tokenizer.mask_token], suffix)) 101 | 102 | return OTAInput.stack(inputs_gold), OTAInput.stack(inputs_inference), index_to_optimize 103 | 104 | def _create_infixes(self): 105 | num_prefix_words = random.randint(self.pmin, self.pmax) 106 | num_suffix_words = random.randint(self.smin, self.smax) 107 | 108 | prefix = list(self.prefix) + [self.generate_random_word() for _ in range(num_prefix_words)] 109 | suffix = [self.generate_random_word() for _ in range(num_suffix_words)] + list(self.suffix) 110 | 111 | logger.debug('Randomly sampled template: {} {}'.format(prefix, suffix)) 112 | return prefix, suffix 113 | 114 | def _prepare_input(self, prefix, word, suffix) -> OTAInput: 115 | token_ids = self.tokenizer.encode(prefix + word + suffix, add_special_tokens=True) 116 | tokens_tensor = torch.tensor(token_ids) 117 | return OTAInput(tokens=tokens_tensor) 118 | -------------------------------------------------------------------------------- /ota.py: -------------------------------------------------------------------------------- 1 | import random 2 | import io 3 | import argparse 4 | import time 5 | from collections import defaultdict 6 | from typing import List 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import scipy.spatial.distance 12 | from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer, GPT2Tokenizer 13 | 14 | import utils 15 | import log 16 | from modeling import InputPreparator, OverwriteableEmbedding 17 | 18 | logger = log.get_logger('root') 19 | 20 | MODELS = { 21 | 'bert': (BertModel, BertTokenizer), 22 | 'roberta': (RobertaModel, RobertaTokenizer) 23 | } 24 | 25 | 26 | def verify_args(args): 27 | if args.pmin > args.pmax: 28 | raise ValueError('pmin must be less than pmax, got pmin={}, pmax={}'.format(args.pmin, args.pmax)) 29 | if args.smin > args.smax: 30 | raise ValueError('smin must be less than smax, got smin={}, smax={}'.format(args.smin, args.smax)) 31 | if (not args.word and not args.words) or (args.word and args.words): 32 | raise ValueError('Either a single word or a file containing words must be given via --word or --words') 33 | if args.prefix != '' and not args.prefix.endswith(' '): 34 | raise ValueError('The prefix must either be empty or end with a space, got "{}"'.format(args.prefix)) 35 | if args.suffix != '' and not args.suffix.startswith(' '): 36 | raise ValueError('The suffix must either be empty or start with a space, got "{}"'.format(args.suffix)) 37 | 38 | 39 | def load_words(word, word_file): 40 | if word: 41 | return [word] 42 | else: 43 | with io.open(word_file, 'r', encoding='utf8') as f: 44 | return f.read().splitlines() 45 | 46 | 47 | def initialize_embeddings(tokens: List[List[str]], embeddings: List[List[np.ndarray]], strategy: str) -> torch.Tensor: 48 | # embeddings and tokens are lists of shape batch_size x nr_of_tokens 49 | batch_size = len(embeddings) 50 | embedding_dim = embeddings[0][0].shape[0] 51 | 52 | logger.info('Initializing embeddings of shape {} x {} (strategy={})'.format(batch_size, embedding_dim, strategy)) 53 | 54 | embeddings_sum = torch.zeros(batch_size, embedding_dim) 55 | if len(embeddings[0]) == 1: 56 | return embeddings_sum 57 | 58 | if strategy == 'sum': 59 | for idx, token_embeddings in enumerate(embeddings): 60 | for embedding in token_embeddings: 61 | embeddings_sum[idx] += torch.tensor(embedding) 62 | embeddings_sum[idx] /= len(token_embeddings) 63 | 64 | elif strategy == 'wsum': 65 | for idx, token_embeddings in enumerate(embeddings): 66 | word_tokens = tokens[idx] 67 | for token_idx, token in enumerate(word_tokens): 68 | embeddings_sum[idx] += torch.tensor(token_embeddings[token_idx]) * utils.token_length(token) 69 | embeddings_sum[idx] /= sum(utils.token_length(token) for token in word_tokens) 70 | 71 | return embeddings_sum 72 | 73 | 74 | def main(): 75 | parser = argparse.ArgumentParser() 76 | 77 | parser.add_argument('--output_file', default=None, type=str, required=True) 78 | 79 | parser.add_argument('--word', type=str, required=False) 80 | parser.add_argument('--words', type=str, required=False) 81 | 82 | parser.add_argument('--model', default='bert-base-uncased', type=str) 83 | parser.add_argument('--model_cls', default='bert', type=str, choices=['bert', 'roberta']) 84 | 85 | parser.add_argument('--seed', default=1234, type=int) 86 | 87 | parser.add_argument('--prefix', '-p', default='', type=str) 88 | parser.add_argument('--suffix', '-s', default=' .', type=str) 89 | 90 | parser.add_argument('--smin', default=1, type=int, 91 | help='Minimum number of random tokens to append to the training string as suffix') 92 | 93 | parser.add_argument('--smax', default=1, type=int, 94 | help='Maximum number of random tokens to append to the training string as suffix') 95 | 96 | parser.add_argument('--pmin', default=1, type=int, 97 | help='Minimum number of random tokens to prepend to the training string as prefix') 98 | 99 | parser.add_argument('--pmax', default=1, type=int, 100 | help='Maximum number of random tokens to prepend to the training string as prefix') 101 | 102 | parser.add_argument('--objective', '-o', default='both', 103 | choices=['first', 'left', 'right', 'both'], 104 | help='Training objective: Whether to minimize only the distance of ' 105 | 'embeddings for the first token (i.e. [CLS]), for all words to ' 106 | 'the left, to the right or both to the left and right.') 107 | 108 | parser.add_argument('--eval_file', default=None, type=str) 109 | 110 | parser.add_argument('--eval_steps', default=[1, 10] + [100 * i for i in range(1, 51)], type=int, nargs='+', 111 | help='The numbers of training steps after which the average cosine distance over the' 112 | 'entire list of words is computed and stored in the eval file') 113 | 114 | parser.add_argument('--batch_size', default=128, type=int) 115 | parser.add_argument('--iterations', default=1000, type=int) 116 | parser.add_argument('--learning_rate', '-lr', default=1e-3, type=float) 117 | 118 | parser.add_argument('--init', default='sum', choices=['sum', 'wsum', 'zeros'], 119 | help='Initialization strategy for the embedding to be inferred. ' 120 | 'With "zeros", the embedding is initialized as a zero vector. ' 121 | 'With "sum", it is initialized as the sum of all embeddings ' 122 | 'of the BPE of the original word.' 123 | 'With "wsum", it is initialized as the weighted sum of all' 124 | 'token embeddings, where the weight is based on each token\'s length') 125 | 126 | args = parser.parse_args() 127 | verify_args(args) 128 | 129 | words = load_words(args.word, args.words) 130 | uses_randomization = args.smax > 0 or args.pmax > 0 131 | 132 | logger.info('Inferring embeddings for {} words, first 10 are: {}'.format(len(words), words[:10])) 133 | 134 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 135 | n_gpu = torch.cuda.device_count() 136 | logger.info("device: {} n_gpu: {}".format(device, n_gpu)) 137 | 138 | random.seed(args.seed) 139 | np.random.seed(args.seed) 140 | torch.manual_seed(args.seed) 141 | if n_gpu > 0: 142 | torch.cuda.manual_seed_all(args.seed) 143 | 144 | model_cls, tokenizer_cls = MODELS[args.model_cls] 145 | 146 | tokenizer = tokenizer_cls.from_pretrained(args.model) 147 | input_preparator = InputPreparator(tokenizer, **vars(args)) 148 | 149 | model = model_cls.from_pretrained(args.model, output_hidden_states=True) 150 | model.to(device) 151 | model.eval() # we don't want any dropout so we use eval mode 152 | for param in model.parameters(): 153 | param.requires_grad = False 154 | 155 | def tokenize_with_optional_space(word): 156 | if isinstance(tokenizer, GPT2Tokenizer): 157 | return tokenizer.tokenize(word, add_prefix_space=True) 158 | return tokenizer.tokenize(word) 159 | 160 | # group words based on their number of tokens 161 | words_by_token_size = defaultdict(list) 162 | for word in words: 163 | num_tokens = len(tokenize_with_optional_space(word)) 164 | if num_tokens > 0: 165 | words_by_token_size[num_tokens].append(word) 166 | 167 | token_sizes = list(words_by_token_size.keys()) 168 | logger.info('Found words with the following token sizes: {}'.format(token_sizes)) 169 | 170 | dists = [] 171 | token_sizes_idx = -1 172 | words_for_token_size = [] 173 | 174 | batch_idx = 0 175 | 176 | inferred_embeddings = {} 177 | cosine_distances = defaultdict(list) 178 | 179 | word_embeddings = model.embeddings.word_embeddings 180 | model.embeddings.word_embeddings = OverwriteableEmbedding(word_embeddings, overwrite_fct=None) 181 | 182 | while True: 183 | batch = words_for_token_size[batch_idx * args.batch_size: (batch_idx + 1) * args.batch_size] 184 | 185 | if not batch: 186 | if token_sizes_idx < len(token_sizes) - 1: 187 | token_sizes_idx += 1 188 | token_size = token_sizes[token_sizes_idx] 189 | words_for_token_size = words_by_token_size[token_size] 190 | batch_idx = 0 191 | logger.info('Processing all words that consist of {} tokens (found {} words, first 10 are: {})'.format( 192 | token_size, len(words_for_token_size), words_for_token_size[:10])) 193 | continue 194 | else: 195 | break 196 | 197 | logger.info( 198 | 'Processing words {} - {} of {}'.format(batch_idx * args.batch_size + 1, 199 | batch_idx * args.batch_size + len(batch), 200 | len(words_for_token_size))) 201 | 202 | tokens = [tokenize_with_optional_space(wrd) for wrd in batch] 203 | embeddings = [[utils.get_word_embedding(wordpart, tokenizer, model, device) for wordpart in 204 | tokenize_with_optional_space(wrd)] for wrd in batch] 205 | 206 | print(len(embeddings[0])) 207 | print([len(x) for x in embeddings]) 208 | 209 | init_vals = initialize_embeddings(tokens, embeddings, args.init) 210 | optim_vars = torch.tensor(init_vals, requires_grad=True) 211 | optimizer = torch.optim.Adam([optim_vars], lr=args.learning_rate) 212 | 213 | input_gold, input_inference, index_to_optimize, layers_gold = None, None, None, None 214 | 215 | if not uses_randomization: 216 | input_gold, input_inference, index_to_optimize = input_preparator.prepare_batch(batch) 217 | 218 | print('IGS:', input_gold.tokens.shape, 'IIS', input_inference.tokens.shape) 219 | 220 | _, _, layers_gold = model(input_gold.tokens.to(device), input_gold.segments.to(device)) 221 | layers_gold = [layer.detach() for layer in layers_gold] 222 | 223 | def overwrite_fct(embds): 224 | for i in range(embds.shape[0]): 225 | embds[i, index_to_optimize, :] = optim_vars[i] 226 | return embds 227 | 228 | start = time.time() 229 | logger.info(' ' * 79 + ' '.join('{:6s}'.format(word[:6]) for word in batch[:10])) 230 | 231 | for iteration in range(1, args.iterations + 1): 232 | 233 | if uses_randomization: 234 | input_gold, input_inference, index_to_optimize = input_preparator.prepare_batch(batch) 235 | with torch.no_grad(): 236 | _, _, layers_gold = model(input_gold.tokens.to(device), input_gold.segments.to(device)) 237 | layers_gold = [layer.detach() for layer in layers_gold] 238 | 239 | model.embeddings.word_embeddings.overwrite_fct = overwrite_fct 240 | _, _, layers = model(input_inference.tokens.to(device), input_inference.segments.to(device)) 241 | model.embeddings.word_embeddings.overwrite_fct = None 242 | 243 | loss = nn.MSELoss() 244 | loss_val = torch.tensor(0, dtype=torch.float).to(device) 245 | 246 | for idx in range(model.config.num_hidden_layers): 247 | 248 | if args.objective == 'first': 249 | loss_val += loss(layers[idx][:, 0, :], layers_gold[idx][:, 0, :]) 250 | 251 | if args.objective == 'left' or args.objective == 'both': 252 | loss_val += loss(layers[idx][:, :index_to_optimize, :], 253 | layers_gold[idx][:, :index_to_optimize, :]) 254 | 255 | if args.objective == 'right' or args.objective == 'both': 256 | loss_val += loss(layers[idx][:, index_to_optimize + 1:, :], 257 | layers_gold[idx][:, index_to_optimize + len(embeddings[0]):, :]) 258 | 259 | loss_val.backward() 260 | optimizer.step() 261 | optimizer.zero_grad() 262 | 263 | now = time.time() 264 | elapsed_time = (now - start) 265 | 266 | do_eval = args.eval_file is not None and iteration in args.eval_steps 267 | 268 | if (iteration == 1 or iteration % 100 == 0) or do_eval: 269 | batch_dists = [] 270 | if len(embeddings[0]) == 1: 271 | for idx, token_embeddings in enumerate(embeddings): 272 | inferred_embedding = optim_vars[idx].cpu().detach().numpy() 273 | if np.linalg.norm(inferred_embedding) > 0: 274 | cosine_distance = scipy.spatial.distance.cosine(token_embeddings[0], 275 | inferred_embedding) 276 | batch_dists.append(cosine_distance) 277 | 278 | if do_eval: 279 | cosine_distances[iteration] += batch_dists 280 | 281 | cosine_string = ' '.join(['{:6.2f}'.format(dist) for dist in batch_dists][:10]) 282 | steps_per_second = iteration / elapsed_time 283 | remaining_words = len(words) - len(inferred_embeddings) - len(batch) 284 | remaining_time_for_other_batches = ( 285 | remaining_words / args.batch_size) * args.iterations / steps_per_second 286 | remaining_time_for_this_batch = (args.iterations - iteration) / steps_per_second 287 | remaining_time = int(remaining_time_for_this_batch + remaining_time_for_other_batches) 288 | 289 | logger.info('step: {:4d} loss: {:8.6f} cosine: {:6.2f} steps/s: {:6.2f} ETR: {:8s} sample: {}'.format( 290 | iteration, loss_val.item(), utils.avg(batch_dists), steps_per_second, 291 | utils.get_date_string(remaining_time), cosine_string)) 292 | 293 | for idx, token_embeddings in enumerate(embeddings): 294 | 295 | word = batch[idx] 296 | 297 | inferred_embedding = optim_vars[idx].cpu().detach().numpy() 298 | if len(embeddings[0]) == 1: 299 | final_cosine_distance = scipy.spatial.distance.cosine(token_embeddings[0], inferred_embedding) 300 | dists.append(final_cosine_distance) 301 | inferred_embeddings[word] = inferred_embedding 302 | 303 | batch_idx += 1 304 | 305 | logger.info('Overall average cosine distance: {}'.format(utils.avg(dists))) 306 | 307 | utils.write_embeddings(inferred_embeddings, args.output_file) 308 | if args.eval_file is not None: 309 | utils.write_eval(cosine_distances, args.eval_file) 310 | 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import Dict, List 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import log 8 | 9 | logger = log.get_logger('root') 10 | 11 | 12 | def avg(l: List[float]): 13 | if not l: 14 | return -1 15 | return sum(l) / len(l) 16 | 17 | 18 | def get_word_embedding(word, tokenizer, model, device): 19 | word_id = tokenizer.convert_tokens_to_ids([word]) 20 | embd = model.embeddings.word_embeddings(torch.tensor(word_id).to(device)).cpu().detach().numpy()[0] 21 | return embd 22 | 23 | 24 | def write_embeddings(inferred_embeddings: Dict[str, np.ndarray], output_file: str) -> None: 25 | with io.open(output_file, 'w', encoding='utf8') as f: 26 | for word in inferred_embeddings.keys(): 27 | f.write(word + ' ' + ' '.join(str(x) for x in inferred_embeddings[word]) + '\n') 28 | 29 | 30 | def write_eval(cosine_distances: Dict[int, List[float]], eval_file: str) -> None: 31 | with open(eval_file, 'w', encoding='utf8') as f: 32 | f.write('iterations,avg_cosine_distance\n') 33 | for count in cosine_distances.keys(): 34 | f.write('{},{}\n'.format(count, avg((cosine_distances[count])))) 35 | 36 | 37 | def token_length(token: str) -> int: 38 | """ 39 | Returns the actual length of a BPE token without preceding ## characters (BERT) or Ġ characters (RoBERTa) 40 | :param token: the BPE token 41 | :return: the number of characters in this token 42 | """ 43 | return len(token) - (2 if token.startswith('##') else 0) - (1 if token.startswith('Ġ') else 0) 44 | 45 | 46 | def get_date_string(seconds: int) -> str: 47 | m, s = divmod(seconds, 60) 48 | h, m = divmod(m, 60) 49 | return '{:02d}:{:02d}:{:02d}'.format(h, m, s) 50 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import Counter 3 | 4 | from transformers import BertTokenizer, GPT2Tokenizer 5 | from typing import Dict, List, Set 6 | 7 | import log 8 | from ota import MODELS 9 | 10 | logger = log.get_logger('root') 11 | 12 | 13 | def load_vocab(path: str) -> List[str]: 14 | with open(path, 'r', encoding='utf8') as f: 15 | vocab = f.read().splitlines() 16 | return vocab 17 | 18 | 19 | def load_vocab_with_counts(path: str, min_freq=-1, max_freq=-1) -> Dict[str, int]: 20 | logger.info('Loading vocab from {} with minimum count {}'.format(path, min_freq)) 21 | vocab = Counter() 22 | 23 | with open(path, 'r', encoding='utf-8') as f: 24 | for line in f: 25 | word, count = line.split() 26 | if (min_freq <= 0 or int(count) >= min_freq) and (max_freq <= 0 or int(count) <= max_freq): 27 | vocab[word] = int(count) 28 | 29 | logger.info('The loaded vocab contains {} words'.format(len(vocab))) 30 | return vocab 31 | 32 | 33 | def get_difference(vocab_path: str, model: str, model_cls: str) -> List[str]: 34 | """ 35 | Returns the difference between the words in a vocabulary file and the words in the vocabulary of a given 36 | Transformer model. 37 | :param vocab_path: the path to the vocabulary file 38 | :param model: the path to the model 39 | :param model_cls: the model class (currently supported: "bert" or "roberta") 40 | :return: the difference as a list of words 41 | """ 42 | vocab_with_counts = load_vocab_with_counts(vocab_path) 43 | vocab = set(vocab_with_counts.keys()) 44 | 45 | model_cls, tokenizer_cls = MODELS[model_cls] 46 | tokenizer = tokenizer_cls.from_pretrained(model) 47 | 48 | if isinstance(tokenizer, BertTokenizer): 49 | model_vocab = tokenizer.vocab.keys() 50 | elif isinstance(tokenizer, GPT2Tokenizer): 51 | model_vocab = set(tokenizer.encoder.keys()) 52 | model_vocab.update([w[1:] for w in model_vocab if w.startswith('Ġ')]) 53 | else: 54 | raise ValueError('Access to vocab is currently only implemented for BertTokenizer and GPT2Tokenizer') 55 | 56 | logger.info('Vocab sizes: file = {}, model = {}'.format(len(vocab), len(model_vocab))) 57 | vocab -= model_vocab 58 | logger.info('Size of vocab difference = {}'.format(len(vocab))) 59 | 60 | vocab = list(vocab) 61 | vocab.sort(key=lambda x: vocab_with_counts[x], reverse=True) 62 | return vocab 63 | 64 | 65 | def split_vocab(path: str, parts: int): 66 | vocab = load_vocab(path) 67 | part_size = int(len(vocab) / parts) + 1 68 | 69 | logger.info("Splitting vocab with {} words into {} parts with size {}".format(len(vocab), parts, part_size)) 70 | vocab_splitted = [vocab[part_size * i: part_size * (i + 1)] for i in range(parts)] 71 | assert sum(len(x) for x in vocab_splitted) == len(vocab) 72 | 73 | for idx, vocab_part in enumerate(vocab_splitted): 74 | logger.info("Vocab part {} has size {}".format(idx, len(vocab_part))) 75 | write_vocab(vocab_part, path + str(idx)) 76 | 77 | 78 | def write_vocab(lines: List[str], file: str) -> None: 79 | with open(file, 'w', encoding='utf8') as f: 80 | for line in lines: 81 | f.write(line + '\n') 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--input', type=str, required=False) 87 | parser.add_argument('--output', default=None, type=str, required=True) 88 | parser.add_argument('--model', default='bert-base-uncased', type=str) 89 | parser.add_argument('--model_cls', default='bert', type=str, choices=['bert', 'roberta']) 90 | parser.add_argument('--parts', default=0, type=int) 91 | args = parser.parse_args() 92 | 93 | vocab = get_difference(args.input, args.model, args.model_cls) 94 | write_vocab(vocab, args.output) 95 | 96 | if args.parts > 0: 97 | split_vocab(args.output, args.parts) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | --------------------------------------------------------------------------------