├── .DS_Store
├── LICENSE
├── README.md
├── assets
├── .DS_Store
└── clip.png
├── clip.py
├── data
└── bpe_simple_vocab_16e6.txt
└── tokenizer.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CLIP-Tensorflow/941d18e7f4ae8a10293733c29fb2dadbf0714848/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## CLIP — Simple TensorFlow Implementation [[Link]](https://openai.com/blog/clip/)
2 |
3 |
4 |

5 |
6 |
7 | ## Usage
8 |
9 | ```python
10 | import tensorflow as tf
11 | from clip import CLIP
12 |
13 | clip_model = CLIP(
14 | dim_text = 512,
15 | dim_image = 512,
16 | dim_latent = 512,
17 | num_text_tokens = 10000,
18 | text_enc_depth = 6,
19 | text_seq_len = 256,
20 | text_heads = 8,
21 | visual_enc_depth = 6,
22 | visual_image_size = 256,
23 | visual_patch_size = 32,
24 | visual_heads = 8,
25 | )
26 |
27 | # mock data
28 | text = tf.random.uniform([4, 256], minval=0, maxval=10000, dtype=tf.int32)
29 | images = tf.random.normal([4, 256, 256, 3])
30 |
31 | # train
32 | loss = clip_model(
33 | text,
34 | images,
35 | freeze_image_encoder = False, # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
36 | return_loss = True # needs to be set to True to return contrastive loss
37 | )
38 | ```
39 |
40 | ## Reference
41 | * [x-clip](https://github.com/lucidrains/x-clip)
42 |
43 | ## Author
44 | * [Junho Kim](http://bit.ly/jhkim_resume)
45 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CLIP-Tensorflow/941d18e7f4ae8a10293733c29fb2dadbf0714848/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/clip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CLIP-Tensorflow/941d18e7f4ae8a10293733c29fb2dadbf0714848/assets/clip.png
--------------------------------------------------------------------------------
/clip.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import Model
3 | from tensorflow.keras.layers import Layer
4 | from tensorflow.keras import Sequential
5 | import tensorflow.keras.layers as nn
6 | from tensorflow import einsum
7 |
8 |
9 | from contextlib import contextmanager
10 | from functools import partial
11 | from einops import rearrange, repeat
12 | from einops.layers.tensorflow import Rearrange
13 |
14 | import numpy as np
15 |
16 | # helper functions
17 | def exists(val):
18 | return val is not None
19 |
20 | def default(val, d):
21 | return val if exists(val) else d
22 |
23 | @contextmanager
24 | def null_context():
25 | yield
26 |
27 | def max_neg_value(dtype):
28 | return -np.finfo(dtype.as_numpy_dtype).max
29 |
30 | def cast_tuple(t):
31 | return t if isinstance(t, (tuple, list)) else (t,)
32 |
33 | def masked_mean(t, mask, dim = 1, eps = 1e-6):
34 | t = masked_fill(t, ~mask, 0.0)
35 | numer = tf.reduce_sum(t, axis=dim)
36 |
37 | denorm = tf.reduce_sum(mask, axis=dim)
38 | denorm = tf.clip_by_value(denorm, clip_value_min=eps, clip_value_max=tf.reduce_max(denorm))
39 |
40 | return numer / denorm
41 |
42 | def log(t, eps = 1e-20):
43 | return tf.math.log(t + eps)
44 |
45 | def l2norm(t):
46 | return tf.math.l2_normalize(t, axis=-1)
47 |
48 | def masked_select(x, mask):
49 | x = tf.cast(x, tf.float32)
50 | mask = tf.cast(mask, tf.int32)
51 |
52 | x = tf.reshape(x, [-1])
53 | mask = tf.reshape(mask, [-1])
54 | mask_true_idx = tf.where(mask)
55 |
56 | return tf.gather_nd(x, mask_true_idx)
57 |
58 | def masked_fill(x, mask, true_val):
59 | x = tf.where(mask, true_val, x)
60 | return x
61 |
62 | def matrix_diag(t):
63 | # t.shape = [1,4,4]
64 | i, j = t.shape[-2:]
65 | num_diag_el = min(i, j)
66 | i_range = tf.range(i)
67 | j_range = tf.range(j)
68 | diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j') # [4,4]
69 |
70 | diag_el = masked_select(t, diag_mask) # [4]
71 |
72 | return rearrange(diag_el, '(b d) -> b d', d = num_diag_el)
73 |
74 | # keyword argument helpers
75 | def pick_and_pop(keys, d):
76 | values = list(map(lambda key: d.pop(key), keys))
77 | return dict(zip(keys, values))
78 |
79 | def group_dict_by_key(cond, d):
80 | return_val = [dict(),dict()]
81 | for key in d.keys():
82 | match = bool(cond(key))
83 | ind = int(not match)
84 | return_val[ind][key] = d[key]
85 | return (*return_val,)
86 |
87 | def string_begins_with(prefix, str):
88 | return str.startswith(prefix)
89 |
90 | def group_by_key_prefix(prefix, d):
91 | return group_dict_by_key(partial(string_begins_with, prefix), d)
92 |
93 | def groupby_prefix_and_trim(prefix, d):
94 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
95 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
96 | return kwargs_without_prefix, kwargs
97 |
98 | # helper classes
99 | class LayerNorm(Layer):
100 | # bias-less layernorm
101 | def __init__(self, dim, eps=1e-5):
102 | super(LayerNorm, self).__init__()
103 | self.eps = eps
104 |
105 | self.g = tf.Variable(tf.ones([dim]))
106 |
107 | def call(self, x, training=True):
108 | var = tf.math.reduce_variance(x, axis=-1, keepdims=True)
109 | mean = tf.reduce_mean(x, axis=-1, keepdims=True)
110 |
111 | x = (x - mean) / tf.sqrt((var + self.eps)) * self.g
112 | return x
113 |
114 | class PreNorm(Layer):
115 | def __init__(self, dim, fn):
116 | super(PreNorm, self).__init__()
117 |
118 | self.norm = LayerNorm(dim)
119 | self.fn = fn
120 |
121 | def call(self, x, **kwargs):
122 | return self.fn(self.norm(x), **kwargs)
123 |
124 | # rotary positional embedding
125 | class RotaryEmbedding(Layer):
126 | def __init__(self, dim):
127 | super(RotaryEmbedding, self).__init__()
128 | self.inv_freq = 1.0 / (10000 ** (tf.range(0, dim, 2, dtype=tf.float32) / dim))
129 |
130 | def call(self, seq_len, training=True):
131 | inv_freq = self.inv_freq
132 | t = tf.cast(tf.range(seq_len), dtype=inv_freq.dtype)
133 | freqs = einsum('i , j -> i j', t, inv_freq)
134 |
135 | x = tf.concat([freqs, freqs], axis=-1)
136 | return x
137 |
138 | def rotate_half(x):
139 | x = rearrange(x, '... (j d) -> ... j d', j = 2)
140 | x1, x2 = tf.unstack(x, axis=-2)
141 |
142 | x = tf.concat([-x2, x1], axis=-1)
143 | return x
144 |
145 | def apply_rotary_pos_emb(freqs, t):
146 | rot_dim = freqs.shape[-1]
147 | t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
148 | t = (t * tf.math.cos(freqs)) + (rotate_half(t) * tf.math.sin(freqs))
149 |
150 | x = tf.concat([t, t_pass], axis=-1)
151 | return x
152 |
153 | # transformer
154 | class SwiGLU(Layer):
155 | def __init__(self):
156 | super(SwiGLU, self).__init__()
157 |
158 | def silu(self, x):
159 | return x * tf.sigmoid(x)
160 |
161 | def call(self, x, training=True):
162 | x, gates = tf.split(x, num_or_size_splits=2, axis=-1)
163 | return x * self.silu(gates)
164 |
165 | class MLP(Layer):
166 | def __init__(self, dim, mult=4, dropout=0.0):
167 | super(MLP, self).__init__()
168 | inner_dim = int(dim * mult)
169 |
170 |
171 | self.net = Sequential([
172 | nn.Dense(units=inner_dim * 2, use_bias=False),
173 | SwiGLU(),
174 | nn.Dropout(rate=dropout),
175 | nn.Dense(units=dim, use_bias=False)
176 | ])
177 |
178 | def call(self, x, training=True):
179 | return self.net(x, training=training)
180 |
181 |
182 | class Attention(Layer):
183 | def __init__(self, dim, dim_head=64, heads=8, dropout=0.0):
184 | super(Attention, self).__init__()
185 | self.heads = heads
186 | self.scale = dim_head ** -0.5
187 | inner_dim = dim_head * heads
188 |
189 |
190 | self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False)
191 | self.to_out = nn.Dense(units=dim, use_bias=False)
192 | self.dropout = nn.Dropout(rate=dropout)
193 |
194 | def call(self, x, mask=None, rotary_pos_emb=None, training=True):
195 | h = self.heads
196 | qkv = self.to_qkv(x)
197 | qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
198 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
199 |
200 | if exists(rotary_pos_emb):
201 | apply_rotary = partial(apply_rotary_pos_emb, rotary_pos_emb)
202 | q, k, v = map(apply_rotary, (q, k, v))
203 |
204 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
205 |
206 | if exists(mask):
207 | mask = rearrange(mask, 'b j -> b 1 1 j')
208 | mask_value = max_neg_value(sim.dtype)
209 | sim = masked_fill(sim, ~mask, mask_value)
210 |
211 | attn = tf.nn.softmax(sim, axis=-1)
212 | attn = self.dropout(attn, training=training)
213 |
214 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
215 | out = rearrange(out, 'b h n d -> b n (h d)')
216 | out = self.to_out(out)
217 |
218 | return out
219 |
220 | class Transformer(Layer):
221 | def __init__(self, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_mult=4):
222 | super(Transformer, self).__init__()
223 |
224 | self.layers = []
225 |
226 | for _ in range(depth):
227 | self.layers.append([
228 | PreNorm(dim, Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout)),
229 | PreNorm(dim, MLP(dim=dim, mult=ff_mult))
230 | ])
231 |
232 | self.norm_out = LayerNorm(dim)
233 |
234 | def call(self, x, rotary_pos_emb=None, mask=None, training=True):
235 | for attn, ff in self.layers:
236 | x = attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb, training=training) + x
237 | x = ff(x, training=training) + x
238 |
239 | x = self.norm_out(x)
240 | return x
241 |
242 | # text and vision transformers
243 | class TextTransformer(Layer):
244 | def __init__(self, dim, num_tokens, max_seq_len, dim_head, rotary_pos_emb=None, **kwargs):
245 | super(TextTransformer, self).__init__()
246 |
247 | self.token_emb = nn.Embedding(num_tokens, dim)
248 |
249 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if not rotary_pos_emb else None
250 | self.rotary_pos_emb = RotaryEmbedding(min(dim_head, 32)) if rotary_pos_emb else None
251 |
252 | self.cls_token = tf.Variable(tf.random.normal(shape=[dim]))
253 |
254 | self.transformer = Transformer(dim, dim_head=dim_head, **kwargs)
255 |
256 | def call(self, x, mask=None, training=True):
257 | b, n = x.shape
258 |
259 | x = self.token_emb(x)
260 |
261 | if exists(self.abs_pos_emb):
262 | pos_emb = self.abs_pos_emb(tf.range(n))
263 | x = x + rearrange(pos_emb, 'n d -> 1 n d')
264 |
265 | rotary_pos_emb = None
266 | if exists(self.rotary_pos_emb):
267 | rotary_pos_emb = self.rotary_pos_emb(n + 1)
268 |
269 | cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=b)
270 | x = tf.concat([cls_tokens, x], axis=1)
271 |
272 | if exists(mask):
273 | mask = tf.pad(mask, paddings=[[0,0], [1,0]], constant_values=True)
274 |
275 | out = self.transformer(x, mask=mask, rotary_pos_emb=rotary_pos_emb, training=training)
276 | return out
277 |
278 | class VisionTransformer(Layer):
279 | def __init__(self, dim, image_size, patch_size, **kwargs):
280 | super(VisionTransformer, self).__init__()
281 |
282 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
283 | num_patches = (image_size // patch_size) ** 2
284 |
285 | self.cls_token = tf.Variable(tf.random.normal(shape=[dim]))
286 |
287 | self.to_tokens = Sequential([
288 | Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
289 | nn.Dense(units=dim)
290 | ])
291 |
292 | self.pos_emb = nn.Embedding(num_patches, dim)
293 | self.transformer = Transformer(dim, **kwargs)
294 |
295 | def call(self, x, training=True):
296 |
297 | x = self.to_tokens(x)
298 | b, n, _ = x.shape
299 |
300 | pos_emb = self.pos_emb(tf.range(n))
301 | x = x + rearrange(pos_emb, 'n d -> 1 n d')
302 |
303 | cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=b)
304 | x = tf.concat([cls_tokens, x], axis=1)
305 |
306 | out = self.transformer(x, training=training)
307 | return out
308 |
309 | # contrastive learning functions
310 | def model_forward_with_context(fn, args, freeze):
311 | enc = fn(*args)
312 | if freeze:
313 | enc = tf.stop_gradient(enc)
314 |
315 | return enc
316 |
317 | # contrastive loss function, adapted from
318 | # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
319 | def contrastive_loss(logits) :
320 | return tf.math.reduce_mean(
321 | tf.keras.metrics.sparse_categorical_crossentropy(
322 | y_true=tf.range(logits.shape[0]), y_pred=logits, from_logits=True
323 | )
324 | )
325 |
326 | def clip_loss(text_embeds, image_embeds, logit_scale) :
327 | # normalized features
328 | image_embeds = image_embeds / tf.norm(tensor=image_embeds, ord="euclidean", axis=-1, keepdims=True)
329 | text_embeds = text_embeds / tf.norm(tensor=text_embeds, ord="euclidean", axis=-1, keepdims=True)
330 |
331 | # cosine similarity as logits
332 | logit_scale = tf.math.exp(logit_scale)
333 | logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale
334 | logits_per_image = tf.transpose(logits_per_text)
335 | similarity = logits_per_text
336 |
337 | caption_loss = contrastive_loss(similarity)
338 | image_loss = contrastive_loss(tf.transpose(similarity))
339 | return (caption_loss + image_loss) / 2.0
340 |
341 | # https://github.com/lucidrains/x-clip
342 | def lucidrains_loss(text_latents, image_latents, temperature):
343 | # equal to clip_loss
344 | num_batch_texts = num_batch_images = 1
345 | text_latents, image_latents = map(l2norm, (text_latents, image_latents))
346 |
347 | # get temperature
348 | temp = tf.exp(temperature)
349 |
350 | # split out multiview dimension for text and images
351 | text_latents = rearrange(text_latents, '(m b) ... -> m b ...', m=num_batch_texts)
352 | image_latents = rearrange(image_latents, '(m b) ... -> m b ...', m=num_batch_images)
353 |
354 | # calculate loss
355 | text_to_image = einsum('m t d, n i d -> m n t i', text_latents, image_latents) * temp
356 | image_to_text = rearrange(text_to_image, '... t i -> ... i t')
357 |
358 | text_to_image = rearrange(text_to_image, 'm n ... -> (m n) ...')
359 | image_to_text = rearrange(image_to_text, 'm n ... -> (m n) ...')
360 |
361 | # exponentiate
362 | text_to_image_exp, image_to_text_exp = map(tf.exp, (text_to_image, image_to_text))
363 |
364 | # numerators
365 | text_to_image_pos, image_to_text_pos = map(matrix_diag, (text_to_image_exp, image_to_text_exp))
366 |
367 | # denominator
368 | text_to_image_denom, image_to_text_denom = map(lambda t: tf.reduce_sum(t, axis=-1),
369 | (text_to_image_exp, image_to_text_exp))
370 |
371 | # loss
372 | text_to_image_loss = tf.reduce_mean(-log(text_to_image_pos / text_to_image_denom), axis=-1)
373 | image_to_text_loss = tf.reduce_mean(-log(image_to_text_pos / image_to_text_denom), axis=-1)
374 |
375 | # calculate CL loss
376 | cl_loss = (text_to_image_loss + image_to_text_loss) / 2
377 |
378 | return cl_loss
379 |
380 | # main clip class
381 | class CLIP(Model):
382 | def __init__(self,
383 | image_encoder=None,
384 | text_encoder=None,
385 | dim_text=512,
386 | dim_image=512,
387 | dim_latent=512,
388 | num_text_tokens=10000,
389 | text_enc_depth=6,
390 | text_seq_len=256,
391 | text_heads=8,
392 | text_dim_head=64,
393 | text_has_cls_token=True,
394 | text_pad_id=0,
395 | text_rotary_pos_emb=False,
396 | visual_enc_depth=6,
397 | visual_heads=8,
398 | visual_dim_head=64,
399 | visual_image_size=256,
400 | visual_patch_size=32,
401 | visual_has_cls_token=True
402 | ):
403 | super(CLIP, self).__init__()
404 | assert (visual_has_cls_token or text_has_cls_token), 'CLS token must be included on both vision and text transformers if you are not using fine-grained contrastive learning loss'
405 | # store some parameters for access
406 | self.dim_text = dim_text
407 | self.dim_image = dim_image
408 | self.dim_latent = dim_latent
409 |
410 | # instantiate text transformer
411 | self.text_pad_id = text_pad_id
412 | self.text_has_cls_token = text_has_cls_token
413 |
414 | if exists(text_encoder):
415 | self.text_transformer = text_encoder
416 | else:
417 | self.text_transformer = TextTransformer(
418 | dim=dim_text,
419 | num_tokens=num_text_tokens,
420 | max_seq_len=text_seq_len,
421 | depth=text_enc_depth,
422 | heads=text_heads,
423 | dim_head=text_dim_head,
424 | rotary_pos_emb=text_rotary_pos_emb
425 | )
426 |
427 | # instantiate image transformer
428 | self.visual_has_cls_token = visual_has_cls_token
429 |
430 | if exists(image_encoder):
431 | self.visual_transformer = image_encoder
432 | else:
433 | self.visual_transformer = VisionTransformer(
434 | dim=dim_image,
435 | image_size=visual_image_size,
436 | patch_size=visual_patch_size,
437 | depth=visual_enc_depth,
438 | heads=visual_heads,
439 | dim_head=visual_dim_head
440 | )
441 |
442 | # text latent projection
443 | self.to_text_latent = nn.Dense(units=dim_latent, use_bias=False)
444 |
445 | # image latent projection
446 | self.to_visual_latent = nn.Dense(units=dim_latent, use_bias=False)
447 |
448 | # temperature
449 | self.temperature = tf.Variable(tf.constant(1.0, dtype=tf.float32))
450 |
451 | def call(self, text, image=None, training=True,
452 | return_loss=False,
453 | return_encodings=False,
454 | freeze_image_encoder=False, # image encoder is not trained if this is set to True, proposed by LiT paper
455 | freeze_text_encoder=False, # text encoder is not trained if this is set to True
456 | **kwargs
457 | ):
458 |
459 | # derive text mask
460 | text_mask = text != self.text_pad_id
461 |
462 |
463 | assert not (return_loss and not training), 'loss cannot be used if not training'
464 |
465 | # get encoded text
466 | enc_text = model_forward_with_context(
467 | fn=self.text_transformer,
468 | args=(text, text_mask, training),
469 | freeze=freeze_text_encoder
470 | )
471 |
472 | # whether to train image encoder, in the case that the image net was pretrained as recommended in LiT
473 | enc_image = model_forward_with_context(
474 | fn=self.visual_transformer,
475 | args=(image, training),
476 | freeze=freeze_image_encoder
477 | )
478 |
479 | # early return of encodings, if needed (for DALL-E2)
480 | if return_encodings:
481 | return enc_text, enc_image
482 |
483 | # depending on whether to do fine-grained CLIP or not, select either all tokens, or CLS tokens only
484 | text_embeds = enc_text[:, 0]
485 | image_embeds = enc_image[:, 0]
486 |
487 | # project to latents
488 | text_latents = self.to_text_latent(text_embeds)
489 | image_latents = self.to_visual_latent(image_embeds)
490 |
491 | # calculate loss
492 | # cl_loss = lucidrains_loss(text_latents, image_latents, self.temperature)
493 | cl_loss = clip_loss(text_latents, image_latents, self.temperature)
494 |
495 | # calculate weights
496 | cl_loss_weight = 1
497 |
498 | loss = cl_loss * cl_loss_weight
499 |
500 | return loss
501 |
502 |
503 |
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
2 | # to give users a quick easy start to training DALL-E without doing BPE
3 |
4 | import html
5 | import os
6 | from functools import lru_cache
7 | from pathlib import Path
8 | import ftfy
9 | import regex as re
10 | import tensorflow as tf
11 | import numpy as np
12 |
13 | # OpenAI simple tokenizer
14 |
15 | @lru_cache()
16 | def default_bpe():
17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
18 |
19 | @lru_cache()
20 | def bytes_to_unicode():
21 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
22 | cs = bs[:]
23 | n = 0
24 | for b in range(2 ** 8):
25 | if b not in bs:
26 | bs.append(b)
27 | cs.append(2 ** 8 + n)
28 | n += 1
29 | cs = [chr(n) for n in cs]
30 | return dict(zip(bs, cs))
31 |
32 | def get_pairs(word):
33 | pairs = set()
34 | prev_char = word[0]
35 | for char in word[1:]:
36 | pairs.add((prev_char, char))
37 | prev_char = char
38 | return pairs
39 |
40 | def basic_clean(text):
41 | text = ftfy.fix_text(text)
42 | text = html.unescape(html.unescape(text))
43 | return text.strip()
44 |
45 | def whitespace_clean(text):
46 | text = re.sub(r'\s+', ' ', text)
47 | text = text.strip()
48 | return text
49 |
50 | class SimpleTokenizer():
51 | def __init__(self, bpe_path = default_bpe()):
52 | self.byte_encoder = bytes_to_unicode()
53 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
54 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
55 | merges = merges[1:49152 - 256 - 2 + 1]
56 | merges = [tuple(merge.split()) for merge in merges]
57 | vocab = list(bytes_to_unicode().values())
58 | vocab = vocab + [v + '' for v in vocab]
59 | for merge in merges:
60 | vocab.append(''.join(merge))
61 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
62 |
63 | self.vocab_size = 49408
64 |
65 | self.encoder = dict(zip(vocab, range(len(vocab))))
66 | self.decoder = {v: k for k, v in self.encoder.items()}
67 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
68 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
69 | self.pat = re.compile(
70 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
71 | re.IGNORECASE)
72 |
73 | def bpe(self, token):
74 | if token in self.cache:
75 | return self.cache[token]
76 | word = tuple(token[:-1]) + (token[-1] + '',)
77 | pairs = get_pairs(word)
78 |
79 | if not pairs:
80 | return token + ''
81 |
82 | while True:
83 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
84 | if bigram not in self.bpe_ranks:
85 | break
86 | first, second = bigram
87 | new_word = []
88 | i = 0
89 | while i < len(word):
90 | try:
91 | j = word.index(first, i)
92 | new_word.extend(word[i:j])
93 | i = j
94 | except:
95 | new_word.extend(word[i:])
96 | break
97 |
98 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
99 | new_word.append(first + second)
100 | i += 2
101 | else:
102 | new_word.append(word[i])
103 | i += 1
104 | new_word = tuple(new_word)
105 | word = new_word
106 | if len(word) == 1:
107 | break
108 | else:
109 | pairs = get_pairs(word)
110 | word = ' '.join(word)
111 | self.cache[token] = word
112 | return word
113 |
114 | def encode(self, text):
115 | bpe_tokens = []
116 | text = whitespace_clean(basic_clean(text)).lower()
117 | for token in re.findall(self.pat, text):
118 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
119 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
120 | return bpe_tokens
121 |
122 | def decode(self, tokens, remove_start_end = True, pad_tokens = {}):
123 | if tf.is_tensor(tokens):
124 | tokens = tokens.numpy().tolist()
125 |
126 | if remove_start_end:
127 | tokens = [token for token in tokens if token not in (49406, 40407, 0)]
128 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])
129 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
130 | return text
131 |
132 | def tokenize(self, texts, context_length = 256, truncate_text = False):
133 | if tf.is_tensor(texts):
134 | texts = texts.numpy().tolist()
135 | t_list = []
136 | for t in texts:
137 | t_list.append(t.decode('utf-8'))
138 |
139 | texts = t_list
140 |
141 | if isinstance(texts, str):
142 | texts = [texts]
143 |
144 | all_tokens = [self.encode(text) for text in texts]
145 | result = np.zeros([len(all_tokens), context_length], dtype=np.long)
146 |
147 | for i, tokens in enumerate(all_tokens):
148 | if len(tokens) > context_length:
149 | if truncate_text:
150 | tokens = tokens[:context_length]
151 | else:
152 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
153 | result[i, :len(tokens)] = tf.convert_to_tensor(tokens, dtype=tf.int32)
154 | return result
155 |
--------------------------------------------------------------------------------