├── figures ├── TextBench.png ├── MiniMaxLogo-Dark.png └── MiniMaxLogo-Light.png ├── MiniMax_M1_tech_report.pdf ├── .github └── ISSUE_TEMPLATE │ ├── Feature request.yml │ ├── Model Inquiry.yml │ ├── Bug Report for MCP.yml │ └── Bad case about the model.yml ├── tokenizer_config.json ├── config.json ├── docs ├── transformers_deployment_guide_cn.md ├── transformers_deployment_guide.md ├── transformers_deployment_guide_pt-br.md ├── vllm_deployment_guide_cn.md ├── vllm_deployment_guide.md ├── vllm_deployment_guide_pt-br.md ├── function_call_guide_cn.md ├── function_call_guide.md └── function_call_guide_pt-br.md ├── main.py ├── configuration_minimax_m1.py ├── LICENSE ├── README.md └── modeling_minimax_m1.py /figures/TextBench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniMax-AI/MiniMax-M1/HEAD/figures/TextBench.png -------------------------------------------------------------------------------- /MiniMax_M1_tech_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniMax-AI/MiniMax-M1/HEAD/MiniMax_M1_tech_report.pdf -------------------------------------------------------------------------------- /figures/MiniMaxLogo-Dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniMax-AI/MiniMax-M1/HEAD/figures/MiniMaxLogo-Dark.png -------------------------------------------------------------------------------- /figures/MiniMaxLogo-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniMax-AI/MiniMax-M1/HEAD/figures/MiniMaxLogo-Light.png -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Feature request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Propose a new feature or enhancement for the project. 3 | title: "[request]: " 4 | labels: ["enhancement", "feature-request", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for suggesting a new feature! Please provide the following details to help us understand your proposal. 10 | 11 | - type: input 12 | id: feature-about 13 | attributes: 14 | label: Basic Information - Feature about 15 | description: "Please briefly describe the feature, including the type of use and the framework, e.g., support Minimax-M1 in Ollama." 16 | placeholder: "e.g., support Minimax-M1 in Ollama." 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: proposal 22 | attributes: 23 | label: Proposal 24 | description: | 25 | Please describe the feature you have requested and the rationale behind it. 26 | The following template is recommended. Feel free to modify it as you needed. 27 | value: | 28 | #### Introduction 29 | I would like that ... 30 | 31 | #### Rational 32 | Implementation of this feature will help the following usecase: 33 | - ... 34 | - ... 35 | 36 | #### Anything else 37 | I find ... has this feature and xxx can serve as a reference for implementation. 38 | validations: 39 | required: true 40 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Model Inquiry.yml: -------------------------------------------------------------------------------- 1 | name: Model Inquiry 2 | description: Ask a question about the open source models. 3 | title: "[Inquiry]: " 4 | labels: ["question", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for reaching out! Please provide the following details to help us understand and address your inquiry about models. 10 | 11 | - type: input 12 | attributes: 13 | label: Basic Information - Models Used 14 | description: | 15 | Please list the model used, e.g., MiniMax-M1, speech-02-hd, etc. 16 | Our models can be referred at [HuggingFace](https://huggingface.co/MiniMaxAI) or [the official site](https://www.minimax.io/platform_overview). 17 | placeholder: "ex: MiniMax-M1" 18 | validations: 19 | required: true 20 | 21 | - type: checkboxes 22 | id: problem-validation 23 | attributes: 24 | label: Is this information known and solvable? 25 | options: 26 | - label: "I have checked [Minimax documentation](https://www.minimax.io/platform_overview) and found no solution." 27 | required: true 28 | - label: "I have searched existing issues and found no duplicates." 29 | required: true 30 | 31 | 32 | - type: textarea 33 | id: detailed-description 34 | attributes: 35 | label: Description 36 | description: "Please describe your question in detail here. If available, please paste relevant screenshots directly into this box." 37 | placeholder: | 38 | - Your detailed question or issue description. 39 | - Relevant context or background information. 40 | - (Paste screenshots directly below this text) 41 | validations: 42 | required: true 43 | -------------------------------------------------------------------------------- /tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": "", 4 | "clean_up_tokenization_spaces": false, 5 | "eos_token": "", 6 | "model_max_length": 40960000, 7 | "tokenizer_class": "GPT2Tokenizer", 8 | "unk_token": "", 9 | "chat_template": "{{ '' -}}{% set ns = namespace(system_prompt='') -%}{% for message in messages -%}{% if message['role'] == 'system' -%}{% set ns.system_prompt = ns.system_prompt + message['content'][0]['text'] -%}{% endif -%}{%- endfor -%}{% if ns.system_prompt != '' -%}{{ 'system ai_setting=assistant\n' + ns.system_prompt + '\n' -}}{%- endif -%}{% if tools -%}{{ 'system tool_setting=tools\nYou are provided with these tools:\n\n' -}}{% for tool in tools -%}{{ tool | tojson ~ '\n' -}}{%- endfor -%}{{ '\n\nIf you need to call tools, please respond with XML tags, and provide tool-name and json-object of arguments, following the format below:\n\n{''name'': , ''arguments'': }\n...\n\n' -}}{%- endif -%}{% for message in messages -%}{% if message['role'] == 'user' -%}{{ 'user name=user\n' + message['content'][0]['text'] + '\n' -}}{% elif message['role'] == 'assistant' -%}{{ 'ai name=assistant\n' -}}{% for content in message['content'] | selectattr('type', 'equalto', 'text') -%}{{ content['text'] -}}{%- endfor -%}{{ '\n' -}}{% elif message['role'] == 'tool' -%}{{ 'tool name=tools\n' }} {%- for content in message['content'] -%}{{- 'tool name: ' + content['name'] + '\n' + 'tool result: ' + content['text'] + '\n\n' -}} {%- endfor -%}{{- '\n' -}}{% endif -%}{%- endfor -%}{% if add_generation_prompt -%}{{ 'ai name=assistant\n' -}}{%- endif -%}" 10 | } 11 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MiniMaxM1ForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "attn_type_list": [ 7 | 0, 8 | 0, 9 | 0, 10 | 0, 11 | 0, 12 | 0, 13 | 0, 14 | 1, 15 | 0, 16 | 0, 17 | 0, 18 | 0, 19 | 0, 20 | 0, 21 | 0, 22 | 1, 23 | 0, 24 | 0, 25 | 0, 26 | 0, 27 | 0, 28 | 0, 29 | 0, 30 | 1, 31 | 0, 32 | 0, 33 | 0, 34 | 0, 35 | 0, 36 | 0, 37 | 0, 38 | 1, 39 | 0, 40 | 0, 41 | 0, 42 | 0, 43 | 0, 44 | 0, 45 | 0, 46 | 1, 47 | 0, 48 | 0, 49 | 0, 50 | 0, 51 | 0, 52 | 0, 53 | 0, 54 | 1, 55 | 0, 56 | 0, 57 | 0, 58 | 0, 59 | 0, 60 | 0, 61 | 0, 62 | 1, 63 | 0, 64 | 0, 65 | 0, 66 | 0, 67 | 0, 68 | 0, 69 | 0, 70 | 1, 71 | 0, 72 | 0, 73 | 0, 74 | 0, 75 | 0, 76 | 0, 77 | 0, 78 | 1, 79 | 0, 80 | 0, 81 | 0, 82 | 0, 83 | 0, 84 | 0, 85 | 0, 86 | 1 87 | ], 88 | "auto_map": { 89 | "AutoConfig": "configuration_minimax_m1.MiniMaxM1Config", 90 | "AutoModelForCausalLM": "modeling_minimax_m1.MiniMaxM1ForCausalLM" 91 | }, 92 | "bos_token_id": null, 93 | "eos_token_id": null, 94 | "head_dim": 128, 95 | "hidden_act": "silu", 96 | "hidden_size": 6144, 97 | "initializer_range": 0.02, 98 | "intermediate_size": 9216, 99 | "layernorm_full_attention_alpha": 3.5565588200778455, 100 | "layernorm_full_attention_beta": 1.0, 101 | "layernorm_linear_attention_alpha": 3.5565588200778455, 102 | "layernorm_linear_attention_beta": 1.0, 103 | "layernorm_mlp_alpha": 3.5565588200778455, 104 | "layernorm_mlp_beta": 1.0, 105 | "max_position_embeddings": 10240000, 106 | "model_type": "minimax_m1", 107 | "num_attention_heads": 64, 108 | "num_experts_per_tok": 2, 109 | "num_hidden_layers": 80, 110 | "num_key_value_heads": 8, 111 | "num_local_experts": 32, 112 | "output_router_logits": false, 113 | "postnorm": true, 114 | "rms_norm_eps": 1e-05, 115 | "rope_theta": 10000000, 116 | "rotary_dim": 64, 117 | "router_aux_loss_coef": 0.001, 118 | "router_jitter_noise": 0.0, 119 | "shared_intermediate_size": 0, 120 | "shared_moe_mode": "sigmoid", 121 | "sliding_window": null, 122 | "tie_word_embeddings": false, 123 | "transformers_version": "4.45.2", 124 | "use_cache": true, 125 | "vocab_size": 200064 126 | } 127 | 128 | -------------------------------------------------------------------------------- /docs/transformers_deployment_guide_cn.md: -------------------------------------------------------------------------------- 1 | # 🚀 MiniMax 模型 Transformers 部署指南 2 | 3 | ## 📖 简介 4 | 5 | 本指南将帮助您使用 [Transformers](https://huggingface.co/docs/transformers/index) 库部署 MiniMax-M1 模型。Transformers 是一个广泛使用的深度学习库,提供了丰富的预训练模型和灵活的模型操作接口。 6 | 7 | ## 🛠️ 环境准备 8 | 9 | ### 安装 Transformers 10 | 11 | ```bash 12 | pip install transformers torch accelerate 13 | ``` 14 | 15 | ## 📋 基本使用示例 16 | 17 | 预训练模型可以按照以下方式使用: 18 | 19 | ```python 20 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 21 | 22 | MODEL_PATH = "{MODEL_PATH}" 23 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True) 24 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 25 | 26 | messages = [ 27 | {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]}, 28 | {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]}, 29 | {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]} 30 | ] 31 | 32 | text = tokenizer.apply_chat_template( 33 | messages, 34 | tokenize=False, 35 | add_generation_prompt=True 36 | ) 37 | 38 | model_inputs = tokenizer(text, return_tensors="pt").to(model.device) 39 | 40 | generation_config = GenerationConfig( 41 | max_new_tokens=20, 42 | eos_token_id=tokenizer.eos_token_id, 43 | use_cache=True, 44 | ) 45 | 46 | generated_ids = model.generate(**model_inputs, generation_config=generation_config) 47 | 48 | generated_ids = [ 49 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 50 | ] 51 | 52 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 53 | print(response) 54 | ``` 55 | 56 | ## ⚡ 性能优化 57 | 58 | ### 使用 Flash Attention 加速 59 | 60 | 上面的代码片段展示了不使用任何优化技巧的推理过程。但通过利用 [Flash Attention](../perf_train_gpu_one#flash-attention-2),可以大幅加速模型,因为它提供了模型内部使用的注意力机制的更快实现。 61 | 62 | 首先,确保安装最新版本的 Flash Attention 2: 63 | 64 | ```bash 65 | pip install -U flash-attn --no-build-isolation 66 | ``` 67 | 68 | 还要确保您拥有与 Flash-Attention 2 兼容的硬件。在[Flash Attention 官方仓库](https://github.com/Dao-AILab/flash-attention)的官方文档中了解更多信息。此外,请确保以半精度(例如 `torch.float16`)加载模型。 69 | 70 | 要使用 Flash Attention-2 加载和运行模型,请参考以下代码片段: 71 | 72 | ```python 73 | import torch 74 | from transformers import AutoModelForCausalLM, AutoTokenizer 75 | 76 | MODEL_PATH = "{MODEL_PATH}" 77 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto") 78 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 79 | 80 | prompt = "My favourite condiment is" 81 | 82 | model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") 83 | generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) 84 | response = tokenizer.batch_decode(generated_ids)[0] 85 | print(response) 86 | ``` 87 | 88 | ## 📮 获取支持 89 | 90 | 如果您在部署 MiniMax-M1 模型过程中遇到任何问题: 91 | - 请查看我们的官方文档 92 | - 通过官方渠道联系我们的技术支持团队 93 | - 在我们的 GitHub 仓库提交 Issue 94 | 95 | 我们会持续优化 Transformers 上的部署体验,欢迎您的反馈! 96 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Bug Report for MCP.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report for MCP&API 2 | description: Report a bug related to MCP and API tasks to help us reproduce and fix the problem. 3 | title: "[Bug for MCP&API]: " 4 | labels: ["bug", "triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for contributing to our project by reporting a bug! To help us understand and resolve the issue as quickly as possible, please provide the following details. 10 | 11 | - type: input 12 | attributes: 13 | label: Basic Information - Models Used 14 | description: | 15 | Please list the model used, e.g., MiniMax-M1, speech-02-hd, etc. 16 | Our models can be referred at [HuggingFace](https://huggingface.co/MiniMaxAI) or [the official site](https://www.minimax.io/platform_overview). 17 | placeholder: "ex: MiniMax-M1" 18 | validations: 19 | required: true 20 | 21 | - type: input 22 | id: scenario-description 23 | attributes: 24 | label: Basic Information - Scenario Description 25 | description: | 26 | Please briefly describe the scenario, including the framework or the platform, 27 | placeholder: "ex: Minimax-M1 cannot be called as MCP tools. " 28 | validations: 29 | required: false 30 | 31 | - type: checkboxes 32 | id: problem-validation 33 | attributes: 34 | label: Is this bug known and solvable? 35 | options: 36 | - label: "I have followed the GitHub READMEs for [`Minimax-MCP`](https://github.com/MiniMax-AI/MiniMax-MCP) and [`Minimax-MCP-JS`](https://github.com/MiniMax-AI/MiniMax-MCP-JS)." 37 | required: true 38 | - label: "I have checked the [official Minimax documentation](https://www.minimax.io/platform_overview) and [existing GitHub issues](https://github.com/MiniMax-AI/MiniMax-MCP/issues),but found no solution." 39 | required: true 40 | 41 | - type: textarea 42 | attributes: 43 | label: Information about environment 44 | description: | 45 | Please provide information about you environment, 46 | e.g., the software versions and the information on the OS, GPUs, python packages(from pip list) if available. 47 | placeholder: 48 | "For example: 49 | - OS: Ubuntu 24.04 50 | - Python: Python 3.11 51 | - PyTorch: 2.6.0+cu124" 52 | 53 | validations: 54 | required: true 55 | 56 | - type: input 57 | id: trace-id 58 | attributes: 59 | label: Trace-ID in the request head 60 | description: "Please copy and paste the trace-ID of the problematic request." 61 | validations: 62 | required: true 63 | 64 | - type: textarea 65 | attributes: 66 | label: Description 67 | description: | 68 | Please **describe the bug** you have encountered when using the MCP tools or API, and **paste the screenshots** of the error or unexpected behaviour here. 69 | The following template is recommended. 70 | Feel free to modify as you needed. 71 | value: | 72 | #### Steps to reproduce 73 | 74 | This happens to Minimax_M1 and xxx. 75 | The bug can be reproduced with the following steps: 76 | 1. ... 77 | 2. ... 78 | 79 | The following example input & output can be used: 80 | ``` 81 | system: ... 82 | user: ... 83 | ... 84 | ``` 85 | 86 | #### Expected results 87 | 88 | The results are expected to be ... 89 | 90 | #### Actual behaviours 91 | 92 | The actual outputs are as follows: ... 93 | 94 | #### Error logs 95 | 96 | The error logs are as follows: ... 97 | 98 | ### The screenshots are as belows: 99 | validations: 100 | required: true 101 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Bad case about the model.yml: -------------------------------------------------------------------------------- 1 | name: Bad Case Report of the model 2 | description: Report a bug related to the model to help us reproduce and fix the problem. 3 | title: "[BadCase about the model]: " 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for contributing to our project by reporting a bad case! To help us understand and resolve the issue as quickly as possible, please provide the following details. 10 | 11 | - type: input 12 | id: models-used 13 | attributes: 14 | label: Basic Information - Models Used 15 | description: | 16 | Please list the model used, e.g., MiniMax-M1, speech-02-hd, etc. 17 | (Note: You can refer to our models at [HuggingFace](https://huggingface.co/MiniMaxAI) or [the official site](https://www.minimax.io/platform_overview) for more details.) 18 | placeholder: "ex: MiniMax-M1" 19 | validations: 20 | required: true 21 | 22 | - type: input 23 | id: scenario-description 24 | attributes: 25 | label: Basic Information - Scenario Description 26 | description: | 27 | Please briefly describe the scenario, including the framework or the platform. 28 | placeholder: "ex: Minimax-M1 return the error related to xxx." 29 | validations: 30 | required: false 31 | 32 | - type: checkboxes 33 | id: problem-validation 34 | attributes: 35 | label: Is this badcase known and solvable? 36 | options: 37 | - label: "I have followed the [GitHub README](https://github.com/MiniMax-AI) of the model and found no duplicates in existing issues." 38 | required: true 39 | - label: "I have checked [Minimax documentation](https://www.minimax.io/platform_overview) and found no solution." 40 | required: true 41 | 42 | - type: textarea 43 | id: environment-info 44 | attributes: 45 | label: Information about environment 46 | description: | 47 | (Include software versions, OS, GPUs if applicable) 48 | placeholder: | 49 | For example: 50 | - OS: Ubuntu 24.04 51 | - Python: Python 3.11 52 | - PyTorch: 2.6.0+cu124 53 | validations: 54 | required: true 55 | 56 | - type: textarea 57 | id: call-execution-info # Consolidated field for call type and details 58 | attributes: 59 | label: Call & Execution Information 60 | description: | 61 | Please describe how you are interacting with the model and provide the relevant details in the box below: 62 | **Call Type**: (e.g., API Call, Deployment Call) 63 | **If API Call**: Please provide the `trace-ID` of the problematic request. 64 | **If Deployment Call**: Please provide the command used for deployment or inference. 65 | placeholder: | 66 | # Example for API Call: 67 | Call Type: API Call 68 | Trace-ID: abcdef1234567890 69 | 70 | # Example for Deployment Call: 71 | Call Type: Deployment Call 72 | Deployment Command: python run_inference.py --model my_model --config config.yaml 73 | validations: 74 | required: true 75 | 76 | - type: textarea 77 | id: description-of-bug 78 | attributes: 79 | label: Description 80 | description: | 81 | Please **describe the bad case** you have encountered and **paste the screenshots** if available. 82 | The following template is recommended (modify as needed): 83 | value: | 84 | ### Steps to reproduce 85 | The bug can be reproduced with the following steps: 86 | 1. ... 87 | 2. ... 88 | 89 | ### Expected behavior 90 | The results are expected to be: ... 91 | 92 | ### Actual behavior 93 | The actual outputs are as follows: ... 94 | 95 | ### Error logs 96 | The error logs are as follows: 97 | ``` 98 | # Paste the related screenshots here 99 | ``` 100 | validations: 101 | required: true 102 | -------------------------------------------------------------------------------- /docs/transformers_deployment_guide.md: -------------------------------------------------------------------------------- 1 | # 🚀 MiniMax Model Transformers Deployment Guide 2 | 3 | [Transformers中文版部署指南](./transformers_deployment_guide_cn.md) 4 | 5 | ## 📖 Introduction 6 | 7 | This guide will help you deploy the MiniMax-M1 model using the [Transformers](https://huggingface.co/docs/transformers/index) library. Transformers is a widely used deep learning library that provides a rich collection of pre-trained models and flexible model operation interfaces. 8 | 9 | ## 🛠️ Environment Setup 10 | 11 | ### Installing Transformers 12 | 13 | ```bash 14 | pip install transformers torch accelerate 15 | ``` 16 | 17 | ## 📋 Basic Usage Example 18 | 19 | The pre-trained model can be used as follows: 20 | 21 | ```python 22 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 23 | 24 | MODEL_PATH = "{MODEL_PATH}" 25 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True) 26 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 27 | 28 | messages = [ 29 | {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]}, 30 | {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]}, 31 | {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]} 32 | ] 33 | 34 | text = tokenizer.apply_chat_template( 35 | messages, 36 | tokenize=False, 37 | add_generation_prompt=True 38 | ) 39 | 40 | model_inputs = tokenizer(text, return_tensors="pt").to(model.device) 41 | 42 | generation_config = GenerationConfig( 43 | max_new_tokens=20, 44 | eos_token_id=tokenizer.eos_token_id, 45 | use_cache=True, 46 | ) 47 | 48 | generated_ids = model.generate(**model_inputs, generation_config=generation_config) 49 | 50 | generated_ids = [ 51 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 52 | ] 53 | 54 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 55 | print(response) 56 | ``` 57 | 58 | ## ⚡ Performance Optimization 59 | 60 | ### Speeding up with Flash Attention 61 | 62 | The code snippet above showcases inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model. 63 | 64 | First, make sure to install the latest version of Flash Attention 2: 65 | 66 | ```bash 67 | pip install -U flash-attn --no-build-isolation 68 | ``` 69 | 70 | Also make sure that you have hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [Flash Attention repository](https://github.com/Dao-AILab/flash-attention). Additionally, ensure you load your model in half-precision (e.g. `torch.float16`). 71 | 72 | To load and run a model using Flash Attention-2, refer to the snippet below: 73 | 74 | ```python 75 | import torch 76 | from transformers import AutoModelForCausalLM, AutoTokenizer 77 | 78 | MODEL_PATH = "{MODEL_PATH}" 79 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto") 80 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 81 | 82 | prompt = "My favourite condiment is" 83 | 84 | model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") 85 | generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) 86 | response = tokenizer.batch_decode(generated_ids)[0] 87 | print(response) 88 | ``` 89 | 90 | ## 📮 Getting Support 91 | 92 | If you encounter any issues while deploying the MiniMax-M1 model: 93 | - Please check our official documentation 94 | - Contact our technical support team through official channels 95 | - Submit an Issue on our GitHub repository 96 | 97 | We continuously optimize the deployment experience on Transformers and welcome your feedback! 98 | -------------------------------------------------------------------------------- /docs/transformers_deployment_guide_pt-br.md: -------------------------------------------------------------------------------- 1 | # 🚀 Guia de Deploy do Modelo MiniMax com Transformers 2 | 3 | [Transformers中文版部署指南](./transformers_deployment_guide_cn.md) 4 | 5 | ## 📖 Introdução 6 | 7 | Este guia irá te ajudar a fazer o deploy do modelo MiniMax-M1 utilizando a biblioteca [Transformers](https://huggingface.co/docs/transformers/index). O Transformers é uma biblioteca de deep learning amplamente utilizada, que oferece uma vasta coleção de modelos pré-treinados e interfaces flexíveis para operação dos modelos. 8 | 9 | ## 🛠️ Configuração do Ambiente 10 | 11 | ### Instalando o Transformers 12 | 13 | ```bash 14 | pip install transformers torch accelerate 15 | ``` 16 | 17 | ## 📋 Exemplo de Uso Básico 18 | 19 | O modelo pré-treinado pode ser utilizado da seguinte maneira: 20 | 21 | ```python 22 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 23 | 24 | MODEL_PATH = "{MODEL_PATH}" 25 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True) 26 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 27 | 28 | messages = [ 29 | {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]}, 30 | {"role": "assistant", "content": [{"type": "text", "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}]}, 31 | {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]} 32 | ] 33 | 34 | text = tokenizer.apply_chat_template( 35 | messages, 36 | tokenize=False, 37 | add_generation_prompt=True 38 | ) 39 | 40 | model_inputs = tokenizer(text, return_tensors="pt").to(model.device) 41 | 42 | generation_config = GenerationConfig( 43 | max_new_tokens=20, 44 | eos_token_id=tokenizer.eos_token_id, 45 | use_cache=True, 46 | ) 47 | 48 | generated_ids = model.generate(**model_inputs, generation_config=generation_config) 49 | 50 | generated_ids = [ 51 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 52 | ] 53 | 54 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 55 | print(response) 56 | ``` 57 | 58 | ## ⚡ Otimização de Desempenho 59 | 60 | ### Acelerando com Flash Attention 61 | 62 | O exemplo acima mostra uma inferência sem nenhum tipo de otimização. No entanto, é possível acelerar significativamente o modelo utilizando [Flash Attention](../perf_train_gpu_one#flash-attention-2), que é uma implementação mais rápida do mecanismo de atenção usado no modelo. 63 | 64 | Primeiro, certifique-se de instalar a versão mais recente do Flash Attention 2: 65 | 66 | ```bash 67 | pip install -U flash-attn --no-build-isolation 68 | ``` 69 | 70 | Além disso, é necessário que seu hardware seja compatível com o Flash Attention 2. Consulte mais informações na documentação oficial do [repositório do Flash Attention](https://github.com/Dao-AILab/flash-attention). Também é recomendado carregar seu modelo em meia precisão (por exemplo, `torch.float16`). 71 | 72 | Para carregar e executar um modelo utilizando Flash Attention 2, utilize o exemplo abaixo: 73 | 74 | ```python 75 | import torch 76 | from transformers import AutoModelForCausalLM, AutoTokenizer 77 | 78 | MODEL_PATH = "{MODEL_PATH}" 79 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto") 80 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) 81 | 82 | prompt = "My favourite condiment is" 83 | 84 | model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") 85 | generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) 86 | response = tokenizer.batch_decode(generated_ids)[0] 87 | print(response) 88 | ``` 89 | 90 | ## 📮 Suporte 91 | 92 | Se você encontrar qualquer problema durante o deploy do modelo MiniMax-M1: 93 | 94 | * Verifique nossa documentação oficial 95 | * Entre em contato com nossa equipe de suporte técnico pelos canais oficiais 96 | * Abra uma Issue no nosso repositório no GitHub 97 | 98 | Estamos continuamente otimizando a experiência de deploy no Transformers e valorizamos muito seu feedback! 99 | -------------------------------------------------------------------------------- /docs/vllm_deployment_guide_cn.md: -------------------------------------------------------------------------------- 1 | # 🚀 MiniMax 模型 vLLM 部署指南 2 | 3 | ## 📖 简介 4 | 5 | 我们推荐使用 [vLLM](https://docs.vllm.ai/en/latest/) 来部署 [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) 模型。经过我们的测试,vLLM 在部署这个模型时表现出色,具有以下特点: 6 | 7 | - 🔥 卓越的服务吞吐量性能 8 | - ⚡ 高效智能的内存管理机制 9 | - 📦 强大的批量请求处理能力 10 | - ⚙️ 深度优化的底层性能 11 | 12 | MiniMax-M1 模型可在单台配备8个H800或8个H20 GPU的服务器上高效运行。在硬件配置方面,搭载8个H800 GPU的服务器可处理长达200万token的上下文输入,而配备8个H20 GPU的服务器则能够支持高达500万token的超长上下文处理能力。 13 | 14 | ## 💾 获取 MiniMax 模型 15 | 16 | ### MiniMax-M1 模型获取 17 | 18 | 您可以从我们的官方 HuggingFace 仓库下载模型:[MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k)、[MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k) 19 | 20 | 下载命令: 21 | ``` 22 | pip install -U huggingface-hub 23 | huggingface-cli download MiniMaxAI/MiniMax-M1-40k 24 | # huggingface-cli download MiniMaxAI/MiniMax-M1-80k 25 | 26 | # 如果遇到网络问题,可以设置代理 27 | export HF_ENDPOINT=https://hf-mirror.com 28 | ``` 29 | 30 | 或者使用 git 下载: 31 | 32 | ```bash 33 | git lfs install 34 | git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k 35 | git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k 36 | ``` 37 | 38 | ⚠️ **重要提示**:请确保系统已安装 [Git LFS](https://git-lfs.github.com/),这对于完整下载模型权重文件是必需的。 39 | 40 | ## 🛠️ 部署方案 41 | 42 | ### 方案:使用 Docker 部署(推荐) 43 | 44 | 为确保部署环境的一致性和稳定性,我们推荐使用 Docker 进行部署。 45 | 46 | ⚠️ **版本要求**: 47 | - 基础要求:vLLM 版本必须 ≥ 0.9.2,以确保对 MiniMax-M1 模型的完整支持 48 | - 特殊说明:如果使用低于 0.9.2 的 vLLM 版本,会遇见无法支持该模型或者精度不正确的情况: 49 | - 详情见:[Fix minimax model cache & lm_head precision #19592](https://github.com/vllm-project/vllm/pull/19592) 50 | 51 | 1. 获取容器镜像: 52 | 53 | 目前 vLLM 官方还未推出v0.9.2版本 docker,我们以 v0.8.3 为例子进行手动编译 vLLM: 54 | ```bash 55 | docker pull vllm/vllm-openai:v0.8.3 56 | ``` 57 | 58 | 2. 运行容器: 59 | ```bash 60 | # 设置环境变量 61 | IMAGE=vllm/vllm-openai:v0.8.3 62 | MODEL_DIR=<模型存放路径> 63 | CODE_DIR=<代码路径> 64 | NAME=MiniMaxImage 65 | 66 | # Docker运行配置 67 | DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864" 68 | 69 | # 启动容器 70 | sudo docker run -it \ 71 | -v $MODEL_DIR:$MODEL_DIR \ 72 | -v $CODE_DIR:$CODE_DIR \ 73 | --name $NAME \ 74 | $DOCKER_RUN_CMD \ 75 | $IMAGE /bin/bash 76 | 77 | # 编译 vLLM 78 | cd $CODE_DIR 79 | git clone https://github.com/vllm-project/vllm.git 80 | cd vllm 81 | pip install -e . 82 | ``` 83 | 84 | 💡 如果您使用其他环境配置,请参考 [vLLM 安装指南](https://docs.vllm.ai/en/latest/getting_started/installation.html) 85 | 86 | ## 🚀 启动服务 87 | 88 | ### 启动 MiniMax-M1 服务 89 | 90 | ```bash 91 | export SAFETENSORS_FAST_GPU=1 92 | export VLLM_USE_V1=0 93 | python3 -m vllm.entrypoints.openai.api_server \ 94 | --model <模型存放路径> \ 95 | --tensor-parallel-size 8 \ 96 | --trust-remote-code \ 97 | --quantization experts_int8 \ 98 | --max_model_len 4096 \ 99 | --dtype bfloat16 100 | ``` 101 | 102 | ### API 调用示例 103 | 104 | ```bash 105 | curl http://localhost:8000/v1/chat/completions \ 106 | -H "Content-Type: application/json" \ 107 | -d '{ 108 | "model": "MiniMaxAI/MiniMax-M1", 109 | "messages": [ 110 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, 111 | {"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]} 112 | ] 113 | }' 114 | ``` 115 | 116 | ## ❗ 常见问题 117 | 118 | ### 模块加载问题 119 | 如果遇到以下错误: 120 | ``` 121 | import vllm._C # noqa 122 | ModuleNotFoundError: No module named 'vllm._C' 123 | ``` 124 | 125 | 或 126 | 127 | ``` 128 | 当前并不支持 MiniMax-M1 模型 129 | ``` 130 | 131 | 我们提供两种解决方案: 132 | 133 | #### 解决方案一:复制依赖文件 134 | ```bash 135 | cd <工作目录> 136 | git clone https://github.com/vllm-project/vllm.git 137 | cd vllm 138 | cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm 139 | cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn 140 | ``` 141 | 142 | #### 解决方案二:从源码安装 143 | ```bash 144 | cd <工作目录> 145 | git clone https://github.com/vllm-project/vllm.git 146 | 147 | cd vllm/ 148 | pip install -e . 149 | ``` 150 | 151 | ## 📮 获取支持 152 | 153 | 如果您在部署 MiniMax-M1 模型过程中遇到任何问题: 154 | - 请查看我们的官方文档 155 | - 通过官方渠道联系我们的技术支持团队 156 | - 在我们的 GitHub 仓库提交 [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues) 157 | 158 | 我们会持续优化模型的部署体验,欢迎您的反馈! 159 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuantoConfig, GenerationConfig 2 | import torch 3 | import argparse 4 | 5 | """ 6 | usage: 7 | export SAFETENSORS_FAST_GPU=1 8 | python main.py --quant_type int8 --world_size 8 --model_id 9 | """ 10 | 11 | def generate_quanto_config(hf_config: AutoConfig, quant_type: str): 12 | QUANT_TYPE_MAP = { 13 | "default": None, 14 | "int8": QuantoConfig( 15 | weights="int8", 16 | modules_to_not_convert=[ 17 | "lm_head", 18 | "embed_tokens", 19 | ] + [f"model.layers.{i}.coefficient" for i in range(hf_config.num_hidden_layers)] 20 | + [f"model.layers.{i}.block_sparse_moe.gate" for i in range(hf_config.num_hidden_layers)] 21 | ), 22 | } 23 | return QUANT_TYPE_MAP[quant_type] 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--quant_type", type=str, default="default", choices=["default", "int8"]) 29 | parser.add_argument("--model_id", type=str, required=True) 30 | parser.add_argument("--world_size", type=int, required=True) 31 | return parser.parse_args() 32 | 33 | 34 | def check_params(args, hf_config: AutoConfig): 35 | if args.quant_type == "int8": 36 | assert args.world_size >= 8, "int8 weight-only quantization requires at least 8 GPUs" 37 | 38 | assert hf_config.num_hidden_layers % args.world_size == 0, f"num_hidden_layers({hf_config.num_hidden_layers}) must be divisible by world_size({args.world_size})" 39 | 40 | 41 | @torch.no_grad() 42 | def main(): 43 | args = parse_args() 44 | print("\n=============== Argument ===============") 45 | for key in vars(args): 46 | print(f"{key}: {vars(args)[key]}") 47 | print("========================================") 48 | 49 | model_id = args.model_id 50 | 51 | hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) 52 | check_params(args, hf_config) 53 | quantization_config = generate_quanto_config(hf_config, args.quant_type) 54 | 55 | device_map = { 56 | 'model.embed_tokens': 'cuda:0', 57 | 'model.norm': f'cuda:{args.world_size - 1}', 58 | 'lm_head': f'cuda:{args.world_size - 1}' 59 | } 60 | layers_per_device = hf_config.num_hidden_layers // args.world_size 61 | for i in range(args.world_size): 62 | for j in range(layers_per_device): 63 | device_map[f'model.layers.{i * layers_per_device + j}'] = f'cuda:{i}' 64 | 65 | tokenizer = AutoTokenizer.from_pretrained(model_id) 66 | message = [ 67 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, 68 | {"role": "user", "content": [{"type": "text", "text": "Hello, what is the weather today?"}]} 69 | ] 70 | tools = [ 71 | {"name": "get_location", "description": "Get the location of the user.", "parameters": {"type": "object", "properties": {}}}, 72 | {"name": "get_weather", "description": "Get the weather of a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The name of the city"}}}}, 73 | {"name": "get_news", "description": "Get the news.", "parameters": {"type": "object", "properties": {"domain": {"type": "string", "description": "The domain of the news"}}}} 74 | ] 75 | text = tokenizer.apply_chat_template( 76 | message, 77 | tools, 78 | tokenize=False, 79 | add_generation_prompt=True 80 | ) 81 | model_inputs = tokenizer(text, return_tensors="pt").to("cuda") 82 | quantized_model = AutoModelForCausalLM.from_pretrained( 83 | model_id, 84 | torch_dtype="bfloat16", 85 | device_map=device_map, 86 | quantization_config=quantization_config, 87 | trust_remote_code=True, 88 | offload_buffers=True, 89 | ) 90 | generation_config = GenerationConfig( 91 | max_new_tokens=20, 92 | eos_token_id=200020, 93 | use_cache=True, 94 | ) 95 | generated_ids = quantized_model.generate(**model_inputs, generation_config=generation_config) 96 | print(f"generated_ids: {generated_ids}") 97 | generated_ids = [ 98 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 99 | ] 100 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 101 | print(response) 102 | 103 | if __name__ == "__main__": 104 | main() 105 | 106 | 107 | -------------------------------------------------------------------------------- /docs/vllm_deployment_guide.md: -------------------------------------------------------------------------------- 1 | # 🚀 MiniMax Models vLLM Deployment Guide 2 | 3 | [vLLM中文版部署指南](./vllm_deployment_guide_cn.md) 4 | 5 | ## 📖 Introduction 6 | 7 | We recommend using [vLLM](https://docs.vllm.ai/en/latest/) to deploy [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) model. Based on our testing, vLLM performs excellently when deploying this model, with the following features: 8 | 9 | - 🔥 Outstanding service throughput performance 10 | - ⚡ Efficient and intelligent memory management 11 | - 📦 Powerful batch request processing capability 12 | - ⚙️ Deeply optimized underlying performance 13 | 14 | The MiniMax-M1 model can run efficiently on a single server equipped with 8 H800 or 8 H20 GPUs. In terms of hardware configuration, a server with 8 H800 GPUs can process context inputs up to 2 million tokens, while a server equipped with 8 H20 GPUs can support ultra-long context processing capabilities of up to 5 million tokens. 15 | 16 | ## 💾 Obtaining MiniMax Models 17 | 18 | ### MiniMax-M1 Model Obtaining 19 | 20 | You can download the model from our official HuggingFace repository: [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k), [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k) 21 | 22 | Download command: 23 | ``` 24 | pip install -U huggingface-hub 25 | huggingface-cli download MiniMaxAI/MiniMax-M1-40k 26 | # huggingface-cli download MiniMaxAI/MiniMax-M1-80k 27 | 28 | # If you encounter network issues, you can set a proxy 29 | export HF_ENDPOINT=https://hf-mirror.com 30 | ``` 31 | 32 | Or download using git: 33 | 34 | ```bash 35 | git lfs install 36 | git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k 37 | git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k 38 | ``` 39 | 40 | ⚠️ **Important Note**: Please ensure that [Git LFS](https://git-lfs.github.com/) is installed on your system, which is necessary for completely downloading the model weight files. 41 | 42 | ## 🛠️ Deployment Options 43 | 44 | ### Option 1: Deploy Using Docker (Recommended) 45 | 46 | To ensure consistency and stability of the deployment environment, we recommend using Docker for deployment. 47 | 48 | ⚠️ **Version Requirements**: 49 | - MiniMax-M1 model requires vLLM version 0.9.2 or later for full support 50 | - Special Note: Using vLLM versions below 0.9.2 may result in incompatibility or incorrect precision for the model: 51 | - For details, see: [Fix minimax model cache & lm_head precision #19592](https://github.com/vllm-project/vllm/pull/19592) 52 | 53 | 1. Get the container image: 54 | 55 | Currently, the official vLLM Docker image for version v0.9.2 has not been released yet. 56 | As an example, we will demonstrate how to manually build vLLM using version v0.8.3. 57 | ```bash 58 | docker pull vllm/vllm-openai:v0.8.3 59 | ``` 60 | 61 | 2. Run the container: 62 | ```bash 63 | # Set environment variables 64 | IMAGE=vllm/vllm-openai:v0.8.3 65 | MODEL_DIR= 66 | CODE_DIR= 67 | NAME=MiniMaxImage 68 | 69 | # Docker run configuration 70 | DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864" 71 | 72 | # Start the container 73 | sudo docker run -it \ 74 | -v $MODEL_DIR:$MODEL_DIR \ 75 | -v $CODE_DIR:$CODE_DIR \ 76 | --name $NAME \ 77 | $DOCKER_RUN_CMD \ 78 | $IMAGE /bin/bash 79 | 80 | # install vLLM 81 | cd $CODE_DIR 82 | git clone https://github.com/vllm-project/vllm.git 83 | cd vllm 84 | pip install -e . 85 | ``` 86 | 87 | 💡 If you are using other environment configurations, please refer to the [vLLM Installation Guide](https://docs.vllm.ai/en/latest/getting_started/installation.html) 88 | 89 | ## 🚀 Starting the Service 90 | 91 | ### Launch MiniMax-M1 Service 92 | 93 | ```bash 94 | export SAFETENSORS_FAST_GPU=1 95 | export VLLM_USE_V1=0 96 | python3 -m vllm.entrypoints.openai.api_server \ 97 | --model \ 98 | --tensor-parallel-size 8 \ 99 | --trust-remote-code \ 100 | --quantization experts_int8 \ 101 | --max_model_len 4096 \ 102 | --dtype bfloat16 103 | ``` 104 | 105 | ### API Call Example 106 | 107 | ```bash 108 | curl http://localhost:8000/v1/chat/completions \ 109 | -H "Content-Type: application/json" \ 110 | -d '{ 111 | "model": "MiniMaxAI/MiniMax-M1", 112 | "messages": [ 113 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, 114 | {"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]} 115 | ] 116 | }' 117 | ``` 118 | 119 | ## ❗ Common Issues 120 | 121 | ### Module Loading Problems 122 | If you encounter the following error: 123 | ``` 124 | import vllm._C # noqa 125 | ModuleNotFoundError: No module named 'vllm._C' 126 | ``` 127 | 128 | Or 129 | 130 | ``` 131 | MiniMax-M1 model is not currently supported 132 | ``` 133 | 134 | We provide two solutions: 135 | 136 | #### Solution 1: Copy Dependency Files 137 | ```bash 138 | cd 139 | git clone https://github.com/vllm-project/vllm.git 140 | cd vllm 141 | cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm 142 | cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn 143 | ``` 144 | 145 | #### Solution 2: Install from Source 146 | ```bash 147 | cd 148 | git clone https://github.com/vllm-project/vllm.git 149 | 150 | cd vllm/ 151 | pip install -e . 152 | ``` 153 | 154 | ## 📮 Getting Support 155 | 156 | If you encounter any issues while deploying MiniMax-M1 model: 157 | - Please check our official documentation 158 | - Contact our technical support team through official channels 159 | - Submit an [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues) on our GitHub repository 160 | 161 | We will continuously optimize the deployment experience of this model and welcome your feedback! 162 | -------------------------------------------------------------------------------- /docs/vllm_deployment_guide_pt-br.md: -------------------------------------------------------------------------------- 1 | # 🚀 Guia de Deploy dos Modelos MiniMax com vLLM 2 | 3 | [vLLM中文版部署指南](./vllm_deployment_guide_cn.md) 4 | 5 | ## 📖 Introdução 6 | 7 | Recomendamos utilizar o [vLLM](https://docs.vllm.ai/en/latest/) para fazer o deploy do modelo [MiniMax-M1](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k). Com base nos nossos testes, o vLLM apresenta excelente desempenho ao executar este modelo, oferecendo as seguintes vantagens: 8 | 9 | - 🔥 Desempenho excepcional em throughput de serviço 10 | - ⚡ Gerenciamento de memória eficiente e inteligente 11 | - 📦 Capacidade robusta de processamento de requisições em lote 12 | - ⚙️ Otimização profunda de desempenho em baixo nível 13 | 14 | O modelo MiniMax-M1 pode ser executado de forma eficiente em um servidor único equipado com 8 GPUs H800 ou 8 GPUs H20. Em termos de configuração de hardware, um servidor com 8 GPUs H800 consegue processar entradas de contexto com até 2 milhões de tokens, enquanto um servidor equipado com 8 GPUs H20 suporta contextos ultra longos de até 5 milhões de tokens. 15 | 16 | ## 💾 Obtendo os Modelos MiniMax 17 | 18 | ### Download do Modelo MiniMax-M1 19 | 20 | Você pode baixar o modelo diretamente do nosso repositório oficial no HuggingFace: [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) ou [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k). 21 | 22 | Comando para download: 23 | ``` 24 | pip install -U huggingface-hub 25 | huggingface-cli download MiniMaxAI/MiniMax-M1-40k 26 | 27 | # huggingface-cli download MiniMaxAI/MiniMax-M1-80k 28 | 29 | # Se você encontrar problemas de rede, pode configurar um proxy 30 | 31 | export HF\_ENDPOINT=[https://hf-mirror.com](https://hf-mirror.com) 32 | ``` 33 | 34 | Ou faça o download usando git: 35 | 36 | ```bash 37 | git lfs install 38 | git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-40k 39 | git clone https://huggingface.co/MiniMaxAI/MiniMax-M1-80k 40 | ``` 41 | 42 | ⚠️ **Atenção Importante**: Certifique-se de que o [Git LFS](https://git-lfs.github.com/) está instalado no seu sistema, pois ele é necessário para baixar completamente os arquivos de pesos do modelo. 43 | 44 | ## 🛠️ Opções de Deploy 45 | 46 | ### Opção 1: Deploy Utilizando Docker (Recomendado) 47 | 48 | Para garantir consistência e estabilidade no ambiente de deployment, recomendamos utilizar Docker. 49 | 50 | ⚠️ **Requisitos de Versão**: 51 | 52 | * O modelo MiniMax-M1 requer vLLM na versão 0.9.2 ou superior para suporte completo. 53 | * Nota especial: Si se utiliza una versión de vLLM inferior a 0.9.2, pueden surgir problemas de incompatibilidad o precisión incorrecta del modelo: 54 | 55 | * Para más detalles, consulta: [Fix minimax model cache & lm_head precision #19592](https://github.com/vllm-project/vllm/pull/19592) 56 | 57 | 1. Obtenha a imagem do container: 58 | 59 | ```bash 60 | docker pull vllm/vllm-openai:v0.8.3 61 | ``` 62 | 63 | 2. Execute o container: 64 | 65 | ```bash 66 | # Defina variáveis de ambiente 67 | IMAGE=vllm/vllm-openai:v0.8.3 68 | MODEL_DIR= 69 | CODE_DIR= 70 | NAME=MiniMaxImage 71 | 72 | # Configuração do Docker run 73 | DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=2gb --rm --gpus all --ulimit stack=67108864" 74 | 75 | # Inicie o container 76 | sudo docker run -it \ 77 | -v $MODEL_DIR:$MODEL_DIR \ 78 | -v $CODE_DIR:$CODE_DIR \ 79 | --name $NAME \ 80 | $DOCKER_RUN_CMD \ 81 | $IMAGE /bin/bash 82 | 83 | cd $CODE_DIR 84 | git clone https://github.com/vllm-project/vllm.git 85 | cd vllm 86 | pip install -e . 87 | ``` 88 | 89 | 💡 Se você estiver utilizando outra configuração de ambiente, consulte o [Guia de Instalação do vLLM](https://docs.vllm.ai/en/latest/getting_started/installation.html). 90 | 91 | ## 🚀 Inicializando o Serviço 92 | 93 | ### Iniciando o Serviço com MiniMax-M1 94 | 95 | ```bash 96 | export SAFETENSORS_FAST_GPU=1 97 | export VLLM_USE_V1=0 98 | python3 -m vllm.entrypoints.openai.api_server \ 99 | --model \ 100 | --tensor-parallel-size 8 \ 101 | --trust-remote-code \ 102 | --quantization experts_int8 \ 103 | --max_model_len 4096 \ 104 | --dtype bfloat16 105 | ``` 106 | 107 | ### Exemplo de Chamada via API 108 | 109 | ```bash 110 | curl http://localhost:8000/v1/chat/completions \ 111 | -H "Content-Type: application/json" \ 112 | -d '{ 113 | "model": "MiniMaxAI/MiniMax-M1", 114 | "messages": [ 115 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, 116 | {"role": "user", "content": [{"type": "text", "text": "Who won the world series in 2020?"}]} 117 | ] 118 | }' 119 | ``` 120 | 121 | ## ❗ Problemas Comuns 122 | 123 | ### Problemas ao Carregar Módulos 124 | 125 | Se você encontrar o erro: 126 | 127 | ``` 128 | import vllm._C # noqa 129 | ModuleNotFoundError: No module named 'vllm._C' 130 | ``` 131 | 132 | Ou 133 | 134 | ``` 135 | MiniMax-M1 model is not currently supported 136 | ``` 137 | 138 | Disponibilizamos duas soluções: 139 | 140 | #### Solução 1: Copiar Arquivos de Dependência 141 | 142 | ```bash 143 | cd 144 | git clone https://github.com/vllm-project/vllm.git 145 | cd vllm 146 | cp /usr/local/lib/python3.12/dist-packages/vllm/*.so vllm 147 | cp -r /usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/* vllm/vllm_flash_attn 148 | ``` 149 | 150 | #### Solução 2: Instalar a partir do Código-Fonte 151 | 152 | ```bash 153 | cd 154 | git clone https://github.com/vllm-project/vllm.git 155 | 156 | cd vllm/ 157 | pip install -e . 158 | ``` 159 | 160 | ## 📮 Suporte 161 | 162 | Se você tiver qualquer problema durante o deploy do modelo MiniMax-M1: 163 | 164 | * Consulte nossa documentação oficial 165 | * Entre em contato com nossa equipe de suporte técnico pelos canais oficiais 166 | * Abra uma [Issue](https://github.com/MiniMax-AI/MiniMax-M1/issues) no nosso repositório do GitHub 167 | 168 | Estamos constantemente otimizando a experiência de deploy deste modelo e valorizamos muito seu feedback! 169 | -------------------------------------------------------------------------------- /configuration_minimax_m1.py: -------------------------------------------------------------------------------- 1 | """ MiniMaxM1 model configuration""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class MiniMaxM1Config(PretrainedConfig): 11 | r""" 12 | This is the configuration class to store the configuration of a [`MiniMaxM1Model`]. It is used to instantiate an 13 | MiniMaxM1 model according to the specified arguments, defining the model architecture. Instantiating a configuration 14 | with the defaults will yield a similar configuration to that of the MiniMaxM1. 15 | 16 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 17 | documentation from [`PretrainedConfig`] for more information. 18 | 19 | 20 | Args: 21 | vocab_size (`int`, *optional*, defaults to 32000): 22 | Vocabulary size of the MiniMaxM1 model. Defines the number of different tokens that can be represented by the 23 | `inputs_ids` passed when calling [`MiniMaxM1Model`] 24 | hidden_size (`int`, *optional*, defaults to 4096): 25 | Dimension of the hidden representations. 26 | intermediate_size (`int`, *optional*, defaults to 14336): 27 | Dimension of the MLP representations. 28 | num_hidden_layers (`int`, *optional*, defaults to 32): 29 | Number of hidden layers in the Transformer encoder. 30 | num_attention_heads (`int`, *optional*, defaults to 32): 31 | Number of attention heads for each attention layer in the Transformer encoder. 32 | num_key_value_heads (`int`, *optional*, defaults to 8): 33 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 34 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 35 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 36 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 37 | by meanpooling all the original heads within that group. For more details checkout [this 38 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 39 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 40 | The non-linear activation function (function or string) in the decoder. 41 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 42 | The maximum sequence length that this model might ever be used with. MiniMaxM1's sliding window attention 43 | allows sequence of up to 4096*32 tokens. 44 | initializer_range (`float`, *optional*, defaults to 0.02): 45 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 46 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 47 | The epsilon used by the rms normalization layers. 48 | use_cache (`bool`, *optional*, defaults to `True`): 49 | Whether or not the model should return the last key/values attentions (not used by all models). Only 50 | relevant if `config.is_decoder=True`. 51 | pad_token_id (`int`, *optional*): 52 | The id of the padding token. 53 | bos_token_id (`int`, *optional*, defaults to 1): 54 | The id of the "beginning-of-sequence" token. 55 | eos_token_id (`int`, *optional*, defaults to 2): 56 | The id of the "end-of-sequence" token. 57 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 58 | Whether the model's input and output word embeddings should be tied. 59 | rope_theta (`float`, *optional*, defaults to 1000000.0): 60 | The base period of the RoPE embeddings. 61 | sliding_window (`int`, *optional*): 62 | Sliding window attention window size. If not specified, will default to `4096`. 63 | attention_dropout (`float`, *optional*, defaults to 0.0): 64 | The dropout ratio for the attention probabilities. 65 | num_experts_per_tok (`int`, *optional*, defaults to 2): 66 | The number of experts to route per-token, can be also interpreted as the `top-k` routing 67 | parameter 68 | num_local_experts (`int`, *optional*, defaults to 8): 69 | Number of experts per Sparse MLP layer. 70 | output_router_logits (`bool`, *optional*, defaults to `False`): 71 | Whether or not the router logits should be returned by the model. Enabeling this will also 72 | allow the model to output the auxiliary loss. See [here]() for more details 73 | router_aux_loss_coef (`float`, *optional*, defaults to 0.001): 74 | The aux loss factor for the total loss. 75 | router_jitter_noise (`float`, *optional*, defaults to 0.0): 76 | Amount of noise to add to the router. 77 | 78 | ```python 79 | >>> from transformers import MiniMaxM1Model, MiniMaxM1Config 80 | 81 | >>> # Initializing a MiniMaxM1 style configuration 82 | >>> configuration = MiniMaxM1Config() 83 | 84 | >>> # Initializing a model from the MiniMaxM1 style configuration 85 | >>> model = MiniMaxM1Model(configuration) 86 | 87 | >>> # Accessing the model configuration 88 | >>> configuration = model.config 89 | ```""" 90 | 91 | model_type = "MiniMaxM1" 92 | keys_to_ignore_at_inference = ["past_key_values"] 93 | 94 | def __init__( 95 | self, 96 | vocab_size=32000, 97 | hidden_size=4096, 98 | intermediate_size=14336, 99 | num_hidden_layers=32, 100 | num_attention_heads=32, 101 | num_key_value_heads=8, 102 | hidden_act="silu", 103 | max_position_embeddings=4096 * 32, 104 | initializer_range=0.02, 105 | rms_norm_eps=1e-5, 106 | use_cache=True, 107 | pad_token_id=None, 108 | bos_token_id=None, 109 | eos_token_id=None, 110 | tie_word_embeddings=False, 111 | rope_theta=1e6, 112 | sliding_window=None, 113 | attention_dropout=0.0, 114 | num_experts_per_tok=2, 115 | num_local_experts=8, 116 | output_router_logits=False, 117 | router_aux_loss_coef=0.001, 118 | router_jitter_noise=0.0, 119 | **kwargs, 120 | ): 121 | self.vocab_size = vocab_size 122 | self.max_position_embeddings = max_position_embeddings 123 | self.hidden_size = hidden_size 124 | self.intermediate_size = intermediate_size 125 | self.num_hidden_layers = num_hidden_layers 126 | self.num_attention_heads = num_attention_heads 127 | self.sliding_window = sliding_window 128 | 129 | # for backward compatibility 130 | if num_key_value_heads is None: 131 | num_key_value_heads = num_attention_heads 132 | 133 | self.num_key_value_heads = num_key_value_heads 134 | self.hidden_act = hidden_act 135 | self.initializer_range = initializer_range 136 | self.rms_norm_eps = rms_norm_eps 137 | self.use_cache = use_cache 138 | self.rope_theta = rope_theta 139 | self.attention_dropout = attention_dropout 140 | 141 | self.num_experts_per_tok = num_experts_per_tok 142 | self.num_local_experts = num_local_experts 143 | self.output_router_logits = output_router_logits 144 | self.router_aux_loss_coef = router_aux_loss_coef 145 | self.router_jitter_noise = router_jitter_noise 146 | super().__init__( 147 | pad_token_id=pad_token_id, 148 | bos_token_id=bos_token_id, 149 | eos_token_id=eos_token_id, 150 | tie_word_embeddings=tie_word_embeddings, 151 | **kwargs, 152 | ) 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 MiniMax 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /docs/function_call_guide_cn.md: -------------------------------------------------------------------------------- 1 | # MiniMax-M1 函数调用(Function Call)功能指南 2 | 3 | ## 📖 简介 4 | 5 | MiniMax-M1 模型支持函数调用功能,使模型能够识别何时需要调用外部函数,并以结构化格式输出函数调用参数。本文档详细介绍了如何使用 MiniMax-M1 的函数调用功能。 6 | 7 | ## 🚀 快速开始 8 | 9 | ### 使用 vLLM 进行 Function Calls(推荐) 10 | 11 | 在实际部署过程中,为了支持类似 OpenAI API 的原生 Function Calling(工具调用)能力,MiniMax-M1 模型集成了专属 `tool_call_parser=minimax` 解析器,从而避免对模型输出结果进行额外的正则解析处理。 12 | 13 | #### 环境准备与重新编译 vLLM 14 | 15 | 由于该功能尚未正式发布在 PyPI 版本中,需基于源码进行编译。以下为基于 vLLM 官方 Docker 镜像 `vllm/vllm-openai:v0.8.3` 的示例流程: 16 | 17 | ```bash 18 | IMAGE=vllm/vllm-openai:v0.8.3 19 | DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=32gb --rm --gpus all --ulimit stack=67108864" 20 | 21 | # 运行 docker 22 | sudo docker run -it -v $MODEL_DIR:$MODEL_DIR \ 23 | -v $CODE_DIR:$CODE_DIR \ 24 | --name vllm_function_call \ 25 | $DOCKER_RUN_CMD \ 26 | --entrypoint /bin/bash \ 27 | $IMAGE 28 | ``` 29 | 30 | #### 编译 vLLM 源码 31 | 32 | 进入容器后,执行以下命令以获取源码并重新安装: 33 | 34 | ```bash 35 | cd $CODE_DIR 36 | git clone https://github.com/vllm-project/vllm.git 37 | cd vllm 38 | pip install -e . 39 | ``` 40 | 41 | #### 启动 vLLM API 服务 42 | 43 | ```bash 44 | export SAFETENSORS_FAST_GPU=1 45 | export VLLM_USE_V1=0 46 | 47 | python3 -m vllm.entrypoints.openai.api_server \ 48 | --model MiniMax-M1-80k \ 49 | --tensor-parallel-size 8 \ 50 | --trust-remote-code \ 51 | --quantization experts_int8 \ 52 | --enable-auto-tool-choice \ 53 | --tool-call-parser minimax \ 54 | --chat-template vllm/examples/tool_chat_template_minimax_m1.jinja \ 55 | --max_model_len 4096 \ 56 | --dtype bfloat16 \ 57 | --gpu-memory-utilization 0.85 58 | ``` 59 | 60 | **⚠️ 注意:** 61 | - `--tool-call-parser minimax` 为关键参数,用于启用 MiniMax-M1 自定义解析器 62 | - `--enable-auto-tool-choice` 启用自动工具选择 63 | - `--chat-template` 模板文件需要适配 tool calling 格式 64 | 65 | #### Function Call 测试脚本示例 66 | 67 | 以下 Python 脚本基于 OpenAI SDK 实现了一个天气查询函数的调用示例: 68 | 69 | ```python 70 | from openai import OpenAI 71 | import json 72 | 73 | client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") 74 | 75 | def get_weather(location: str, unit: str): 76 | return f"Getting the weather for {location} in {unit}..." 77 | 78 | tool_functions = {"get_weather": get_weather} 79 | 80 | tools = [{ 81 | "type": "function", 82 | "function": { 83 | "name": "get_weather", 84 | "description": "Get the current weather in a given location", 85 | "parameters": { 86 | "type": "object", 87 | "properties": { 88 | "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, 89 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} 90 | }, 91 | "required": ["location", "unit"] 92 | } 93 | } 94 | }] 95 | 96 | response = client.chat.completions.create( 97 | model=client.models.list().data[0].id, 98 | messages=[{"role": "user", "content": "What's the weather like in San Francisco? use celsius."}], 99 | tools=tools, 100 | tool_choice="auto" 101 | ) 102 | 103 | print(response) 104 | 105 | tool_call = response.choices[0].message.tool_calls[0].function 106 | print(f"Function called: {tool_call.name}") 107 | print(f"Arguments: {tool_call.arguments}") 108 | print(f"Result: {get_weather(**json.loads(tool_call.arguments))}") 109 | ``` 110 | 111 | **输出示例:** 112 | ``` 113 | Function called: get_weather 114 | Arguments: {"location": "San Francisco, CA", "unit": "celsius"} 115 | Result: Getting the weather for San Francisco, CA in celsius... 116 | ``` 117 | 118 | ### 手动解析模型输出 119 | 120 | 如果您无法使用 vLLM 的内置解析器,或者需要使用其他推理框架(如 transformers、TGI 等),可以使用以下方法手动解析模型的原始输出。这种方法需要您自己解析模型输出的 XML 标签格式。 121 | 122 | #### 使用 Transformers 的示例 123 | 124 | 以下是使用 transformers 库的完整示例: 125 | 126 | ```python 127 | from transformers import AutoTokenizer 128 | 129 | def get_default_tools(): 130 | return [ 131 | { 132 | "name": "get_current_weather", 133 | "description": "Get the latest weather for a location", 134 | "parameters": { 135 | "type": "object", 136 | "properties": { 137 | "location": { 138 | "type": "string", 139 | "description": "A certain city, such as Beijing, Shanghai" 140 | } 141 | }, 142 | } 143 | "required": ["location"], 144 | "type": "object" 145 | } 146 | ] 147 | 148 | # 加载模型和分词器 149 | tokenizer = AutoTokenizer.from_pretrained(model_id) 150 | prompt = "What's the weather like in Shanghai today?" 151 | messages = [ 152 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]}, 153 | {"role": "user", "content": [{"type": "text", "text": prompt}]}, 154 | ] 155 | 156 | # 启用函数调用工具 157 | tools = get_default_tools() 158 | 159 | # 应用聊天模板,并加入工具定义 160 | text = tokenizer.apply_chat_template( 161 | messages, 162 | tokenize=False, 163 | add_generation_prompt=True, 164 | tools=tools 165 | ) 166 | 167 | # 发送请求(这里使用任何推理服务) 168 | import requests 169 | payload = { 170 | "model": "MiniMaxAI/MiniMax-M1-40k", 171 | "prompt": text, 172 | "max_tokens": 4000 173 | } 174 | response = requests.post( 175 | "http://localhost:8000/v1/completions", 176 | headers={"Content-Type": "application/json"}, 177 | json=payload, 178 | stream=False, 179 | ) 180 | 181 | # 模型输出需要手动解析 182 | raw_output = response.json()["choices"][0]["text"] 183 | print("原始输出:", raw_output) 184 | 185 | # 使用下面的解析函数处理输出 186 | function_calls = parse_function_calls(raw_output) 187 | ``` 188 | 189 | ## 🛠️ 函数调用的定义 190 | 191 | ### 函数结构体 192 | 193 | 函数调用需要在请求体中定义 `tools` 字段,每个函数由以下部分组成: 194 | 195 | ```json 196 | { 197 | "tools": [ 198 | { 199 | "name": "search_web", 200 | "description": "搜索函数。", 201 | "parameters": { 202 | "properties": { 203 | "query_list": { 204 | "description": "进行搜索的关键词,列表元素个数为1。", 205 | "items": { "type": "string" }, 206 | "type": "array" 207 | }, 208 | "query_tag": { 209 | "description": "query的分类", 210 | "items": { "type": "string" }, 211 | "type": "array" 212 | } 213 | }, 214 | "required": [ "query_list", "query_tag" ], 215 | "type": "object" 216 | } 217 | } 218 | ] 219 | } 220 | ``` 221 | 222 | **字段说明:** 223 | - `name`: 函数名称 224 | - `description`: 函数功能描述 225 | - `parameters`: 函数参数定义 226 | - `properties`: 参数属性定义,key 是参数名,value 包含参数的详细描述 227 | - `required`: 必填参数列表 228 | - `type`: 参数类型(通常为 "object") 229 | 230 | ### 模型内部处理格式 231 | 232 | 在模型内部处理时,函数定义会被转换为特殊格式并拼接到输入文本中: 233 | 234 | ``` 235 | system ai_setting=MiniMax AI 236 | MiniMax AI是由上海稀宇科技有限公司(MiniMax)自主研发的AI助理。 237 | system tool_setting=tools 238 | You are provided with these tools: 239 | 240 | {"name": "search_web", "description": "搜索函数。", "parameters": {"properties": {"query_list": {"description": "进行搜索的关键词,列表元素个数为1。", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "query的分类", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}} 241 | 242 | If you need to call tools, please respond with XML tags, and provide tool-name and json-object of arguments, following the format below: 243 | 244 | {"name": , "arguments": } 245 | ... 246 | 247 | user name=用户 248 | OpenAI 和 Gemini 的最近一次发布会都是什么时候? 249 | ai name=MiniMax AI 250 | ``` 251 | 252 | ### 模型输出格式 253 | 254 | 模型会以以下格式输出函数调用: 255 | 256 | ```xml 257 | 258 | Okay, I will search for the OpenAI and Gemini latest release. 259 | 260 | 261 | {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}} 262 | {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}} 263 | 264 | ``` 265 | 266 | ## 📥 手动解析函数调用结果 267 | 268 | ### 解析函数调用 269 | 270 | 当需要手动解析时,您需要解析模型输出的 XML 标签格式: 271 | 272 | ```python 273 | import re 274 | import json 275 | def parse_function_calls(content: str): 276 | """ 277 | 解析模型输出中的函数调用 278 | """ 279 | function_calls = [] 280 | 281 | # 匹配 标签内的内容 282 | tool_calls_pattern = r"(.*?)" 283 | tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL) 284 | 285 | if not tool_calls_match: 286 | return function_calls 287 | 288 | tool_calls_content = tool_calls_match.group(1).strip() 289 | 290 | # 解析每个函数调用(每行一个JSON对象) 291 | for line in tool_calls_content.split('\n'): 292 | line = line.strip() 293 | if not line: 294 | continue 295 | 296 | try: 297 | # 解析JSON格式的函数调用 298 | call_data = json.loads(line) 299 | function_name = call_data.get("name") 300 | arguments = call_data.get("arguments", {}) 301 | 302 | function_calls.append({ 303 | "name": function_name, 304 | "arguments": arguments 305 | }) 306 | 307 | print(f"调用函数: {function_name}, 参数: {arguments}") 308 | 309 | except json.JSONDecodeError as e: 310 | print(f"参数解析失败: {line}, 错误: {e}") 311 | 312 | return function_calls 313 | 314 | # 示例:处理天气查询函数 315 | def execute_function_call(function_name: str, arguments: dict): 316 | """ 317 | 执行函数调用并返回结果 318 | """ 319 | if function_name == "get_current_weather": 320 | location = arguments.get("location", "未知位置") 321 | # 构建函数执行结果 322 | return { 323 | "role": "tool", 324 | "content": [ 325 | { 326 | "name": function_name, 327 | "type": "text", 328 | "text": json.dumps({ 329 | "location": location, 330 | "temperature": "25", 331 | "unit": "celsius", 332 | "weather": "晴朗" 333 | }, ensure_ascii=False) 334 | } 335 | ] 336 | } 337 | elif function_name == "search_web": 338 | query_list = arguments.get("query_list", []) 339 | query_tag = arguments.get("query_tag", []) 340 | # 模拟搜索结果 341 | return { 342 | "role": "tool", 343 | "content": [ 344 | { 345 | "name": function_name, 346 | "type": "text", 347 | "text": f"搜索关键词: {query_list}, 分类: {query_tag}\n搜索结果: 相关信息已找到" 348 | } 349 | ] 350 | } 351 | 352 | return None 353 | ``` 354 | 355 | ### 将函数执行结果返回给模型 356 | 357 | 成功解析函数调用后,您应将函数执行结果添加到对话历史中,以便模型在后续交互中能够访问和利用这些信息。 358 | 359 | #### 单个结果 360 | 361 | 假如模型调用了 `search_web` 函数,您可以参考如下格式添加执行结果,`name` 字段为具体的函数名称。 362 | 363 | ```json 364 | { 365 | "role": "tool", 366 | "content": [ 367 | { 368 | "name": "search_web", 369 | "type": "text", 370 | "text": "test_result" 371 | } 372 | ] 373 | } 374 | ``` 375 | 376 | 对应如下的模型输入格式: 377 | ``` 378 | tool name=tools 379 | tool name: search_web 380 | tool result: test_result 381 | 382 | ``` 383 | 384 | #### 多个结果 385 | 386 | 假如模型同时调用了 `search_web` 和 `get_current_weather` 函数,您可以参考如下格式添加执行结果,`content`包含多个结果。 387 | 388 | ```json 389 | { 390 | "role": "tool", 391 | "content": [ 392 | { 393 | "name": "search_web", 394 | "type": "text", 395 | "text": "test_result1" 396 | }, 397 | { 398 | "name": "get_current_weather", 399 | "type": "text", 400 | "text": "test_result2" 401 | } 402 | ] 403 | } 404 | ``` 405 | 406 | 对应如下的模型输入格式: 407 | ``` 408 | tool name=tools 409 | tool name: search_web 410 | tool result: test_result1 411 | tool name: get_current_weather 412 | tool result: test_result2 413 | ``` 414 | 415 | 虽然我们建议您参考以上格式,但只要返回给模型的输入易于理解,`name` 和 `text` 的具体内容完全由您自主决定。 416 | 417 | ## 📚 参考资料 418 | 419 | - [MiniMax-M1 模型仓库](https://github.com/MiniMaxAI/MiniMax-M1) 420 | - [vLLM 项目主页](https://github.com/vllm-project/vllm) 421 | - [vLLM Function Calling PR](https://github.com/vllm-project/vllm/pull/20297) 422 | - [OpenAI Python SDK](https://github.com/openai/openai-python) 423 | -------------------------------------------------------------------------------- /docs/function_call_guide.md: -------------------------------------------------------------------------------- 1 | # MiniMax-M1 Function Call Guide 2 | 3 | [FunctionCall中文使用指南](./function_call_guide_cn.md) 4 | 5 | ## 📖 Introduction 6 | 7 | The MiniMax-M1 model supports function calling capabilities, enabling the model to identify when external functions need to be called and output function call parameters in a structured format. This document provides detailed instructions on how to use the function calling feature of MiniMax-M1. 8 | 9 | ## 🚀 Quick Start 10 | 11 | ### Using vLLM for Function Calls (Recommended) 12 | 13 | In actual deployment, to support native Function Calling (tool calling) capabilities similar to OpenAI API, the MiniMax-M1 model integrates a dedicated `tool_call_parser=minimax` parser, avoiding additional regex parsing of model output. 14 | 15 | #### Environment Setup and vLLM Recompilation 16 | 17 | Since this feature has not been officially released in the PyPI version, compilation from source code is required. The following is an example process based on the official vLLM Docker image `vllm/vllm-openai:v0.8.3`: 18 | 19 | ```bash 20 | IMAGE=vllm/vllm-openai:v0.8.3 21 | DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=32gb --rm --gpus all --ulimit stack=67108864" 22 | 23 | # Run docker 24 | sudo docker run -it -v $MODEL_DIR:$MODEL_DIR \ 25 | -v $CODE_DIR:$CODE_DIR \ 26 | --name vllm_function_call \ 27 | $DOCKER_RUN_CMD \ 28 | --entrypoint /bin/bash \ 29 | $IMAGE 30 | ``` 31 | 32 | #### Compiling vLLM Source Code 33 | 34 | After entering the container, execute the following commands to get the source code and reinstall: 35 | 36 | ```bash 37 | cd $CODE_DIR 38 | git clone https://github.com/vllm-project/vllm.git 39 | cd vllm 40 | pip install -e . 41 | ``` 42 | 43 | #### Starting vLLM API Service 44 | 45 | ```bash 46 | export SAFETENSORS_FAST_GPU=1 47 | export VLLM_USE_V1=0 48 | 49 | python3 -m vllm.entrypoints.openai.api_server \ 50 | --model MiniMax-M1-80k \ 51 | --tensor-parallel-size 8 \ 52 | --trust-remote-code \ 53 | --quantization experts_int8 \ 54 | --enable-auto-tool-choice \ 55 | --tool-call-parser minimax \ 56 | --chat-template vllm/examples/tool_chat_template_minimax_m1.jinja \ 57 | --max_model_len 4096 \ 58 | --dtype bfloat16 \ 59 | --gpu-memory-utilization 0.85 60 | ``` 61 | 62 | **⚠️ Note:** 63 | - `--tool-call-parser minimax` is a key parameter for enabling the MiniMax-M1 custom parser 64 | - `--enable-auto-tool-choice` enables automatic tool selection 65 | - `--chat-template` template file needs to be adapted for tool calling format 66 | 67 | #### Function Call Test Script Example 68 | 69 | The following Python script implements a weather query function call example based on OpenAI SDK: 70 | 71 | ```python 72 | from openai import OpenAI 73 | import json 74 | 75 | client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") 76 | 77 | def get_weather(location: str, unit: str): 78 | return f"Getting the weather for {location} in {unit}..." 79 | 80 | tool_functions = {"get_weather": get_weather} 81 | 82 | tools = [{ 83 | "type": "function", 84 | "function": { 85 | "name": "get_weather", 86 | "description": "Get the current weather in a given location", 87 | "parameters": { 88 | "type": "object", 89 | "properties": { 90 | "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, 91 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} 92 | }, 93 | "required": ["location", "unit"] 94 | } 95 | } 96 | }] 97 | 98 | response = client.chat.completions.create( 99 | model=client.models.list().data[0].id, 100 | messages=[{"role": "user", "content": "What's the weather like in San Francisco? use celsius."}], 101 | tools=tools, 102 | tool_choice="auto" 103 | ) 104 | 105 | print(response) 106 | 107 | tool_call = response.choices[0].message.tool_calls[0].function 108 | print(f"Function called: {tool_call.name}") 109 | print(f"Arguments: {tool_call.arguments}") 110 | print(f"Result: {get_weather(**json.loads(tool_call.arguments))}") 111 | ``` 112 | 113 | **Output Example:** 114 | ``` 115 | Function called: get_weather 116 | Arguments: {"location": "San Francisco, CA", "unit": "celsius"} 117 | Result: Getting the weather for San Francisco, CA in celsius... 118 | ``` 119 | 120 | ### Manual Parsing of Model Output 121 | 122 | If you cannot use vLLM's built-in parser, or need to use other inference frameworks (such as transformers, TGI, etc.), you can use the following method to manually parse the model's raw output. This method requires you to parse the XML tag format of the model output yourself. 123 | 124 | #### Using Transformers Example 125 | 126 | The following is a complete example using the transformers library: 127 | 128 | ```python 129 | from transformers import AutoTokenizer 130 | 131 | def get_default_tools(): 132 | return [ 133 | { 134 | "name": "get_current_weather", 135 | "description": "Get the latest weather for a location", 136 | "parameters": { 137 | "type": "object", 138 | "properties": { 139 | "location": { 140 | "type": "string", 141 | "description": "A certain city, such as Beijing, Shanghai" 142 | } 143 | }, 144 | } 145 | "required": ["location"], 146 | "type": "object" 147 | } 148 | ] 149 | 150 | # Load model and tokenizer 151 | tokenizer = AutoTokenizer.from_pretrained(model_id) 152 | prompt = "What's the weather like in Shanghai today?" 153 | messages = [ 154 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]}, 155 | {"role": "user", "content": [{"type": "text", "text": prompt}]}, 156 | ] 157 | 158 | # Enable function call tools 159 | tools = get_default_tools() 160 | 161 | # Apply chat template and add tool definitions 162 | text = tokenizer.apply_chat_template( 163 | messages, 164 | tokenize=False, 165 | add_generation_prompt=True, 166 | tools=tools 167 | ) 168 | 169 | # Send request (using any inference service here) 170 | import requests 171 | payload = { 172 | "model": "MiniMaxAI/MiniMax-M1-40k", 173 | "prompt": text, 174 | "max_tokens": 4000 175 | } 176 | response = requests.post( 177 | "http://localhost:8000/v1/completions", 178 | headers={"Content-Type": "application/json"}, 179 | json=payload, 180 | stream=False, 181 | ) 182 | 183 | # Model output needs manual parsing 184 | raw_output = response.json()["choices"][0]["text"] 185 | print("Raw output:", raw_output) 186 | 187 | # Use the parsing function below to process the output 188 | function_calls = parse_function_calls(raw_output) 189 | ``` 190 | 191 | ## 🛠️ Function Call Definition 192 | 193 | ### Function Structure 194 | 195 | Function calls need to be defined in the `tools` field of the request body. Each function consists of the following components: 196 | 197 | ```json 198 | { 199 | "tools": [ 200 | { 201 | "name": "search_web", 202 | "description": "Search function.", 203 | "parameters": { 204 | "properties": { 205 | "query_list": { 206 | "description": "Keywords for search, with list element count of 1.", 207 | "items": { "type": "string" }, 208 | "type": "array" 209 | }, 210 | "query_tag": { 211 | "description": "Classification of the query", 212 | "items": { "type": "string" }, 213 | "type": "array" 214 | } 215 | }, 216 | "required": [ "query_list", "query_tag" ], 217 | "type": "object" 218 | } 219 | } 220 | ] 221 | } 222 | ``` 223 | 224 | **Field Descriptions:** 225 | - `name`: Function name 226 | - `description`: Function description 227 | - `parameters`: Function parameter definition 228 | - `properties`: Parameter property definitions, where key is the parameter name and value contains detailed parameter description 229 | - `required`: List of required parameters 230 | - `type`: Parameter type (usually "object") 231 | 232 | ### Internal Model Processing Format 233 | 234 | When processed internally by the model, function definitions are converted to a special format and concatenated to the input text: 235 | 236 | ``` 237 | system ai_setting=MiniMax AI 238 | MiniMax AI是由上海稀宇科技有限公司(MiniMax)自主研发的AI助理。 239 | system tool_setting=tools 240 | You are provided with these tools: 241 | 242 | {"name": "search_web", "description": "搜索函数。", "parameters": {"properties": {"query_list": {"description": "进行搜索的关键词,列表元素个数为1。", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "query的分类", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}} 243 | 244 | If you need to call tools, please respond with XML tags, and provide tool-name and json-object of arguments, following the format below: 245 | 246 | {"name": , "arguments": } 247 | ... 248 | 249 | user name=用户 250 | OpenAI 和 Gemini 的最近一次发布会都是什么时候? 251 | ai name=MiniMax AI 252 | ``` 253 | 254 | ### Model Output Format 255 | 256 | The model outputs function calls in the following format: 257 | 258 | ```xml 259 | 260 | Okay, I will search for the OpenAI and Gemini latest release. 261 | 262 | 263 | {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}} 264 | {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}} 265 | 266 | ``` 267 | 268 | ## 📥 Manual Parsing of Function Call Results 269 | 270 | ### Parsing Function Calls 271 | 272 | When manual parsing is required, you need to parse the XML tag format of the model output: 273 | 274 | ```python 275 | import re 276 | import json 277 | def parse_function_calls(content: str): 278 | """ 279 | Parse function calls from model output 280 | """ 281 | function_calls = [] 282 | 283 | # Match content within tags 284 | tool_calls_pattern = r"(.*?)" 285 | tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL) 286 | 287 | if not tool_calls_match: 288 | return function_calls 289 | 290 | tool_calls_content = tool_calls_match.group(1).strip() 291 | 292 | # Parse each function call (one JSON object per line) 293 | for line in tool_calls_content.split('\n'): 294 | line = line.strip() 295 | if not line: 296 | continue 297 | 298 | try: 299 | # Parse JSON format function call 300 | call_data = json.loads(line) 301 | function_name = call_data.get("name") 302 | arguments = call_data.get("arguments", {}) 303 | 304 | function_calls.append({ 305 | "name": function_name, 306 | "arguments": arguments 307 | }) 308 | 309 | print(f"Function call: {function_name}, Arguments: {arguments}") 310 | 311 | except json.JSONDecodeError as e: 312 | print(f"Parameter parsing failed: {line}, Error: {e}") 313 | 314 | return function_calls 315 | 316 | # Example: Handle weather query function 317 | def execute_function_call(function_name: str, arguments: dict): 318 | """ 319 | Execute function call and return result 320 | """ 321 | if function_name == "get_current_weather": 322 | location = arguments.get("location", "Unknown location") 323 | # Build function execution result 324 | return { 325 | "role": "tool", 326 | "content": [ 327 | { 328 | "name": function_name, 329 | "type": "text", 330 | "text": json.dumps({ 331 | "location": location, 332 | "temperature": "25", 333 | "unit": "celsius", 334 | "weather": "Sunny" 335 | }, ensure_ascii=False) 336 | } 337 | ] 338 | } 339 | elif function_name == "search_web": 340 | query_list = arguments.get("query_list", []) 341 | query_tag = arguments.get("query_tag", []) 342 | # Simulate search results 343 | return { 344 | "role": "tool", 345 | "content": [ 346 | { 347 | "name": function_name, 348 | "type": "text", 349 | "text": f"Search keywords: {query_list}, Categories: {query_tag}\nSearch results: Relevant information found" 350 | } 351 | ] 352 | } 353 | 354 | return None 355 | ``` 356 | 357 | ### Returning Function Execution Results to the Model 358 | 359 | After successfully parsing function calls, you should add the function execution results to the conversation history so that the model can access and utilize this information in subsequent interactions. 360 | 361 | #### Single Result 362 | 363 | If the model calls the `search_web` function, you can refer to the following format to add execution results, with the `name` field being the specific function name. 364 | 365 | ```json 366 | { 367 | "role": "tool", 368 | "content": [ 369 | { 370 | "name": "search_web", 371 | "type": "text", 372 | "text": "test_result" 373 | } 374 | ] 375 | } 376 | ``` 377 | 378 | Corresponding model input format: 379 | ``` 380 | tool name=tools 381 | tool name: search_web 382 | tool result: test_result 383 | 384 | ``` 385 | 386 | #### Multiple Results 387 | 388 | If the model calls both `search_web` and `get_current_weather` functions simultaneously, you can refer to the following format to add execution results, with `content` containing multiple results. 389 | 390 | ```json 391 | { 392 | "role": "tool", 393 | "content": [ 394 | { 395 | "name": "search_web", 396 | "type": "text", 397 | "text": "test_result1" 398 | }, 399 | { 400 | "name": "get_current_weather", 401 | "type": "text", 402 | "text": "test_result2" 403 | } 404 | ] 405 | } 406 | ``` 407 | 408 | Corresponding model input format: 409 | ``` 410 | tool name=tools 411 | tool name: search_web 412 | tool result: test_result1 413 | tool name: get_current_weather 414 | tool result: test_result2 415 | ``` 416 | 417 | While we recommend following the above formats, as long as the input returned to the model is easy to understand, the specific content of `name` and `text` is entirely up to you. 418 | 419 | ## 📚 References 420 | 421 | - [MiniMax-M1 Model Repository](https://github.com/MiniMaxAI/MiniMax-M1) 422 | - [vLLM Project Homepage](https://github.com/vllm-project/vllm) 423 | - [vLLM Function Calling PR](https://github.com/vllm-project/vllm/pull/20297) 424 | - [OpenAI Python SDK](https://github.com/openai/openai-python) -------------------------------------------------------------------------------- /docs/function_call_guide_pt-br.md: -------------------------------------------------------------------------------- 1 | # Guia de Uso de Function Call no MiniMax-M1 2 | 3 | [FunctionCall中文使用指南](./function_call_guide_cn.md) 4 | 5 | ## 📖 Introdução 6 | 7 | O modelo MiniMax-M1 possui suporte para chamadas de funções (Function Call), permitindo que o modelo identifique quando funções externas precisam ser chamadas e gere os parâmetros dessas chamadas em um formato estruturado. Este documento fornece instruções detalhadas sobre como utilizar o recurso de chamadas de funções do MiniMax-M1. 8 | 9 | ## 🚀 Início Rápido 10 | 11 | ### Usando vLLM para Function Calls (Recomendado) 12 | 13 | Na implantação real, para suportar capacidades nativas de Function Calling (chamada de ferramentas) semelhantes à API OpenAI, o modelo MiniMax-M1 integra um parser dedicado `tool_call_parser=minimax`, evitando análise regex adicional da saída do modelo. 14 | 15 | #### Configuração do Ambiente e Recompilação do vLLM 16 | 17 | Como este recurso ainda não foi oficialmente lançado na versão PyPI, é necessária compilação a partir do código fonte. O seguinte é um processo de exemplo baseado na imagem oficial do Docker vLLM `vllm/vllm-openai:v0.8.3`: 18 | 19 | ```bash 20 | IMAGE=vllm/vllm-openai:v0.8.3 21 | DOCKER_RUN_CMD="--network=host --privileged --ipc=host --ulimit memlock=-1 --shm-size=32gb --rm --gpus all --ulimit stack=67108864" 22 | 23 | # Executar docker 24 | sudo docker run -it -v $MODEL_DIR:$MODEL_DIR \ 25 | -v $CODE_DIR:$CODE_DIR \ 26 | --name vllm_function_call \ 27 | $DOCKER_RUN_CMD \ 28 | --entrypoint /bin/bash \ 29 | $IMAGE 30 | ``` 31 | 32 | #### Compilando o Código Fonte do vLLM 33 | 34 | Após entrar no container, execute os seguintes comandos para obter o código fonte e reinstalar: 35 | 36 | ```bash 37 | cd $CODE_DIR 38 | git clone https://github.com/vllm-project/vllm.git 39 | cd vllm 40 | pip install -e . 41 | ``` 42 | 43 | #### Iniciando o Serviço API vLLM 44 | 45 | ```bash 46 | export SAFETENSORS_FAST_GPU=1 47 | export VLLM_USE_V1=0 48 | 49 | python3 -m vllm.entrypoints.openai.api_server \ 50 | --model MiniMax-M1-80k \ 51 | --tensor-parallel-size 8 \ 52 | --trust-remote-code \ 53 | --quantization experts_int8 \ 54 | --enable-auto-tool-choice \ 55 | --tool-call-parser minimax \ 56 | --chat-template vllm/examples/tool_chat_template_minimax_m1.jinja \ 57 | --max_model_len 4096 \ 58 | --dtype bfloat16 \ 59 | --gpu-memory-utilization 0.85 60 | ``` 61 | 62 | **⚠️ Nota:** 63 | - `--tool-call-parser minimax` é um parâmetro chave para habilitar o parser personalizado MiniMax-M1 64 | - `--enable-auto-tool-choice` habilita a seleção automática de ferramentas 65 | - `--chat-template` arquivo de template precisa ser adaptado para o formato de chamada de ferramentas 66 | 67 | #### Exemplo de Script de Teste de Function Call 68 | 69 | O seguinte script Python implementa um exemplo de chamada de função de consulta meteorológica baseado no SDK OpenAI: 70 | 71 | ```python 72 | from openai import OpenAI 73 | import json 74 | 75 | client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") 76 | 77 | def get_weather(location: str, unit: str): 78 | return f"Getting the weather for {location} in {unit}..." 79 | 80 | tool_functions = {"get_weather": get_weather} 81 | 82 | tools = [{ 83 | "type": "function", 84 | "function": { 85 | "name": "get_weather", 86 | "description": "Get the current weather in a given location", 87 | "parameters": { 88 | "type": "object", 89 | "properties": { 90 | "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, 91 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} 92 | }, 93 | "required": ["location", "unit"] 94 | } 95 | } 96 | }] 97 | 98 | response = client.chat.completions.create( 99 | model=client.models.list().data[0].id, 100 | messages=[{"role": "user", "content": "What's the weather like in San Francisco? use celsius."}], 101 | tools=tools, 102 | tool_choice="auto" 103 | ) 104 | 105 | print(response) 106 | 107 | tool_call = response.choices[0].message.tool_calls[0].function 108 | print(f"Function called: {tool_call.name}") 109 | print(f"Arguments: {tool_call.arguments}") 110 | print(f"Result: {get_weather(**json.loads(tool_call.arguments))}") 111 | ``` 112 | 113 | **Exemplo de Saída:** 114 | ``` 115 | Function called: get_weather 116 | Arguments: {"location": "San Francisco, CA", "unit": "celsius"} 117 | Result: Getting the weather for San Francisco, CA in celsius... 118 | ``` 119 | 120 | ### Análise Manual da Saída do Modelo 121 | 122 | Se você não puder usar o parser integrado do vLLM, ou precisar usar outros frameworks de inferência (como transformers, TGI, etc.), você pode usar o seguinte método para analisar manualmente a saída bruta do modelo. Este método requer que você analise o formato de tags XML da saída do modelo. 123 | 124 | #### Exemplo Usando Transformers 125 | 126 | O seguinte é um exemplo completo usando a biblioteca transformers: 127 | 128 | ```python 129 | from transformers import AutoTokenizer 130 | 131 | def get_default_tools(): 132 | return [ 133 | { 134 | "name": "get_current_weather", 135 | "description": "Get the latest weather for a location", 136 | "parameters": { 137 | "type": "object", 138 | "properties": { 139 | "location": { 140 | "type": "string", 141 | "description": "A certain city, such as Beijing, Shanghai" 142 | } 143 | }, 144 | } 145 | "required": ["location"], 146 | "type": "object" 147 | } 148 | ] 149 | 150 | # Carregar modelo e tokenizador 151 | tokenizer = AutoTokenizer.from_pretrained(model_id) 152 | prompt = "What's the weather like in Shanghai today?" 153 | messages = [ 154 | {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-M1 model."}]}, 155 | {"role": "user", "content": [{"type": "text", "text": prompt}]}, 156 | ] 157 | 158 | # Habilitar ferramentas de chamada de função 159 | tools = get_default_tools() 160 | 161 | # Aplicar template de chat e adicionar definições de ferramentas 162 | text = tokenizer.apply_chat_template( 163 | messages, 164 | tokenize=False, 165 | add_generation_prompt=True, 166 | tools=tools 167 | ) 168 | 169 | # Enviar requisição (usando qualquer serviço de inferência aqui) 170 | import requests 171 | payload = { 172 | "model": "MiniMaxAI/MiniMax-M1-40k", 173 | "prompt": text, 174 | "max_tokens": 4000 175 | } 176 | response = requests.post( 177 | "http://localhost:8000/v1/completions", 178 | headers={"Content-Type": "application/json"}, 179 | json=payload, 180 | stream=False, 181 | ) 182 | 183 | # Saída do modelo precisa de análise manual 184 | raw_output = response.json()["choices"][0]["text"] 185 | print("Saída bruta:", raw_output) 186 | 187 | # Use a função de análise abaixo para processar a saída 188 | function_calls = parse_function_calls(raw_output) 189 | ``` 190 | 191 | ## 🛠️ Definição de Function Call 192 | 193 | ### Estrutura da Função 194 | 195 | As funções precisam ser definidas no campo `tools` do corpo da requisição. Cada função é composta pelos seguintes elementos: 196 | 197 | ```json 198 | { 199 | "tools": [ 200 | { 201 | "name": "search_web", 202 | "description": "Função de busca.", 203 | "parameters": { 204 | "properties": { 205 | "query_list": { 206 | "description": "Palavras-chave para busca, com contagem de elementos da lista de 1.", 207 | "items": { "type": "string" }, 208 | "type": "array" 209 | }, 210 | "query_tag": { 211 | "description": "Classificação da consulta", 212 | "items": { "type": "string" }, 213 | "type": "array" 214 | } 215 | }, 216 | "required": [ "query_list", "query_tag" ], 217 | "type": "object" 218 | } 219 | } 220 | ] 221 | } 222 | ``` 223 | 224 | **Descrição dos Campos:** 225 | 226 | * `name`: Nome da função 227 | * `description`: Descrição da função 228 | * `parameters`: Definição dos parâmetros da função 229 | 230 | * `properties`: Definições dos parâmetros, onde a chave é o nome do parâmetro e o valor contém a descrição detalhada do parâmetro 231 | * `required`: Lista de parâmetros obrigatórios 232 | * `type`: Tipo de parâmetro (geralmente "object") 233 | 234 | ### Formato de Processamento Interno do Modelo 235 | 236 | Quando processadas internamente pelo modelo, as definições de função são convertidas para um formato especial e concatenadas ao texto de entrada: 237 | 238 | ``` 239 | system ai_setting=MiniMax AI 240 | MiniMax AI是由上海稀宇科技有限公司(MiniMax)自主研发的AI助理。 241 | system tool_setting=tools 242 | You are provided with these tools: 243 | 244 | {"name": "search_web", "description": "搜索函数。", "parameters": {"properties": {"query_list": {"description": "进行搜索的关键词,列表元素个数为1。", "items": {"type": "string"}, "type": "array"}, "query_tag": {"description": "query的分类", "items": {"type": "string"}, "type": "array"}}, "required": ["query_list", "query_tag"], "type": "object"}} 245 | 246 | If you need to call tools, please respond with XML tags, and provide tool-name and json-object of arguments, following the format below: 247 | 248 | {"name": , "arguments": } 249 | ... 250 | 251 | user name=用户 252 | OpenAI 和 Gemini 的最近一次发布会都是什么时候? 253 | ai name=MiniMax AI 254 | ``` 255 | 256 | ### Formato de Saída do Modelo 257 | 258 | O modelo gera chamadas de função no seguinte formato: 259 | 260 | ```xml 261 | 262 | Ok, vou procurar a versão mais recente do OpenAI e do Gemini. 263 | 264 | 265 | {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"OpenAI\" \"latest\" \"release\""]}} 266 | {"name": "search_web", "arguments": {"query_tag": ["technology", "events"], "query_list": ["\"Gemini\" \"latest\" \"release\""]}} 267 | 268 | ``` 269 | 270 | ## 📥 Análise Manual dos Resultados de Function Call 271 | 272 | ### Fazendo o Parse das Chamadas de Função 273 | 274 | Quando a análise manual é necessária, você precisa analisar o formato de tags XML da saída do modelo: 275 | 276 | ```python 277 | import re 278 | import json 279 | def parse_function_calls(content: str): 280 | """ 281 | Analisar chamadas de função da saída do modelo 282 | """ 283 | function_calls = [] 284 | 285 | # Corresponder conteúdo dentro das tags 286 | tool_calls_pattern = r"(.*?)" 287 | tool_calls_match = re.search(tool_calls_pattern, content, re.DOTALL) 288 | 289 | if not tool_calls_match: 290 | return function_calls 291 | 292 | tool_calls_content = tool_calls_match.group(1).strip() 293 | 294 | # Analisar cada chamada de função (um objeto JSON por linha) 295 | for line in tool_calls_content.split('\n'): 296 | line = line.strip() 297 | if not line: 298 | continue 299 | 300 | try: 301 | # Analisar chamada de função em formato JSON 302 | call_data = json.loads(line) 303 | function_name = call_data.get("name") 304 | arguments = call_data.get("arguments", {}) 305 | 306 | function_calls.append({ 307 | "name": function_name, 308 | "arguments": arguments 309 | }) 310 | 311 | print(f"Chamada de função: {function_name}, Argumentos: {arguments}") 312 | 313 | except json.JSONDecodeError as e: 314 | print(f"Falha na análise de parâmetros: {line}, Erro: {e}") 315 | 316 | return function_calls 317 | 318 | # Exemplo: Manipular função de consulta de clima 319 | def execute_function_call(function_name: str, arguments: dict): 320 | """ 321 | Executar chamada de função e retornar resultado 322 | """ 323 | if function_name == "get_current_weather": 324 | location = arguments.get("location", "Localização desconhecida") 325 | # Construir resultado da execução da função 326 | return { 327 | "role": "tool", 328 | "content": [ 329 | { 330 | "name": function_name, 331 | "type": "text", 332 | "text": json.dumps({ 333 | "location": location, 334 | "temperature": "25", 335 | "unit": "celsius", 336 | "weather": "Ensolarado" 337 | }, ensure_ascii=False) 338 | } 339 | ] 340 | } 341 | elif function_name == "search_web": 342 | query_list = arguments.get("query_list", []) 343 | query_tag = arguments.get("query_tag", []) 344 | # Simular resultados de pesquisa 345 | return { 346 | "role": "tool", 347 | "content": [ 348 | { 349 | "name": function_name, 350 | "type": "text", 351 | "text": f"Palavras-chave de busca: {query_list}, Categorias: {query_tag}\nResultados da busca: Informações relevantes encontradas" 352 | } 353 | ] 354 | } 355 | 356 | return None 357 | ``` 358 | 359 | ### Retornando os Resultados da Execução de Função para o Modelo 360 | 361 | Após analisar com sucesso as chamadas de função, você deve adicionar os resultados da execução da função ao histórico da conversa para que o modelo possa acessar e utilizar essas informações em interações subsequentes. 362 | 363 | #### Resultado Único 364 | 365 | Se o modelo chamar a função `search_web`, você pode se referir ao seguinte formato para adicionar resultados de execução, com o campo `name` sendo o nome específico da função. 366 | 367 | ```json 368 | { 369 | "role": "tool", 370 | "content": [ 371 | { 372 | "name": "search_web", 373 | "type": "text", 374 | "text": "test_result" 375 | } 376 | ] 377 | } 378 | ``` 379 | 380 | Formato de entrada correspondente do modelo: 381 | ``` 382 | tool name=tools 383 | tool name: search_web 384 | tool result: test_result 385 | 386 | ``` 387 | 388 | #### Vários Resultados 389 | 390 | Se o modelo chamar simultaneamente as funções `search_web` e `get_current_weather`, você pode se referir ao seguinte formato para adicionar resultados de execução, com `content` contendo vários resultados. 391 | 392 | ```json 393 | { 394 | "role": "tool", 395 | "content": [ 396 | { 397 | "name": "search_web", 398 | "type": "text", 399 | "text": "test_result1" 400 | }, 401 | { 402 | "name": "get_current_weather", 403 | "type": "text", 404 | "text": "test_result2" 405 | } 406 | ] 407 | } 408 | ``` 409 | 410 | Formato de entrada correspondente do modelo: 411 | ``` 412 | tool name=tools 413 | tool name: search_web 414 | tool result: test_result1 415 | tool name: get_current_weather 416 | tool result: test_result2 417 | ``` 418 | 419 | Embora recomendemos seguir os formatos acima, desde que a entrada retornada ao modelo seja fácil de entender, o conteúdo específico de `name` e `text` é inteiramente de sua escolha. 420 | 421 | ## 📚 Referências 422 | 423 | - [Repositório do Modelo MiniMax-M1](https://github.com/MiniMaxAI/MiniMax-M1) 424 | - [Página Principal do Projeto vLLM](https://github.com/vllm-project/vllm) 425 | - [PR de Function Calling do vLLM](https://github.com/vllm-project/vllm/pull/20297) 426 | - [SDK Python OpenAI](https://github.com/openai/openai-python) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | MiniMax 5 | 6 | 7 |
8 |
9 | 10 | 27 | 44 | 45 | # MiniMax-M1 46 | 47 | ## 1. Model Overview 48 | 49 | We introduce MiniMax-M1, the world's first open-weight, large-scale hybrid-attention reasoning model. 50 | MiniMax-M1 is powered by a hybrid Mixture-of-Experts (MoE) architecture combined with a lightning 51 | attention mechanism. The model is developed based on our previous [MiniMax-Text-01 model](https://huggingface.co/MiniMaxAI/MiniMax-Text-01), 52 | which contains a total of 456 billion parameters with 45.9 billion parameters activated 53 | per token. Consistent with MiniMax-Text-01, the M1 model natively supports a context length of 1 54 | million tokens, 8x the context size of DeepSeek R1. Furthermore, the lightning attention mechanism 55 | in MiniMax-M1 enables efficient scaling of test-time compute – For example, compared to DeepSeek 56 | R1, M1 consumes 25% of the FLOPs at a generation length of 100K tokens. These properties make M1 57 | particularly suitable for complex tasks that require processing long inputs and thinking extensively. 58 | MiniMax-M1 is trained using large-scale reinforcement learning (RL) on diverse problems ranging from 59 | traditional mathematical reasoning to sandbox-based, real-world software engineering environments. 60 | We develop an efficient RL scaling framework for M1 highlighting two perspectives: (1) We propose 61 | CISPO, a novel algorithm that clips importance sampling weights instead of token updates, which 62 | outperforms other competitive RL variants; (2) Our hybrid-attention design naturally enhances the 63 | efficiency of RL, where we address unique challenges when scaling RL with the hybrid architecture. We 64 | train two versions of MiniMax-M1 models with [40K](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) and 65 | [80K](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k) thinking budgets respectively. Experiments 66 | on standard benchmarks show that our models outperform other strong open-weight models such as 67 | the original DeepSeek-R1 and Qwen3-235B, particularly on complex software engineering, tool using, 68 | and long context tasks. With efficient scaling of test-time compute, MiniMax-M1 serves as a strong 69 | foundation for next-generation language model agents to reason and tackle real-world challenges. 70 | 71 |

72 | 73 |
74 | Benchmark performance comparison of leading commercial and open-weight models across competition-level mathematics, coding, software engineering, agentic tool use, and long-context understanding tasks. We use the MiniMax-M1-80k model here for MiniMax-M1. 75 |

76 | 77 | 78 | ## 2. Evaluation 79 | 80 | **Performance of MiniMax-M1 on core benchmarks.** 81 | 82 | 83 | | **Category** | **Task** | **MiniMax-M1-80K** | **MiniMax-M1-40K** | **Qwen3-235B-A22B** | **DeepSeek-R1-0528** | **DeepSeek-R1** | **Seed-Thinking-v1.5** | **Claude 4 Opus** | **Gemini 2.5 Pro (06-05)** | **OpenAI-o3** | 84 | |:---|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 85 | | | *Extended Thinking* | *80K* | *40K* | *32k* | *64k* | *32k* | *32k* | *64k* | *64k* | *100k* | 86 | | ***Mathematics*** | AIME 2024 | 86.0 | 83.3 | 85.7 | 91.4 | 79.8 | 86.7 | 76.0 | 92.0 | 91.6 | 87 | | | AIME 2025 | 76.9 | 74.6 | 81.5 | 87.5 | 70.0 | 74.0 | 75.5 | 88.0 | 88.9 | 88 | | | MATH-500 | 96.8 | 96.0 | 96.2 | 98.0 | 97.3 | 96.7 | 98.2 | 98.8 | 98.1 | 89 | | ***General Coding*** | LiveCodeBench *(24/8~25/5)* | 65.0 | 62.3 | 65.9 | 73.1 | 55.9 | 67.5 | 56.6 | 77.1 | 75.8 | 90 | | | FullStackBench | 68.3 | 67.6 | 62.9 | 69.4 | 70.1 | 69.9 | 70.3 | -- | 69.3 | 91 | | ***Reasoning & Knowledge***| GPQA Diamond | 70.0 | 69.2 | 71.1 | 81.0 | 71.5 | 77.3 | 79.6 | 86.4 | 83.3 | 92 | | | HLE *(no tools)* | 8.4\* | 7.2\* | 7.6\* | 17.7\* | 8.6\* | 8.2 | 10.7 | 21.6 | 20.3 | 93 | | | ZebraLogic | 86.8 | 80.1 | 80.3 | 95.1 | 78.7 | 84.4 | 95.1 | 91.6 | 95.8 | 94 | | | MMLU-Pro | 81.1 | 80.6 | 83.0 | 85.0 | 84.0 | 87.0 | 85.0 | 86.0 | 85.0 | 95 | | ***Software Engineering***| SWE-bench Verified| 56.0 | 55.6 | 34.4 | 57.6 | 49.2 | 47.0 | 72.5 | 67.2 | 69.1 | 96 | | ***Long Context*** | OpenAI-MRCR *(128k)* | 73.4 | 76.1 | 27.7 | 51.5 | 35.8 | 54.3 | 48.9 | 76.8 | 56.5 | 97 | | | OpenAI-MRCR *(1M)* | 56.2 | 58.6 | -- | -- | -- | -- | -- | 58.8 | -- | 98 | | | LongBench-v2 | 61.5 | 61.0 | 50.1 | 52.1 | 58.3 | 52.5 | 55.6 | 65.0 | 58.8 | 99 | | ***Agentic Tool Use***| TAU-bench *(airline)* | 62.0 | 60.0 | 34.7 | 53.5 | -- | 44.0 | 59.6 | 50.0 | 52.0 | 100 | | | TAU-bench *(retail)* | 63.5 | 67.8 | 58.6 | 63.9 | -- | 55.7 | 81.4 | 67.0 | 73.9 | 101 | | ***Factuality*** | SimpleQA | 18.5 | 17.9 | 11.0 | 27.8 | 30.1 | 12.9 | -- | 54.0 | 49.4 | 102 | | ***General Assistant***| MultiChallenge | 44.7 | 44.7 | 40.0 | 45.0 | 40.7 | 43.0 | 45.8 | 51.8 | 56.5 | 103 | 104 | \* conducted on the text-only HLE subset. 105 | 106 | Our models are evaluated with `temperature=1.0`, `top_p=0.95`. 107 | 108 | ### SWE-bench methodology 109 | We report results derived from the Agentless scaffold. Departing from the original pipeline, our methodology employs a two-stage localization process (without any embedding-based retrieval mechanisms): initial coarse-grained file localization followed by fine-grained localization to specific files and code elements. The values for our models are calculated on the subset of n=486 verified tasks which work on our infrastructure. The excluded 14 test cases that were incompatible with our internal infrastructure are: 110 | `"astropy__astropy-7606"`, 111 | `"astropy__astropy-8707"`, 112 | `"astropy__astropy-8872"`, 113 | `"django__django-10097"`, 114 | `"matplotlib__matplotlib-20488"`, 115 | `"psf__requests-2317"`, 116 | `"psf__requests-2931"`, 117 | `"psf__requests-5414"`, 118 | `"pylint-dev__pylint-6528"`, 119 | `"pylint-dev__pylint-7277"`, 120 | `"sphinx-doc__sphinx-10435"`, 121 | `"sphinx-doc__sphinx-7985"`, 122 | `"sphinx-doc__sphinx-8269"`, 123 | `"sphinx-doc__sphinx-8475"` 124 | 125 | ### TAU-bench methodology 126 | We evaluate TAU-Bench with GPT-4.1 as user model and without any custom tools. The maximum number of interaction steps is 40. 127 | Our general system prompt is: 128 | ``` 129 | - In each round, you need to carefully examine the tools provided to you to determine if any can be used. 130 | - You must adhere to all of the policies. Pay attention to the details in the terms. Solutions for most situations can be found within these policies. 131 | ``` 132 | 133 | ## 3. Recommendations for Minimax-M1 Model Usage 134 | 135 | To achieve the best results with the Minimax-M1 model, we suggest focusing on two key points: Inference Parameters and the System Prompt. 136 | 137 | ### 3.1. Inference Parameters 138 | - Temperature: **`1.0`** 139 | - Top_p: **`0.95`** 140 | 141 | This setting is optimal for encouraging creativity and diversity in the model's responses. It allows the model to explore a wider range of linguistic possibilities, preventing outputs that are too rigid or repetitive, while still maintaining strong logical coherence. 142 | 143 | ### 3.2. System Prompt 144 | Tailoring your system prompt to the specific task is crucial for guiding the model effectively. Below are suggested settings for different scenarios. 145 | 146 | #### A. General-Purpose Scenarios 147 | For common tasks like summarization, translation, Q&A, or creative writing: 148 | ``` 149 | You are a helpful assistant. 150 | ``` 151 | #### B. Web Development Scenarios 152 | For complex tasks like generating code for web pages: 153 | ``` 154 | You are a web development engineer, writing web pages according to the instructions below. You are a powerful code editing assistant capable of writing code and creating artifacts in conversations with users, or modifying and updating existing artifacts as requested by users. 155 | All code is written in a single code block to form a complete code file for display, without separating HTML and JavaScript code. An artifact refers to a runnable complete code snippet, you prefer to integrate and output such complete runnable code rather than breaking it down into several code blocks. For certain types of code, they can render graphical interfaces in a UI window. After generation, please check the code execution again to ensure there are no errors in the output. 156 | Output only the HTML, without any additional descriptive text. Make the UI looks modern and beautiful. 157 | ``` 158 | #### C. Mathematical Scenarios 159 | When dealing with problems that require calculation or logical deduction: 160 | ``` 161 | Please reason step by step, and put your final answer within \boxed{}. 162 | ``` 163 | 164 | ## 4. Deployment Guide 165 | 166 | Download the model from HuggingFace repository: 167 | - [MiniMax-M1-40k](https://huggingface.co/MiniMaxAI/MiniMax-M1-40k) 168 | - [MiniMax-M1-80k](https://huggingface.co/MiniMaxAI/MiniMax-M1-80k) 169 | 170 | For production deployment, we recommend using [vLLM](https://docs.vllm.ai/en/latest/) to serve MiniMax-M1. vLLM provides excellent performance for serving large language models with the following features: 171 | - 🔥 Outstanding service throughout performance 172 | - ⚡ Efficient and intelligent memory management 173 | - 📦 Powerful batch request processing capability 174 | - ⚙️ Deeply optimized underlying performance 175 | 176 | For detailed vLLM deployment instructions, please refer to our [vLLM Deployment Guide](./docs/vllm_deployment_guide.md). 177 | Alternatively, you can also deploy using Transformers directly. For detailed Transformers deployment instructions, you can see our [MiniMax-M1 Transformers Deployment Guide](./docs/transformers_deployment_guide.md). 178 | 179 | 180 | ## 5. Function Calling 181 | 182 | The MiniMax-M1 model supports function calling capabilities, enabling the model to identify when external functions need to be called and output function call parameters in a structured format. [MiniMax-M1 Function Call Guide](./docs/function_call_guide.md) provides detailed instructions on how to use the function calling feature of MiniMax-M1. 183 | 184 | 185 | ## 6. Chatbot & API 186 | For general use and evaluation, we provide a [Chatbot](https://chat.minimax.io/) with online search capabilities and the [online API](https://www.minimax.io/platform/) for developers. For general use and evaluation, we provide the [MiniMax MCP Server](https://github.com/MiniMax-AI/MiniMax-MCP) with video generation, image generation, speech synthesis, and voice cloning for developers. 187 | 188 | 189 | ## 7. Citation 190 | ``` 191 | @misc{minimax2025minimaxm1scalingtesttimecompute, 192 | title={MiniMax-M1: Scaling Test-Time Compute Efficiently with Lightning Attention}, 193 | author={MiniMax}, 194 | year={2025}, 195 | eprint={2506.13585}, 196 | archivePrefix={arXiv}, 197 | primaryClass={cs.CL}, 198 | url={https://arxiv.org/abs/2506.13585}, 199 | } 200 | ``` 201 | 202 | ## 8. Contact Us 203 | Contact us at [model@minimax.io](mailto:model@minimax.io). -------------------------------------------------------------------------------- /modeling_minimax_m1.py: -------------------------------------------------------------------------------- 1 | """ PyTorch MiniMaxM1 model.""" 2 | import inspect 3 | import math 4 | import warnings 5 | from typing import List, Optional, Tuple, Union 6 | import os 7 | import copy 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from torch import nn 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from einops import rearrange, repeat 14 | from transformers.activations import ACT2FN 15 | from transformers.cache_utils import Cache, DynamicCache 16 | from transformers.modeling_attn_mask_utils import ( 17 | _prepare_4d_causal_attention_mask, 18 | ) 19 | from transformers.modeling_outputs import ( 20 | MoeCausalLMOutputWithPast, 21 | MoeModelOutputWithPast, 22 | SequenceClassifierOutputWithPast, 23 | ) 24 | from transformers.modeling_utils import PreTrainedModel 25 | from transformers.utils import ( 26 | add_start_docstrings, 27 | add_start_docstrings_to_model_forward, 28 | is_flash_attn_2_available, 29 | is_flash_attn_greater_or_equal_2_10, 30 | logging, 31 | replace_return_docstrings, 32 | ) 33 | from transformers.utils.import_utils import is_torch_fx_available 34 | from .configuration_minimax_m1 import MiniMaxM1Config 35 | 36 | if is_flash_attn_2_available(): 37 | from flash_attn import flash_attn_func, flash_attn_varlen_func 38 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 39 | 40 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 41 | 42 | # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. 43 | # It means that the function will not be traced through and simply appear as a node in the graph. 44 | if is_torch_fx_available(): 45 | _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) 46 | 47 | use_triton = eval(os.environ.get("use_triton", default="False")) 48 | debug = eval(os.environ.get("debug", default="False")) 49 | do_eval = eval(os.environ.get("do_eval", default="False")) 50 | eval_and_not_generate = eval(os.environ.get("eval_and_not_generate", default="False")) 51 | BLOCK = 256 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CONFIG_FOR_DOC = "MiniMaxM1Config" 56 | 57 | 58 | def get_activation_fn(activation): 59 | if debug: 60 | logger.info(f"activation: {activation}") 61 | if activation == "gelu": 62 | return F.gelu 63 | elif activation == "relu": 64 | return F.relu 65 | elif activation == "elu": 66 | return F.elu 67 | elif activation == "sigmoid": 68 | return F.sigmoid 69 | elif activation == "exp": 70 | 71 | def f(x): 72 | with torch.no_grad(): 73 | x_max = torch.max(x, dim=-1, keepdims=True).values 74 | y = torch.exp(x - x_max) 75 | 76 | return y 77 | 78 | return f 79 | elif activation == "leak": 80 | return F.leaky_relu 81 | elif activation == "1+elu": 82 | 83 | def f(x): 84 | return 1 + F.elu(x) 85 | 86 | return f 87 | elif activation == "2+elu": 88 | 89 | def f(x): 90 | return 2 + F.elu(x) 91 | 92 | return f 93 | elif activation == "silu" or activation == "swish": 94 | return F.silu 95 | elif activation == "sine": 96 | return torch.sin 97 | else: 98 | logger.info( 99 | f"activation: does not support {activation}, use Identity!!!") 100 | return lambda x: x 101 | 102 | 103 | def load_balancing_loss_func( 104 | gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, 105 | attention_mask: Optional[torch.Tensor] = None 106 | ) -> float: 107 | r""" 108 | Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. 109 | 110 | See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss 111 | function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between 112 | experts is too unbalanced. 113 | 114 | Args: 115 | gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): 116 | Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of 117 | shape [batch_size X sequence_length, num_experts]. 118 | attention_mask (`torch.Tensor`, None): 119 | The attention_mask used in forward function 120 | shape [batch_size X sequence_length] if not None. 121 | num_experts (`int`, *optional*): 122 | Number of experts 123 | 124 | Returns: 125 | The auxiliary loss. 126 | """ 127 | if gate_logits is None or not isinstance(gate_logits, tuple): 128 | return 0 129 | 130 | if isinstance(gate_logits, tuple): 131 | compute_device = gate_logits[0].device 132 | concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) 133 | 134 | routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) 135 | 136 | _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) 137 | 138 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) 139 | 140 | if attention_mask is None: 141 | # Compute the percentage of tokens routed to each experts 142 | tokens_per_expert = torch.mean(expert_mask.float(), dim=0) 143 | 144 | # Compute the average probability of routing to these experts 145 | router_prob_per_expert = torch.mean(routing_weights, dim=0) 146 | else: 147 | batch_size, sequence_length = attention_mask.shape 148 | num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) 149 | 150 | # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask 151 | expert_attention_mask = ( 152 | attention_mask[None, :, :, None, None] 153 | .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) 154 | .reshape(-1, top_k, num_experts) 155 | .to(compute_device) 156 | ) 157 | 158 | # Compute the percentage of tokens routed to each experts 159 | tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( 160 | expert_attention_mask, dim=0 161 | ) 162 | 163 | # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert 164 | router_per_expert_attention_mask = ( 165 | attention_mask[None, :, :, None] 166 | .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) 167 | .reshape(-1, num_experts) 168 | .to(compute_device) 169 | ) 170 | 171 | # Compute the average probability of routing to these experts 172 | router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( 173 | router_per_expert_attention_mask, dim=0 174 | ) 175 | 176 | overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) 177 | return overall_loss * num_experts 178 | 179 | 180 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 181 | def _get_unpad_data(attention_mask): 182 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 183 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 184 | max_seqlen_in_batch = seqlens_in_batch.max().item() 185 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 186 | return ( 187 | indices, 188 | cu_seqlens, 189 | max_seqlen_in_batch, 190 | ) 191 | 192 | 193 | class GLU(nn.Module): 194 | 195 | def __init__(self, d1, d2, bias=False): 196 | super().__init__() 197 | 198 | self.l1 = nn.Linear(d1, d2, bias=bias) 199 | self.l2 = nn.Linear(d1, d2, bias=bias) 200 | self.l3 = nn.Linear(d2, d1, bias=bias) 201 | 202 | def forward(self, x): 203 | o1 = self.l1(x) 204 | o2 = self.l2(x) 205 | output = o1 * o2 206 | output = self.l3(output) 207 | return output 208 | 209 | 210 | class MiniMaxM1LightningAttention(nn.Module): 211 | def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None): 212 | super().__init__() 213 | bias = False 214 | self.hidden_size = config.hidden_size 215 | self.num_heads = config.num_attention_heads 216 | self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads) 217 | 218 | self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias) 219 | self.act = get_activation_fn(config.hidden_act) 220 | self.norm = MiniMaxM1RMSNorm(self.head_dim * self.num_heads) 221 | 222 | self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias) 223 | self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias) 224 | 225 | # for inference only 226 | self.offset = 0 227 | self.layer_idx = layer_idx 228 | 229 | def forward( 230 | self, 231 | hidden_states, 232 | attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m) 233 | output_attentions: bool = False, 234 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 235 | use_cache: bool = False, 236 | slope_rate: Optional[torch.Tensor] = None, 237 | **kwargs 238 | ): 239 | if (not self.training) and (not do_eval): 240 | return self.inference( 241 | hidden_states, 242 | attn_mask, 243 | output_attentions, 244 | past_key_value, 245 | use_cache, 246 | slope_rate, 247 | ) 248 | 249 | def inference( 250 | self, 251 | x, 252 | attn_mask: Optional[torch.Tensor] = None, # (b, n) 253 | output_attentions: bool = False, 254 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 255 | use_cache: bool = False, 256 | slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1) 257 | ): 258 | # x: b n d 259 | b, n, d = x.shape 260 | # linear map 261 | qkv = self.act(self.qkv_proj(x)) 262 | new_shape = qkv.size()[:-1] + (self.num_heads, -1) 263 | qkv = qkv.view(*new_shape) 264 | q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3) 265 | q = q.transpose(1, 2) 266 | k = k.transpose(1, 2) 267 | v = v.transpose(1, 2) 268 | 269 | if past_key_value is None: 270 | self.offset = q.shape[-2] 271 | else: 272 | self.offset += 1 273 | 274 | # for align with metaseq 275 | ratio = torch.exp(-slope_rate) 276 | 277 | # only use for the first time 278 | if past_key_value is None: 279 | slope_rate = slope_rate.to(torch.float32) 280 | if attn_mask is not None: 281 | v = v.masked_fill((1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) 282 | NUM_BLOCK = (n + BLOCK - 1) // BLOCK 283 | b, h, n, d = q.shape 284 | e = v.shape[-1] 285 | # other 286 | array = torch.arange(BLOCK).to(q) + 1 287 | q_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) 288 | k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) 289 | index = array[:, None] - array[None, :] 290 | s_index = slope_rate * index[ 291 | None, 292 | None, 293 | ] 294 | s_index = torch.where(index >= 0, -s_index, float("-inf")) 295 | diag_decay = torch.exp(s_index) 296 | 297 | kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) 298 | output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) 299 | for i in range(NUM_BLOCK): 300 | si = i * BLOCK 301 | ei = min(si + BLOCK, n) 302 | m = ei - si 303 | qi = q[:, :, si:ei].contiguous() 304 | ki = k[:, :, si:ei].contiguous() 305 | vi = v[:, :, si:ei].contiguous() 306 | qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32) 307 | 308 | # diag 309 | qk = torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m] 310 | qkv_diag = torch.matmul(qk, vi.to(torch.float32)) 311 | block_decay = torch.exp(-slope_rate * m) 312 | output[:, :, si:ei] = qkv_none_diag + qkv_diag 313 | kv = block_decay * kv + torch.matmul((ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi) 314 | 315 | else: 316 | kv = past_key_value 317 | output = [] 318 | for i in range(n): 319 | kv = ratio * kv + torch.einsum( 320 | "... n d, ... n e -> ... d e", 321 | k[:, :, i:i + 1], 322 | v[:, :, i:i + 1], 323 | ) 324 | qkv = torch.einsum("... n e, ... e d -> ... n d", q[:, :, i:i + 1], kv.to(q.dtype)) 325 | output.append(qkv) 326 | output = torch.concat(output, dim=-2) 327 | # reshape 328 | output = rearrange(output, "b h n d -> b n (h d)") 329 | # normalize 330 | output = self.norm(output) 331 | # gate 332 | output = F.sigmoid(self.output_gate(x)) * output 333 | # outproj 334 | output = self.out_proj(output) 335 | 336 | attn_weights = None 337 | 338 | return output, attn_weights, kv 339 | 340 | 341 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxM1 342 | class MiniMaxM1RMSNorm(nn.Module): 343 | def __init__(self, hidden_size, eps=1e-6): 344 | """ 345 | MiniMaxM1RMSNorm is equivalent to T5LayerNorm 346 | """ 347 | super().__init__() 348 | self.weight = nn.Parameter(torch.ones(hidden_size)) 349 | self.variance_epsilon = eps 350 | 351 | def forward(self, hidden_states): 352 | input_dtype = hidden_states.dtype 353 | hidden_states = hidden_states.to(torch.float32) 354 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 355 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 356 | return self.weight * hidden_states.to(input_dtype) 357 | 358 | 359 | # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxM1 360 | class MiniMaxM1RotaryEmbedding(nn.Module): 361 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 362 | super().__init__() 363 | 364 | self.dim = dim 365 | self.max_position_embeddings = max_position_embeddings 366 | self.base = base 367 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) 368 | self.register_buffer("inv_freq", inv_freq, persistent=False) 369 | 370 | # Build here to make `torch.jit.trace` work. 371 | self._set_cos_sin_cache( 372 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32 373 | ) 374 | 375 | def _set_cos_sin_cache(self, seq_len, device, dtype): 376 | self.max_seq_len_cached = seq_len 377 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) 378 | 379 | freqs = torch.outer(t, self.inv_freq) 380 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 381 | emb = torch.cat((freqs, freqs), dim=-1) 382 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 383 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 384 | 385 | def forward(self, x, seq_len=None): 386 | # x: [bs, num_attention_heads, seq_len, head_size] 387 | if seq_len > self.max_seq_len_cached: 388 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32) 389 | 390 | return ( 391 | self.cos_cached[:seq_len].to(dtype=torch.float32), 392 | self.sin_cached[:seq_len].to(dtype=torch.float32), 393 | ) 394 | 395 | 396 | # Copied from transformers.models.llama.modeling_llama.rotate_half 397 | def rotate_half(x): 398 | """Rotates half the hidden dims of the input.""" 399 | x1 = x[..., : x.shape[-1] // 2] 400 | x2 = x[..., x.shape[-1] // 2:] 401 | return torch.cat((-x2, x1), dim=-1) 402 | 403 | 404 | # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb 405 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 406 | """Applies Rotary Position Embedding to the query and key tensors. 407 | 408 | Args: 409 | q (`torch.Tensor`): The query tensor. 410 | k (`torch.Tensor`): The key tensor. 411 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 412 | sin (`torch.Tensor`): The sine part of the rotary embedding. 413 | position_ids (`torch.Tensor`): 414 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 415 | used to pass offsetted position ids when working with a KV-cache. 416 | unsqueeze_dim (`int`, *optional*, defaults to 1): 417 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 418 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 419 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 420 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 421 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 422 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 423 | Returns: 424 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 425 | """ 426 | dtype = q.dtype 427 | rot_dim = cos.shape[-1] 428 | q_, q_pass = q[..., :rot_dim], q[..., rot_dim:] 429 | k_, k_pass = k[..., :rot_dim], k[..., rot_dim:] 430 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 431 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 432 | q_embed = (q_ * cos) + (rotate_half(q_) * sin) 433 | k_embed = (k_ * cos) + (rotate_half(k_) * sin) 434 | return torch.cat((q_embed, q_pass), dim=-1).to(dtype), torch.cat((k_embed, k_pass), dim=-1).to(dtype) 435 | 436 | 437 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 438 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 439 | """ 440 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 441 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 442 | """ 443 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 444 | if n_rep == 1: 445 | return hidden_states 446 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 447 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 448 | 449 | 450 | # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxM1 451 | class MiniMaxM1Attention(nn.Module): 452 | """ 453 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer 454 | and "Generating Long Sequences with Sparse Transformers". 455 | """ 456 | 457 | def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None): 458 | super().__init__() 459 | self.config = config 460 | self.layer_idx = layer_idx 461 | if layer_idx is None: 462 | logger.warning_once( 463 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " 464 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " 465 | "when creating this class." 466 | ) 467 | 468 | self.hidden_size = config.hidden_size 469 | self.num_heads = config.num_attention_heads 470 | self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads) 471 | self.num_key_value_heads = config.num_key_value_heads 472 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 473 | self.max_position_embeddings = config.max_position_embeddings 474 | self.rope_theta = config.rope_theta 475 | self.is_causal = True 476 | self.attention_dropout = config.attention_dropout 477 | 478 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 479 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 480 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 481 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 482 | self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim) 483 | 484 | self.rotary_emb = MiniMaxM1RotaryEmbedding( 485 | self.rotary_dim, 486 | max_position_embeddings=self.max_position_embeddings, 487 | base=self.rope_theta, 488 | ) 489 | 490 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 491 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 492 | 493 | def forward( 494 | self, 495 | hidden_states: torch.Tensor, 496 | attention_mask: Optional[torch.Tensor] = None, 497 | position_ids: Optional[torch.LongTensor] = None, 498 | past_key_value: Optional[Cache] = None, 499 | output_attentions: bool = False, 500 | use_cache: bool = False, 501 | **kwargs, 502 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 503 | if "padding_mask" in kwargs: 504 | warnings.warn( 505 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 506 | ) 507 | bsz, q_len, _ = hidden_states.size() 508 | 509 | query_states = self.q_proj(hidden_states) 510 | key_states = self.k_proj(hidden_states) 511 | value_states = self.v_proj(hidden_states) 512 | 513 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 514 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 515 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 516 | 517 | kv_seq_len = key_states.shape[-2] 518 | if past_key_value is not None: 519 | if self.layer_idx is None: 520 | raise ValueError( 521 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 522 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 523 | "with a layer index." 524 | ) 525 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 526 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 527 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 528 | 529 | if past_key_value is not None: 530 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 531 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 532 | 533 | # repeat k/v heads if n_kv_heads < n_heads 534 | key_states = repeat_kv(key_states, self.num_key_value_groups) 535 | value_states = repeat_kv(value_states, self.num_key_value_groups) 536 | 537 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 538 | 539 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 540 | raise ValueError( 541 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 542 | f" {attn_weights.size()}" 543 | ) 544 | 545 | if attention_mask is not None: 546 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 547 | raise ValueError( 548 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 549 | ) 550 | 551 | attn_weights = attn_weights + attention_mask 552 | 553 | # upcast attention to fp32 554 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 555 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 556 | attn_output = torch.matmul(attn_weights, value_states) 557 | 558 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 559 | raise ValueError( 560 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 561 | f" {attn_output.size()}" 562 | ) 563 | 564 | attn_output = attn_output.transpose(1, 2).contiguous() 565 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 566 | 567 | attn_output = self.o_proj(attn_output) 568 | 569 | if not output_attentions: 570 | attn_weights = None 571 | 572 | return attn_output, attn_weights, past_key_value 573 | 574 | 575 | # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxM1 576 | class MiniMaxM1FlashAttention2(MiniMaxM1Attention): 577 | """ 578 | MiniMaxM1 flash attention module. This module inherits from `MiniMaxM1Attention` as the weights of the module stays 579 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 580 | flash attention and deal with padding tokens in case the input contains any of them. 581 | """ 582 | 583 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 584 | def __init__(self, *args, **kwargs): 585 | super().__init__(*args, **kwargs) 586 | 587 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 588 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 589 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 590 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 591 | 592 | def forward( 593 | self, 594 | hidden_states: torch.Tensor, 595 | attention_mask: Optional[torch.Tensor] = None, 596 | position_ids: Optional[torch.LongTensor] = None, 597 | past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None, 598 | output_attentions: bool = False, 599 | use_cache: bool = False, 600 | **kwargs, 601 | ): 602 | if "padding_mask" in kwargs: 603 | warnings.warn( 604 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 605 | ) 606 | 607 | # overwrite attention_mask with padding_mask 608 | attention_mask = kwargs.pop("padding_mask") 609 | bsz, q_len, _ = hidden_states.size() 610 | 611 | query_states = self.q_proj(hidden_states) 612 | key_states = self.k_proj(hidden_states) 613 | value_states = self.v_proj(hidden_states) 614 | 615 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 616 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 617 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 618 | 619 | kv_seq_len = key_states.shape[-2] 620 | if past_key_value is not None: 621 | kv_seq_len += past_key_value[0].shape[-3] 622 | 623 | # Because the input can be padded, the absolute sequence length depends on the max position id. 624 | rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 625 | cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) 626 | 627 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 628 | 629 | use_sliding_windows = ( 630 | _flash_supports_window_size 631 | and getattr(self.config, "sliding_window", None) is not None 632 | and kv_seq_len > self.config.sliding_window 633 | ) 634 | 635 | if not _flash_supports_window_size: 636 | logger.warning_once( 637 | "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" 638 | " make sure to upgrade flash-attn library." 639 | ) 640 | 641 | dropout_rate = 0.0 if not self.training else self.attention_dropout 642 | 643 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 644 | # therefore the input hidden states gets silently casted in float32. Hence, we need 645 | # cast them back in float16 just to be sure everything works as expected. 646 | input_dtype = query_states.dtype 647 | if input_dtype == torch.float32: 648 | if torch.is_autocast_enabled(): 649 | target_dtype = torch.get_autocast_gpu_dtype() 650 | # Handle the case where the model is quantized 651 | elif hasattr(self.config, "_pre_quantization_dtype"): 652 | target_dtype = self.config._pre_quantization_dtype 653 | else: 654 | target_dtype = self.q_proj.weight.dtype 655 | 656 | logger.warning_once( 657 | f"The input hidden states seems to be silently casted in float32, this might be related to" 658 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 659 | f" {target_dtype}." 660 | ) 661 | 662 | query_states = query_states.to(target_dtype) 663 | key_states = key_states.to(target_dtype) 664 | value_states = value_states.to(target_dtype) 665 | 666 | # Reshape to the expected shape for Flash Attention 667 | query_states = query_states.transpose(1, 2) 668 | key_states = key_states.transpose(1, 2) 669 | value_states = value_states.transpose(1, 2) 670 | 671 | if past_key_value is not None: 672 | # reuse k, v, for evaluation only 673 | key_states = torch.cat([past_key_value[0], key_states], dim=-3) 674 | value_states = torch.cat([past_key_value[1], value_states], dim=-3) 675 | 676 | past_key_value = (key_states, value_states) if use_cache else None 677 | 678 | attn_output = self._flash_attention_forward( 679 | query_states, 680 | key_states, 681 | value_states, 682 | attention_mask, 683 | q_len, 684 | dropout=dropout_rate, 685 | use_sliding_windows=use_sliding_windows, 686 | ) 687 | 688 | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() 689 | attn_output = self.o_proj(attn_output) 690 | 691 | if not output_attentions: 692 | attn_weights = None 693 | 694 | return attn_output, attn_weights, past_key_value 695 | 696 | def _flash_attention_forward( 697 | self, 698 | query_states, 699 | key_states, 700 | value_states, 701 | attention_mask, 702 | query_length, 703 | dropout=0.0, 704 | softmax_scale=None, 705 | use_sliding_windows=False, 706 | ): 707 | """ 708 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 709 | first unpad the input, then computes the attention scores and pad the final attention scores. 710 | 711 | Args: 712 | query_states (`torch.Tensor`): 713 | Input query states to be passed to Flash Attention API 714 | key_states (`torch.Tensor`): 715 | Input key states to be passed to Flash Attention API 716 | value_states (`torch.Tensor`): 717 | Input value states to be passed to Flash Attention API 718 | attention_mask (`torch.Tensor`): 719 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 720 | position of padding tokens and 1 for the position of non-padding tokens. 721 | dropout (`float`): 722 | Attention dropout 723 | softmax_scale (`float`, *optional*): 724 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 725 | use_sliding_windows (`bool`, *optional*): 726 | Whether to activate sliding window attention. 727 | """ 728 | if not self._flash_attn_uses_top_left_mask: 729 | causal = self.is_causal 730 | else: 731 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 732 | causal = self.is_causal and query_length != 1 733 | 734 | # Contains at least one padding token in the sequence 735 | if attention_mask is not None: 736 | batch_size = query_states.shape[0] 737 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 738 | query_states, key_states, value_states, attention_mask, query_length 739 | ) 740 | 741 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 742 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 743 | 744 | if not use_sliding_windows: 745 | attn_output_unpad = flash_attn_varlen_func( 746 | query_states, 747 | key_states, 748 | value_states, 749 | cu_seqlens_q=cu_seqlens_q, 750 | cu_seqlens_k=cu_seqlens_k, 751 | max_seqlen_q=max_seqlen_in_batch_q, 752 | max_seqlen_k=max_seqlen_in_batch_k, 753 | dropout_p=dropout, 754 | softmax_scale=softmax_scale, 755 | causal=causal, 756 | ) 757 | else: 758 | attn_output_unpad = flash_attn_varlen_func( 759 | query_states, 760 | key_states, 761 | value_states, 762 | cu_seqlens_q=cu_seqlens_q, 763 | cu_seqlens_k=cu_seqlens_k, 764 | max_seqlen_q=max_seqlen_in_batch_q, 765 | max_seqlen_k=max_seqlen_in_batch_k, 766 | dropout_p=dropout, 767 | softmax_scale=softmax_scale, 768 | causal=causal, 769 | window_size=(self.config.sliding_window, self.config.sliding_window), 770 | ) 771 | 772 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 773 | else: 774 | if not use_sliding_windows: 775 | attn_output = flash_attn_func( 776 | query_states, 777 | key_states, 778 | value_states, 779 | dropout, 780 | softmax_scale=softmax_scale, 781 | causal=causal, 782 | ) 783 | else: 784 | attn_output = flash_attn_func( 785 | query_states, 786 | key_states, 787 | value_states, 788 | dropout, 789 | softmax_scale=softmax_scale, 790 | causal=causal, 791 | window_size=(self.config.sliding_window, self.config.sliding_window), 792 | ) 793 | 794 | return attn_output 795 | 796 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 797 | batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape 798 | 799 | # On the first iteration we need to properly re-create the padding mask 800 | # by slicing it on the proper place 801 | if kv_seq_len != attention_mask.shape[-1]: 802 | attention_mask_num_tokens = attention_mask.shape[-1] 803 | attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len:] 804 | 805 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 806 | 807 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 808 | value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 809 | 810 | if query_length == kv_seq_len: 811 | query_layer = index_first_axis( 812 | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k 813 | ) 814 | cu_seqlens_q = cu_seqlens_k 815 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 816 | indices_q = indices_k 817 | elif query_length == 1: 818 | max_seqlen_in_batch_q = 1 819 | cu_seqlens_q = torch.arange( 820 | batch_size + 1, dtype=torch.int32, device=query_layer.device 821 | ) # There is a memcpy here, that is very bad. 822 | indices_q = cu_seqlens_q[:-1] 823 | query_layer = query_layer.squeeze(1) 824 | else: 825 | # The -q_len: slice assumes left padding. 826 | attention_mask = attention_mask[:, -query_length:] 827 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 828 | 829 | return ( 830 | query_layer, 831 | key_layer, 832 | value_layer, 833 | indices_q, 834 | (cu_seqlens_q, cu_seqlens_k), 835 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 836 | ) 837 | 838 | 839 | class MiniMaxM1MLP(nn.Module): 840 | def __init__(self, config): 841 | super().__init__() 842 | self.config = config 843 | self.hidden_size = config.hidden_size 844 | self.intermediate_size = config.intermediate_size 845 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 846 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 847 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 848 | self.act_fn = ACT2FN[config.hidden_act] 849 | 850 | def forward(self, x): 851 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 852 | return down_proj 853 | 854 | 855 | class MiniMaxM1BlockSparseTop2MLP(nn.Module): 856 | def __init__(self, config: MiniMaxM1Config): 857 | super().__init__() 858 | self.ffn_dim = config.intermediate_size 859 | self.hidden_dim = config.hidden_size 860 | 861 | self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) 862 | self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) 863 | self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) 864 | 865 | self.act_fn = ACT2FN[config.hidden_act] 866 | 867 | def forward(self, hidden_states): 868 | current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) 869 | current_hidden_states = self.w2(current_hidden_states) 870 | return current_hidden_states 871 | 872 | 873 | class MiniMaxM1BLockSparseTop2MLP(MiniMaxM1BlockSparseTop2MLP): 874 | def __init__(self, *args, **kwargs): 875 | logger.warning_once( 876 | "MiniMaxM1BLockSparseTop2MLP is deprecated by MiniMaxM1BlockSparseTop2MLP and will be removed in v4.40." 877 | ) 878 | super().__init__(*args, **kwargs) 879 | 880 | 881 | class MiniMaxM1SparseMoeBlock(nn.Module): 882 | """ 883 | This implementation is 884 | strictly equivalent to standard MoE with full capacity (no 885 | dropped tokens). It's faster since it formulates MoE operations 886 | in terms of block-sparse operations to accomodate imbalanced 887 | assignments of tokens to experts, whereas standard MoE either 888 | (1) drop tokens at the cost of reduced performance or (2) set 889 | capacity factor to number of experts and thus waste computation 890 | and memory on padding. 891 | """ 892 | 893 | def __init__(self, config): 894 | super().__init__() 895 | self.hidden_dim = config.hidden_size 896 | self.ffn_dim = config.intermediate_size 897 | self.num_experts = config.num_local_experts 898 | self.top_k = config.num_experts_per_tok 899 | 900 | # gating 901 | self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) 902 | 903 | self.experts = nn.ModuleList([MiniMaxM1BlockSparseTop2MLP(config) for _ in range(self.num_experts)]) 904 | 905 | # Jitter parameters 906 | self.jitter_noise = config.router_jitter_noise 907 | 908 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 909 | """ """ 910 | batch_size, sequence_length, hidden_dim = hidden_states.shape 911 | if self.training and self.jitter_noise > 0: 912 | hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) 913 | hidden_states = hidden_states.view(-1, hidden_dim) 914 | # router_logits: (batch * sequence_length, n_experts) 915 | router_logits = self.gate(hidden_states) 916 | 917 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 918 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) 919 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) 920 | # we cast back to the input dtype 921 | routing_weights = routing_weights.to(hidden_states.dtype) 922 | 923 | final_hidden_states = torch.zeros( 924 | (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device 925 | ) 926 | 927 | # One hot encode the selected experts to create an expert mask 928 | # this will be used to easily index which expert is going to be sollicitated 929 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) 930 | 931 | # Loop over all available experts in the model and perform the computation on each expert 932 | for expert_idx in range(self.num_experts): 933 | expert_layer = self.experts[expert_idx] 934 | idx, top_x = torch.where(expert_mask[expert_idx]) 935 | 936 | # Index the correct hidden states and compute the expert hidden state for 937 | # the current expert. We need to make sure to multiply the output hidden 938 | # states by `routing_weights` on the corresponding tokens (top-1 and top-2) 939 | current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) 940 | current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] 941 | 942 | # However `index_add_` only support torch tensors for indexing so we'll use 943 | # the `top_x` tensor here. 944 | final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) 945 | final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) 946 | return final_hidden_states, router_logits 947 | 948 | 949 | class MiniMaxM1DecoderLayer(nn.Module): 950 | def __init__(self, config: MiniMaxM1Config, layer_idx: int): 951 | super().__init__() 952 | self.config = config 953 | self.hidden_size = config.hidden_size 954 | 955 | self.self_attn = self.build_attn(config, layer_idx) 956 | 957 | self.layer_idx = layer_idx 958 | 959 | self.block_sparse_moe = MiniMaxM1SparseMoeBlock(config) 960 | self.input_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 961 | self.post_attention_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 962 | 963 | self.postnorm = getattr(config, 'postnorm', False) 964 | self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \ 965 | if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_alpha', 1) 966 | self.layernorm_attention_beta = getattr(config, 'layernorm_linear_attention_beta', 1) \ 967 | if config.attention_type == 0 else getattr(config, 'layernorm_full_attention_beta', 1) 968 | self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1) 969 | self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1) 970 | 971 | shared_intermediate = getattr(config, 'shared_intermediate_size', 0) 972 | self.shared_moe = False 973 | if shared_intermediate > 0: 974 | self.shared_moe = True 975 | self.shared_mlp = MiniMaxM1MLP(config) 976 | self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False) 977 | 978 | def build_attn(self, config, layer_idx): 979 | if config.attention_type == 0: 980 | Attention_module = MiniMaxM1LightningAttention 981 | else: 982 | Attention_module = MiniMaxM1FlashAttention2 983 | 984 | return Attention_module( 985 | config, 986 | layer_idx 987 | ) 988 | 989 | def forward( 990 | self, 991 | hidden_states: torch.Tensor, 992 | attention_mask: Optional[torch.Tensor] = None, 993 | position_ids: Optional[torch.LongTensor] = None, 994 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 995 | output_attentions: Optional[bool] = False, 996 | output_router_logits: Optional[bool] = False, 997 | use_cache: Optional[bool] = False, 998 | slope_rate: Optional[float] = None, 999 | **kwargs, 1000 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 1001 | if "padding_mask" in kwargs: 1002 | warnings.warn( 1003 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 1004 | ) 1005 | """ 1006 | Args: 1007 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 1008 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 1009 | `(batch, sequence_length)` where padding elements are indicated by 0. 1010 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 1011 | output_attentions (`bool`, *optional*): 1012 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1013 | returned tensors for more detail. 1014 | output_router_logits (`bool`, *optional*): 1015 | Whether or not to return the logits of all the routers. They are useful for computing the router loss, and 1016 | should not be returned during inference. 1017 | use_cache (`bool`, *optional*): 1018 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 1019 | (see `past_key_values`). 1020 | """ 1021 | 1022 | residual = hidden_states 1023 | 1024 | hidden_states = self.input_layernorm(hidden_states) 1025 | if self.postnorm: 1026 | residual = hidden_states 1027 | 1028 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 1029 | hidden_states=hidden_states, 1030 | position_ids=position_ids, 1031 | attn_mask=attention_mask, 1032 | past_key_value=past_key_value, 1033 | output_attentions=output_attentions, 1034 | use_cache=use_cache, 1035 | slope_rate=slope_rate, 1036 | ) 1037 | 1038 | hidden_states = residual * self.layernorm_attention_alpha \ 1039 | + hidden_states * self.layernorm_attention_beta 1040 | 1041 | # Fully Connected 1042 | residual = hidden_states 1043 | hidden_states = self.post_attention_layernorm(hidden_states) 1044 | if self.postnorm: 1045 | residual = hidden_states 1046 | 1047 | moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) 1048 | if self.shared_moe: 1049 | output_mlp = self.shared_mlp(hidden_states) 1050 | weight_fp32 = self.coefficient.weight.float() 1051 | coef = hidden_states.to(torch.float32) @ weight_fp32.T 1052 | coef = torch.nn.functional.sigmoid(coef).to(hidden_states.dtype) 1053 | hidden_states = moe_hidden_states * (1 - coef) + output_mlp * coef 1054 | else: 1055 | hidden_states = moe_hidden_states 1056 | 1057 | hidden_states = residual * self.layernorm_mlp_alpha \ 1058 | + hidden_states * self.layernorm_mlp_beta 1059 | 1060 | outputs = (hidden_states,) 1061 | 1062 | if output_attentions: 1063 | outputs += (self_attn_weights,) 1064 | 1065 | if use_cache: 1066 | outputs += (present_key_value,) 1067 | 1068 | if output_router_logits: 1069 | outputs += (router_logits,) 1070 | 1071 | return outputs 1072 | 1073 | 1074 | MIXTRAL_START_DOCSTRING = r""" 1075 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 1076 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 1077 | etc.) 1078 | 1079 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 1080 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 1081 | and behavior. 1082 | 1083 | Parameters: 1084 | config ([`MiniMaxM1Config`]): 1085 | Model configuration class with all the parameters of the model. Initializing with a config file does not 1086 | load the weights associated with the model, only the configuration. Check out the 1087 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 1088 | """ 1089 | 1090 | 1091 | @add_start_docstrings( 1092 | "The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.", 1093 | MIXTRAL_START_DOCSTRING, 1094 | ) 1095 | # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxM1 1096 | class MiniMaxM1PreTrainedModel(PreTrainedModel): 1097 | config_class = MiniMaxM1Config 1098 | base_model_prefix = "model" 1099 | supports_gradient_checkpointing = True 1100 | _no_split_modules = ["MiniMaxM1DecoderLayer"] 1101 | _skip_keys_device_placement = "past_key_values" 1102 | _supports_flash_attn_2 = True 1103 | _supports_sdpa = True 1104 | 1105 | def _init_weights(self, module): 1106 | std = self.config.initializer_range 1107 | if isinstance(module, nn.Linear): 1108 | module.weight.data.normal_(mean=0.0, std=std) 1109 | if module.bias is not None: 1110 | module.bias.data.zero_() 1111 | elif isinstance(module, nn.Embedding): 1112 | module.weight.data.normal_(mean=0.0, std=std) 1113 | if module.padding_idx is not None: 1114 | module.weight.data[module.padding_idx].zero_() 1115 | 1116 | 1117 | MIXTRAL_INPUTS_DOCSTRING = r""" 1118 | Args: 1119 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1120 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 1121 | it. 1122 | 1123 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1124 | [`PreTrainedTokenizer.__call__`] for details. 1125 | 1126 | [What are input IDs?](../glossary#input-ids) 1127 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1128 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1129 | 1130 | - 1 for tokens that are **not masked**, 1131 | - 0 for tokens that are **masked**. 1132 | 1133 | [What are attention masks?](../glossary#attention-mask) 1134 | 1135 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1136 | [`PreTrainedTokenizer.__call__`] for details. 1137 | 1138 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 1139 | `past_key_values`). 1140 | 1141 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 1142 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 1143 | information on the default strategy. 1144 | 1145 | - 1 indicates the head is **not masked**, 1146 | - 0 indicates the head is **masked**. 1147 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1148 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 1149 | config.n_positions - 1]`. 1150 | 1151 | [What are position IDs?](../glossary#position-ids) 1152 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 1153 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 1154 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 1155 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 1156 | 1157 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 1158 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 1159 | 1160 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 1161 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 1162 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 1163 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1164 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 1165 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 1166 | model's internal embedding lookup matrix. 1167 | use_cache (`bool`, *optional*): 1168 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 1169 | `past_key_values`). 1170 | output_attentions (`bool`, *optional*): 1171 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1172 | tensors for more detail. 1173 | output_hidden_states (`bool`, *optional*): 1174 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1175 | more detail. 1176 | output_router_logits (`bool`, *optional*): 1177 | Whether or not to return the logits of all the routers. They are useful for computing the router loss, and 1178 | should not be returned during inference. 1179 | return_dict (`bool`, *optional*): 1180 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1181 | """ 1182 | 1183 | 1184 | @add_start_docstrings( 1185 | "The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.", 1186 | MIXTRAL_START_DOCSTRING, 1187 | ) 1188 | # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxM1 1189 | class MiniMaxM1Model(MiniMaxM1PreTrainedModel): 1190 | """ 1191 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxM1DecoderLayer`] 1192 | 1193 | Args: 1194 | config: MiniMaxM1Config 1195 | """ 1196 | 1197 | def __init__(self, config: MiniMaxM1Config): 1198 | super().__init__(config) 1199 | self.padding_idx = config.pad_token_id 1200 | self.vocab_size = config.vocab_size 1201 | 1202 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 1203 | self.attn_type_list = config.attn_type_list 1204 | config_copy = copy.deepcopy(config) 1205 | 1206 | self.layers = nn.ModuleList([]) 1207 | for i in range(config.num_hidden_layers): 1208 | _config = copy.deepcopy(config) 1209 | if self.attn_type_list[i] == 0: 1210 | _config._attn_implementation = 'linear_attention' 1211 | _config.attention_type = 0 1212 | else: 1213 | _config._attn_implementation = config_copy._attn_implementation 1214 | _config.attention_type = 1 1215 | self.layers.append(MiniMaxM1DecoderLayer(_config, i)) 1216 | 1217 | self._attn_implementation = config_copy._attn_implementation 1218 | self.norm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1219 | 1220 | self.gradient_checkpointing = False 1221 | self.slopes = self._build_slope_tensor(config.num_attention_heads) 1222 | # mask 1223 | self._linear_attn_mask = torch.empty(0) 1224 | 1225 | # Initialize weights and apply final processing 1226 | self.post_init() 1227 | 1228 | def get_input_embeddings(self): 1229 | return self.embed_tokens 1230 | 1231 | def set_input_embeddings(self, value): 1232 | self.embed_tokens = value 1233 | 1234 | @staticmethod 1235 | def _build_slope_tensor(n_attention_heads: int): 1236 | 1237 | def get_slopes(n): 1238 | 1239 | def get_slopes_power_of_2(n): 1240 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 1241 | ratio = start 1242 | return [start * ratio ** i for i in range(n)] 1243 | 1244 | if math.log2(n).is_integer(): 1245 | return get_slopes_power_of_2( 1246 | n) # In the paper, we only train models that have 2^a heads for some a. This function has 1247 | else: # some good properties that only occur when the input is a power of 2. To maintain that even 1248 | closest_power_of_2 = 2 ** math.floor( 1249 | math.log2(n)) # when the number of heads is not a power of 2, we use this workaround. 1250 | return (get_slopes_power_of_2(closest_power_of_2) 1251 | + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) 1252 | 1253 | # h, 1, 1 1254 | slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32).reshape(n_attention_heads, 1, 1) 1255 | 1256 | return slopes 1257 | 1258 | # Ignore copy 1259 | @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 1260 | def forward( 1261 | self, 1262 | input_ids: torch.LongTensor = None, 1263 | attention_mask: Optional[torch.Tensor] = None, 1264 | position_ids: Optional[torch.LongTensor] = None, 1265 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1266 | inputs_embeds: Optional[torch.FloatTensor] = None, 1267 | use_cache: Optional[bool] = None, 1268 | output_attentions: Optional[bool] = None, 1269 | output_hidden_states: Optional[bool] = None, 1270 | output_router_logits: Optional[bool] = None, 1271 | return_dict: Optional[bool] = None, 1272 | ) -> Union[Tuple, MoeModelOutputWithPast]: 1273 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1274 | output_router_logits = ( 1275 | output_router_logits if output_router_logits is not None else self.config.output_router_logits 1276 | ) 1277 | output_hidden_states = ( 1278 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1279 | ) 1280 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1281 | 1282 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1283 | 1284 | # retrieve input_ids and inputs_embeds 1285 | if input_ids is not None and inputs_embeds is not None: 1286 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1287 | elif input_ids is not None: 1288 | batch_size, seq_length = input_ids.shape 1289 | default_device = input_ids.device 1290 | elif inputs_embeds is not None: 1291 | batch_size, seq_length, _ = inputs_embeds.shape 1292 | default_device = inputs_embeds.device 1293 | else: 1294 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1295 | 1296 | past_key_values_length = 0 1297 | 1298 | if self.gradient_checkpointing and self.training: 1299 | if use_cache: 1300 | logger.warning_once( 1301 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1302 | ) 1303 | use_cache = False 1304 | 1305 | seq_length_with_past = seq_length 1306 | if past_key_values is not None: 1307 | for idx in range(len(past_key_values)): 1308 | if self.attn_type_list[idx] == 1: 1309 | past_key_values_length = past_key_values[idx][0].shape[-3] 1310 | seq_length_with_past = seq_length_with_past + past_key_values_length 1311 | break 1312 | 1313 | if position_ids is None: 1314 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1315 | position_ids = torch.arange( 1316 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 1317 | ) 1318 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1319 | else: 1320 | position_ids = position_ids.view(-1, seq_length).long() 1321 | 1322 | if inputs_embeds is None: 1323 | inputs_embeds = self.embed_tokens(input_ids) 1324 | 1325 | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: 1326 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size 1327 | if is_padding_right: 1328 | raise ValueError( 1329 | "You are attempting to perform batched generation with padding_side='right'" 1330 | " this may lead to unexpected behaviour for Flash Attention version of MiniMaxM1. Make sure to " 1331 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 1332 | ) 1333 | slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))] 1334 | hidden_states = inputs_embeds 1335 | # decoder layers 1336 | all_hidden_states = () if output_hidden_states else None 1337 | all_self_attns = () if output_attentions else None 1338 | all_router_logits = () if output_router_logits else None 1339 | next_decoder_cache = () if use_cache else None 1340 | 1341 | for idx, decoder_layer in enumerate(self.layers): 1342 | if output_hidden_states: 1343 | all_hidden_states += (hidden_states,) 1344 | 1345 | past_key_value = (past_key_values[idx] if past_key_values is not None else None) 1346 | attn_mask = attention_mask 1347 | slope_rate = slope_rates[idx] 1348 | slope_rate = slope_rate * (1 - idx / (len(self.layers) - 1) + 1e-5) 1349 | if self.gradient_checkpointing and self.training: 1350 | layer_outputs = self._gradient_checkpointing_func( 1351 | decoder_layer.__call__, 1352 | hidden_states, 1353 | attention_mask, 1354 | position_ids, 1355 | past_key_values, 1356 | output_attentions, 1357 | output_router_logits, 1358 | use_cache, 1359 | ) 1360 | else: 1361 | layer_outputs = decoder_layer( 1362 | hidden_states, 1363 | attention_mask=attn_mask, 1364 | position_ids=position_ids, 1365 | past_key_value=past_key_value, 1366 | output_attentions=output_attentions, 1367 | output_router_logits=output_router_logits, 1368 | use_cache=use_cache, 1369 | slope_rate=slope_rate 1370 | ) 1371 | 1372 | hidden_states = layer_outputs[0] 1373 | 1374 | if use_cache: 1375 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 1376 | 1377 | if output_attentions: 1378 | all_self_attns += (layer_outputs[1],) 1379 | 1380 | if output_router_logits: 1381 | all_router_logits += (layer_outputs[-1],) 1382 | 1383 | hidden_states = self.norm(hidden_states) 1384 | 1385 | # add hidden states from the last decoder layer 1386 | if output_hidden_states: 1387 | all_hidden_states += (hidden_states,) 1388 | next_cache = next_decoder_cache if use_cache else None 1389 | if not return_dict: 1390 | return tuple( 1391 | v 1392 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] 1393 | if v is not None 1394 | ) 1395 | return MoeModelOutputWithPast( 1396 | last_hidden_state=hidden_states, 1397 | past_key_values=next_cache, 1398 | hidden_states=all_hidden_states, 1399 | attentions=all_self_attns, 1400 | router_logits=all_router_logits, 1401 | ) 1402 | 1403 | 1404 | class MiniMaxM1ForCausalLM(MiniMaxM1PreTrainedModel): 1405 | _tied_weights_keys = ["lm_head.weight"] 1406 | 1407 | def __init__(self, config): 1408 | super().__init__(config) 1409 | self.model = MiniMaxM1Model(config) 1410 | self.vocab_size = config.vocab_size 1411 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1412 | self.router_aux_loss_coef = config.router_aux_loss_coef 1413 | self.num_experts = config.num_local_experts 1414 | self.num_experts_per_tok = config.num_experts_per_tok 1415 | # Initialize weights and apply final processing 1416 | self.post_init() 1417 | 1418 | def get_input_embeddings(self): 1419 | return self.model.embed_tokens 1420 | 1421 | def set_input_embeddings(self, value): 1422 | self.model.embed_tokens = value 1423 | 1424 | def get_output_embeddings(self): 1425 | return self.lm_head 1426 | 1427 | def set_output_embeddings(self, new_embeddings): 1428 | self.lm_head = new_embeddings 1429 | 1430 | def set_decoder(self, decoder): 1431 | self.model = decoder 1432 | 1433 | def get_decoder(self): 1434 | return self.model 1435 | 1436 | @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 1437 | @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1438 | # Ignore copy 1439 | def forward( 1440 | self, 1441 | input_ids: torch.LongTensor = None, 1442 | attention_mask: Optional[torch.Tensor] = None, 1443 | position_ids: Optional[torch.LongTensor] = None, 1444 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1445 | inputs_embeds: Optional[torch.FloatTensor] = None, 1446 | labels: Optional[torch.LongTensor] = None, 1447 | use_cache: Optional[bool] = None, 1448 | output_attentions: Optional[bool] = None, 1449 | output_hidden_states: Optional[bool] = None, 1450 | output_router_logits: Optional[bool] = None, 1451 | return_dict: Optional[bool] = None, 1452 | ) -> Union[Tuple, MoeCausalLMOutputWithPast]: 1453 | r""" 1454 | Args: 1455 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1456 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1457 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1458 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1459 | 1460 | Returns: 1461 | 1462 | Example: 1463 | 1464 | ```python 1465 | >>> from transformers import AutoTokenizer, MiniMaxM1ForCausalLM 1466 | 1467 | >>> model = MiniMaxM1ForCausalLM.from_pretrained(PATH_TO_WEIGHTS) 1468 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS) 1469 | 1470 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1471 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1472 | 1473 | >>> # Generate 1474 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1475 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1476 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1477 | ```""" 1478 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1479 | output_router_logits = ( 1480 | output_router_logits if output_router_logits is not None else self.config.output_router_logits 1481 | ) 1482 | 1483 | output_hidden_states = ( 1484 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1485 | ) 1486 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1487 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1488 | outputs = self.model( 1489 | input_ids=input_ids, 1490 | attention_mask=attention_mask, 1491 | position_ids=position_ids, 1492 | past_key_values=past_key_values, 1493 | inputs_embeds=inputs_embeds, 1494 | use_cache=use_cache, 1495 | output_attentions=output_attentions, 1496 | output_hidden_states=output_hidden_states, 1497 | output_router_logits=output_router_logits, 1498 | return_dict=return_dict, 1499 | ) 1500 | 1501 | hidden_states = outputs[0] 1502 | logits = self.lm_head(hidden_states) 1503 | logits = logits.float() 1504 | 1505 | loss = None 1506 | if labels is not None: 1507 | # Shift so that tokens < n predict n 1508 | shift_logits = logits[..., :-1, :].contiguous() 1509 | shift_labels = labels[..., 1:].contiguous() 1510 | # Flatten the tokens 1511 | loss_fct = CrossEntropyLoss() 1512 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1513 | shift_labels = shift_labels.view(-1) 1514 | # Enable model parallelism 1515 | shift_labels = shift_labels.to(shift_logits.device) 1516 | loss = loss_fct(shift_logits, shift_labels) 1517 | 1518 | aux_loss = None 1519 | if output_router_logits: 1520 | aux_loss = load_balancing_loss_func( 1521 | outputs.router_logits if return_dict else outputs[-1], 1522 | self.num_experts, 1523 | self.num_experts_per_tok, 1524 | attention_mask, 1525 | ) 1526 | if labels is not None: 1527 | loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device 1528 | 1529 | if not return_dict: 1530 | output = (logits,) + outputs[1:] 1531 | if output_router_logits: 1532 | output = (aux_loss,) + output 1533 | return (loss,) + output if loss is not None else output 1534 | 1535 | torch.cuda.empty_cache() 1536 | return MoeCausalLMOutputWithPast( 1537 | loss=loss, 1538 | aux_loss=aux_loss, 1539 | logits=logits, 1540 | past_key_values=outputs.past_key_values, 1541 | hidden_states=outputs.hidden_states, 1542 | attentions=outputs.attentions, 1543 | router_logits=outputs.router_logits, 1544 | ) 1545 | 1546 | def prepare_inputs_for_generation( 1547 | self, 1548 | input_ids, 1549 | past_key_values=None, 1550 | attention_mask=None, 1551 | inputs_embeds=None, 1552 | **kwargs, 1553 | ): 1554 | if past_key_values: 1555 | input_ids = input_ids[:, -1:] 1556 | 1557 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1558 | if inputs_embeds is not None and past_key_values is None: 1559 | model_inputs = {"inputs_embeds": inputs_embeds} 1560 | else: 1561 | model_inputs = {"input_ids": input_ids} 1562 | 1563 | model_inputs.update({ 1564 | "past_key_values": past_key_values, 1565 | "use_cache": kwargs.get("use_cache"), 1566 | "attention_mask": attention_mask, 1567 | }) 1568 | return model_inputs 1569 | 1570 | @staticmethod 1571 | def _reorder_cache(past_key_values, beam_idx): 1572 | reordered_past = () 1573 | for layer_past in past_key_values: 1574 | reordered_past += ( 1575 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1576 | ) 1577 | return reordered_past 1578 | 1579 | 1580 | @add_start_docstrings( 1581 | """ 1582 | The MiniMaxM1 Model transformer with a sequence classification head on top (linear layer). 1583 | 1584 | [`MiniMaxM1ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1585 | (e.g. GPT-2) do. 1586 | 1587 | Since it does classification on the last token, it requires to know the position of the last token. If a 1588 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1589 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1590 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1591 | each row of the batch). 1592 | """, 1593 | MIXTRAL_START_DOCSTRING, 1594 | ) 1595 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxM1, LLAMA->MIXTRAL 1596 | class MiniMaxM1ForSequenceClassification(MiniMaxM1PreTrainedModel): 1597 | def __init__(self, config): 1598 | super().__init__(config) 1599 | self.num_labels = config.num_labels 1600 | self.model = MiniMaxM1Model(config) 1601 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1602 | 1603 | # Initialize weights and apply final processing 1604 | self.post_init() 1605 | 1606 | def get_input_embeddings(self): 1607 | return self.model.embed_tokens 1608 | 1609 | def set_input_embeddings(self, value): 1610 | self.model.embed_tokens = value 1611 | 1612 | @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 1613 | def forward( 1614 | self, 1615 | input_ids: torch.LongTensor = None, 1616 | attention_mask: Optional[torch.Tensor] = None, 1617 | position_ids: Optional[torch.LongTensor] = None, 1618 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 1619 | inputs_embeds: Optional[torch.FloatTensor] = None, 1620 | labels: Optional[torch.LongTensor] = None, 1621 | use_cache: Optional[bool] = None, 1622 | output_attentions: Optional[bool] = None, 1623 | output_hidden_states: Optional[bool] = None, 1624 | return_dict: Optional[bool] = None, 1625 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1626 | r""" 1627 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1628 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1629 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1630 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1631 | """ 1632 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1633 | 1634 | transformer_outputs = self.model( 1635 | input_ids, 1636 | attention_mask=attention_mask, 1637 | position_ids=position_ids, 1638 | past_key_values=past_key_values, 1639 | inputs_embeds=inputs_embeds, 1640 | use_cache=use_cache, 1641 | output_attentions=output_attentions, 1642 | output_hidden_states=output_hidden_states, 1643 | return_dict=return_dict, 1644 | ) 1645 | hidden_states = transformer_outputs[0] 1646 | logits = self.score(hidden_states) 1647 | 1648 | if input_ids is not None: 1649 | batch_size = input_ids.shape[0] 1650 | else: 1651 | batch_size = inputs_embeds.shape[0] 1652 | 1653 | if self.config.pad_token_id is None and batch_size != 1: 1654 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1655 | if self.config.pad_token_id is None: 1656 | sequence_lengths = -1 1657 | else: 1658 | if input_ids is not None: 1659 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 1660 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1661 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 1662 | sequence_lengths = sequence_lengths.to(logits.device) 1663 | else: 1664 | sequence_lengths = -1 1665 | 1666 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1667 | 1668 | loss = None 1669 | if labels is not None: 1670 | labels = labels.to(logits.device) 1671 | if self.config.problem_type is None: 1672 | if self.num_labels == 1: 1673 | self.config.problem_type = "regression" 1674 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1675 | self.config.problem_type = "single_label_classification" 1676 | else: 1677 | self.config.problem_type = "multi_label_classification" 1678 | 1679 | if self.config.problem_type == "regression": 1680 | loss_fct = MSELoss() 1681 | if self.num_labels == 1: 1682 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1683 | else: 1684 | loss = loss_fct(pooled_logits, labels) 1685 | elif self.config.problem_type == "single_label_classification": 1686 | loss_fct = CrossEntropyLoss() 1687 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1688 | elif self.config.problem_type == "multi_label_classification": 1689 | loss_fct = BCEWithLogitsLoss() 1690 | loss = loss_fct(pooled_logits, labels) 1691 | if not return_dict: 1692 | output = (pooled_logits,) + transformer_outputs[1:] 1693 | return ((loss,) + output) if loss is not None else output 1694 | 1695 | return SequenceClassifierOutputWithPast( 1696 | loss=loss, 1697 | logits=pooled_logits, 1698 | past_key_values=transformer_outputs.past_key_values, 1699 | hidden_states=transformer_outputs.hidden_states, 1700 | attentions=transformer_outputs.attentions, 1701 | ) 1702 | --------------------------------------------------------------------------------