├── .env ├── .gitignore ├── README.md ├── assets ├── inference.gif ├── ppl.gif └── tokenizer.gif ├── setup.cfg ├── setup.py ├── token_visualizer ├── __init__.py ├── main.css ├── models.py ├── token_html.py └── utils.py ├── visual_tokenizer.py └── visualizer.py /.env: -------------------------------------------------------------------------------- 1 | BASE_URL="http://47.236.144.103/v1/chat/completions" 2 | OPENAI_KEY="sk-7FrI14KGfbZPk7aR6155Dc668095493d895bCd177473276e" 3 | TGI_URL="http://tgi_example_url:port" 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Linux ### 2 | *~ 3 | 4 | # temporary files which can be created if a process still has a handle open of a deleted file 5 | .fuse_hidden* 6 | 7 | # KDE directory preferences 8 | .directory 9 | 10 | # Linux trash folder which might appear on any partition or disk 11 | .Trash-* 12 | 13 | # .nfs files are created when an open file is removed but is still being accessed 14 | .nfs* 15 | 16 | ### PyCharm ### 17 | # User-specific stuff 18 | .idea 19 | 20 | # CMake 21 | cmake-build-*/ 22 | 23 | # Mongo Explorer plugin 24 | .idea/**/mongoSettings.xml 25 | 26 | # File-based project format 27 | *.iws 28 | 29 | # IntelliJ 30 | out/ 31 | 32 | # mpeltonen/sbt-idea plugin 33 | .idea_modules/ 34 | 35 | # JIRA plugin 36 | atlassian-ide-plugin.xml 37 | 38 | # Cursive Clojure plugin 39 | .idea/replstate.xml 40 | 41 | # Crashlytics plugin (for Android Studio and IntelliJ) 42 | com_crashlytics_export_strings.xml 43 | crashlytics.properties 44 | crashlytics-build.properties 45 | fabric.properties 46 | 47 | # Editor-based Rest Client 48 | .idea/httpRequests 49 | 50 | # Android studio 3.1+ serialized cache file 51 | .idea/caches/build_file_checksums.ser 52 | 53 | # JetBrains templates 54 | **___jb_tmp___ 55 | 56 | ### Python ### 57 | # Byte-compiled / optimized / DLL files 58 | __pycache__/ 59 | *.py[cod] 60 | *$py.class 61 | 62 | # C extensions 63 | *.so 64 | 65 | # Distribution / packaging 66 | .Python 67 | build/ 68 | develop-eggs/ 69 | dist/ 70 | downloads/ 71 | eggs/ 72 | .eggs/ 73 | lib/ 74 | lib64/ 75 | parts/ 76 | sdist/ 77 | var/ 78 | wheels/ 79 | pip-wheel-metadata/ 80 | share/python-wheels/ 81 | *.egg-info/ 82 | .installed.cfg 83 | *.egg 84 | MANIFEST 85 | 86 | # PyInstaller 87 | # Usually these files are written by a python script from a template 88 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 89 | *.manifest 90 | *.spec 91 | 92 | # Installer logs 93 | pip-log.txt 94 | pip-delete-this-directory.txt 95 | 96 | # Unit test / coverage reports 97 | htmlcov/ 98 | .tox/ 99 | .nox/ 100 | .coverage 101 | .coverage.* 102 | .cache 103 | report.xml 104 | nosetests.xml 105 | coverage.xml 106 | *.cover 107 | .hypothesis/ 108 | .pytest_cache/ 109 | 110 | # Translations 111 | *.mo 112 | *.pot 113 | 114 | # Django stuff: 115 | *.log 116 | local_settings.py 117 | db.sqlite3 118 | 119 | # Flask stuff: 120 | instance/ 121 | .webassets-cache 122 | 123 | # Scrapy stuff: 124 | .scrapy 125 | 126 | # Sphinx documentation 127 | docs/_build/ 128 | docs/build/ 129 | 130 | # PyBuilder 131 | target/ 132 | 133 | # Jupyter Notebook 134 | .ipynb_checkpoints 135 | 136 | # IPython 137 | profile_default/ 138 | ipython_config.py 139 | 140 | # pyenv 141 | .python-version 142 | 143 | # pipenv 144 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 145 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 146 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 147 | # install all needed dependencies. 148 | #Pipfile.lock 149 | 150 | # celery beat schedule file 151 | celerybeat-schedule 152 | 153 | # SageMath parsed files 154 | *.sage.py 155 | 156 | # Environments 157 | .env 158 | .venv 159 | env/ 160 | venv/ 161 | ENV/ 162 | env.bak/ 163 | venv.bak/ 164 | 165 | # Spyder project settings 166 | .spyderproject 167 | .spyproject 168 | 169 | # Rope project settings 170 | .ropeproject 171 | 172 | # mkdocs documentation 173 | /site 174 | 175 | # mypy 176 | .mypy_cache/ 177 | .dmypy.json 178 | dmypy.json 179 | 180 | # Pyre type checker 181 | .pyre/ 182 | 183 | ### Vim ### 184 | # Swap 185 | [._]*.s[a-v][a-z] 186 | [._]*.sw[a-p] 187 | [._]s[a-rt-v][a-z] 188 | [._]ss[a-gi-z] 189 | [._]sw[a-p] 190 | 191 | # Session 192 | Session.vim 193 | 194 | # Temporary 195 | .netrwhist 196 | # Auto-generated tag files 197 | tags 198 | # Persistent undo 199 | [._]*.un~ 200 | 201 | .code-workspace.code-workspace 202 | output 203 | instant_test_output 204 | inference_test_output 205 | *.pkl 206 | *.txt 207 | *.json 208 | *.mge 209 | *.jpg 210 | *.png 211 | *.npy 212 | *.pth 213 | events.out.tfevents* 214 | **/log 215 | 216 | # vscode 217 | *.code-workspace 218 | .vscode 219 | 220 | # vim 221 | .vim 222 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Introduction 3 | token visualizer is a token-level visualization tool to visualize LLM. 4 | 5 | ## Quick start 6 | 7 | ### Installation 8 | 9 | #### Install from source 10 | 11 | Run the following command to install the package. 12 | ```shell 13 | git clone git@github.com:FateScript/token_visualizer.git 14 | cd token_visualizer 15 | pip install -v -e . # or python3 setup.py develop 16 | ``` 17 | 18 | #### Check installation 19 | 20 | If you could see the version of `token_visualizer` by running 21 | ```shell 22 | python3 -c "import token_visualizer; print(token_visualizer.__file__)" 23 | ``` 24 | 25 | ## Visualization demo 26 | 27 | ### Inference 28 | 29 | #### Start demo 30 | Run the following command to start inference visualizer. 31 | ```shell 32 | python3 visualizer.py 33 | ``` 34 | 35 | The command will start a OpenAIProxy model, to use it without exception, user should fill in the value of `BASE_URL` and `OPENAI_KEY`. 36 | 37 | `token_visualizer` also support `OpenAIModel` and HuggingFace `TransformerModel` in [models.py](https://github.com/FateScript/token_visualizer/blob/main/token_visualizer/models.py), feel free to modify the code. 38 | 39 | #### Demo gif 40 | After inputing your prompt, you will see the large language model's answer and the answer's visualization result. 41 | 42 | 43 | 44 | **The redder the color of the token, the lower the corresponding probability. The greener the color of the token, the higher the corresponding probability.** 45 | 46 | ### Perplexity 47 | 48 | #### Start demo 49 | Run the following command to start perplexity visualizer, then click the `ppl` tab. 50 | 51 | ```shell 52 | python3 visualizer.py 53 | ``` 54 | 55 | #### Demo gif 56 | After inputing your text, you will see the perplexity and visualization result of the text. 57 | 58 | 59 | 60 | ### Tokenizer 61 | 62 | #### Start demo 63 | Run the following command to start interactive tokenizer encoding web demo. 64 | 65 | ```shell 66 | python3 visual_tokenizer.py 67 | ``` 68 | 69 | #### Demo gif 70 | User could select tokenizer to interacte with and text to encode. For speical string 71 | 72 | 73 | 74 | ## TODO 75 | - [x] Support ppl visualization. 76 | - [x] Select transformers/openai/TGI with cli. 77 | - [ ] Support OpenAI tokenizer visualization. 78 | - [x] Support TGI inference visualization. 79 | - [ ] Support multi-turn chat visualization. 80 | - [ ] Support dark mode. 81 | 82 | 83 | ## Related projects/websites 84 | 85 | * [LLM architecture visualization](https://bbycroft.net/llm) 86 | * [perplexity visualization](https://bbycroft.net/ppl) 87 | 88 | 89 | ## Acknowledgement 90 | 91 | * Use front-end setting from [https://perplexity.vercel.app/](https://bbycroft.net/ppl) 92 | * Color algorithm from [post](https://twitter.com/thesephist/status/1617909119423500288) by [thesephist](https://twitter.com/thesephist). 93 | -------------------------------------------------------------------------------- /assets/inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FateScript/token_visualizer/a0c2cff3f7ff8d58e8ebe026aa57b86fcc785dac/assets/inference.gif -------------------------------------------------------------------------------- /assets/ppl.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FateScript/token_visualizer/a0c2cff3f7ff8d58e8ebe026aa57b86fcc785dac/assets/ppl.gif -------------------------------------------------------------------------------- /assets/tokenizer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FateScript/token_visualizer/a0c2cff3f7ff8d58e8ebe026aa57b86fcc785dac/assets/tokenizer.gif -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = token_visualizer 3 | version = 0.0.1 4 | description = token level visualization tools for large language models 5 | author = Feng Wang 6 | author_email = wangfeng19950315@163.com 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/FateScript/token_visualizer 10 | 11 | [options] 12 | packages = find: 13 | python_requires = >=3.7 14 | install_requires = 15 | numpy 16 | torch 17 | gradio 18 | transformers==4.34.0 19 | retrying 20 | loguru 21 | sentencepiece 22 | openai>=1.0 23 | python-dotenv 24 | 25 | [flake8] 26 | max-line-length = 100 27 | max-complexity = 18 28 | exclude = __init__.py 29 | 30 | [isort] 31 | line_length = 100 32 | multi_line_output = 3 33 | include_trailing_comma = true 34 | balanced_wrapping = true 35 | known_thirdparty = numpy, loguru, gradio, torch, openai, transformers, retrying, sentencepiece, dotenv 36 | KNOWN_MYSELF = token_visualizer 37 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,MYSELF,LOCALFOLDER 38 | no_lines_before=STDLIB 39 | default_section = FIRSTPARTY 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import setup 4 | 5 | setup() 6 | -------------------------------------------------------------------------------- /token_visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .models import ( 4 | OpenAIModel, 5 | OpenAIProxyModel, 6 | TGIModel, 7 | TopkTokenModel, 8 | TransformerModel, 9 | generate_topk_token_prob, 10 | load_model_tokenizer, 11 | openai_payload, 12 | ) 13 | from .token_html import ( 14 | Token, 15 | candidate_tokens_html, 16 | color_token_by_logprob, 17 | set_tokens_ppl, 18 | single_token_html, 19 | tokens_info_to_html, 20 | tokens_min_max_logprob, 21 | ) 22 | from .utils import css_style, ensure_os_env 23 | -------------------------------------------------------------------------------- /token_visualizer/main.css: -------------------------------------------------------------------------------- 1 | .visual-tokens { 2 | line-height: 1.5em; 3 | max-width: 72ch; 4 | word-wrap: break-word; 5 | white-space: pre-wrap; 6 | tab-size: 4; 7 | } 8 | 9 | .flex-spacer { 10 | width: 0; 11 | height: 0; 12 | flex-grow: 1; 13 | } 14 | 15 | .flex-row { 16 | display: flex; 17 | flex-direction: row; 18 | align-items: center; 19 | flex-grow: 1; 20 | gap: 16px; 21 | } 22 | 23 | .flex-col { 24 | display: flex; 25 | flex-direction: column; 26 | align-items: center; 27 | flex-grow: 1; 28 | gap: 16px; 29 | } 30 | 31 | .ppl-visualization-tokens { 32 | line-height: 1.5em; 33 | max-width: 72ch; 34 | word-wrap: break-word; 35 | white-space: pre-wrap; 36 | tab-size: 4; 37 | } 38 | 39 | .ppl-visualization-tokens:empty::after { 40 | content: 'Nothing visualized yet.'; 41 | color: #9b9b9b; 42 | font-style: italic; 43 | } 44 | 45 | .ppl-token { 46 | position: relative; 47 | } 48 | 49 | .ppl-token:hover { 50 | color: #fdfeff; 51 | background: #111111 !important; 52 | } 53 | 54 | .ppl-pseudo-token { 55 | color: #9b9b9b; 56 | } 57 | 58 | .ppl-hud { 59 | position: absolute; 60 | top: calc(100% + 4px); 61 | left: 0; 62 | color: #111111; 63 | background: #fdfeff; 64 | box-shadow: 0 2px 4px rgba(0, 0, 0, .36); 65 | padding: 4px 8px; 66 | border-radius: 4px; 67 | z-index: 5; 68 | opacity: 0; 69 | visibility: hidden; 70 | } 71 | 72 | .ppl-hud-row { 73 | display: flex; 74 | flex-direction: row; 75 | align-items: flex-end; 76 | justify-content: space-between; 77 | gap: 8px; 78 | } 79 | .ppl-hud-label { 80 | color: #9b9b9b !important; 81 | font-size: calc(1em - 2px); 82 | } 83 | .ppl-predictions { 84 | margin-top: 8px; 85 | } 86 | .ppl-token:hover .ppl-hud { 87 | opacity: 1; 88 | visibility: visible; 89 | } 90 | -------------------------------------------------------------------------------- /token_visualizer/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import math 5 | import os 6 | from dataclasses import dataclass 7 | from typing import Dict, List, Optional, Tuple, Union 8 | 9 | import requests 10 | import torch 11 | from loguru import logger 12 | from openai import OpenAI 13 | from requests.adapters import HTTPAdapter 14 | from retrying import retry 15 | from urllib3.util import Retry 16 | 17 | from .token_html import Token, tokens_info_to_html 18 | 19 | __all__ = [ 20 | "TopkTokenModel", 21 | "TransformerModel", 22 | "TGIModel", 23 | "OpenAIModel", 24 | "OpenAIProxyModel", 25 | "generate_topk_token_prob", 26 | "load_model_tokenizer", 27 | "openai_payload", 28 | ] 29 | 30 | 31 | def load_model_tokenizer(repo): 32 | from transformers import AutoModelForCausalLM, AutoTokenizer 33 | model = AutoModelForCausalLM.from_pretrained(repo, device_map="auto", trust_remote_code=True) 34 | tokenizer = AutoTokenizer.from_pretrained(repo, use_fast=True, trust_remote_code=True) 35 | return model, tokenizer 36 | 37 | 38 | def format_reverse_vocab(tokenizer) -> Dict[int, str]: 39 | """ 40 | Format the vocab to make it more human-readable, return a token_id to token_value mapping. 41 | """ 42 | rev_vocab = {v: k for k, v in tokenizer.get_vocab().items()} 43 | sp_space = b"\xe2\x96\x81".decode() # reference link below in sentencepiece: 44 | # https://github.com/google/sentencepiece/blob/8cbdf13794284c30877936f91c6f31e2c1d5aef7/src/sentencepiece_processor.cc#L41-L42 45 | 46 | for idx, token in rev_vocab.items(): 47 | if sp_space in token: 48 | rev_vocab[idx] = token.replace(sp_space, "␣") 49 | elif token.isspace(): # token like \n, \t or multiple spaces 50 | rev_vocab[idx] = repr(token)[1:-1] # 1:-1 to strip ', it will convert \n to \\n 51 | elif token.startswith("<") and token.endswith(">"): # tokens like 52 | # NOTE: string like
<s>
is better, but <|s|> is simple, better-looking 53 | # rev_vocab[idx] = f"
<{token[1:-1]}>
" 54 | rev_vocab[idx] = f"<|{token[1:-1]}|>" 55 | 56 | return rev_vocab 57 | 58 | 59 | def generate_topk_token_prob( 60 | inputs: str, model, tokenizer, 61 | num_topk_tokens: int = 10, 62 | inputs_device: str = "cuda:0", 63 | **kwargs 64 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 65 | """ 66 | Generate topk token and it's prob for each token of auto regressive model. 67 | """ 68 | if not torch.cuda.is_available(): 69 | inputs_device = "cpu" 70 | model = model.to(inputs_device) 71 | logger.warning(f"CUDA not available, switch to {inputs_device}.") 72 | 73 | logger.info(f"generate response for:\n{inputs}") 74 | inputs = tokenizer(inputs, return_tensors='pt') 75 | inputs = inputs.to(inputs_device) 76 | outputs = model.generate( 77 | **inputs, 78 | return_dict_in_generate=True, 79 | output_scores=True, 80 | **kwargs 81 | ) 82 | logits = torch.stack(outputs.scores) 83 | probs = torch.softmax(logits, dim=-1) 84 | topk_tokens = torch.topk(logits, k=num_topk_tokens).indices 85 | topk_probs = torch.gather(probs, -1, topk_tokens) 86 | return topk_tokens, topk_probs, outputs.sequences 87 | 88 | 89 | def openai_top_response_tokens(response: Dict) -> List[Token]: 90 | token_logprobs = response["choices"][0]["logprobs"]["content"] 91 | tokens = [] 92 | for token_prob in token_logprobs: 93 | prob = math.exp(token_prob["logprob"]) 94 | candidate_tokens = [ 95 | Token(t["token"], math.exp(t["logprob"])) 96 | for t in token_prob["top_logprobs"] 97 | ] 98 | token = Token(token_prob["token"], prob, top_candidates=candidate_tokens) 99 | tokens.append(token) 100 | return tokens 101 | 102 | 103 | def openai_payload( 104 | prompt: Union[List[str], str], 105 | model_name: str, 106 | system_prompt: str = "", 107 | **kwargs 108 | ) -> Dict: 109 | """Generate payload for openai api call.""" 110 | messages = [] 111 | if system_prompt: 112 | messages.append({"role": "system", "content": system_prompt}) 113 | if isinstance(prompt, str): 114 | prompt = [prompt] 115 | for idx, p in enumerate(prompt): 116 | role = "user" if idx % 2 == 0 else "assistant" 117 | messages.append({"role": role, "content": p}) 118 | 119 | payload = {"model": model_name, "messages": messages, **kwargs} 120 | return payload 121 | 122 | 123 | @dataclass 124 | class TopkTokenModel: 125 | do_sample: bool = False 126 | temperature: float = 1.0 127 | max_tokens: int = 4096 128 | repetition_penalty: float = 1.0 129 | num_beams: int = 1 130 | topk: int = 50 131 | topp: float = 1.0 132 | 133 | topk_per_token: int = 5 # number of topk tokens to generate for each token 134 | generated_answer: Optional[str] = None # generated answer from model, to display in frontend 135 | display_whitespace: bool = False 136 | 137 | def generate_topk_per_token(self, text: str) -> List[Token]: 138 | """ 139 | Generate prob, text and candidates for each token of the model's output. 140 | This function is used to visualize the inference process. 141 | """ 142 | raise NotImplementedError 143 | 144 | def generate_inputs_prob(self, text: str) -> List[Token]: 145 | """ 146 | Generate prob and text for each token of the input text. 147 | This function is used to visualize the ppl. 148 | """ 149 | raise NotImplementedError 150 | 151 | def html_to_visualize(self, tokens: List[Token]) -> str: 152 | """Generate html to visualize the tokens.""" 153 | return tokens_info_to_html(tokens, display_whitespace=self.display_whitespace) 154 | 155 | 156 | @dataclass 157 | class TransformerModel(TopkTokenModel): 158 | 159 | repo: Optional[str] = None 160 | model = None 161 | tokenizer = None 162 | rev_vocab = None 163 | 164 | def get_model_tokenizer(self): 165 | assert self.repo, "Please provide repo name to load model and tokenizer." 166 | if self.model is None or self.tokenizer is None: 167 | self.model, self.tokenizer = load_model_tokenizer(self.repo) 168 | if self.rev_vocab is None: 169 | self.rev_vocab = format_reverse_vocab(self.tokenizer) 170 | return self.model, self.tokenizer 171 | 172 | def generate_topk_per_token(self, text: str) -> List[Token]: 173 | model, tokenizer = self.get_model_tokenizer() 174 | rev_vocab = self.rev_vocab 175 | assert rev_vocab, f"Reverse vocab not loaded for {self.repo} model" 176 | 177 | topk_tokens, topk_probs, sequences = generate_topk_token_prob( 178 | text, model, tokenizer, num_topk_tokens=self.topk_per_token, 179 | do_sample=self.do_sample, 180 | temperature=max(self.temperature, 0.01), 181 | max_new_tokens=self.max_tokens, 182 | repetition_penalty=self.repetition_penalty, 183 | num_beams=self.num_beams, 184 | top_k=self.topk, 185 | top_p=self.topp, 186 | ) 187 | self.generated_answer = tokenizer.decode(sequences[0]) 188 | 189 | seq_length = topk_tokens.shape[0] 190 | np_seq = sequences[0, -seq_length:].cpu().numpy() 191 | gen_tokens = [] 192 | for seq_id, token, prob in zip(np_seq, topk_tokens.cpu().numpy(), topk_probs.cpu().numpy()): 193 | candidate_tokens = [Token(f"{rev_vocab[idx]}", float(p)) for idx, p in zip(token[0], prob[0])] # noqa 194 | seq_id_prob = float(prob[0][token[0] == seq_id]) 195 | display_token = Token(f"{rev_vocab[seq_id]}", seq_id_prob, candidate_tokens) 196 | gen_tokens.append(display_token) 197 | return gen_tokens 198 | 199 | 200 | def tgi_response( # type: ignore[return] 201 | input_text: str, 202 | url: str, 203 | max_new_tokens: int = 2048, 204 | repetition_penalty: float = 1.1, 205 | temperature: float = 0.01, 206 | top_k: int = 5, 207 | top_p: float = 0.85, 208 | do_sample: bool = True, 209 | topk_logits: Optional[int] = None, 210 | details: bool = False, 211 | **kwargs 212 | ) -> Dict: 213 | headers = {"Content-Type": "application/json"} 214 | params = { 215 | "max_new_tokens": max_new_tokens, 216 | "repetition_penalty": repetition_penalty, 217 | "do_sample": do_sample, 218 | "temperature": temperature, 219 | "top_n_tokens": topk_logits, 220 | "details": details, 221 | **kwargs, 222 | } 223 | if do_sample: # tgi use or logic for top_k/top_p with do_sample 224 | params.update({"top_k": top_k, "top_p": top_p}) 225 | 226 | data = {"inputs": input_text, "parameters": params} 227 | 228 | response = requests.post(url, json=data, headers=headers) 229 | 230 | if response.status_code != 200: 231 | logger.error(f"Error {response.status_code}: {response.text}") 232 | return response.json() 233 | 234 | 235 | @dataclass 236 | class TGIModel(TopkTokenModel): 237 | 238 | url: Optional[str] = None 239 | system_prompt = "" 240 | details: bool = False 241 | decoder_input_details: bool = False # input logprobs 242 | num_prefill_tokens: Optional[int] = None 243 | 244 | # tgi support top_n_tokens, reference below: 245 | # https://github.com/huggingface/text-generation-inference/blob/7dbaf9e9013060af52024ea1a8b361b107b50a69/proto/generate.proto#L108-L109 246 | 247 | def response_to_inputs(self, inputs: str) -> Dict: 248 | assert self.url, f"Please provide url to access tgi api. url: {self.url}" 249 | json_response = tgi_response( 250 | inputs, url=self.url, 251 | max_new_tokens=self.max_tokens, 252 | repetition_penalty=self.repetition_penalty, 253 | temperature=self.temperature, 254 | top_k=self.topk, 255 | top_p=min(self.topp, 0.99), 256 | do_sample=self.do_sample, 257 | decoder_input_details=self.decoder_input_details, 258 | topk_logits=self.topk_per_token, 259 | details=self.details, 260 | ) 261 | response = json_response[0] 262 | self.generated_answer = response["generated_text"] 263 | return response 264 | 265 | def generate_topk_per_token(self, text: str) -> List[Token]: 266 | assert self.details, "Please set details to True." 267 | response = self.response_to_inputs(text) 268 | tokens: List[Token] = [] 269 | 270 | token_details = response["details"]["tokens"] 271 | topk_tokens = response["details"]["top_tokens"] 272 | 273 | for details, candidate in zip(token_details, topk_tokens): 274 | candidate_tokens = [Token(x["text"], math.exp(x["logprob"])) for x in candidate] 275 | token = Token( 276 | details["text"], 277 | math.exp(details["logprob"]), 278 | top_candidates=candidate_tokens, 279 | ) 280 | tokens.append(token) 281 | 282 | return tokens 283 | 284 | def generate_inputs_prob(self, text: str) -> List[Token]: 285 | assert self.decoder_input_details, "Please set decoder_input_details to True." 286 | response = self.response_to_inputs(text) 287 | token_details = response["details"]["prefill"] 288 | tokens = [] 289 | for token in token_details: 290 | logprob = token.get("logprob", None) 291 | if logprob is None: 292 | continue 293 | tokens.append(Token(token["text"], math.exp(logprob))) 294 | return tokens 295 | 296 | def set_num_prefill_tokens(self, response): 297 | if self.details: 298 | self.num_prefill_tokens = response["details"]["prefill_tokens"] 299 | else: 300 | self.num_prefill_tokens = None 301 | 302 | 303 | @dataclass 304 | class OpenAIModel(TopkTokenModel): 305 | api_key: Optional[str] = None 306 | base_url: Optional[str] = None 307 | 308 | system_prompt: str = "" 309 | model_name: str = "gpt-4-0125-preview" 310 | # choices for model_name: see https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo 311 | json_mode: bool = False 312 | seed: Optional[int] = None 313 | 314 | def __post_init__(self): 315 | assert self.api_key is not None, "Please provide api key to access openai api." 316 | self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) 317 | 318 | def generate_topk_per_token(self, text: str, **kwargs) -> List[Token]: 319 | kwargs = { 320 | "temperature": self.temperature, 321 | "top_p": self.topp, 322 | } 323 | if self.seed: 324 | kwargs["seed"] = self.seed 325 | if self.json_mode: 326 | kwargs["response_format"] = {"type": "json_object"} 327 | if self.topk_per_token > 0: 328 | kwargs["logprobs"] = True 329 | kwargs["top_logprobs"] = self.topk_per_token 330 | 331 | payload = openai_payload(text, self.model_name, system_prompt=self.system_prompt, **kwargs) 332 | completion = self.client.completions.create(payload) 333 | self.generated_answer = completion.choices[0].message.content 334 | return openai_top_response_tokens(completion.dict()) 335 | 336 | 337 | @dataclass 338 | class OpenAIProxyModel(TopkTokenModel): 339 | api_key: Optional[str] = None 340 | base_url: Optional[str] = None 341 | 342 | system_prompt = "" 343 | model_name: str = "gpt-4-0125-preview" 344 | # choices for model_name: see https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo 345 | json_mode: bool = False 346 | seed: Optional[int] = None 347 | 348 | def __post_init__(self): 349 | assert self.base_url is not None, "Please provide url to access openai api." 350 | assert self.api_key is not None, "Please provide api key to access openai api." 351 | retry_strategy = Retry( 352 | total=1, # max retry times 353 | backoff_factor=1, # time interval between retries 354 | status_forcelist=[429, 500, 502, 503, 504], # retry when these status code 355 | allowed_methods=["POST"], # retry only when POST 356 | ) 357 | adapter = HTTPAdapter(max_retries=retry_strategy) 358 | self.session = requests.Session() 359 | self.session.mount("https://", adapter) 360 | self.session.mount("http://", adapter) 361 | if self.api_key is None: 362 | self.api_key = os.environ.get("OPENAI_API_KEY") 363 | 364 | @retry(stop_max_attempt_number=3) 365 | def openai_api_call(self, payload): 366 | headers = { 367 | "Content-Type": "application/json", 368 | "Authorization": "Bearer " + self.api_key, 369 | } 370 | response = self.session.post(self.base_url, headers=headers, data=json.dumps(payload)) 371 | if response.status_code != 200: 372 | err_msg = f"Access openai error, status code: {response.status_code}, errmsg: {response.text}" # noqa 373 | raise ValueError(err_msg, response.status_code) 374 | 375 | data = json.loads(response.text) 376 | return data 377 | 378 | def generate_topk_per_token(self, text: str, **kwargs) -> List[Token]: 379 | kwargs = { 380 | "temperature": self.temperature, 381 | "top_p": self.topp, 382 | } 383 | if self.seed: 384 | kwargs["seed"] = self.seed 385 | if self.json_mode: 386 | kwargs["response_format"] = {"type": "json_object"} 387 | if self.topk_per_token > 0: 388 | kwargs["logprobs"] = True 389 | kwargs["top_logprobs"] = self.topk_per_token 390 | 391 | payload = openai_payload(text, self.model_name, system_prompt=self.system_prompt, **kwargs) 392 | response = self.openai_api_call(payload) 393 | self.generated_answer = response["choices"][0]["message"]["content"] 394 | return openai_top_response_tokens(response) 395 | -------------------------------------------------------------------------------- /token_visualizer/token_html.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import colorsys 4 | import itertools 5 | import math 6 | import operator 7 | import statistics 8 | from dataclasses import dataclass, field 9 | from typing import List, Tuple, Union 10 | 11 | __all__ = [ 12 | "Token", 13 | "candidate_tokens_html", 14 | "color_token_by_logprob", 15 | "set_tokens_ppl", 16 | "single_token_html", 17 | "tokens_min_max_logprob", 18 | "tokens_info_to_html", 19 | ] 20 | 21 | 22 | @dataclass 23 | class Token: 24 | text: str 25 | prob: float 26 | top_candidates: List = field(default_factory=list) 27 | ppl: Union[float, None] = field(default=None) 28 | 29 | @property 30 | def logprob(self) -> float: 31 | return math.log(self.prob) 32 | 33 | 34 | def text_to_html_token(text: str, replace_whitespace: bool = True) -> str: 35 | """Convert a raw text to token that could be display in html. 36 | For example, "" will be converted to "<func_call>". 37 | 38 | Args: 39 | text (str): raw text. 40 | replace_whitespace (bool, optional): whether to replace whitespace. Defaults to True. 41 | """ 42 | text = text.replace("<", "<").replace(">", ">") # display token like correctly # noqa 43 | text = text.replace("\n", "↵") # display newline 44 | if replace_whitespace: 45 | text = text.replace(" ", "␣") # display whitespace 46 | if len(text) == 0: # special token 47 | return "␣(null)" 48 | return text 49 | 50 | 51 | def candidate_tokens_html(topk_candidate_tokens: List[Token]) -> str: 52 | """HTML content of a single token's topk candidate tokens.""" 53 | template = '' \ 54 | '{token}{prob}' 55 | 56 | html_text = "".join([ 57 | template.format(token=text_to_html_token(token.text), prob=f"{token.prob:.3%}") 58 | for token in topk_candidate_tokens 59 | ]) 60 | html_text = f'
{html_text}
' 61 | return html_text 62 | 63 | 64 | def color_token_by_logprob( 65 | log_prob: float, min_log_prob: float, max_log_prob: float, 66 | hue_red: int = 0, hue_green: int = 150, epsilon: float = 1e-5, 67 | ) -> str: 68 | """According to the token's log prob, assign RGB color to the token. 69 | reference: https://twitter.com/thesephist/status/1617909119423500288 70 | 71 | Args: 72 | log_prob (float): log prob of the token. 73 | min_log_prob (float): min log prob of all tokens. 74 | max_log_prob (float): max log prob of all tokens. 75 | hue_red (int, optional): hue value of the red color. Defaults to 0. 76 | hue_green (int, optional): hue value of the green color. Defaults to 150. 77 | epsilon (float, optional): avoid divide by zero. Defaults to 1e-5. 78 | """ 79 | # clamp the log_prob and scale to (hsl_red, hsl_green) 80 | if min_log_prob == max_log_prob: 81 | ratio = 1 # set to green color 82 | else: 83 | log_prob = max(min_log_prob, min(log_prob, max_log_prob)) 84 | ratio = (log_prob - min_log_prob) / max((max_log_prob - min_log_prob), epsilon) 85 | hue = ratio * (hue_green - hue_red) + hue_red 86 | red, green, blue = colorsys.hls_to_rgb(hue / 360.0, 0.85, 0.6) # hls({hue}deg 85% 60%) 87 | rgb_string = f"rgb({int(red * 255)}, {int(green * 255)}, {int(blue * 255)})" 88 | return rgb_string 89 | 90 | 91 | def set_tokens_ppl(tokens: List[Token]): 92 | """Set ppl value for each token in the list of tokens.""" 93 | logprob_sum = itertools.accumulate([x.logprob for x in tokens], operator.add) 94 | for num_tokens, (token, logprob) in enumerate(zip(tokens, logprob_sum), 1): 95 | token.ppl = math.exp(-logprob / num_tokens) 96 | 97 | 98 | def single_token_html(token: Token) -> str: 99 | """HTML text of single token.""" 100 | template = '
{label}{value}
' # noqa 101 | 102 | html_text = template.format(label="prob", value=f"{token.prob:.3%}") 103 | html_text += template.format(label="logprob", value=f"{token.logprob:.4f}") 104 | if token.ppl is not None: 105 | html_text += template.format(label="ppl", value=f"{token.ppl:.4f}") 106 | html_text += candidate_tokens_html(token.top_candidates) 107 | html_text = f"
{html_text}
" 108 | return html_text 109 | 110 | 111 | def tokens_min_max_logprob(tokens: List[Token]) -> Tuple[float, float]: 112 | """Calculate the normalized min and max logprob of a list of tokens.""" 113 | if len(tokens) == 1: 114 | min_logprob = max_logprob = tokens[0].logprob 115 | else: 116 | logprob_mean = statistics.mean([token.logprob for token in tokens]) 117 | logprob_stddev = statistics.stdev([token.logprob for token in tokens]) 118 | min_logprob = logprob_mean - (2.5 * logprob_stddev) 119 | max_logprob = logprob_mean + (2.5 * logprob_stddev) 120 | return min_logprob, max_logprob 121 | 122 | 123 | def tokens_info_to_html(tokens: List[Token], display_whitespace: bool = True) -> str: 124 | """ 125 | Generate html for a list of token, include token color and hover text. 126 | 127 | Args: 128 | tokens (List[Token]): a list of tokens to generate html for. 129 | """ 130 | min_logprob, max_logprob = tokens_min_max_logprob(tokens) 131 | set_tokens_ppl(tokens) 132 | 133 | tokens_html = "" 134 | for token in tokens: 135 | hover_html = single_token_html(token) 136 | rgb = color_token_by_logprob(token.logprob, min_logprob, max_logprob) 137 | is_newline = "\n" in token.text 138 | token_text = text_to_html_token(token.text, replace_whitespace=display_whitespace) 139 | if is_newline: 140 | token_text = f'{token_text}' 141 | token_html = f'{token_text}{hover_html}' # noqa 142 | if is_newline: 143 | token_html += "
" 144 | tokens_html += token_html 145 | tokens_html = f'
{tokens_html}
' # noqa 146 | return tokens_html 147 | -------------------------------------------------------------------------------- /token_visualizer/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from typing import Union 5 | 6 | from dotenv import load_dotenv 7 | 8 | import token_visualizer 9 | 10 | __all__ = ["css_style", "ensure_os_env"] 11 | 12 | 13 | def css_style() -> str: 14 | with open(os.path.join(os.path.dirname(token_visualizer.__file__), "main.css")) as f: 15 | css_style = f.read() 16 | css_html = f"""" 17 | return css_html 18 | 19 | 20 | def ensure_os_env(env_name: str, default_value: Union[str, None] = None): 21 | if env_name in os.environ: 22 | env_value = os.getenv(env_name) 23 | else: 24 | load_dotenv() 25 | env_value = os.getenv(env_name, default_value) 26 | return env_value 27 | -------------------------------------------------------------------------------- /visual_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from functools import lru_cache 5 | from typing import Dict 6 | 7 | import gradio as gr 8 | from loguru import logger 9 | from transformers import AutoTokenizer 10 | 11 | from sentencepiece import SentencePieceProcessor 12 | 13 | CANDIDATES = [ # model name sorted by alphabet 14 | "baichuan-inc/Baichuan2-13B-Chat", 15 | "bigcode/starcoder2-15b", 16 | "deepseek-ai/deepseek-coder-33b-instruct", 17 | # "google/gemma-7b", 18 | "gpt2", 19 | # "meta-llama/Llama-2-7b-chat-hf", 20 | "mistralai/Mixtral-8x7B-Instruct-v0.1", 21 | "THUDM/chatglm3-6b", 22 | ] 23 | SENTENCE_PIECE_MAPPING = {} 24 | SP_PREFIX = "SentencePiece/" 25 | 26 | 27 | def add_sp_tokenizer(name: str, tokenizer_path: str): 28 | """Add a sentence piece tokenizer to the list of available tokenizers.""" 29 | model_key = SP_PREFIX + name 30 | if not os.path.exists(tokenizer_path): 31 | raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}") 32 | SENTENCE_PIECE_MAPPING[model_key] = tokenizer_path 33 | CANDIDATES.append(model_key) 34 | 35 | 36 | # add_sp_tokenizer("LLaMa", "llama_tokenizer.model") 37 | logger.info(f"SentencePiece tokenizer: {list(SENTENCE_PIECE_MAPPING.keys())}") 38 | 39 | 40 | @lru_cache 41 | def get_tokenizer_and_vocab(name): 42 | if name.startswith(SP_PREFIX): 43 | local_file_path = SENTENCE_PIECE_MAPPING[name] 44 | tokenizer = SentencePieceProcessor(local_file_path) 45 | rev_vocab = {id_: tokenizer.id_to_piece(id_) for id_ in range(tokenizer.get_piece_size())} # noqa 46 | else: 47 | tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True) 48 | rev_vocab = {v: k for k, v in tokenizer.get_vocab().items()} 49 | return tokenizer, rev_vocab 50 | 51 | 52 | def tokenize(name: str, text: str) -> Dict: 53 | tokenizer, rev_vocab = get_tokenizer_and_vocab(name) 54 | 55 | ids = tokenizer.encode(text) 56 | s, entities = '', [] 57 | for i in ids: 58 | entity = str(i) 59 | start = len(s) 60 | s += rev_vocab[i] 61 | end = len(s) 62 | entities.append({"entity": entity, "start": start, "end": end}) 63 | 64 | return { 65 | "text": s + f"\n({len(ids)} tokens / {len(text)} characters)", 66 | "entities": entities 67 | } 68 | 69 | 70 | @logger.catch(reraise=True) 71 | def make_demo(): 72 | logger.info("Creating Interface..") 73 | 74 | DEFAULT_TOKENIZER = CANDIDATES[0] 75 | DEFAULT_INPUTTEXT = "Hello world." 76 | 77 | demo = gr.Interface( 78 | fn=tokenize, 79 | inputs=[ 80 | gr.Dropdown( 81 | CANDIDATES, value=DEFAULT_TOKENIZER, 82 | label="Tokenizer", allow_custom_value=False 83 | ), 84 | gr.TextArea(value=DEFAULT_INPUTTEXT, label="Input text"), 85 | ], 86 | outputs=[ 87 | gr.HighlightedText( 88 | value=tokenize(DEFAULT_TOKENIZER, DEFAULT_INPUTTEXT), 89 | label="Tokenized results" 90 | ) 91 | ], 92 | title="Tokenzier Visualizer", 93 | description="If you want to try more tokenizers, please contact the author@wangfeng", # noqa 94 | examples=[ 95 | [DEFAULT_TOKENIZER, "乒乓球拍卖完了,无线电法国别研究,我一把把把把住了"], 96 | ["bigcode/starcoder2-15b", "def print():\n print('Hello')"], 97 | ], 98 | cache_examples=True, 99 | live=True, 100 | ) 101 | return demo 102 | 103 | 104 | if __name__ == "__main__": 105 | demo = make_demo() 106 | demo.launch(server_name="0.0.0.0") 107 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import functools 4 | from argparse import ArgumentParser 5 | from typing import Tuple, Optional 6 | 7 | import gradio as gr 8 | from loguru import logger 9 | 10 | import token_visualizer 11 | from token_visualizer import TopkTokenModel, css_style, ensure_os_env 12 | 13 | 14 | def make_parser() -> ArgumentParser: 15 | parser = ArgumentParser("Inference process visualizer") 16 | parser.add_argument( 17 | "-t", "--type", 18 | choices=["llm", "tgi", "oai", "oai-proxy"], 19 | default="oai-proxy", 20 | help="Type of model to use, default to openai-proxy" 21 | ) 22 | parser.add_argument( 23 | "--hf-repo", type=str, default=None, 24 | help="Huggingface model repository, used when type is 'llm'. Default to None" 25 | ) 26 | parser.add_argument( 27 | "--oai-model", type=str, default="gpt-4-turbo-2024-04-09", 28 | help="OpenAI model name, used when type is 'oai'/'oai-proxy'. " 29 | "Check https://platform.openai.com/docs/models for more details. " 30 | "Default to `gpt-4-turbo-2024-04-09`." 31 | ) 32 | parser.add_argument( 33 | "--oai-key", type=str, default=None, 34 | help="OpenAI api key, used when type is 'oai'/'oai-proxy'. " 35 | "If provided, will override OPENAI_KEY env variable.", 36 | ) 37 | parser.add_argument( 38 | "--tgi-url", type=str, default=None, 39 | help="Service url of TGI model, used when type is 'tgi'. " 40 | "If provided, will override TGI_URL env variable.", 41 | ) 42 | parser.add_argument( 43 | "-s", "--share", action="store_true", 44 | help="Share service to the internet.", 45 | ) 46 | parser.add_argument( 47 | "-p", "--port", type=int, default=12123, 48 | help="Port to run the service, default to 12123." 49 | ) 50 | return parser 51 | 52 | 53 | def build_model_by_args(args) -> token_visualizer.TopkTokenModel: 54 | BASE_URL = ensure_os_env("BASE_URL") 55 | OPENAI_API_KEY = ensure_os_env("OPENAI_KEY") 56 | TGI_URL = ensure_os_env("TGI_URL") 57 | 58 | model: Optional[token_visualizer.TopkTokenModel] = None 59 | 60 | if args.type == "llm": 61 | model = token_visualizer.TransformerModel(repo=args.hf_repo) 62 | elif args.type == "tgi": 63 | if args.tgi_url: 64 | TGI_URL = args.tgi_url 65 | model = token_visualizer.TGIModel(url=TGI_URL, details=True) 66 | elif args.type == "oai": 67 | model = token_visualizer.OpenAIModel( 68 | base_url=BASE_URL, 69 | api_key=OPENAI_API_KEY, 70 | model_name=args.oai_model, 71 | ) 72 | elif args.type == "oai-proxy": 73 | model = token_visualizer.OpenAIProxyModel( 74 | base_url=BASE_URL, 75 | api_key=OPENAI_API_KEY, 76 | model_name="gpt-4-turbo-2024-04-09", 77 | ) 78 | else: 79 | raise ValueError(f"Unknown model type {args.type}") 80 | 81 | return model 82 | 83 | 84 | @logger.catch(reraise=True) 85 | def text_analysis( 86 | text: str, 87 | display_whitespace: bool, 88 | do_sample: bool, 89 | temperature: float, 90 | max_tokens: int, 91 | repetition_penalty: float, 92 | num_beams: int, 93 | topk: int, 94 | topp: float, 95 | topk_per_token: int, 96 | model: TopkTokenModel, # model should be built in the interface 97 | ) -> Tuple[str, str]: 98 | model.display_whitespace = display_whitespace 99 | model.do_sample = do_sample 100 | model.temperature = temperature 101 | model.max_tokens = max_tokens 102 | model.repetition_penalty = repetition_penalty 103 | model.num_beams = num_beams 104 | model.topk = topk 105 | model.topp = topp 106 | model.topk_per_token = topk_per_token 107 | 108 | tokens = model.generate_topk_per_token(text) 109 | html = model.html_to_visualize(tokens) 110 | 111 | html += "
" 112 | if isinstance(model, token_visualizer.TGIModel) and model.num_prefill_tokens: 113 | html += f"
input tokens: {model.num_prefill_tokens}
" 114 | html += f"
output tokens: {len(tokens)}
" 115 | 116 | return model.generated_answer, html 117 | 118 | 119 | def build_inference_analysis_demo(args): 120 | model = build_model_by_args(args) 121 | inference_func = functools.partial(text_analysis, model=model) 122 | 123 | interface = gr.Interface( 124 | inference_func, 125 | [ 126 | gr.TextArea(placeholder="Please input text here"), 127 | gr.Checkbox(value=False, label="display whitespace in output"), 128 | gr.Checkbox(value=True, label="do_sample"), 129 | gr.Slider(minimum=0, maximum=1, step=0.05, value=1.0, label="temperature"), 130 | gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="max tokens"), 131 | gr.Slider(minimum=1, maximum=2, step=0.1, value=1.0, label="repetition penalty"), 132 | gr.Slider(minimum=1, maximum=10, step=1, value=1, label="num beams"), 133 | gr.Slider(minimum=1, maximum=100, step=1, value=50, label="topk"), 134 | gr.Slider(minimum=0, maximum=1, step=0.05, value=1.0, label="topp"), 135 | gr.Slider(minimum=1, maximum=10, step=1, value=5, label="per-token topk"), 136 | ], 137 | [ 138 | gr.TextArea(label="LLM answer"), 139 | "html", 140 | ], 141 | examples=[ 142 | ["Who are Hannah Quinlivan's child?"], 143 | ["Write python code to read a file and print its content."], 144 | ], 145 | title="LLM inference analysis", 146 | ) 147 | return interface 148 | 149 | 150 | @logger.catch(reraise=True) 151 | def ppl_from_model( 152 | text: str, 153 | url: str, 154 | bos: str, 155 | eos: str, 156 | display_whitespace: bool, 157 | model, 158 | ) -> str: 159 | """Generate PPL visualization from model. 160 | 161 | Args: 162 | text (str): input text to visualize. 163 | url (str): tgi url. 164 | bos (str): begin of sentence token. 165 | eos (str): end of sentence token. 166 | display_whitespace (bool): whether to display whitespace for output text. 167 | If set to True, whitespace will be displayed as "␣". 168 | """ 169 | url = url.strip() 170 | assert url, f"Please provide url of your tgi model. Current url: {url}" 171 | logger.info(f"Set url to {url}") 172 | model.url = url 173 | model.display_whitespace = display_whitespace 174 | model.max_tokens = 1 175 | 176 | text = bos + text + eos 177 | tokens = model.generate_inputs_prob(text) 178 | html = model.html_to_visualize(tokens) 179 | 180 | # display tokens and ppl at the end 181 | html += "
" 182 | html += f"
total tokens: {len(tokens)}
" 183 | ppl = tokens[-1].ppl 184 | html += f"
ppl: {ppl:.4f}
" 185 | return html 186 | 187 | 188 | def build_ppl_visualizer_demo(args): 189 | model = build_model_by_args(args) 190 | ppl_func = functools.partial(ppl_from_model, model=model) 191 | 192 | ppl_interface = gr.Interface( 193 | ppl_func, 194 | [ 195 | gr.TextArea(placeholder="Please input text to visualize here"), 196 | gr.TextArea( 197 | placeholder="Please input tgi url here (Error if not provided)", 198 | lines=1, 199 | ), 200 | gr.TextArea(placeholder="BOS token, default to empty string", lines=1), 201 | gr.TextArea(placeholder="EOS token, default to empty string", lines=1), 202 | gr.Checkbox(value=False, label="display whitespace in output, default to False"), 203 | ], 204 | "html", 205 | title="PPL Visualizer", 206 | ) 207 | return ppl_interface 208 | 209 | 210 | def demo(): 211 | args = make_parser().parse_args() 212 | logger.info(f"Args: {args}") 213 | 214 | demo = gr.Blocks(css=css_style()) 215 | with demo: 216 | with gr.Tab("Inference"): 217 | build_inference_analysis_demo(args) 218 | with gr.Tab("PPL"): 219 | build_ppl_visualizer_demo(args) 220 | 221 | demo.launch( 222 | server_name="0.0.0.0", 223 | share=args.share, 224 | server_port=args.port, 225 | ) 226 | 227 | 228 | if __name__ == "__main__": 229 | demo() 230 | --------------------------------------------------------------------------------