├── moe_visualizer ├── __init__.py └── plot_histogram.py ├── requirements.txt ├── images └── demo.jpg ├── LICENSE ├── README.md ├── .gitignore └── qwen1_5_moe.py /moe_visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | # 如何在推理过程中可视化MoE模型的expert的输出 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | torch 3 | transformers 4 | accelerate 5 | plotly 6 | -------------------------------------------------------------------------------- /images/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenZiHong-Gavin/MoE-Visualizer/HEAD/images/demo.jpg -------------------------------------------------------------------------------- /moe_visualizer/plot_histogram.py: -------------------------------------------------------------------------------- 1 | import plotly.express as px 2 | import pandas as pd 3 | 4 | def plot_histogram(expert_counts: dict): 5 | if not isinstance(expert_counts, dict): 6 | raise ValueError("expert_counts must be a dictionary") 7 | data = [] 8 | for layer_idx, counts in expert_counts.items(): 9 | for expert in counts: 10 | data.append({ 11 | "layer_idx": layer_idx, 12 | "expert": expert, 13 | "count": counts[expert] 14 | }) 15 | df = pd.DataFrame(data) 16 | 17 | max_layer_idx = int(df["layer_idx"].max()) 18 | max_expert = int(df["expert"].max()) 19 | 20 | fig = px.density_heatmap(df, x="expert", y="layer_idx", z="count", nbinsx=max_expert+1, nbinsy=max_layer_idx+1, histfunc="sum") 21 | return fig 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 chenzihong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MoE-Visualizer 2 | 3 |  4 | 5 | ## Introduction 6 | 7 | This project is a visualizer for Mixture of Experts (MoE) models. We aim to provide a visual tool to help users understand the usage of experts in MoE models. 8 | 9 | We designed a hook that can be mounted on a specific layer of the MoE model, which records which experts are used for each sample during inference. Ultimately, this allows us to count the usage of each expert. 10 | 11 | Therefore, this is a plug-and-play module that can be used with any MoE model, with [Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B) provided as an example. 12 | 13 | ## What we have done 14 | - [x] Visualize the usage of experts in prefill and generate phase 15 | - [x] Support batch processing 16 | - [x] Support downloading data 17 | 18 | ## Models we support 19 | - [x] Support [Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B) 20 | 21 | ## How to use 22 | 23 | ### Step 1: Install the package 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### Step 2: Run the demo 29 | ```bash 30 | python qwen1_5_moe.py 31 | ``` 32 | 33 | 34 | If this project helps you, please give us a star. 🌟 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | .idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /qwen1_5_moe.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/QwenLM/Qwen2.5/blob/main/examples/demo/web_demo.py 2 | 3 | from argparse import ArgumentParser 4 | from threading import Thread 5 | 6 | import gradio as gr 7 | import torch 8 | import json 9 | import tempfile 10 | import torch.nn.functional as F 11 | from collections import defaultdict 12 | from tqdm import tqdm 13 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer 14 | 15 | from moe_visualizer.plot_histogram import plot_histogram 16 | 17 | DEFAULT_CKPT_PATH = "Qwen/Qwen1.5-MoE-A2.7B-Chat" 18 | 19 | def _get_args(): 20 | parser = ArgumentParser(description="Qwen1.5-MoE Visualizer Demo") 21 | parser.add_argument( 22 | "-c", 23 | "--checkpoint-path", 24 | type=str, 25 | default=DEFAULT_CKPT_PATH, 26 | help="Checkpoint name or path, default to %(default)r", 27 | ) 28 | parser.add_argument( 29 | "--cpu-only", action="store_true", help="Run demo with CPU only" 30 | ) 31 | parser.add_argument( 32 | "--share", 33 | action="store_true", 34 | default=False, 35 | help="Create a publicly shareable link for the interface.", 36 | ) 37 | parser.add_argument( 38 | "--inbrowser", 39 | action="store_true", 40 | default=False, 41 | help="Automatically launch the interface in a new tab on the default browser.", 42 | ) 43 | parser.add_argument( 44 | "--server-port", type=int, default=8000, help="Demo server port." 45 | ) 46 | parser.add_argument( 47 | "--server-name", type=str, default="127.0.0.1", help="Demo server name." 48 | ) 49 | 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | class ExpertActivationTracker: 55 | def __init__(self): 56 | self.activations = defaultdict(list) 57 | 58 | def add_activation(self, layer_idx, activation): 59 | self.activations[layer_idx].append(activation) 60 | 61 | def clear(self): 62 | self.activations.clear() 63 | 64 | 65 | expert_tracker = ExpertActivationTracker() 66 | 67 | 68 | @torch.no_grad() 69 | def moe_activation_hook_factory(layer_idx: int): 70 | def hook_fn(module, input, output): 71 | hidden_states = input[0] 72 | _, _, hidden_dim = hidden_states.shape 73 | hidden_states = hidden_states.view(-1, hidden_dim) 74 | 75 | router_logits = module.gate(hidden_states) 76 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 77 | _, selected_experts = torch.topk(routing_weights, module.top_k, dim=-1) # (batch_size, num_tokens, top_k) 78 | 79 | expert_tracker.add_activation(layer_idx, selected_experts.cpu().numpy()) 80 | 81 | return hook_fn 82 | 83 | 84 | def _register_hooks(model): 85 | hooks = [] 86 | for layer_idx, layer in enumerate(model.model.layers): 87 | if hasattr(layer.mlp, "gate"): 88 | hook = layer.mlp.register_forward_hook(moe_activation_hook_factory(layer_idx)) 89 | hooks.append(hook) 90 | return hooks 91 | 92 | 93 | def count_expert_activations(): 94 | prefill_expert_counts = defaultdict(lambda: defaultdict(int)) 95 | generate_expert_counts = defaultdict(lambda: defaultdict(int)) 96 | 97 | for layer_idx, activations in expert_tracker.activations.items(): 98 | for i, arr in enumerate(activations): 99 | if arr.shape[0] > 1: 100 | counts = prefill_expert_counts[layer_idx] 101 | else: 102 | counts = generate_expert_counts[layer_idx] 103 | for token in arr.flatten(): 104 | counts[int(token)] += 1 105 | return prefill_expert_counts, generate_expert_counts 106 | 107 | 108 | def prepare_data(): 109 | prefill_expert_counts, generate_expert_counts = count_expert_activations() 110 | 111 | data = { 112 | "prefill_expert_counts": prefill_expert_counts, 113 | "generate_expert_counts": generate_expert_counts 114 | } 115 | 116 | with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: 117 | json.dump(data, f) 118 | return f.name 119 | 120 | 121 | def generate_plots(): 122 | prefill_expert_counts, generate_expert_counts = count_expert_activations() 123 | 124 | prefill_fig = plot_histogram(prefill_expert_counts) 125 | generate_fig = plot_histogram(generate_expert_counts) 126 | return prefill_fig, generate_fig 127 | 128 | 129 | def _load_model_tokenizer(args): 130 | tokenizer = AutoTokenizer.from_pretrained( 131 | args.checkpoint_path, 132 | resume_download=True, 133 | ) 134 | 135 | if args.cpu_only: 136 | device_map = "cpu" 137 | else: 138 | device_map = "auto" 139 | 140 | model = AutoModelForCausalLM.from_pretrained( 141 | args.checkpoint_path, 142 | torch_dtype="auto", 143 | device_map=device_map, 144 | resume_download=True, 145 | trust_remote_code=True 146 | ).eval() 147 | 148 | hooks = _register_hooks(model) 149 | model.hooks = hooks 150 | 151 | model.generation_config.max_new_tokens = 2048 152 | 153 | return model, tokenizer 154 | 155 | 156 | def _chat_stream(model, tokenizer, query, history): 157 | conversation = [] 158 | for query_h, response_h in history: 159 | conversation.append({"role": "user", "content": query_h}) 160 | conversation.append({"role": "assistant", "content": response_h}) 161 | conversation.append({"role": "user", "content": query}) 162 | input_text = tokenizer.apply_chat_template( 163 | conversation, 164 | add_generation_prompt=True, 165 | tokenize=False, 166 | ) 167 | 168 | inputs = tokenizer([input_text], return_tensors="pt").to(model.device) 169 | 170 | streamer = TextIteratorStreamer( 171 | tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True 172 | ) 173 | generation_kwargs = { 174 | **inputs, 175 | "streamer": streamer, 176 | } 177 | thread = Thread(target=model.generate, kwargs=generation_kwargs) 178 | thread.start() 179 | 180 | for new_text in streamer: 181 | yield new_text 182 | 183 | 184 | def _gc(): 185 | import gc 186 | 187 | gc.collect() 188 | if torch.cuda.is_available(): 189 | torch.cuda.empty_cache() 190 | 191 | 192 | def _launch_demo(args, model, tokenizer): 193 | def predict(_query, _chatbot): 194 | reset_state(_chatbot) 195 | 196 | print(f"User: {_query}") 197 | _chatbot.append({"role": "user", "content": _query}) 198 | _chatbot.append({"role": "assistant", "content": ""}) 199 | full_response = "" 200 | response = "" 201 | for new_text in _chat_stream(model, tokenizer, _query, history=[]): 202 | response += new_text 203 | _chatbot[-1] = {"role": "assistant", "content": response} 204 | yield _chatbot 205 | full_response = response 206 | 207 | print(f"Qwen: {full_response}") 208 | 209 | def process_batch(batch_file, progress=gr.Progress()): 210 | if not batch_file: 211 | raise gr.Error("No file uploaded") 212 | 213 | try: 214 | questions = [] 215 | with open(batch_file, "r") as f: 216 | data = json.load(f) 217 | for item in data: 218 | if "question" in item: 219 | questions.append(item["question"]) 220 | except Exception as e: 221 | raise gr.Error(f"Failed to parse file: {e}") 222 | 223 | if not questions: 224 | raise gr.Error("No question found in the file") 225 | 226 | progress(0, "Processing...") 227 | for i, question in enumerate(tqdm(questions)): 228 | for _ in _chat_stream(model, tokenizer, question, history=[]): 229 | pass 230 | progress((i + 1) / len(questions)) 231 | 232 | def reset_user_input(): 233 | return gr.update(value="") 234 | 235 | def reset_state(_chatbot): 236 | _chatbot.clear() 237 | expert_tracker.clear() 238 | _gc() 239 | return _chatbot 240 | 241 | with gr.Blocks() as demo: 242 | gr.HTML("