├── .DS_Store ├── LICENSE ├── README.md ├── assets ├── .DS_Store └── clip.png ├── clip.py ├── data └── bpe_simple_vocab_16e6.txt └── tokenizer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/CLIP-Tensorflow/941d18e7f4ae8a10293733c29fb2dadbf0714848/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CLIP — Simple TensorFlow Implementation [[Link]](https://openai.com/blog/clip/) 2 | 3 |
4 | 5 |
6 | 7 | ## Usage 8 | 9 | ```python 10 | import tensorflow as tf 11 | from clip import CLIP 12 | 13 | clip_model = CLIP( 14 | dim_text = 512, 15 | dim_image = 512, 16 | dim_latent = 512, 17 | num_text_tokens = 10000, 18 | text_enc_depth = 6, 19 | text_seq_len = 256, 20 | text_heads = 8, 21 | visual_enc_depth = 6, 22 | visual_image_size = 256, 23 | visual_patch_size = 32, 24 | visual_heads = 8, 25 | ) 26 | 27 | # mock data 28 | text = tf.random.uniform([4, 256], minval=0, maxval=10000, dtype=tf.int32) 29 | images = tf.random.normal([4, 256, 256, 3]) 30 | 31 | # train 32 | loss = clip_model( 33 | text, 34 | images, 35 | freeze_image_encoder = False, # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper 36 | return_loss = True # needs to be set to True to return contrastive loss 37 | ) 38 | ``` 39 | 40 | ## Reference 41 | * [x-clip](https://github.com/lucidrains/x-clip) 42 | 43 | ## Author 44 | * [Junho Kim](http://bit.ly/jhkim_resume) 45 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/CLIP-Tensorflow/941d18e7f4ae8a10293733c29fb2dadbf0714848/assets/.DS_Store -------------------------------------------------------------------------------- /assets/clip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/CLIP-Tensorflow/941d18e7f4ae8a10293733c29fb2dadbf0714848/assets/clip.png -------------------------------------------------------------------------------- /clip.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras import Sequential 5 | import tensorflow.keras.layers as nn 6 | from tensorflow import einsum 7 | 8 | 9 | from contextlib import contextmanager 10 | from functools import partial 11 | from einops import rearrange, repeat 12 | from einops.layers.tensorflow import Rearrange 13 | 14 | import numpy as np 15 | 16 | # helper functions 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(val, d): 21 | return val if exists(val) else d 22 | 23 | @contextmanager 24 | def null_context(): 25 | yield 26 | 27 | def max_neg_value(dtype): 28 | return -np.finfo(dtype.as_numpy_dtype).max 29 | 30 | def cast_tuple(t): 31 | return t if isinstance(t, (tuple, list)) else (t,) 32 | 33 | def masked_mean(t, mask, dim = 1, eps = 1e-6): 34 | t = masked_fill(t, ~mask, 0.0) 35 | numer = tf.reduce_sum(t, axis=dim) 36 | 37 | denorm = tf.reduce_sum(mask, axis=dim) 38 | denorm = tf.clip_by_value(denorm, clip_value_min=eps, clip_value_max=tf.reduce_max(denorm)) 39 | 40 | return numer / denorm 41 | 42 | def log(t, eps = 1e-20): 43 | return tf.math.log(t + eps) 44 | 45 | def l2norm(t): 46 | return tf.math.l2_normalize(t, axis=-1) 47 | 48 | def masked_select(x, mask): 49 | x = tf.cast(x, tf.float32) 50 | mask = tf.cast(mask, tf.int32) 51 | 52 | x = tf.reshape(x, [-1]) 53 | mask = tf.reshape(mask, [-1]) 54 | mask_true_idx = tf.where(mask) 55 | 56 | return tf.gather_nd(x, mask_true_idx) 57 | 58 | def masked_fill(x, mask, true_val): 59 | x = tf.where(mask, true_val, x) 60 | return x 61 | 62 | def matrix_diag(t): 63 | # t.shape = [1,4,4] 64 | i, j = t.shape[-2:] 65 | num_diag_el = min(i, j) 66 | i_range = tf.range(i) 67 | j_range = tf.range(j) 68 | diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j') # [4,4] 69 | 70 | diag_el = masked_select(t, diag_mask) # [4] 71 | 72 | return rearrange(diag_el, '(b d) -> b d', d = num_diag_el) 73 | 74 | # keyword argument helpers 75 | def pick_and_pop(keys, d): 76 | values = list(map(lambda key: d.pop(key), keys)) 77 | return dict(zip(keys, values)) 78 | 79 | def group_dict_by_key(cond, d): 80 | return_val = [dict(),dict()] 81 | for key in d.keys(): 82 | match = bool(cond(key)) 83 | ind = int(not match) 84 | return_val[ind][key] = d[key] 85 | return (*return_val,) 86 | 87 | def string_begins_with(prefix, str): 88 | return str.startswith(prefix) 89 | 90 | def group_by_key_prefix(prefix, d): 91 | return group_dict_by_key(partial(string_begins_with, prefix), d) 92 | 93 | def groupby_prefix_and_trim(prefix, d): 94 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 95 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 96 | return kwargs_without_prefix, kwargs 97 | 98 | # helper classes 99 | class LayerNorm(Layer): 100 | # bias-less layernorm 101 | def __init__(self, dim, eps=1e-5): 102 | super(LayerNorm, self).__init__() 103 | self.eps = eps 104 | 105 | self.g = tf.Variable(tf.ones([dim])) 106 | 107 | def call(self, x, training=True): 108 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True) 109 | mean = tf.reduce_mean(x, axis=-1, keepdims=True) 110 | 111 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g 112 | return x 113 | 114 | class PreNorm(Layer): 115 | def __init__(self, dim, fn): 116 | super(PreNorm, self).__init__() 117 | 118 | self.norm = LayerNorm(dim) 119 | self.fn = fn 120 | 121 | def call(self, x, **kwargs): 122 | return self.fn(self.norm(x), **kwargs) 123 | 124 | # rotary positional embedding 125 | class RotaryEmbedding(Layer): 126 | def __init__(self, dim): 127 | super(RotaryEmbedding, self).__init__() 128 | self.inv_freq = 1.0 / (10000 ** (tf.range(0, dim, 2, dtype=tf.float32) / dim)) 129 | 130 | def call(self, seq_len, training=True): 131 | inv_freq = self.inv_freq 132 | t = tf.cast(tf.range(seq_len), dtype=inv_freq.dtype) 133 | freqs = einsum('i , j -> i j', t, inv_freq) 134 | 135 | x = tf.concat([freqs, freqs], axis=-1) 136 | return x 137 | 138 | def rotate_half(x): 139 | x = rearrange(x, '... (j d) -> ... j d', j = 2) 140 | x1, x2 = tf.unstack(x, axis=-2) 141 | 142 | x = tf.concat([-x2, x1], axis=-1) 143 | return x 144 | 145 | def apply_rotary_pos_emb(freqs, t): 146 | rot_dim = freqs.shape[-1] 147 | t, t_pass = t[..., :rot_dim], t[..., rot_dim:] 148 | t = (t * tf.math.cos(freqs)) + (rotate_half(t) * tf.math.sin(freqs)) 149 | 150 | x = tf.concat([t, t_pass], axis=-1) 151 | return x 152 | 153 | # transformer 154 | class SwiGLU(Layer): 155 | def __init__(self): 156 | super(SwiGLU, self).__init__() 157 | 158 | def silu(self, x): 159 | return x * tf.sigmoid(x) 160 | 161 | def call(self, x, training=True): 162 | x, gates = tf.split(x, num_or_size_splits=2, axis=-1) 163 | return x * self.silu(gates) 164 | 165 | class MLP(Layer): 166 | def __init__(self, dim, mult=4, dropout=0.0): 167 | super(MLP, self).__init__() 168 | inner_dim = int(dim * mult) 169 | 170 | 171 | self.net = Sequential([ 172 | nn.Dense(units=inner_dim * 2, use_bias=False), 173 | SwiGLU(), 174 | nn.Dropout(rate=dropout), 175 | nn.Dense(units=dim, use_bias=False) 176 | ]) 177 | 178 | def call(self, x, training=True): 179 | return self.net(x, training=training) 180 | 181 | 182 | class Attention(Layer): 183 | def __init__(self, dim, dim_head=64, heads=8, dropout=0.0): 184 | super(Attention, self).__init__() 185 | self.heads = heads 186 | self.scale = dim_head ** -0.5 187 | inner_dim = dim_head * heads 188 | 189 | 190 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False) 191 | self.to_out = nn.Dense(units=dim, use_bias=False) 192 | self.dropout = nn.Dropout(rate=dropout) 193 | 194 | def call(self, x, mask=None, rotary_pos_emb=None, training=True): 195 | h = self.heads 196 | qkv = self.to_qkv(x) 197 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) 198 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 199 | 200 | if exists(rotary_pos_emb): 201 | apply_rotary = partial(apply_rotary_pos_emb, rotary_pos_emb) 202 | q, k, v = map(apply_rotary, (q, k, v)) 203 | 204 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 205 | 206 | if exists(mask): 207 | mask = rearrange(mask, 'b j -> b 1 1 j') 208 | mask_value = max_neg_value(sim.dtype) 209 | sim = masked_fill(sim, ~mask, mask_value) 210 | 211 | attn = tf.nn.softmax(sim, axis=-1) 212 | attn = self.dropout(attn, training=training) 213 | 214 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 215 | out = rearrange(out, 'b h n d -> b n (h d)') 216 | out = self.to_out(out) 217 | 218 | return out 219 | 220 | class Transformer(Layer): 221 | def __init__(self, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_mult=4): 222 | super(Transformer, self).__init__() 223 | 224 | self.layers = [] 225 | 226 | for _ in range(depth): 227 | self.layers.append([ 228 | PreNorm(dim, Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout)), 229 | PreNorm(dim, MLP(dim=dim, mult=ff_mult)) 230 | ]) 231 | 232 | self.norm_out = LayerNorm(dim) 233 | 234 | def call(self, x, rotary_pos_emb=None, mask=None, training=True): 235 | for attn, ff in self.layers: 236 | x = attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb, training=training) + x 237 | x = ff(x, training=training) + x 238 | 239 | x = self.norm_out(x) 240 | return x 241 | 242 | # text and vision transformers 243 | class TextTransformer(Layer): 244 | def __init__(self, dim, num_tokens, max_seq_len, dim_head, rotary_pos_emb=None, **kwargs): 245 | super(TextTransformer, self).__init__() 246 | 247 | self.token_emb = nn.Embedding(num_tokens, dim) 248 | 249 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if not rotary_pos_emb else None 250 | self.rotary_pos_emb = RotaryEmbedding(min(dim_head, 32)) if rotary_pos_emb else None 251 | 252 | self.cls_token = tf.Variable(tf.random.normal(shape=[dim])) 253 | 254 | self.transformer = Transformer(dim, dim_head=dim_head, **kwargs) 255 | 256 | def call(self, x, mask=None, training=True): 257 | b, n = x.shape 258 | 259 | x = self.token_emb(x) 260 | 261 | if exists(self.abs_pos_emb): 262 | pos_emb = self.abs_pos_emb(tf.range(n)) 263 | x = x + rearrange(pos_emb, 'n d -> 1 n d') 264 | 265 | rotary_pos_emb = None 266 | if exists(self.rotary_pos_emb): 267 | rotary_pos_emb = self.rotary_pos_emb(n + 1) 268 | 269 | cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=b) 270 | x = tf.concat([cls_tokens, x], axis=1) 271 | 272 | if exists(mask): 273 | mask = tf.pad(mask, paddings=[[0,0], [1,0]], constant_values=True) 274 | 275 | out = self.transformer(x, mask=mask, rotary_pos_emb=rotary_pos_emb, training=training) 276 | return out 277 | 278 | class VisionTransformer(Layer): 279 | def __init__(self, dim, image_size, patch_size, **kwargs): 280 | super(VisionTransformer, self).__init__() 281 | 282 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 283 | num_patches = (image_size // patch_size) ** 2 284 | 285 | self.cls_token = tf.Variable(tf.random.normal(shape=[dim])) 286 | 287 | self.to_tokens = Sequential([ 288 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 289 | nn.Dense(units=dim) 290 | ]) 291 | 292 | self.pos_emb = nn.Embedding(num_patches, dim) 293 | self.transformer = Transformer(dim, **kwargs) 294 | 295 | def call(self, x, training=True): 296 | 297 | x = self.to_tokens(x) 298 | b, n, _ = x.shape 299 | 300 | pos_emb = self.pos_emb(tf.range(n)) 301 | x = x + rearrange(pos_emb, 'n d -> 1 n d') 302 | 303 | cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=b) 304 | x = tf.concat([cls_tokens, x], axis=1) 305 | 306 | out = self.transformer(x, training=training) 307 | return out 308 | 309 | # contrastive learning functions 310 | def model_forward_with_context(fn, args, freeze): 311 | enc = fn(*args) 312 | if freeze: 313 | enc = tf.stop_gradient(enc) 314 | 315 | return enc 316 | 317 | # contrastive loss function, adapted from 318 | # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html 319 | def contrastive_loss(logits) : 320 | return tf.math.reduce_mean( 321 | tf.keras.metrics.sparse_categorical_crossentropy( 322 | y_true=tf.range(logits.shape[0]), y_pred=logits, from_logits=True 323 | ) 324 | ) 325 | 326 | def clip_loss(text_embeds, image_embeds, logit_scale) : 327 | # normalized features 328 | image_embeds = image_embeds / tf.norm(tensor=image_embeds, ord="euclidean", axis=-1, keepdims=True) 329 | text_embeds = text_embeds / tf.norm(tensor=text_embeds, ord="euclidean", axis=-1, keepdims=True) 330 | 331 | # cosine similarity as logits 332 | logit_scale = tf.math.exp(logit_scale) 333 | logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale 334 | logits_per_image = tf.transpose(logits_per_text) 335 | similarity = logits_per_text 336 | 337 | caption_loss = contrastive_loss(similarity) 338 | image_loss = contrastive_loss(tf.transpose(similarity)) 339 | return (caption_loss + image_loss) / 2.0 340 | 341 | # https://github.com/lucidrains/x-clip 342 | def lucidrains_loss(text_latents, image_latents, temperature): 343 | # equal to clip_loss 344 | num_batch_texts = num_batch_images = 1 345 | text_latents, image_latents = map(l2norm, (text_latents, image_latents)) 346 | 347 | # get temperature 348 | temp = tf.exp(temperature) 349 | 350 | # split out multiview dimension for text and images 351 | text_latents = rearrange(text_latents, '(m b) ... -> m b ...', m=num_batch_texts) 352 | image_latents = rearrange(image_latents, '(m b) ... -> m b ...', m=num_batch_images) 353 | 354 | # calculate loss 355 | text_to_image = einsum('m t d, n i d -> m n t i', text_latents, image_latents) * temp 356 | image_to_text = rearrange(text_to_image, '... t i -> ... i t') 357 | 358 | text_to_image = rearrange(text_to_image, 'm n ... -> (m n) ...') 359 | image_to_text = rearrange(image_to_text, 'm n ... -> (m n) ...') 360 | 361 | # exponentiate 362 | text_to_image_exp, image_to_text_exp = map(tf.exp, (text_to_image, image_to_text)) 363 | 364 | # numerators 365 | text_to_image_pos, image_to_text_pos = map(matrix_diag, (text_to_image_exp, image_to_text_exp)) 366 | 367 | # denominator 368 | text_to_image_denom, image_to_text_denom = map(lambda t: tf.reduce_sum(t, axis=-1), 369 | (text_to_image_exp, image_to_text_exp)) 370 | 371 | # loss 372 | text_to_image_loss = tf.reduce_mean(-log(text_to_image_pos / text_to_image_denom), axis=-1) 373 | image_to_text_loss = tf.reduce_mean(-log(image_to_text_pos / image_to_text_denom), axis=-1) 374 | 375 | # calculate CL loss 376 | cl_loss = (text_to_image_loss + image_to_text_loss) / 2 377 | 378 | return cl_loss 379 | 380 | # main clip class 381 | class CLIP(Model): 382 | def __init__(self, 383 | image_encoder=None, 384 | text_encoder=None, 385 | dim_text=512, 386 | dim_image=512, 387 | dim_latent=512, 388 | num_text_tokens=10000, 389 | text_enc_depth=6, 390 | text_seq_len=256, 391 | text_heads=8, 392 | text_dim_head=64, 393 | text_has_cls_token=True, 394 | text_pad_id=0, 395 | text_rotary_pos_emb=False, 396 | visual_enc_depth=6, 397 | visual_heads=8, 398 | visual_dim_head=64, 399 | visual_image_size=256, 400 | visual_patch_size=32, 401 | visual_has_cls_token=True 402 | ): 403 | super(CLIP, self).__init__() 404 | assert (visual_has_cls_token or text_has_cls_token), 'CLS token must be included on both vision and text transformers if you are not using fine-grained contrastive learning loss' 405 | # store some parameters for access 406 | self.dim_text = dim_text 407 | self.dim_image = dim_image 408 | self.dim_latent = dim_latent 409 | 410 | # instantiate text transformer 411 | self.text_pad_id = text_pad_id 412 | self.text_has_cls_token = text_has_cls_token 413 | 414 | if exists(text_encoder): 415 | self.text_transformer = text_encoder 416 | else: 417 | self.text_transformer = TextTransformer( 418 | dim=dim_text, 419 | num_tokens=num_text_tokens, 420 | max_seq_len=text_seq_len, 421 | depth=text_enc_depth, 422 | heads=text_heads, 423 | dim_head=text_dim_head, 424 | rotary_pos_emb=text_rotary_pos_emb 425 | ) 426 | 427 | # instantiate image transformer 428 | self.visual_has_cls_token = visual_has_cls_token 429 | 430 | if exists(image_encoder): 431 | self.visual_transformer = image_encoder 432 | else: 433 | self.visual_transformer = VisionTransformer( 434 | dim=dim_image, 435 | image_size=visual_image_size, 436 | patch_size=visual_patch_size, 437 | depth=visual_enc_depth, 438 | heads=visual_heads, 439 | dim_head=visual_dim_head 440 | ) 441 | 442 | # text latent projection 443 | self.to_text_latent = nn.Dense(units=dim_latent, use_bias=False) 444 | 445 | # image latent projection 446 | self.to_visual_latent = nn.Dense(units=dim_latent, use_bias=False) 447 | 448 | # temperature 449 | self.temperature = tf.Variable(tf.constant(1.0, dtype=tf.float32)) 450 | 451 | def call(self, text, image=None, training=True, 452 | return_loss=False, 453 | return_encodings=False, 454 | freeze_image_encoder=False, # image encoder is not trained if this is set to True, proposed by LiT paper 455 | freeze_text_encoder=False, # text encoder is not trained if this is set to True 456 | **kwargs 457 | ): 458 | 459 | # derive text mask 460 | text_mask = text != self.text_pad_id 461 | 462 | 463 | assert not (return_loss and not training), 'loss cannot be used if not training' 464 | 465 | # get encoded text 466 | enc_text = model_forward_with_context( 467 | fn=self.text_transformer, 468 | args=(text, text_mask, training), 469 | freeze=freeze_text_encoder 470 | ) 471 | 472 | # whether to train image encoder, in the case that the image net was pretrained as recommended in LiT 473 | enc_image = model_forward_with_context( 474 | fn=self.visual_transformer, 475 | args=(image, training), 476 | freeze=freeze_image_encoder 477 | ) 478 | 479 | # early return of encodings, if needed (for DALL-E2) 480 | if return_encodings: 481 | return enc_text, enc_image 482 | 483 | # depending on whether to do fine-grained CLIP or not, select either all tokens, or CLS tokens only 484 | text_embeds = enc_text[:, 0] 485 | image_embeds = enc_image[:, 0] 486 | 487 | # project to latents 488 | text_latents = self.to_text_latent(text_embeds) 489 | image_latents = self.to_visual_latent(image_embeds) 490 | 491 | # calculate loss 492 | # cl_loss = lucidrains_loss(text_latents, image_latents, self.temperature) 493 | cl_loss = clip_loss(text_latents, image_latents, self.temperature) 494 | 495 | # calculate weights 496 | cl_loss_weight = 1 497 | 498 | loss = cl_loss * cl_loss_weight 499 | 500 | return loss 501 | 502 | 503 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py 2 | # to give users a quick easy start to training DALL-E without doing BPE 3 | 4 | import html 5 | import os 6 | from functools import lru_cache 7 | from pathlib import Path 8 | import ftfy 9 | import regex as re 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | # OpenAI simple tokenizer 14 | 15 | @lru_cache() 16 | def default_bpe(): 17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt") 18 | 19 | @lru_cache() 20 | def bytes_to_unicode(): 21 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 22 | cs = bs[:] 23 | n = 0 24 | for b in range(2 ** 8): 25 | if b not in bs: 26 | bs.append(b) 27 | cs.append(2 ** 8 + n) 28 | n += 1 29 | cs = [chr(n) for n in cs] 30 | return dict(zip(bs, cs)) 31 | 32 | def get_pairs(word): 33 | pairs = set() 34 | prev_char = word[0] 35 | for char in word[1:]: 36 | pairs.add((prev_char, char)) 37 | prev_char = char 38 | return pairs 39 | 40 | def basic_clean(text): 41 | text = ftfy.fix_text(text) 42 | text = html.unescape(html.unescape(text)) 43 | return text.strip() 44 | 45 | def whitespace_clean(text): 46 | text = re.sub(r'\s+', ' ', text) 47 | text = text.strip() 48 | return text 49 | 50 | class SimpleTokenizer(): 51 | def __init__(self, bpe_path = default_bpe()): 52 | self.byte_encoder = bytes_to_unicode() 53 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 54 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n') 55 | merges = merges[1:49152 - 256 - 2 + 1] 56 | merges = [tuple(merge.split()) for merge in merges] 57 | vocab = list(bytes_to_unicode().values()) 58 | vocab = vocab + [v + '' for v in vocab] 59 | for merge in merges: 60 | vocab.append(''.join(merge)) 61 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 62 | 63 | self.vocab_size = 49408 64 | 65 | self.encoder = dict(zip(vocab, range(len(vocab)))) 66 | self.decoder = {v: k for k, v in self.encoder.items()} 67 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 68 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 69 | self.pat = re.compile( 70 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 71 | re.IGNORECASE) 72 | 73 | def bpe(self, token): 74 | if token in self.cache: 75 | return self.cache[token] 76 | word = tuple(token[:-1]) + (token[-1] + '',) 77 | pairs = get_pairs(word) 78 | 79 | if not pairs: 80 | return token + '' 81 | 82 | while True: 83 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 84 | if bigram not in self.bpe_ranks: 85 | break 86 | first, second = bigram 87 | new_word = [] 88 | i = 0 89 | while i < len(word): 90 | try: 91 | j = word.index(first, i) 92 | new_word.extend(word[i:j]) 93 | i = j 94 | except: 95 | new_word.extend(word[i:]) 96 | break 97 | 98 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 99 | new_word.append(first + second) 100 | i += 2 101 | else: 102 | new_word.append(word[i]) 103 | i += 1 104 | new_word = tuple(new_word) 105 | word = new_word 106 | if len(word) == 1: 107 | break 108 | else: 109 | pairs = get_pairs(word) 110 | word = ' '.join(word) 111 | self.cache[token] = word 112 | return word 113 | 114 | def encode(self, text): 115 | bpe_tokens = [] 116 | text = whitespace_clean(basic_clean(text)).lower() 117 | for token in re.findall(self.pat, text): 118 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 119 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 120 | return bpe_tokens 121 | 122 | def decode(self, tokens, remove_start_end = True, pad_tokens = {}): 123 | if tf.is_tensor(tokens): 124 | tokens = tokens.numpy().tolist() 125 | 126 | if remove_start_end: 127 | tokens = [token for token in tokens if token not in (49406, 40407, 0)] 128 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens]) 129 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 130 | return text 131 | 132 | def tokenize(self, texts, context_length = 256, truncate_text = False): 133 | if tf.is_tensor(texts): 134 | texts = texts.numpy().tolist() 135 | t_list = [] 136 | for t in texts: 137 | t_list.append(t.decode('utf-8')) 138 | 139 | texts = t_list 140 | 141 | if isinstance(texts, str): 142 | texts = [texts] 143 | 144 | all_tokens = [self.encode(text) for text in texts] 145 | result = np.zeros([len(all_tokens), context_length], dtype=np.long) 146 | 147 | for i, tokens in enumerate(all_tokens): 148 | if len(tokens) > context_length: 149 | if truncate_text: 150 | tokens = tokens[:context_length] 151 | else: 152 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 153 | result[i, :len(tokens)] = tf.convert_to_tensor(tokens, dtype=tf.int32) 154 | return result 155 | --------------------------------------------------------------------------------