correctly # noqa
43 | text = text.replace("\n", "↵") # display newline
44 | if replace_whitespace:
45 | text = text.replace(" ", "␣") # display whitespace
46 | if len(text) == 0: # special token
47 | return "␣(null)"
48 | return text
49 |
50 |
51 | def candidate_tokens_html(topk_candidate_tokens: List[Token]) -> str:
52 | """HTML content of a single token's topk candidate tokens."""
53 | template = '' \
54 | '{token}{prob}'
55 |
56 | html_text = "".join([
57 | template.format(token=text_to_html_token(token.text), prob=f"{token.prob:.3%}")
58 | for token in topk_candidate_tokens
59 | ])
60 | html_text = f'{html_text}
'
61 | return html_text
62 |
63 |
64 | def color_token_by_logprob(
65 | log_prob: float, min_log_prob: float, max_log_prob: float,
66 | hue_red: int = 0, hue_green: int = 150, epsilon: float = 1e-5,
67 | ) -> str:
68 | """According to the token's log prob, assign RGB color to the token.
69 | reference: https://twitter.com/thesephist/status/1617909119423500288
70 |
71 | Args:
72 | log_prob (float): log prob of the token.
73 | min_log_prob (float): min log prob of all tokens.
74 | max_log_prob (float): max log prob of all tokens.
75 | hue_red (int, optional): hue value of the red color. Defaults to 0.
76 | hue_green (int, optional): hue value of the green color. Defaults to 150.
77 | epsilon (float, optional): avoid divide by zero. Defaults to 1e-5.
78 | """
79 | # clamp the log_prob and scale to (hsl_red, hsl_green)
80 | if min_log_prob == max_log_prob:
81 | ratio = 1 # set to green color
82 | else:
83 | log_prob = max(min_log_prob, min(log_prob, max_log_prob))
84 | ratio = (log_prob - min_log_prob) / max((max_log_prob - min_log_prob), epsilon)
85 | hue = ratio * (hue_green - hue_red) + hue_red
86 | red, green, blue = colorsys.hls_to_rgb(hue / 360.0, 0.85, 0.6) # hls({hue}deg 85% 60%)
87 | rgb_string = f"rgb({int(red * 255)}, {int(green * 255)}, {int(blue * 255)})"
88 | return rgb_string
89 |
90 |
91 | def set_tokens_ppl(tokens: List[Token]):
92 | """Set ppl value for each token in the list of tokens."""
93 | logprob_sum = itertools.accumulate([x.logprob for x in tokens], operator.add)
94 | for num_tokens, (token, logprob) in enumerate(zip(tokens, logprob_sum), 1):
95 | token.ppl = math.exp(-logprob / num_tokens)
96 |
97 |
98 | def single_token_html(token: Token) -> str:
99 | """HTML text of single token."""
100 | template = '{label}{value}
' # noqa
101 |
102 | html_text = template.format(label="prob", value=f"{token.prob:.3%}")
103 | html_text += template.format(label="logprob", value=f"{token.logprob:.4f}")
104 | if token.ppl is not None:
105 | html_text += template.format(label="ppl", value=f"{token.ppl:.4f}")
106 | html_text += candidate_tokens_html(token.top_candidates)
107 | html_text = f"{html_text}
"
108 | return html_text
109 |
110 |
111 | def tokens_min_max_logprob(tokens: List[Token]) -> Tuple[float, float]:
112 | """Calculate the normalized min and max logprob of a list of tokens."""
113 | if len(tokens) == 1:
114 | min_logprob = max_logprob = tokens[0].logprob
115 | else:
116 | logprob_mean = statistics.mean([token.logprob for token in tokens])
117 | logprob_stddev = statistics.stdev([token.logprob for token in tokens])
118 | min_logprob = logprob_mean - (2.5 * logprob_stddev)
119 | max_logprob = logprob_mean + (2.5 * logprob_stddev)
120 | return min_logprob, max_logprob
121 |
122 |
123 | def tokens_info_to_html(tokens: List[Token], display_whitespace: bool = True) -> str:
124 | """
125 | Generate html for a list of token, include token color and hover text.
126 |
127 | Args:
128 | tokens (List[Token]): a list of tokens to generate html for.
129 | """
130 | min_logprob, max_logprob = tokens_min_max_logprob(tokens)
131 | set_tokens_ppl(tokens)
132 |
133 | tokens_html = ""
134 | for token in tokens:
135 | hover_html = single_token_html(token)
136 | rgb = color_token_by_logprob(token.logprob, min_logprob, max_logprob)
137 | is_newline = "\n" in token.text
138 | token_text = text_to_html_token(token.text, replace_whitespace=display_whitespace)
139 | if is_newline:
140 | token_text = f'{token_text}'
141 | token_html = f'{token_text}{hover_html}' # noqa
142 | if is_newline:
143 | token_html += "
"
144 | tokens_html += token_html
145 | tokens_html = f'{tokens_html}
' # noqa
146 | return tokens_html
147 |
--------------------------------------------------------------------------------
/token_visualizer/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | from typing import Union
5 |
6 | from dotenv import load_dotenv
7 |
8 | import token_visualizer
9 |
10 | __all__ = ["css_style", "ensure_os_env"]
11 |
12 |
13 | def css_style() -> str:
14 | with open(os.path.join(os.path.dirname(token_visualizer.__file__), "main.css")) as f:
15 | css_style = f.read()
16 | css_html = f""""
17 | return css_html
18 |
19 |
20 | def ensure_os_env(env_name: str, default_value: Union[str, None] = None):
21 | if env_name in os.environ:
22 | env_value = os.getenv(env_name)
23 | else:
24 | load_dotenv()
25 | env_value = os.getenv(env_name, default_value)
26 | return env_value
27 |
--------------------------------------------------------------------------------
/visual_tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | from functools import lru_cache
5 | from typing import Dict
6 |
7 | import gradio as gr
8 | from loguru import logger
9 | from transformers import AutoTokenizer
10 |
11 | from sentencepiece import SentencePieceProcessor
12 |
13 | CANDIDATES = [ # model name sorted by alphabet
14 | "baichuan-inc/Baichuan2-13B-Chat",
15 | "bigcode/starcoder2-15b",
16 | "deepseek-ai/deepseek-coder-33b-instruct",
17 | # "google/gemma-7b",
18 | "gpt2",
19 | # "meta-llama/Llama-2-7b-chat-hf",
20 | "mistralai/Mixtral-8x7B-Instruct-v0.1",
21 | "THUDM/chatglm3-6b",
22 | ]
23 | SENTENCE_PIECE_MAPPING = {}
24 | SP_PREFIX = "SentencePiece/"
25 |
26 |
27 | def add_sp_tokenizer(name: str, tokenizer_path: str):
28 | """Add a sentence piece tokenizer to the list of available tokenizers."""
29 | model_key = SP_PREFIX + name
30 | if not os.path.exists(tokenizer_path):
31 | raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}")
32 | SENTENCE_PIECE_MAPPING[model_key] = tokenizer_path
33 | CANDIDATES.append(model_key)
34 |
35 |
36 | # add_sp_tokenizer("LLaMa", "llama_tokenizer.model")
37 | logger.info(f"SentencePiece tokenizer: {list(SENTENCE_PIECE_MAPPING.keys())}")
38 |
39 |
40 | @lru_cache
41 | def get_tokenizer_and_vocab(name):
42 | if name.startswith(SP_PREFIX):
43 | local_file_path = SENTENCE_PIECE_MAPPING[name]
44 | tokenizer = SentencePieceProcessor(local_file_path)
45 | rev_vocab = {id_: tokenizer.id_to_piece(id_) for id_ in range(tokenizer.get_piece_size())} # noqa
46 | else:
47 | tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
48 | rev_vocab = {v: k for k, v in tokenizer.get_vocab().items()}
49 | return tokenizer, rev_vocab
50 |
51 |
52 | def tokenize(name: str, text: str) -> Dict:
53 | tokenizer, rev_vocab = get_tokenizer_and_vocab(name)
54 |
55 | ids = tokenizer.encode(text)
56 | s, entities = '', []
57 | for i in ids:
58 | entity = str(i)
59 | start = len(s)
60 | s += rev_vocab[i]
61 | end = len(s)
62 | entities.append({"entity": entity, "start": start, "end": end})
63 |
64 | return {
65 | "text": s + f"\n({len(ids)} tokens / {len(text)} characters)",
66 | "entities": entities
67 | }
68 |
69 |
70 | @logger.catch(reraise=True)
71 | def make_demo():
72 | logger.info("Creating Interface..")
73 |
74 | DEFAULT_TOKENIZER = CANDIDATES[0]
75 | DEFAULT_INPUTTEXT = "Hello world."
76 |
77 | demo = gr.Interface(
78 | fn=tokenize,
79 | inputs=[
80 | gr.Dropdown(
81 | CANDIDATES, value=DEFAULT_TOKENIZER,
82 | label="Tokenizer", allow_custom_value=False
83 | ),
84 | gr.TextArea(value=DEFAULT_INPUTTEXT, label="Input text"),
85 | ],
86 | outputs=[
87 | gr.HighlightedText(
88 | value=tokenize(DEFAULT_TOKENIZER, DEFAULT_INPUTTEXT),
89 | label="Tokenized results"
90 | )
91 | ],
92 | title="Tokenzier Visualizer",
93 | description="If you want to try more tokenizers, please contact the author@wangfeng", # noqa
94 | examples=[
95 | [DEFAULT_TOKENIZER, "乒乓球拍卖完了,无线电法国别研究,我一把把把把住了"],
96 | ["bigcode/starcoder2-15b", "def print():\n print('Hello')"],
97 | ],
98 | cache_examples=True,
99 | live=True,
100 | )
101 | return demo
102 |
103 |
104 | if __name__ == "__main__":
105 | demo = make_demo()
106 | demo.launch(server_name="0.0.0.0")
107 |
--------------------------------------------------------------------------------
/visualizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import functools
4 | from argparse import ArgumentParser
5 | from typing import Tuple, Optional
6 |
7 | import gradio as gr
8 | from loguru import logger
9 |
10 | import token_visualizer
11 | from token_visualizer import TopkTokenModel, css_style, ensure_os_env
12 |
13 |
14 | def make_parser() -> ArgumentParser:
15 | parser = ArgumentParser("Inference process visualizer")
16 | parser.add_argument(
17 | "-t", "--type",
18 | choices=["llm", "tgi", "oai", "oai-proxy"],
19 | default="oai-proxy",
20 | help="Type of model to use, default to openai-proxy"
21 | )
22 | parser.add_argument(
23 | "--hf-repo", type=str, default=None,
24 | help="Huggingface model repository, used when type is 'llm'. Default to None"
25 | )
26 | parser.add_argument(
27 | "--oai-model", type=str, default="gpt-4-turbo-2024-04-09",
28 | help="OpenAI model name, used when type is 'oai'/'oai-proxy'. "
29 | "Check https://platform.openai.com/docs/models for more details. "
30 | "Default to `gpt-4-turbo-2024-04-09`."
31 | )
32 | parser.add_argument(
33 | "--oai-key", type=str, default=None,
34 | help="OpenAI api key, used when type is 'oai'/'oai-proxy'. "
35 | "If provided, will override OPENAI_KEY env variable.",
36 | )
37 | parser.add_argument(
38 | "--tgi-url", type=str, default=None,
39 | help="Service url of TGI model, used when type is 'tgi'. "
40 | "If provided, will override TGI_URL env variable.",
41 | )
42 | parser.add_argument(
43 | "-s", "--share", action="store_true",
44 | help="Share service to the internet.",
45 | )
46 | parser.add_argument(
47 | "-p", "--port", type=int, default=12123,
48 | help="Port to run the service, default to 12123."
49 | )
50 | return parser
51 |
52 |
53 | def build_model_by_args(args) -> token_visualizer.TopkTokenModel:
54 | BASE_URL = ensure_os_env("BASE_URL")
55 | OPENAI_API_KEY = ensure_os_env("OPENAI_KEY")
56 | TGI_URL = ensure_os_env("TGI_URL")
57 |
58 | model: Optional[token_visualizer.TopkTokenModel] = None
59 |
60 | if args.type == "llm":
61 | model = token_visualizer.TransformerModel(repo=args.hf_repo)
62 | elif args.type == "tgi":
63 | if args.tgi_url:
64 | TGI_URL = args.tgi_url
65 | model = token_visualizer.TGIModel(url=TGI_URL, details=True)
66 | elif args.type == "oai":
67 | model = token_visualizer.OpenAIModel(
68 | base_url=BASE_URL,
69 | api_key=OPENAI_API_KEY,
70 | model_name=args.oai_model,
71 | )
72 | elif args.type == "oai-proxy":
73 | model = token_visualizer.OpenAIProxyModel(
74 | base_url=BASE_URL,
75 | api_key=OPENAI_API_KEY,
76 | model_name="gpt-4-turbo-2024-04-09",
77 | )
78 | else:
79 | raise ValueError(f"Unknown model type {args.type}")
80 |
81 | return model
82 |
83 |
84 | @logger.catch(reraise=True)
85 | def text_analysis(
86 | text: str,
87 | display_whitespace: bool,
88 | do_sample: bool,
89 | temperature: float,
90 | max_tokens: int,
91 | repetition_penalty: float,
92 | num_beams: int,
93 | topk: int,
94 | topp: float,
95 | topk_per_token: int,
96 | model: TopkTokenModel, # model should be built in the interface
97 | ) -> Tuple[str, str]:
98 | model.display_whitespace = display_whitespace
99 | model.do_sample = do_sample
100 | model.temperature = temperature
101 | model.max_tokens = max_tokens
102 | model.repetition_penalty = repetition_penalty
103 | model.num_beams = num_beams
104 | model.topk = topk
105 | model.topp = topp
106 | model.topk_per_token = topk_per_token
107 |
108 | tokens = model.generate_topk_per_token(text)
109 | html = model.html_to_visualize(tokens)
110 |
111 | html += "
"
112 | if isinstance(model, token_visualizer.TGIModel) and model.num_prefill_tokens:
113 | html += f"input tokens: {model.num_prefill_tokens}
"
114 | html += f"output tokens: {len(tokens)}
"
115 |
116 | return model.generated_answer, html
117 |
118 |
119 | def build_inference_analysis_demo(args):
120 | model = build_model_by_args(args)
121 | inference_func = functools.partial(text_analysis, model=model)
122 |
123 | interface = gr.Interface(
124 | inference_func,
125 | [
126 | gr.TextArea(placeholder="Please input text here"),
127 | gr.Checkbox(value=False, label="display whitespace in output"),
128 | gr.Checkbox(value=True, label="do_sample"),
129 | gr.Slider(minimum=0, maximum=1, step=0.05, value=1.0, label="temperature"),
130 | gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="max tokens"),
131 | gr.Slider(minimum=1, maximum=2, step=0.1, value=1.0, label="repetition penalty"),
132 | gr.Slider(minimum=1, maximum=10, step=1, value=1, label="num beams"),
133 | gr.Slider(minimum=1, maximum=100, step=1, value=50, label="topk"),
134 | gr.Slider(minimum=0, maximum=1, step=0.05, value=1.0, label="topp"),
135 | gr.Slider(minimum=1, maximum=10, step=1, value=5, label="per-token topk"),
136 | ],
137 | [
138 | gr.TextArea(label="LLM answer"),
139 | "html",
140 | ],
141 | examples=[
142 | ["Who are Hannah Quinlivan's child?"],
143 | ["Write python code to read a file and print its content."],
144 | ],
145 | title="LLM inference analysis",
146 | )
147 | return interface
148 |
149 |
150 | @logger.catch(reraise=True)
151 | def ppl_from_model(
152 | text: str,
153 | url: str,
154 | bos: str,
155 | eos: str,
156 | display_whitespace: bool,
157 | model,
158 | ) -> str:
159 | """Generate PPL visualization from model.
160 |
161 | Args:
162 | text (str): input text to visualize.
163 | url (str): tgi url.
164 | bos (str): begin of sentence token.
165 | eos (str): end of sentence token.
166 | display_whitespace (bool): whether to display whitespace for output text.
167 | If set to True, whitespace will be displayed as "␣".
168 | """
169 | url = url.strip()
170 | assert url, f"Please provide url of your tgi model. Current url: {url}"
171 | logger.info(f"Set url to {url}")
172 | model.url = url
173 | model.display_whitespace = display_whitespace
174 | model.max_tokens = 1
175 |
176 | text = bos + text + eos
177 | tokens = model.generate_inputs_prob(text)
178 | html = model.html_to_visualize(tokens)
179 |
180 | # display tokens and ppl at the end
181 | html += "
"
182 | html += f"total tokens: {len(tokens)}
"
183 | ppl = tokens[-1].ppl
184 | html += f"ppl: {ppl:.4f}
"
185 | return html
186 |
187 |
188 | def build_ppl_visualizer_demo(args):
189 | model = build_model_by_args(args)
190 | ppl_func = functools.partial(ppl_from_model, model=model)
191 |
192 | ppl_interface = gr.Interface(
193 | ppl_func,
194 | [
195 | gr.TextArea(placeholder="Please input text to visualize here"),
196 | gr.TextArea(
197 | placeholder="Please input tgi url here (Error if not provided)",
198 | lines=1,
199 | ),
200 | gr.TextArea(placeholder="BOS token, default to empty string", lines=1),
201 | gr.TextArea(placeholder="EOS token, default to empty string", lines=1),
202 | gr.Checkbox(value=False, label="display whitespace in output, default to False"),
203 | ],
204 | "html",
205 | title="PPL Visualizer",
206 | )
207 | return ppl_interface
208 |
209 |
210 | def demo():
211 | args = make_parser().parse_args()
212 | logger.info(f"Args: {args}")
213 |
214 | demo = gr.Blocks(css=css_style())
215 | with demo:
216 | with gr.Tab("Inference"):
217 | build_inference_analysis_demo(args)
218 | with gr.Tab("PPL"):
219 | build_ppl_visualizer_demo(args)
220 |
221 | demo.launch(
222 | server_name="0.0.0.0",
223 | share=args.share,
224 | server_port=args.port,
225 | )
226 |
227 |
228 | if __name__ == "__main__":
229 | demo()
230 |
--------------------------------------------------------------------------------