├── figures
├── ant-bailing.png
└── data-accuracy-efficiency.png
├── .gitattributes
├── .editorconfig
├── examples
├── deepspeed
│ ├── ds_z0_config.json
│ ├── ds_z2_config.json
│ ├── ds_z2_offload_config.json
│ ├── ds_z3_config.json
│ └── ds_z3_offload_config.json
├── sft
│ ├── ling_lora_sft.yaml
│ └── ling_full_sft.yaml
└── dpo
│ ├── ling_lora_dpo.yaml
│ └── ling_full_dpo.yaml
├── LICENCE
├── inference
├── mindie
│ ├── patch_atb_llm.sh
│ ├── lite
│ │ ├── model_chat_config.json
│ │ ├── model_base_config.json
│ │ ├── config.json
│ │ └── config.base.json
│ ├── plus
│ │ ├── model_chat_config.json
│ │ ├── model_base_config.json
│ │ ├── config.json
│ │ └── config.base.json
│ ├── atb_llm
│ │ ├── config_deepseek.py
│ │ ├── input_builder_deepseek.py
│ │ ├── utils-layers-__init__.py
│ │ ├── atb_llm-models-__init__.py
│ │ ├── router_deepseek.py
│ │ ├── utils-file_utils.py
│ │ ├── atb_llm-models-base-router.py
│ │ ├── modeling_deepseek.py
│ │ ├── utils-layers-linear-__init__.py
│ │ ├── flash_causal_deepseek.py
│ │ └── utils-weights.py
│ ├── convert_bin_to_safetensor.py
│ └── convert_bin_to_safetensor_base.py
└── vllm
│ └── bailing_moe.patch
├── .gitignore
├── models
└── configuration_bailing_moe.py
└── README.md
/figures/ant-bailing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/Ling-Coder-Lite/master/figures/ant-bailing.png
--------------------------------------------------------------------------------
/figures/data-accuracy-efficiency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/Ling-Coder-Lite/master/figures/data-accuracy-efficiency.png
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | [attr]rust text eol=lf whitespace=tab-in-indent,trailing-space,tabwidth=4
2 |
3 | * text=auto eol=lf
4 | *.cpp rust
5 | *.h rust
6 | *.rs rust diff=rust
7 | *.fixed linguist-language=Rust
8 | *.mir linguist-language=Rust
9 | src/etc/installer/gfx/* binary
10 | src/vendor/** -text
11 | Cargo.lock linguist-generated=false
12 |
13 | # Older git versions try to fix line endings on images and fonts, this prevents it.
14 | *.png binary
15 | *.ico binary
16 | *.woff binary
17 | *.woff2 binary
18 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig helps developers define and maintain consistent
2 | # coding styles between different editors and IDEs
3 | # editorconfig.org
4 |
5 | root = true
6 |
7 | [*]
8 | end_of_line = lf
9 | charset = utf-8
10 | trim_trailing_whitespace = true
11 | insert_final_newline = true
12 |
13 |
14 | [*.md]
15 | # double whitespace at end of line
16 | # denotes a line break in Markdown
17 | trim_trailing_whitespace = false
18 |
19 | [*.yml]
20 | indent_size = 2
21 |
22 | [Makefile]
23 | indent_style = tab
24 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z0_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 0,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z2_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "allgather_partitions": true,
25 | "allgather_bucket_size": 5e8,
26 | "overlap_comm": true,
27 | "reduce_scatter": true,
28 | "reduce_bucket_size": 5e8,
29 | "contiguous_gradients": true,
30 | "round_robin_gradients": true
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "overlap_comm": true,
21 | "contiguous_gradients": true,
22 | "sub_group_size": 1e9,
23 | "reduce_bucket_size": "auto",
24 | "stage3_prefetch_bucket_size": "auto",
25 | "stage3_param_persistence_threshold": "auto",
26 | "stage3_max_live_parameters": 1e9,
27 | "stage3_max_reuse_distance": 1e9,
28 | "stage3_gather_16bit_weights_on_model_save": true
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/examples/sft/ling_lora_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: inclusionAI/Ling-lite
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 8
10 | lora_target: all
11 |
12 | ### dataset
13 | dataset: identity,alpaca_en_demo
14 | template: bailing
15 | cutoff_len: 2048
16 | max_samples: 1000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: saves/ling-moe-lite/lora/sft
22 | logging_steps: 10
23 | save_steps: 500
24 | plot_loss: true
25 | overwrite_output_dir: true
26 |
27 | ### train
28 | per_device_train_batch_size: 1
29 | gradient_accumulation_steps: 8
30 | learning_rate: 1.0e-4
31 | num_train_epochs: 3.0
32 | lr_scheduler_type: cosine
33 | warmup_ratio: 0.1
34 | bf16: true
35 | ddp_timeout: 180000000
36 |
37 | ### eval
38 | # val_size: 0.1
39 | # per_device_eval_batch_size: 1
40 | # eval_strategy: steps
41 | # eval_steps: 500
42 |
--------------------------------------------------------------------------------
/examples/dpo/ling_lora_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: inclusionAI/Ling-lite
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: dpo
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 8
10 | lora_target: all
11 | pref_beta: 0.1
12 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
13 |
14 | ### dataset
15 | dataset: dpo_en_demo
16 | template: bailing
17 | cutoff_len: 2048
18 | max_samples: 1000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: saves/ling-moe-lite/lora/dpo
24 | logging_steps: 10
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 |
29 | ### train
30 | per_device_train_batch_size: 1
31 | gradient_accumulation_steps: 2
32 | learning_rate: 5.0e-6
33 | num_train_epochs: 3.0
34 | lr_scheduler_type: cosine
35 | warmup_ratio: 0.1
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | # val_size: 0.1
41 | # per_device_eval_batch_size: 1
42 | # eval_strategy: steps
43 | # eval_steps: 500
44 |
--------------------------------------------------------------------------------
/examples/sft/ling_full_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: inclusionAI/Ling-lite
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
10 |
11 | ### dataset
12 | dataset: identity,alpaca_en_demo
13 | template: bailing
14 | cutoff_len: 2048
15 | max_samples: 1000
16 | overwrite_cache: true
17 | preprocessing_num_workers: 16
18 |
19 | ### output
20 | output_dir: saves/ling-moe-lite/full/sft
21 | logging_steps: 10
22 | save_steps: 500
23 | plot_loss: true
24 | overwrite_output_dir: true
25 |
26 | ### train
27 | per_device_train_batch_size: 1
28 | gradient_accumulation_steps: 2
29 | learning_rate: 1.0e-5
30 | num_train_epochs: 3.0
31 | lr_scheduler_type: cosine
32 | warmup_ratio: 0.1
33 | bf16: true
34 | ddp_timeout: 180000000
35 |
36 | ### eval
37 | # val_size: 0.1
38 | # per_device_eval_batch_size: 1
39 | # eval_strategy: steps
40 | # eval_steps: 500
41 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 inclusionAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z3_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "offload_param": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "overlap_comm": true,
29 | "contiguous_gradients": true,
30 | "sub_group_size": 1e9,
31 | "reduce_bucket_size": "auto",
32 | "stage3_prefetch_bucket_size": "auto",
33 | "stage3_param_persistence_threshold": "auto",
34 | "stage3_max_live_parameters": 1e9,
35 | "stage3_max_reuse_distance": 1e9,
36 | "stage3_gather_16bit_weights_on_model_save": true
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/inference/mindie/patch_atb_llm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_DIR=$(cd "$(dirname "$0")"; pwd)
4 | cd ${SCRIPT_DIR}/atb_llm
5 | cp -f ./atb_llm-models-__init__.py /usr/local/Ascend/atb-models/atb_llm/models/__init__.py
6 | cp -f ./atb_llm-models-base-router.py /usr/local/Ascend/atb-models/atb_llm/models/base/router.py
7 | cp -f ./config_deepseek.py /usr/local/Ascend/atb-models/atb_llm/models/deepseek/config_deepseek.py
8 | cp -f ./flash_causal_deepseek.py /usr/local/Ascend/atb-models/atb_llm/models/deepseek/flash_causal_deepseek.py
9 | cp -f ./input_builder_deepseek.py /usr/local/Ascend/atb-models/atb_llm/models/deepseek/input_builder_deepseek.py
10 | cp -f ./modeling_deepseek.py /usr/local/Ascend/atb-models/atb_llm/models/deepseek/modeling_deepseek.py
11 | cp -f ./router_deepseek.py /usr/local/Ascend/atb-models/atb_llm/models/deepseek/router_deepseek.py
12 | cp -f ./utils-file_utils.py /usr/local/Ascend/atb-models/atb_llm/utils/file_utils.py
13 | cp -f ./utils-layers-__init__.py /usr/local/Ascend/atb-models/atb_llm/utils/layers/__init__.py
14 | cp -f ./utils-layers-linear-__init__.py /usr/local/Ascend/atb-models/atb_llm/utils/layers/linear/__init__.py
15 | cp -f ./utils-weights.py /usr/local/Ascend/atb-models/atb_llm/utils/weights.py
--------------------------------------------------------------------------------
/examples/dpo/ling_full_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: inclusionAI/Ling-lite
3 | trust_remote_code: true
4 | # flash_attn: fa2
5 |
6 | ### method
7 | stage: dpo
8 | do_train: true
9 | finetuning_type: full
10 | pref_beta: 0.1
11 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
12 | deepspeed: examples/deepspeed/ds_z2_offload_config.json # deepspeed: ds_z2_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
13 |
14 | ### dataset
15 | dataset: dpo_zh_demo
16 | template: bailing
17 | cutoff_len: 2048
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/ling-moe-lite/full/dpo
23 | report_to: tensorboard
24 | logging_dir: saves/ling-moe-lite/full/dpo/run
25 | logging_steps: 1
26 | save_steps: 50000
27 | plot_loss: true
28 | overwrite_output_dir: true
29 |
30 | ### train
31 | per_device_train_batch_size: 1
32 | gradient_accumulation_steps: 1
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 4.0
35 | lr_scheduler_type: cosine
36 | warmup_ratio: 0.1
37 | bf16: true
38 | ddp_timeout: 180000000
39 | pure_bf16: true
40 |
41 | ### eval
42 | # val_size: 0.1
43 | # per_device_eval_batch_size: 1
44 | # eval_strategy: steps
45 | # eval_steps: 500
--------------------------------------------------------------------------------
/inference/mindie/lite/model_chat_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BailingMoeForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "auto_map": {
7 | "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8 | "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9 | "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM"
10 | },
11 | "eos_token_id": 126081,
12 | "pad_token_id": 126081,
13 | "first_k_dense_replace": 0,
14 | "hidden_act": "silu",
15 | "hidden_size": 2048,
16 | "initializer_range": 0.006,
17 | "intermediate_size": 5632,
18 | "max_position_embeddings": 16384,
19 | "model_type": "bailing_moe",
20 | "moe_intermediate_size": 1408,
21 | "num_experts": 64,
22 | "num_shared_experts": 2,
23 | "norm_topk_prob": true,
24 | "num_attention_heads": 16,
25 | "num_experts_per_tok": 6,
26 | "num_hidden_layers": 28,
27 | "num_key_value_heads": 4,
28 | "pretraining_tp": 1,
29 | "rms_norm_eps": 1e-06,
30 | "rope_scaling": null,
31 | "rope_theta": 600000,
32 | "tie_word_embeddings": false,
33 | "torch_dtype": "bfloat16",
34 | "transformers_version": "4.36.0",
35 | "use_cache": true,
36 | "use_bias": false,
37 | "use_qkv_bias": false,
38 | "vocab_size": 126464,
39 | "embedding_dropout": 0.0,
40 | "norm_head": true,
41 | "norm_softmax": false,
42 | "output_dropout": 0.0
43 | }
44 |
--------------------------------------------------------------------------------
/inference/mindie/lite/model_base_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BailingMoeForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "auto_map": {
7 | "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8 | "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9 | "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM"
10 | },
11 | "eos_token_id": 126081,
12 | "pad_token_id": 126081,
13 | "first_k_dense_replace": 0,
14 | "hidden_act": "silu",
15 | "hidden_size": 2048,
16 | "initializer_range": 0.006,
17 | "intermediate_size": 5632,
18 | "max_position_embeddings": 16384,
19 | "model_type": "bailing_moe",
20 | "moe_intermediate_size": 1408,
21 | "num_experts": 64,
22 | "num_shared_experts": 2,
23 | "norm_topk_prob": true,
24 | "num_attention_heads": 16,
25 | "num_experts_per_tok": 6,
26 | "num_hidden_layers": 28,
27 | "num_key_value_heads": 4,
28 | "pretraining_tp": 1,
29 | "rms_norm_eps": 1e-05,
30 | "rope_scaling": null,
31 | "rope_theta": 600000,
32 | "tie_word_embeddings": false,
33 | "torch_dtype": "bfloat16",
34 | "transformers_version": "4.36.0",
35 | "use_cache": true,
36 | "use_bias": false,
37 | "use_qkv_bias": false,
38 | "vocab_size": 126464,
39 | "output_router_logits": false,
40 | "embedding_dropout": 0.1,
41 | "norm_head": true,
42 | "norm_softmax": false,
43 | "output_dropout": 0.1
44 | }
45 |
--------------------------------------------------------------------------------
/inference/mindie/plus/model_chat_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BailingMoeForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "auto_map": {
7 | "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8 | "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9 | "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM"
10 | },
11 | "eos_token_id": 126081,
12 | "pad_token_id": 126081,
13 | "first_k_dense_replace": 0,
14 | "hidden_act": "silu",
15 | "hidden_size": 5376,
16 | "initializer_range": 0.006,
17 | "intermediate_size": 12288,
18 | "max_position_embeddings": 16384,
19 | "model_type": "bailing_moe",
20 | "moe_intermediate_size": 3072,
21 | "num_experts": 64,
22 | "num_shared_experts": 1,
23 | "norm_topk_prob": true,
24 | "num_attention_heads": 42,
25 | "num_experts_per_tok": 4,
26 | "num_hidden_layers": 88,
27 | "num_key_value_heads": 6,
28 | "pretraining_tp": 1,
29 | "rms_norm_eps": 1e-06,
30 | "rope_scaling": null,
31 | "rope_theta": 600000,
32 | "tie_word_embeddings": false,
33 | "torch_dtype": "bfloat16",
34 | "transformers_version": "4.36.0",
35 | "use_cache": true,
36 | "use_bias": false,
37 | "use_qkv_bias": false,
38 | "vocab_size": 126464,
39 | "output_router_logits": false,
40 | "embedding_dropout": 0.0,
41 | "norm_head": true,
42 | "norm_softmax": false,
43 | "output_dropout": 0.0
44 | }
45 |
--------------------------------------------------------------------------------
/inference/mindie/plus/model_base_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BailingMoeForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "auto_map": {
7 | "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8 | "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9 | "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM"
10 | },
11 | "eos_token_id": 126081,
12 | "pad_token_id": 126081,
13 | "first_k_dense_replace": 0,
14 | "hidden_act": "silu",
15 | "hidden_size": 5376,
16 | "initializer_range": 0.006,
17 | "intermediate_size": 12288,
18 | "max_position_embeddings": 16384,
19 | "model_type": "bailing_moe",
20 | "moe_intermediate_size": 3072,
21 | "n_routed_experts": 64,
22 | "n_shared_experts": 1,
23 | "norm_topk_prob": true,
24 | "num_attention_heads": 42,
25 | "num_experts_per_tok": 4,
26 | "num_hidden_layers": 88,
27 | "num_key_value_heads": 6,
28 | "pretraining_tp": 1,
29 | "rms_norm_eps": 1e-06,
30 | "rope_scaling": null,
31 | "rope_theta": 600000,
32 | "tie_word_embeddings": false,
33 | "torch_dtype": "bfloat16",
34 | "transformers_version": "4.36.0",
35 | "use_cache": true,
36 | "use_bias": false,
37 | "use_qkv_bias": false,
38 | "vocab_size": 126464,
39 | "output_router_logits": false,
40 | "embedding_dropout": 0.0,
41 | "norm_head": true,
42 | "norm_softmax": false,
43 | "output_dropout": 0.0
44 | }
45 |
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/config_deepseek.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | import torch
4 | from ..base.config import BaseConfig
5 |
6 |
7 | @dataclass
8 | class DeepseekConfig(BaseConfig):
9 | model_type: str = "deepseek"
10 | attention_bias: bool = False
11 | attention_dropout: float = 0.0
12 | aux_loss_alpha: float = 0.001
13 | bos_token_id: int = 100000
14 | eos_token_id: int = 100001
15 | first_k_dense_replace: int = 1
16 | hidden_act: str = "silu"
17 | hidden_size: int = 2048
18 | initializer_range: float = 0.02
19 | intermediate_size: int = 10944
20 | max_position_embeddings: int = 4096
21 | moe_intermediate_size: int = 1408
22 | moe_layer_freq: int = 1
23 | #n_routed_experts: int = 64
24 | #n_shared_experts: int = 2
25 | num_experts: int = 64
26 | num_shared_experts: int = 2
27 | norm_topk_prob: bool = False
28 | num_attention_heads: int = 16
29 | num_experts_per_tok: int = 6
30 | num_hidden_layers: int = 28
31 | num_key_value_heads: int = 16
32 | rms_norm_eps: float = 1e-06
33 | rope_scaling: Optional[int] = None
34 | rope_theta: float = 10000.0
35 | scoring_func: str = "softmax"
36 | seq_aux: bool = True
37 | tie_word_embedding: bool = False
38 | use_cache: bool = True
39 | vocab_size: int = 102400
40 |
41 |
42 | def __init__(self, **kwargs):
43 | super().__init__(**kwargs)
44 | if 'world_size' not in kwargs:
45 | self.world_size = 8
46 | if 'tp' not in kwargs:
47 | self.tp = True
48 | self.torch_dtype = torch.bfloat16
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/input_builder_deepseek.py:
--------------------------------------------------------------------------------
1 | from atb_llm.models.base.input_builder import InputBuilder
2 | from atb_llm.utils.log import logger
3 | from atb_llm.utils.log.error_code import ErrorCode
4 |
5 |
6 | class DeepseekInputBuilder(InputBuilder):
7 | def __init__(self, tokenizer, model_version, **kwargs):
8 | self.model_version = model_version
9 | super().__init__(tokenizer, **kwargs)
10 |
11 | def apply_chat_template_default(self, conversation, **kwargs):
12 | role_field = "role"
13 | content_field = "content"
14 | bos_token = "<|begin▁of▁sentence|>"
15 | eos_token = "<|end▁of▁sentence|>"
16 | formatted = bos_token
17 | for message in conversation:
18 | if message[role_field] == "user":
19 | formatted += "User: " + message[content_field] + "\n\n"
20 | elif message[role_field] == "assistant":
21 | formatted += "Assistant: " + message[content_field] + eos_token
22 | elif message[role_field] == "system":
23 | formatted += message[content_field] + "\n\n"
24 | else:
25 | msg = "Only user/assistant/system roles are supported!"
26 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
27 | raise ValueError(msg)
28 | if "add_generation_prompt" in kwargs and kwargs["add_generation_prompt"]:
29 | formatted += "Assistant:"
30 | return self.tokenizer.encode(formatted, add_special_tokens=False)
31 |
32 | def _apply_chat_template(self, conversation, **kwargs):
33 | if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template:
34 | return super()._apply_chat_template(conversation, **kwargs)
35 | return self.apply_chat_template_default(conversation, **kwargs)
--------------------------------------------------------------------------------
/inference/mindie/convert_bin_to_safetensor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import os
4 | from safetensors import safe_open
5 |
6 | from safetensors.torch import save_file
7 |
8 | # filename = '/mnt/nas_acr89/nanxiao/mm'
9 | # tensors = {}
10 | # for i in range(1000):
11 | # tensors[str(i)] = torch.rand(256,5120,dtype=torch.bfloat16,device='cuda:0')
12 | # save_file(tensors, f"{filename}/embs.safetensors", metadata={'format': 'pt'})
13 |
14 | # with safe_open(f"{filename}/embs.safetensors", framework="pt", device=0) as f:
15 | # for ok in f.keys():
16 | # pass
17 |
18 | src_dir = '/home/HwHiAiUser/Ascend/Ling_lite'
19 | dst_dir = '/home/HwHiAiUser/Ascend/Ling_lite_safetensor'
20 | total_size = 0
21 | sd = torch.load(f'{src_dir}/pytorch_model.bin',weights_only=True,map_location="cpu")
22 | n_shard = 4
23 | block_size = 8
24 | weight_map = {}
25 |
26 | os.makedirs(dst_dir, exist_ok=True)
27 | for i in range(n_shard):
28 | ts = str(100000+n_shard)[1:]
29 | cs = str(100000+i+1)[1:]
30 | tensors = {}
31 | filename = f'model-{cs}-of-{ts}.safetensors'
32 | for k,v in sd.items():
33 | try:
34 | layer_idx = int(k.split('layers.')[1].split('.')[0])
35 | block_idx = layer_idx//block_size
36 | except:
37 | block_idx = n_shard-1
38 | if block_idx != i:
39 | continue
40 | print(k,v.shape,v.dtype)
41 | weight_map[k] = filename
42 | total_size += v.numel()*v.element_size()
43 | tensors[k] = v.contiguous()
44 | save_file(tensors, f"{dst_dir}/{filename}", metadata={'format': 'pt'})
45 |
46 |
47 | meta = {
48 | "metadata": {
49 | "total_size": total_size
50 | },
51 | "weight_map": dict(sorted(weight_map.items(), key=lambda x:x[1]+x[0] ))
52 | }
53 |
54 |
55 | with open(f'{dst_dir}/model.safetensors.index.json', 'w') as f:
56 | json.dump(meta, f,indent=4)
57 |
--------------------------------------------------------------------------------
/inference/mindie/convert_bin_to_safetensor_base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import os
4 | from safetensors import safe_open
5 |
6 | from safetensors.torch import save_file
7 |
8 | # filename = '/mnt/nas_acr89/nanxiao/mm'
9 | # tensors = {}
10 | # for i in range(1000):
11 | # tensors[str(i)] = torch.rand(256,5120,dtype=torch.bfloat16,device='cuda:0')
12 | # save_file(tensors, f"{filename}/embs.safetensors", metadata={'format': 'pt'})
13 |
14 | # with safe_open(f"{filename}/embs.safetensors", framework="pt", device=0) as f:
15 | # for ok in f.keys():
16 | # pass
17 |
18 | src_dir = '/home/HwHiAiUser/Ascend/Ling_lite_base'
19 | dst_dir = '/home/HwHiAiUser/Ascend/Ling_lite_base_safetensor'
20 | total_size = 0
21 | sd = torch.load(f'{src_dir}/pytorch_model.bin',weights_only=True,map_location="cpu")
22 | n_shard = 4
23 | block_size = 8
24 | weight_map = {}
25 |
26 | os.makedirs(dst_dir, exist_ok=True)
27 | for i in range(n_shard):
28 | ts = str(100000+n_shard)[1:]
29 | cs = str(100000+i+1)[1:]
30 | tensors = {}
31 | filename = f'model-{cs}-of-{ts}.safetensors'
32 | for k,v in sd.items():
33 | try:
34 | layer_idx = int(k.split('layers.')[1].split('.')[0])
35 | block_idx = layer_idx//block_size
36 | except:
37 | block_idx = n_shard-1
38 | if block_idx != i:
39 | continue
40 | print(k,v.shape,v.dtype)
41 | weight_map[k] = filename
42 | total_size += v.numel()*v.element_size()
43 | tensors[k] = v.contiguous()
44 | save_file(tensors, f"{dst_dir}/{filename}", metadata={'format': 'pt'})
45 |
46 |
47 | meta = {
48 | "metadata": {
49 | "total_size": total_size
50 | },
51 | "weight_map": dict(sorted(weight_map.items(), key=lambda x:x[1]+x[0] ))
52 | }
53 |
54 |
55 | with open(f'{dst_dir}/model.safetensors.index.json', 'w') as f:
56 | json.dump(meta, f,indent=4)
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # This file should only ignore things that are generated during a `x.py` build,
2 | # generated by common IDEs, and optional files controlled by the user that
3 | # affect the build (such as config.toml).
4 | # In particular, things like `mir_dump` should not be listed here; they are only
5 | # created during manual debugging and many people like to clean up instead of
6 | # having git ignore such leftovers. You can use `.git/info/exclude` to
7 | # configure your local ignore list.
8 |
9 | ## File system
10 | .DS_Store
11 | desktop.ini
12 |
13 | ## Editor
14 | *.swp
15 | *.swo
16 | Session.vim
17 | .cproject
18 | .idea
19 | *.iml
20 | .vscode
21 | .project
22 | .vim/
23 | .helix/
24 | .zed/
25 | .favorites.json
26 | .settings/
27 | .vs/
28 | .dir-locals.el
29 |
30 | ## Tool
31 | .valgrindrc
32 | .cargo
33 | # Included because it is part of the test case
34 | !/tests/run-make/thumb-none-qemu/example/.cargo
35 |
36 | ## Configuration
37 | /config.toml
38 | /Makefile
39 | config.mk
40 | config.stamp
41 | no_llvm_build
42 |
43 | ## Build
44 | /dl/
45 | /doc/
46 | /inst/
47 | /llvm/
48 | /mingw-build/
49 | /build
50 | /build-rust-analyzer/
51 | /dist/
52 | /unicode-downloads
53 | /target
54 | /library/target
55 | /src/bootstrap/target
56 | /src/tools/x/target
57 | # Created by `x vendor`
58 | /vendor
59 | # Created by default with `src/ci/docker/run.sh`
60 | /obj/
61 | # Created by nix dev shell / .envrc
62 | src/tools/nix-dev-shell/flake.lock
63 |
64 | ## ICE reports
65 | rustc-ice-*.txt
66 |
67 | ## Temporary files
68 | *~
69 | \#*
70 | \#*\#
71 | .#*
72 |
73 | ## Tags
74 | tags
75 | tags.*
76 | TAGS
77 | TAGS.*
78 |
79 | ## Python
80 | __pycache__/
81 | *.py[cod]
82 | *$py.class
83 |
84 | ## Node
85 | node_modules
86 | package-lock.json
87 | package.json
88 | /src/doc/rustc-dev-guide/mermaid.min.js
89 |
90 | ## Rustdoc GUI tests
91 | tests/rustdoc-gui/src/**.lock
92 |
93 | ## direnv
94 | /.envrc
95 | /.direnv/
96 |
97 | ## nix
98 | /flake.nix
99 | flake.lock
100 | /default.nix
101 |
102 | # Before adding new lines, see the comment at the top.
103 |
--------------------------------------------------------------------------------
/models/configuration_bailing_moe.py:
--------------------------------------------------------------------------------
1 | """ Bailing MoE model configuration """
2 |
3 | from transformers.configuration_utils import PretrainedConfig
4 |
5 |
6 | class BailingMoeConfig(PretrainedConfig):
7 | model_type = "bailing_moe"
8 |
9 | def __init__(
10 | self,
11 | vocab_size=30592,
12 | hidden_size=1024,
13 | intermediate_size=None,
14 | num_hidden_layers=24,
15 | num_attention_heads=16,
16 | num_key_value_heads=0,
17 | hidden_act="silu",
18 | use_qkv_bias=False, # bailing only
19 | use_bias=True, # bailing only
20 | rms_norm_eps=1e-05,
21 | norm_head=False, # bailing only
22 | tie_word_embeddings=False, # PretrainedConfig key, here change default value.
23 | embedding_dropout=0.1,
24 | attention_dropout=0.1,
25 | output_dropout=0.1,
26 | initializer_range=0.02,
27 | max_position_embeddings=16384,
28 | rope_theta=10000.0,
29 | use_cache=True,
30 | use_sliding_window=False,
31 | sliding_window=4096,
32 | max_window_layers=28,
33 | rope_scaling=None,
34 | pad_token_id=126081,
35 | num_experts=16,
36 | num_shared_experts=0,
37 | num_experts_per_tok=2,
38 | norm_topk_prob=True,
39 | moe_intermediate_size=None,
40 | first_k_dense_replace=0,
41 | head_dim=None,
42 | output_router_logits=False,
43 | **kwargs,
44 | ):
45 | self.num_hidden_layers = num_hidden_layers
46 | self.vocab_size = vocab_size
47 | self.hidden_size = hidden_size
48 | self.intermediate_size = intermediate_size
49 | self.num_attention_heads = num_attention_heads
50 | self.num_key_value_heads = num_key_value_heads
51 | self.hidden_act = hidden_act
52 | self.use_qkv_bias = use_qkv_bias
53 | self.use_bias = use_bias
54 | self.norm_head = norm_head
55 | self.rms_norm_eps = rms_norm_eps
56 | self.embedding_dropout = embedding_dropout
57 | self.attention_dropout = attention_dropout
58 | self.output_dropout = output_dropout
59 | self.initializer_range = initializer_range
60 | self.max_position_embeddings = max_position_embeddings
61 | self.rope_theta = rope_theta
62 | self.use_cache = use_cache
63 | self.use_sliding_window = use_sliding_window
64 | self.sliding_window = sliding_window
65 | self.max_window_layers = max_window_layers
66 | self.head_dim = head_dim
67 | self.rope_scaling = rope_scaling
68 |
69 | # MoE configs
70 | self.num_experts = num_experts
71 | self.num_shared_experts = num_shared_experts
72 | self.num_experts_per_tok = num_experts_per_tok
73 | self.norm_topk_prob = norm_topk_prob
74 | self.moe_intermediate_size = moe_intermediate_size
75 | self.first_k_dense_replace = first_k_dense_replace
76 | self.output_router_logits = output_router_logits
77 |
78 | super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
79 |
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/utils-layers-__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch_npu
5 | import torch.distributed
6 | from torch.nn import functional as F
7 |
8 | from atb_llm.utils.log import logger
9 | from .attention import AttentionMask, flash_attn, paged_attn, reshape_and_cache, KvCache, FA3
10 | from .embedding.position_rotary_embedding import PositionRotaryEmbedding
11 | from .embedding.tensor_embedding import TensorEmbedding, TensorParallelEmbedding
12 | from .linear import (
13 | get_linear,
14 | TensorParallelRowLinear,
15 | TensorParallelColumnLinear,
16 | TensorReplicatedLinear,
17 | TensorParallelHead,
18 | TensorHead
19 | )
20 | from .linear.reduce_quant import ReduceQuant
21 | from .norm.fast_layer_norm import RMSNorm, RMSNormBias, RMSNormWrapper, RMSNormAntiOutlierWrapper
22 |
23 |
24 | def _load_gqa(config, prefix: str, weights):
25 | hidden_size, num_attention_heads = config.hidden_size, config.num_attention_heads
26 | process_group_size = weights.process_group.size()
27 | if not hidden_size % num_attention_heads == 0:
28 | logger.error(f'{hidden_size} % {num_attention_heads} != 0')
29 | if not num_attention_heads % process_group_size == 0:
30 | logger.error(f'{num_attention_heads} % {process_group_size} != 0')
31 |
32 | weight_prefixes = [f"{prefix}.{proj}" for proj in ["q_proj", "k_proj", "v_proj"]]
33 | weight = weights.get_multi_weights_col(prefixes=weight_prefixes, quantize=config.quantize, dim=0)
34 |
35 | return TensorParallelColumnLinear(
36 | get_linear(weight, bias=None, quantize=config.quantize, prefixes=weight_prefixes,
37 | num_linear_before_pack=len(weight_prefixes), tensor_parallel_dim=0, align_size=1)
38 | )
39 |
40 |
41 | def load_column_multi(
42 | config, prefixes: List[str], weights, head_size, lm_head: bool = False, \
43 | norm: bool = False, bias: bool = False, dim: int = 0, norm_head: bool = False,
44 | ):
45 | soc_version = torch_npu._C._npu_get_soc_version()
46 | quantize = None if lm_head else config.quantize
47 | weight = weights.get_multi_weights_col(prefixes, quantize=quantize, dim=0, gqa_size=head_size, norm_head=norm_head)
48 | if bias:
49 | b = [weights.get_sharded(f"{p}.bias", dim=0, gqa_size=head_size) for p in prefixes]
50 | bias = torch.cat(b, dim=dim)
51 | else:
52 | bias = None
53 | if lm_head:
54 | weight_type = weight.dtype
55 | weight = weight.float()
56 | weight = weight if not norm else torch.nan_to_num(F.normalize(weight))
57 | if soc_version == 240:
58 | weight = weight.to(dtype=weight_type)
59 | weight = weight.npu()
60 | else:
61 | weight = weight.to(dtype=weight_type).npu()
62 | linear = get_linear(weight, bias, quantize, prefixes=prefixes,
63 | num_linear_before_pack=len(prefixes), tensor_parallel_dim=0, align_size=head_size)
64 |
65 | process_group = weights.process_group
66 | should_gather = weights.process_group.size() != 1
67 | if lm_head:
68 | return TensorParallelHead(linear, process_group=process_group, should_gather=should_gather)
69 | else:
70 | return TensorParallelColumnLinear(linear)
71 |
72 |
73 | def load_row(config, prefix: str, weights, head_size):
74 | weight = weights.get_sharded(f"{prefix}.weight", dim=1, gqa_size=head_size)
75 | linear = get_linear(weight, None, quantize=config.quantize, prefixes=[prefix],
76 | tensor_parallel_dim=1, align_size=head_size)
77 | return TensorParallelRowLinear(linear, process_group=weights.process_group)
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/atb_llm-models-__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from enum import Enum
3 | import importlib
4 | import os
5 |
6 | from atb_llm.models.base.model_utils import safe_get_config_dict
7 | from atb_llm.utils import file_utils
8 |
9 |
10 | class InferenceMode(int, Enum):
11 | REGRESSION = 0
12 | SPECULATE = 1
13 | SPLITFUSE = 2
14 | PREFIXCACHE = 3
15 |
16 |
17 | def get_model(model_name_or_path: str,
18 | is_flash_causal_lm: bool = True,
19 | load_tokenizer: bool = True,
20 | max_position_embeddings: Optional[int] = None,
21 | revision: Optional[str] = None,
22 | tokenizer_path: Optional[str] = None,
23 | trust_remote_code: bool = False,
24 | enable_atb_torch: bool = False):
25 | model_name_or_path = file_utils.standardize_path(model_name_or_path, check_link=False)
26 | file_utils.check_path_permission(model_name_or_path)
27 | model_type_key = 'model_type'
28 | config_dict = safe_get_config_dict(model_name_or_path)
29 | config_dict[model_type_key] = config_dict[model_type_key].lower()
30 | model_type = config_dict[model_type_key]
31 | if model_type == "kclgpt":
32 | model_type = "codeshell"
33 | elif model_type == "internvl_chat":
34 | model_type = "internvl"
35 | elif model_type == "llava_next_video":
36 | model_type = "llava_next"
37 | elif model_type == "llava" and "_name_or_path" in config_dict.keys():
38 | if "yi-vl" in config_dict["_name_or_path"].lower():
39 | model_type = config_dict[model_type_key] = "yivl"
40 | elif model_type == "minicpmv" and "MiniCPM-Llama3-V-2_5" in model_name_or_path:
41 | model_type = "minicpm_llama3_v2"
42 | elif "clip" in model_type:
43 | model_type = "clip"
44 | elif model_type == "bunny-qwen2" or model_type == "bunny-minicpm":
45 | model_type = "bunny"
46 | elif model_type == "chatglm" and "vision_config" in config_dict:
47 | model_type = "glm4v"
48 | elif model_type == "bailing_moe":
49 | model_type = "deepseek"
50 |
51 |
52 | # 安全校验
53 | current_path = os.path.dirname(os.path.abspath(__file__))
54 | supported_models = []
55 | for foldername in file_utils.safe_listdir(current_path):
56 | is_folder = os.path.isdir(os.path.join(current_path, foldername))
57 | skip_base_folder = foldername != "base"
58 | skip_invalid_folder = not foldername.startswith("_")
59 | if is_folder and skip_base_folder and skip_invalid_folder:
60 | supported_models.append(foldername)
61 |
62 | if model_type not in supported_models:
63 | raise NotImplementedError(
64 | f"unsupported model type: {model_type};"
65 | f"请确认atb_llm.models路径下是否存在名为{model_type}的文件夹。"
66 | )
67 |
68 | router_path = f"atb_llm.models.{model_type}.router_{model_type}"
69 | if model_type == "qwen2_moe" or model_type == "minicpm_llama3_v2":
70 | model_type = model_type.replace('_', '')
71 | if model_type == "qwen2_audio":
72 | model_type = model_type.replace('_', '')
73 | if model_type == "qwen2_vl":
74 | model_type = model_type.replace('_', '')
75 | router = importlib.import_module(router_path)
76 | router_cls = getattr(router, f"{model_type.capitalize()}Router")
77 | router_ins = router_cls(
78 | model_name_or_path,
79 | config_dict,
80 | is_flash_causal_lm,
81 | load_tokenizer,
82 | max_position_embeddings,
83 | revision,
84 | tokenizer_path,
85 | trust_remote_code,
86 | enable_atb_torch)
87 | return router_ins
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/router_deepseek.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from ..base.router import BaseRouter
3 | from .flash_causal_deepseek import DeepseekConfig
4 | from .input_builder_deepseek import DeepseekInputBuilder
5 | from ..base.model_utils import safe_get_tokenizer_from_pretrained
6 | from ...utils.log import logger
7 | from ...utils.log.error_code import ErrorCode
8 |
9 |
10 | @dataclass
11 | class DeepseekRouter(BaseRouter):
12 | def get_config(self):
13 | config = DeepseekConfig.from_dict(self.config_dict)
14 | self.check_config_deepseek(config)
15 | return config
16 |
17 | def get_tokenizer(self):
18 | return safe_get_tokenizer_from_pretrained(
19 | self.tokenizer_path,
20 | padding_side="left",
21 | trust_remote_code=False,
22 | use_fast=False,
23 | pad_token='[PAD]'
24 | )
25 |
26 | def check_config_deepseek(self, config):
27 | super().check_config(config)
28 | attribute_ranges = {
29 | "moe_intermediate_size" : (0, 2147483647),
30 | "attention_dropout" : (0, 2147483647),
31 | "initializer_range" : (0, 2147483647),
32 | "num_attention_heads" : (0, 2147483647),
33 | "num_experts_per_tok" : (1, 128),
34 | "num_shared_experts" : (0, 128),
35 | "moe_layer_freq" : (1, 128),
36 | "first_k_dense_replace" : (0, 2147483647),
37 | "num_key_value_heads" : (1, 2147483647),
38 | "num_experts " : (2, 128),
39 | "rope_theta" : (0, 2147483647),
40 | "router_aux_loss_coef" : (0, 2147483647),
41 | "rms_norm_eps" : (0, 2147483647),
42 | "aux_loss_alpha": (0, 2147483647),
43 | }
44 | for attr, (min_val, max_val) in attribute_ranges.items():
45 | if not hasattr(config, attr) or getattr(config, attr) is None:
46 | continue
47 | value = getattr(config, attr)
48 | if value < min_val or value > max_val:
49 | msg = f"self._config.{attr} must be between {min_val} and {max_val}"
50 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
51 | raise ValueError(msg)
52 |
53 | if getattr(config, "num_experts_per_tok", 0) > getattr(config, "num_experts", 0):
54 | msg = "self._config.num_experts_per_tok must be smaller than or equal to self._config.num_experts"
55 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
56 | raise ValueError(msg)
57 |
58 | hidden_act = getattr(config, "hidden_act")
59 | if hidden_act != "silu":
60 | msg = "self._config.hidden_act must be silu"
61 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
62 | raise ValueError(msg)
63 |
64 | if not isinstance(getattr(config, "use_cache", False), bool):
65 | msg = "self._config.use_cache must be a boolean"
66 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
67 | raise ValueError(msg)
68 |
69 | if not isinstance(getattr(config, "seq_aux", False), bool):
70 | msg = "self._config.seq_aux must be a boolean"
71 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
72 | raise ValueError(msg)
73 |
74 | if not isinstance(getattr(config, "norm_topk_prob", False), bool):
75 | msg = "self._config.norm_topk_prob must be a boolean"
76 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
77 | raise ValueError(msg)
78 |
79 | def get_input_builder(self):
80 | return DeepseekInputBuilder(self.tokenizer, self.model_version)
--------------------------------------------------------------------------------
/inference/mindie/plus/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "Version" : "1.1.0",
3 | "LogConfig" :
4 | {
5 | "logLevel" : "Info",
6 | "logFileSize" : 20,
7 | "logFileNum" : 20,
8 | "logPath" : "logs/mindservice.log"
9 | },
10 |
11 | "ServerConfig" :
12 | {
13 | "ipAddress" : "127.0.0.1",
14 | "managementIpAddress" : "127.0.0.2",
15 | "port" : 1025,
16 | "managementPort" : 1026,
17 | "metricsPort" : 1027,
18 | "allowAllZeroIpListening" : false,
19 | "maxLinkNum" : 1000,
20 | "httpsEnabled" : false,
21 | "fullTextEnabled" : false,
22 | "tlsCaPath" : "security/ca/",
23 | "tlsCaFile" : ["ca.pem"],
24 | "tlsCert" : "security/certs/server.pem",
25 | "tlsPk" : "security/keys/server.key.pem",
26 | "tlsPkPwd" : "security/pass/key_pwd.txt",
27 | "tlsCrlPath" : "security/certs/",
28 | "tlsCrlFiles" : ["server_crl.pem"],
29 | "managementTlsCaFile" : ["management_ca.pem"],
30 | "managementTlsCert" : "security/certs/management/server.pem",
31 | "managementTlsPk" : "security/keys/management/server.key.pem",
32 | "managementTlsPkPwd" : "security/pass/management/key_pwd.txt",
33 | "managementTlsCrlPath" : "security/management/certs/",
34 | "managementTlsCrlFiles" : ["server_crl.pem"],
35 | "kmcKsfMaster" : "tools/pmt/master/ksfa",
36 | "kmcKsfStandby" : "tools/pmt/standby/ksfb",
37 | "inferMode" : "standard",
38 | "interCommTLSEnabled" : false,
39 | "interCommPort" : 1121,
40 | "interCommTlsCaPath" : "security/grpc/ca/",
41 | "interCommTlsCaFiles" : ["ca.pem"],
42 | "interCommTlsCert" : "security/grpc/certs/server.pem",
43 | "interCommPk" : "security/grpc/keys/server.key.pem",
44 | "interCommPkPwd" : "security/grpc/pass/key_pwd.txt",
45 | "interCommTlsCrlPath" : "security/grpc/certs/",
46 | "interCommTlsCrlFiles" : ["server_crl.pem"],
47 | "openAiSupport" : "vllm"
48 | },
49 |
50 | "BackendConfig" : {
51 | "backendName" : "mindieservice_llm_engine",
52 | "modelInstanceNumber" : 1,
53 | "npuDeviceIds" : [[0,1,2,3,4,5,6,7]],
54 | "tokenizerProcessNumber" : 8,
55 | "multiNodesInferEnabled" : true,
56 | "multiNodesInferPort" : 1120,
57 | "interNodeTLSEnabled" : false,
58 | "interNodeTlsCaPath" : "security/grpc/ca/",
59 | "interNodeTlsCaFiles" : ["ca.pem"],
60 | "interNodeTlsCert" : "security/grpc/certs/server.pem",
61 | "interNodeTlsPk" : "security/grpc/keys/server.key.pem",
62 | "interNodeTlsPkPwd" : "security/grpc/pass/mindie_server_key_pwd.txt",
63 | "interNodeTlsCrlPath" : "security/grpc/certs/",
64 | "interNodeTlsCrlFiles" : ["server_crl.pem"],
65 | "interNodeKmcKsfMaster" : "tools/pmt/master/ksfa",
66 | "interNodeKmcKsfStandby" : "tools/pmt/standby/ksfb",
67 | "ModelDeployConfig" :
68 | {
69 | "maxSeqLen" : 14000,
70 | "maxInputTokenLen" : 8192,
71 | "truncation" : false,
72 | "ModelConfig" : [
73 | {
74 | "modelInstanceType" : "Standard",
75 | "modelName" : "bailing_moe",
76 | "modelWeightPath" : "/home/HwHiAiUser/Ascend/Ling_plus",
77 | "worldSize" : 16,
78 | "cpuMemSize" : 5,
79 | "npuMemSize" : -1,
80 | "backendType" : "atb",
81 | "trustRemoteCode" : false
82 | }
83 | ]
84 | },
85 |
86 | "ScheduleConfig" :
87 | {
88 | "templateType" : "Standard",
89 | "templateName" : "Standard_LLM",
90 | "cacheBlockSize" : 128,
91 |
92 | "maxPrefillBatchSize" : 50,
93 | "maxPrefillTokens" : 8192,
94 | "prefillTimeMsPerReq" : 150,
95 | "prefillPolicyType" : 0,
96 |
97 | "decodeTimeMsPerReq" : 50,
98 | "decodePolicyType" : 0,
99 |
100 | "maxBatchSize" : 200,
101 | "maxIterTimes" : 4096,
102 | "maxPreemptCount" : 0,
103 | "supportSelectBatch" : false,
104 | "maxQueueDelayMicroseconds" : 5000
105 | }
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/inference/mindie/lite/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "Version" : "1.1.0",
3 | "LogConfig" :
4 | {
5 | "logLevel" : "Info",
6 | "logFileSize" : 20,
7 | "logFileNum" : 20,
8 | "logPath" : "logs/mindservice.log"
9 | },
10 |
11 | "ServerConfig" :
12 | {
13 | "ipAddress" : "127.0.0.1",
14 | "managementIpAddress" : "127.0.0.2",
15 | "port" : 1025,
16 | "managementPort" : 1026,
17 | "metricsPort" : 1027,
18 | "allowAllZeroIpListening" : false,
19 | "maxLinkNum" : 1000,
20 | "httpsEnabled" : false,
21 | "fullTextEnabled" : false,
22 | "tlsCaPath" : "security/ca/",
23 | "tlsCaFile" : ["ca.pem"],
24 | "tlsCert" : "security/certs/server.pem",
25 | "tlsPk" : "security/keys/server.key.pem",
26 | "tlsPkPwd" : "security/pass/key_pwd.txt",
27 | "tlsCrlPath" : "security/certs/",
28 | "tlsCrlFiles" : ["server_crl.pem"],
29 | "managementTlsCaFile" : ["management_ca.pem"],
30 | "managementTlsCert" : "security/certs/management/server.pem",
31 | "managementTlsPk" : "security/keys/management/server.key.pem",
32 | "managementTlsPkPwd" : "security/pass/management/key_pwd.txt",
33 | "managementTlsCrlPath" : "security/management/certs/",
34 | "managementTlsCrlFiles" : ["server_crl.pem"],
35 | "kmcKsfMaster" : "tools/pmt/master/ksfa",
36 | "kmcKsfStandby" : "tools/pmt/standby/ksfb",
37 | "inferMode" : "standard",
38 | "interCommTLSEnabled" : true,
39 | "interCommPort" : 1121,
40 | "interCommTlsCaPath" : "security/grpc/ca/",
41 | "interCommTlsCaFiles" : ["ca.pem"],
42 | "interCommTlsCert" : "security/grpc/certs/server.pem",
43 | "interCommPk" : "security/grpc/keys/server.key.pem",
44 | "interCommPkPwd" : "security/grpc/pass/key_pwd.txt",
45 | "interCommTlsCrlPath" : "security/grpc/certs/",
46 | "interCommTlsCrlFiles" : ["server_crl.pem"],
47 | "openAiSupport" : "vllm"
48 | },
49 |
50 | "BackendConfig" : {
51 | "backendName" : "mindieservice_llm_engine",
52 | "modelInstanceNumber" : 1,
53 | "npuDeviceIds" : [[0,1,2,3,4,5,6,7]],
54 | "tokenizerProcessNumber" : 8,
55 | "multiNodesInferEnabled" : false,
56 | "multiNodesInferPort" : 1120,
57 | "interNodeTLSEnabled" : true,
58 | "interNodeTlsCaPath" : "security/grpc/ca/",
59 | "interNodeTlsCaFiles" : ["ca.pem"],
60 | "interNodeTlsCert" : "security/grpc/certs/server.pem",
61 | "interNodeTlsPk" : "security/grpc/keys/server.key.pem",
62 | "interNodeTlsPkPwd" : "security/grpc/pass/mindie_server_key_pwd.txt",
63 | "interNodeTlsCrlPath" : "security/grpc/certs/",
64 | "interNodeTlsCrlFiles" : ["server_crl.pem"],
65 | "interNodeKmcKsfMaster" : "tools/pmt/master/ksfa",
66 | "interNodeKmcKsfStandby" : "tools/pmt/standby/ksfb",
67 | "ModelDeployConfig" :
68 | {
69 | "maxSeqLen" : 14000,
70 | "maxInputTokenLen" : 8192,
71 | "truncation" : false,
72 | "ModelConfig" : [
73 | {
74 | "modelInstanceType" : "Standard",
75 | "modelName" : "bailing_moe",
76 | "modelWeightPath" : "/home/HwHiAiUser/Ascend/Ling_lite_safetensor",
77 | "worldSize" : 8,
78 | "cpuMemSize" : 5,
79 | "npuMemSize" : -1,
80 | "backendType" : "atb",
81 | "trustRemoteCode" : false
82 | }
83 | ]
84 | },
85 |
86 | "ScheduleConfig" :
87 | {
88 | "templateType" : "Standard",
89 | "templateName" : "Standard_LLM",
90 | "cacheBlockSize" : 128,
91 |
92 | "maxPrefillBatchSize" : 50,
93 | "maxPrefillTokens" : 8192,
94 | "prefillTimeMsPerReq" : 150,
95 | "prefillPolicyType" : 0,
96 |
97 | "decodeTimeMsPerReq" : 50,
98 | "decodePolicyType" : 0,
99 |
100 | "maxBatchSize" : 200,
101 | "maxIterTimes" : 4096,
102 | "maxPreemptCount" : 0,
103 | "supportSelectBatch" : false,
104 | "maxQueueDelayMicroseconds" : 5000
105 | }
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/inference/mindie/plus/config.base.json:
--------------------------------------------------------------------------------
1 | {
2 | "Version" : "1.1.0",
3 | "LogConfig" :
4 | {
5 | "logLevel" : "Info",
6 | "logFileSize" : 20,
7 | "logFileNum" : 20,
8 | "logPath" : "logs/mindservice.log"
9 | },
10 |
11 | "ServerConfig" :
12 | {
13 | "ipAddress" : "127.0.0.1",
14 | "managementIpAddress" : "127.0.0.2",
15 | "port" : 1025,
16 | "managementPort" : 1026,
17 | "metricsPort" : 1027,
18 | "allowAllZeroIpListening" : false,
19 | "maxLinkNum" : 1000,
20 | "httpsEnabled" : false,
21 | "fullTextEnabled" : false,
22 | "tlsCaPath" : "security/ca/",
23 | "tlsCaFile" : ["ca.pem"],
24 | "tlsCert" : "security/certs/server.pem",
25 | "tlsPk" : "security/keys/server.key.pem",
26 | "tlsPkPwd" : "security/pass/key_pwd.txt",
27 | "tlsCrlPath" : "security/certs/",
28 | "tlsCrlFiles" : ["server_crl.pem"],
29 | "managementTlsCaFile" : ["management_ca.pem"],
30 | "managementTlsCert" : "security/certs/management/server.pem",
31 | "managementTlsPk" : "security/keys/management/server.key.pem",
32 | "managementTlsPkPwd" : "security/pass/management/key_pwd.txt",
33 | "managementTlsCrlPath" : "security/management/certs/",
34 | "managementTlsCrlFiles" : ["server_crl.pem"],
35 | "kmcKsfMaster" : "tools/pmt/master/ksfa",
36 | "kmcKsfStandby" : "tools/pmt/standby/ksfb",
37 | "inferMode" : "standard",
38 | "interCommTLSEnabled" : false,
39 | "interCommPort" : 1121,
40 | "interCommTlsCaPath" : "security/grpc/ca/",
41 | "interCommTlsCaFiles" : ["ca.pem"],
42 | "interCommTlsCert" : "security/grpc/certs/server.pem",
43 | "interCommPk" : "security/grpc/keys/server.key.pem",
44 | "interCommPkPwd" : "security/grpc/pass/key_pwd.txt",
45 | "interCommTlsCrlPath" : "security/grpc/certs/",
46 | "interCommTlsCrlFiles" : ["server_crl.pem"],
47 | "openAiSupport" : "vllm"
48 | },
49 |
50 | "BackendConfig" : {
51 | "backendName" : "mindieservice_llm_engine",
52 | "modelInstanceNumber" : 1,
53 | "npuDeviceIds" : [[0,1,2,3,4,5,6,7]],
54 | "tokenizerProcessNumber" : 8,
55 | "multiNodesInferEnabled" : true,
56 | "multiNodesInferPort" : 1120,
57 | "interNodeTLSEnabled" : false,
58 | "interNodeTlsCaPath" : "security/grpc/ca/",
59 | "interNodeTlsCaFiles" : ["ca.pem"],
60 | "interNodeTlsCert" : "security/grpc/certs/server.pem",
61 | "interNodeTlsPk" : "security/grpc/keys/server.key.pem",
62 | "interNodeTlsPkPwd" : "security/grpc/pass/mindie_server_key_pwd.txt",
63 | "interNodeTlsCrlPath" : "security/grpc/certs/",
64 | "interNodeTlsCrlFiles" : ["server_crl.pem"],
65 | "interNodeKmcKsfMaster" : "tools/pmt/master/ksfa",
66 | "interNodeKmcKsfStandby" : "tools/pmt/standby/ksfb",
67 | "ModelDeployConfig" :
68 | {
69 | "maxSeqLen" : 14000,
70 | "maxInputTokenLen" : 8192,
71 | "truncation" : false,
72 | "ModelConfig" : [
73 | {
74 | "modelInstanceType" : "Standard",
75 | "modelName" : "bailing_moe",
76 | "modelWeightPath" : "/home/HwHiAiUser/Ascend/Ling_plus_base",
77 | "worldSize" : 16,
78 | "cpuMemSize" : 5,
79 | "npuMemSize" : -1,
80 | "backendType" : "atb",
81 | "trustRemoteCode" : false
82 | }
83 | ]
84 | },
85 |
86 | "ScheduleConfig" :
87 | {
88 | "templateType" : "Standard",
89 | "templateName" : "Standard_LLM",
90 | "cacheBlockSize" : 128,
91 |
92 | "maxPrefillBatchSize" : 50,
93 | "maxPrefillTokens" : 8192,
94 | "prefillTimeMsPerReq" : 150,
95 | "prefillPolicyType" : 0,
96 |
97 | "decodeTimeMsPerReq" : 50,
98 | "decodePolicyType" : 0,
99 |
100 | "maxBatchSize" : 200,
101 | "maxIterTimes" : 4096,
102 | "maxPreemptCount" : 0,
103 | "supportSelectBatch" : false,
104 | "maxQueueDelayMicroseconds" : 5000
105 | }
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/inference/mindie/lite/config.base.json:
--------------------------------------------------------------------------------
1 | {
2 | "Version" : "1.1.0",
3 | "LogConfig" :
4 | {
5 | "logLevel" : "Info",
6 | "logFileSize" : 20,
7 | "logFileNum" : 20,
8 | "logPath" : "logs/mindservice.log"
9 | },
10 |
11 | "ServerConfig" :
12 | {
13 | "ipAddress" : "127.0.0.1",
14 | "managementIpAddress" : "127.0.0.2",
15 | "port" : 1025,
16 | "managementPort" : 1026,
17 | "metricsPort" : 1027,
18 | "allowAllZeroIpListening" : false,
19 | "maxLinkNum" : 1000,
20 | "httpsEnabled" : false,
21 | "fullTextEnabled" : false,
22 | "tlsCaPath" : "security/ca/",
23 | "tlsCaFile" : ["ca.pem"],
24 | "tlsCert" : "security/certs/server.pem",
25 | "tlsPk" : "security/keys/server.key.pem",
26 | "tlsPkPwd" : "security/pass/key_pwd.txt",
27 | "tlsCrlPath" : "security/certs/",
28 | "tlsCrlFiles" : ["server_crl.pem"],
29 | "managementTlsCaFile" : ["management_ca.pem"],
30 | "managementTlsCert" : "security/certs/management/server.pem",
31 | "managementTlsPk" : "security/keys/management/server.key.pem",
32 | "managementTlsPkPwd" : "security/pass/management/key_pwd.txt",
33 | "managementTlsCrlPath" : "security/management/certs/",
34 | "managementTlsCrlFiles" : ["server_crl.pem"],
35 | "kmcKsfMaster" : "tools/pmt/master/ksfa",
36 | "kmcKsfStandby" : "tools/pmt/standby/ksfb",
37 | "inferMode" : "standard",
38 | "interCommTLSEnabled" : true,
39 | "interCommPort" : 1121,
40 | "interCommTlsCaPath" : "security/grpc/ca/",
41 | "interCommTlsCaFiles" : ["ca.pem"],
42 | "interCommTlsCert" : "security/grpc/certs/server.pem",
43 | "interCommPk" : "security/grpc/keys/server.key.pem",
44 | "interCommPkPwd" : "security/grpc/pass/key_pwd.txt",
45 | "interCommTlsCrlPath" : "security/grpc/certs/",
46 | "interCommTlsCrlFiles" : ["server_crl.pem"],
47 | "openAiSupport" : "vllm"
48 | },
49 |
50 | "BackendConfig" : {
51 | "backendName" : "mindieservice_llm_engine",
52 | "modelInstanceNumber" : 1,
53 | "npuDeviceIds" : [[0,1,2,3,4,5,6,7]],
54 | "tokenizerProcessNumber" : 8,
55 | "multiNodesInferEnabled" : false,
56 | "multiNodesInferPort" : 1120,
57 | "interNodeTLSEnabled" : true,
58 | "interNodeTlsCaPath" : "security/grpc/ca/",
59 | "interNodeTlsCaFiles" : ["ca.pem"],
60 | "interNodeTlsCert" : "security/grpc/certs/server.pem",
61 | "interNodeTlsPk" : "security/grpc/keys/server.key.pem",
62 | "interNodeTlsPkPwd" : "security/grpc/pass/mindie_server_key_pwd.txt",
63 | "interNodeTlsCrlPath" : "security/grpc/certs/",
64 | "interNodeTlsCrlFiles" : ["server_crl.pem"],
65 | "interNodeKmcKsfMaster" : "tools/pmt/master/ksfa",
66 | "interNodeKmcKsfStandby" : "tools/pmt/standby/ksfb",
67 | "ModelDeployConfig" :
68 | {
69 | "maxSeqLen" : 14000,
70 | "maxInputTokenLen" : 8192,
71 | "truncation" : false,
72 | "ModelConfig" : [
73 | {
74 | "modelInstanceType" : "Standard",
75 | "modelName" : "bailing_moe",
76 | "modelWeightPath" : "/home/HwHiAiUser/Ascend/Ling_lite_base_safetensor",
77 | "worldSize" : 8,
78 | "cpuMemSize" : 5,
79 | "npuMemSize" : -1,
80 | "backendType" : "atb",
81 | "trustRemoteCode" : false
82 | }
83 | ]
84 | },
85 |
86 | "ScheduleConfig" :
87 | {
88 | "templateType" : "Standard",
89 | "templateName" : "Standard_LLM",
90 | "cacheBlockSize" : 128,
91 |
92 | "maxPrefillBatchSize" : 50,
93 | "maxPrefillTokens" : 8192,
94 | "prefillTimeMsPerReq" : 150,
95 | "prefillPolicyType" : 0,
96 |
97 | "decodeTimeMsPerReq" : 50,
98 | "decodePolicyType" : 0,
99 |
100 | "maxBatchSize" : 200,
101 | "maxIterTimes" : 4096,
102 | "maxPreemptCount" : 0,
103 | "supportSelectBatch" : false,
104 | "maxQueueDelayMicroseconds" : 5000
105 | }
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/utils-file_utils.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import os
3 | import stat
4 |
5 | MAX_PATH_LENGTH = 4096
6 | MAX_FILE_SIZE = 100 * 1024 * 1024
7 | MAX_FILENUM_PER_DIR = 1024
8 | MAX_LINENUM_PER_FILE = 10 * 1024 * 1024
9 |
10 | FLAG_OS_MAP = {
11 | 'r': os.O_RDONLY, 'r+': os.O_RDWR,
12 | 'w': os.O_CREAT | os.O_TRUNC | os.O_WRONLY,
13 | 'w+': os.O_CREAT | os.O_TRUNC | os.O_RDWR,
14 | 'a': os.O_CREAT | os.O_APPEND | os.O_WRONLY,
15 | 'a+': os.O_CREAT | os.O_APPEND | os.O_RDWR,
16 | 'x': os.O_CREAT | os.O_EXCL,
17 | "b": getattr(os, "O_BINARY", 0)
18 | }
19 |
20 |
21 | def safe_open(file_path: str, mode='r', encoding=None, permission_mode=0o600, is_exist_ok=True, **kwargs):
22 | """
23 | :param file_path: 文件路径
24 | :param mode: 文件打开模式
25 | :param encoding: 文件编码方式
26 | :param permission_mode: 文件权限模式
27 | :param is_exist_ok: 是否允许文件存在
28 | :param max_path_length: 文件路径最大长度
29 | :param max_file_size: 文件最大大小,单位: 字节, 默认值10MB
30 | :param check_link: 是否校验软链接
31 | :param kwargs:
32 | :return:
33 | """
34 | max_path_length = kwargs.get('max_path_length', MAX_PATH_LENGTH)
35 | max_file_size = kwargs.get('max_file_size', MAX_FILE_SIZE)
36 | check_link = kwargs.get('check_link', True)
37 |
38 | file_path = standardize_path(file_path, max_path_length, check_link)
39 | check_file_safety(file_path, mode, is_exist_ok, max_file_size)
40 |
41 | flags = []
42 | for item in list(mode):
43 | if item == "+" and flags:
44 | flags[-1] = f"{flags[-1]}+"
45 | continue
46 | flags.append(item)
47 | flags = [FLAG_OS_MAP.get(mode, os.O_RDONLY) for mode in flags]
48 | total_flag = reduce(lambda a, b: a | b, flags)
49 |
50 | return os.fdopen(os.open(file_path, total_flag, permission_mode),
51 | mode, encoding=encoding)
52 |
53 |
54 | def standardize_path(path: str, max_path_length=MAX_PATH_LENGTH, check_link=True):
55 | """
56 | check path
57 | param: path
58 | return: data real path after check
59 | """
60 | check_path_is_none(path)
61 | check_path_length_lt(path, max_path_length)
62 | if check_link:
63 | check_path_is_link(path)
64 | path = os.path.realpath(path)
65 | return path
66 |
67 |
68 | def is_path_exists(path: str):
69 | return os.path.exists(path)
70 |
71 |
72 | def check_path_is_none(path: str):
73 | if path is None:
74 | raise TypeError("The file path should not be None.")
75 |
76 |
77 | def check_path_is_link(path: str):
78 | if os.path.islink(os.path.normpath(path)):
79 | raise ValueError("The path should not be a symbolic link file. "
80 | f"Please check the input path:{path}.")
81 |
82 |
83 | def check_path_length_lt(path: str, max_path_length=MAX_PATH_LENGTH):
84 | path_length = path.__len__()
85 | if path_length > max_path_length:
86 | raise ValueError(f"The length of path should not be greater than {max_path_length}, but got {path_length}. "
87 | f"Please check the input path within the valid length range:{path[:max_path_length]}.")
88 |
89 |
90 | def check_file_size_lt(path: str, max_file_size=MAX_FILE_SIZE):
91 | file_size = os.path.getsize(path)
92 | if file_size > max_file_size:
93 | raise ValueError(f"The size of file should not be greater than {max_file_size}, but got {file_size}. "
94 | f"Please check the input path:{path}.")
95 |
96 |
97 | def check_owner(path: str):
98 | """
99 | check the path owner
100 | param: the input path
101 | """
102 | path_stat = os.stat(path)
103 | path_owner, path_gid = path_stat.st_uid, path_stat.st_gid
104 | cur_uid = os.geteuid()
105 | cur_gid = os.getgid()
106 | if not (cur_uid == 0 or cur_uid == path_owner or path_gid == cur_gid):
107 | raise PermissionError(f"The current user does not have permission to access the path:{path}. "
108 | "Because he is not root or the path owner, "
109 | "and not in the same user group with the path owner. "
110 | "Please check and make sure to satisfy at least one of the conditions above.")
111 |
112 |
113 | def check_other_write_permission(file_path: str):
114 | """
115 | check if the specified file is writable by others who are neither the owner nor in the group
116 | param: the path to the file to be checked
117 | """
118 | # Get the status of the file
119 | file_stat = os.stat(file_path)
120 | # Get the mode (permission) of the file
121 | mode = file_stat.st_mode
122 | # check the write permission for others
123 | if mode & stat.S_IWOTH:
124 | raise PermissionError("The file should not be writable by others who are neither the owner nor in the group. "
125 | f"Please check the input path:{file_path}, and change mode to {mode & ~stat.S_IWOTH}.")
126 |
127 |
128 | def check_path_permission(file_path: str, is_internal_file=False):
129 | check_inputfiles_permission = os.getenv("MINDIE_CHECK_INPUTFILES_PERMISSION", "1") != "0"
130 | check_permission_flag = is_internal_file or check_inputfiles_permission
131 | if check_permission_flag:
132 | check_owner(file_path)
133 | #check_other_write_permission(file_path)
134 |
135 |
136 | def check_file_safety(file_path: str, mode='r', is_exist_ok=True,
137 | max_file_size=MAX_FILE_SIZE, is_check_file_size=True):
138 | if is_path_exists(file_path):
139 | if not is_exist_ok:
140 | raise FileExistsError("The file is expected not to exist, but it already does. "
141 | f"Please check the input path:{file_path}.")
142 | if is_check_file_size:
143 | check_file_size_lt(file_path, max_file_size)
144 | file_dir = file_path
145 | else:
146 | if mode == 'r' or mode == 'r+':
147 | raise FileNotFoundError("The file is expected to exist, but it does not. "
148 | f"Please check the input path:{file_path}.")
149 | file_dir = os.path.dirname(file_path)
150 |
151 | check_path_permission(file_dir)
152 |
153 |
154 | def safe_listdir(file_path: str, max_file_num=MAX_FILENUM_PER_DIR):
155 | filenames = os.listdir(file_path)
156 | file_num = len(filenames)
157 | if file_num > max_file_num:
158 | raise ValueError(f"The file num in dir is {file_num}, which exceeds the limit {max_file_num}. "
159 | f"Please check the input path:{file_path}.")
160 | return filenames
161 |
162 |
163 | def safe_chmod(file_path: str, permission_mode):
164 | standard_path = standardize_path(file_path)
165 | check_path_permission(standard_path)
166 | os.chmod(file_path, permission_mode)
167 |
168 |
169 | def has_owner_write_permission(file_path: str):
170 | st = os.stat(file_path)
171 | return st.st_mode & stat.S_IWUSR
172 |
173 |
174 | def safe_readlines(file_obj, max_line_num=MAX_LINENUM_PER_FILE):
175 | lines = file_obj.readlines()
176 | line_num = len(lines)
177 | if line_num > max_line_num:
178 | raise ValueError(f"The file line num is {line_num}, which exceeds the limit {max_line_num}. "
179 | f"Please check the input file:{file_obj.name}.")
180 | return lines
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ling-Coder-Lite
2 |
3 |
4 |
5 |
6 |
7 |
8 | 🤗 Hugging Face   |   🤖 ModelScope   |   🖥️ GitHub
9 |
10 | ## Introduction
11 |
12 | Ling-Coder-Lite is a MoE LLM provided and open-sourced by InclusionAI, which has 16.8B parameters with 2.75B activated parameters. This model demonstrates state-of-the-art performance on 12 coding benchmarks, while simultaneously offering competitive latency and throughput compared to code LLMs of similar size. In addition to open-sourcing the model itself, we also release a substantial amount of code-related data, including synthetic QA, SFT and DPO datasets.
13 |
14 |
15 |
16 |
17 |
18 | ## Model Downloads
19 |
20 | You can download the following table to see the various parameters for your use case. If you are located in mainland China, we also provide the model on ModelScope.cn to speed up the download process.
21 |
22 |
23 |
24 | | **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Download** |
25 | | :------------: | :---------------: | :-------------------: | :----------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------: |
26 | | Ling-Coder-lite-base | 16.8B | 2.75B | 16K | [🤗 HuggingFace](https://huggingface.co/inclusionAI/Ling-Coder-lite-base)
[🤖 ModelScope](https://modelscope.cn/models/inclusionAI/Ling-Coder-lite-base) |
27 | | Ling-Coder-lite | 16.8B | 2.75B | 16K | [🤗 HuggingFace](https://huggingface.co/inclusionAI/Ling-Coder-lite)
[🤖 ModelScope](https://modelscope.cn/models/inclusionAI/Ling-Coder-lite) |
28 | | Ling-Coder-lite-GPTQ-Int8 | 16.8B | 2.75B | 16K | [🤗 HuggingFace](https://huggingface.co/inclusionAI/Ling-Coder-lite-GPTQ-Int8)
[🤖 ModelScope](https://modelscope.cn/models/inclusionAI/Ling-Coder-lite-GPTQ-Int8) |
29 |
30 |
31 |
32 | ## Dataset Downloads
33 |
34 |
35 |
36 | | **Model** | **Samples** | **Download** |
37 | | :------------: | :----------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------: |
38 | | Ling-Coder-SyntheticQA | 24M | [🤗 HuggingFace](https://huggingface.co/datasets/inclusionAI/Ling-Coder-SyntheticQA)
[🤖 ModelScope](https://modelscope.cn/datasets/inclusionAI/Ling-Coder-SyntheticQA) |
39 | | Ling-Coder-SFT | 5M | [🤗 HuggingFace](https://huggingface.co/datasets/inclusionAI/Ling-Coder-SFT)
[🤖 ModelScope](https://modelscope.cn/datasets/inclusionAI/Ling-Coder-SFT) |
40 | | Ling-Coder-DPO | 250K | [🤗 HuggingFace](https://huggingface.co/datasets/inclusionAI/Ling-Coder-DPO)
[🤖 ModelScope](https://modelscope.cn/datasets/inclusionAI/Ling-Coder-DPO) |
41 |
42 |
43 |
44 | ## Evaluation
45 |
46 | Detailed evaluation results are reported in our [technical report](https://arxiv.org/abs/2503.17793). For detailed evaluation code, please refer to the evaluation method of Ling-Coder-Lite in [CodeFuse-Evaluation](https://github.com/codefuse-ai/codefuse-evaluation).
47 |
48 | ## Quickstart
49 |
50 | ### 🤗 Hugging Face Transformers
51 |
52 | Here is a code snippet to show you how to use the chat model with `transformers`:
53 |
54 | ```python
55 | from transformers import AutoModelForCausalLM, AutoTokenizer
56 |
57 | model_name = "inclusionAI/Ling-Coder-lite"
58 |
59 | model = AutoModelForCausalLM.from_pretrained(
60 | model_name,
61 | torch_dtype="auto",
62 | device_map="auto",
63 | trust_remote_code=True
64 | )
65 | tokenizer = AutoTokenizer.from_pretrained(
66 | model_name,
67 | trust_remote_code=True
68 | )
69 |
70 | prompt = "Write a quick sort algorithm in python."
71 | messages = [
72 | {"role": "user", "content": prompt}
73 | ]
74 | text = tokenizer.apply_chat_template(
75 | messages,
76 | tokenize=False,
77 | add_generation_prompt=True
78 | )
79 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
80 |
81 | generated_ids = model.generate(
82 | **model_inputs,
83 | max_new_tokens=512
84 | )
85 | generated_ids = [
86 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
87 | ]
88 |
89 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
90 | print(response)
91 | ```
92 |
93 | ### 🤖 ModelScope
94 |
95 | If you're in mainland China, we strongly recommend you to use our model from 🤖 ModelScope.
96 |
97 | ## Deployment
98 |
99 | ### vLLM
100 |
101 | vLLM supports offline batched inference or launching an OpenAI-Compatible API Service for online inference.
102 |
103 | #### Environment Preparation
104 |
105 | Since the Pull Request (PR) has not been submitted to the vLLM community at this stage, please prepare the environment by following the steps below:
106 |
107 | ```bash
108 | git clone -b v0.7.3 https://github.com/vllm-project/vllm.git
109 | cd vllm
110 | git apply Ling-Coder-Lite/inference/vllm/bailing_moe.patch
111 | pip install -e .
112 | ```
113 |
114 | #### Offline Inference:
115 |
116 | ```bash
117 | from transformers import AutoTokenizer
118 | from vllm import LLM, SamplingParams
119 |
120 | tokenizer = AutoTokenizer.from_pretrained("inclusionAI/Ling-Coder-lite")
121 |
122 | sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)
123 |
124 | llm = LLM(model="inclusionAI/Ling-Coder-lite", dtype='bfloat16')
125 | prompt = "Give me a short introduction to large language models."
126 | messages = [
127 | {"role": "system", "content": "You are Ling-Coder-Lite, an assistant created by CodeFuse-AI"},
128 | {"role": "user", "content": prompt}
129 | ]
130 |
131 | text = tokenizer.apply_chat_template(
132 | messages,
133 | tokenize=False,
134 | add_generation_prompt=True
135 | )
136 | outputs = llm.generate([text], sampling_params)
137 |
138 |
139 | ```
140 |
141 | #### Online Inference:
142 |
143 | ```bash
144 | vllm serve inclusionAI/Ling-lite \
145 | --tensor-parallel-size 2 \
146 | --pipeline-parrallel-size 1 \
147 | --use-v2-block-manager \
148 | --gpu-memory-utilization 0.90
149 | ```
150 |
151 | For detailed guidance, please refer to the vLLM [`instructions`](https://docs.vllm.ai/en/latest/).
152 |
153 | ### vLLM GPTQ Int8
154 |
155 | #### Environment Preparation
156 |
157 | Requirement: `vllm==0.6.3.post1`.
158 |
159 | Patch `ling_gptq.patch` onto vLLM by executing:
160 | ```bash
161 | patch -p1 < ling_gptq.patch -d $(python -c "from importlib.util import find_spec; print(find_spec('vllm').submodule_search_locations[0])")
162 | ```
163 |
164 | #### Inference Example
165 |
166 | ```python
167 | from vllm import LLM
168 | from vllm.sampling_params import SamplingParams
169 | from transformers import AutoTokenizer
170 |
171 | model_name = "inclusionAI/Ling-Coder-lite-GPTQ-Int8"
172 |
173 | model = LLM(model_name, trust_remote_code=True, gpu_memory_utilization=0.80, max_model_len=4096)
174 |
175 | tokenizer = AutoTokenizer.from_pretrained(
176 | model_name,
177 | trust_remote_code=True
178 | )
179 |
180 | prompt = "Write a quick sort algorithm in python."
181 | messages = [
182 | {"role": "user", "content": prompt}
183 | ]
184 | text = tokenizer.apply_chat_template(
185 | messages,
186 | tokenize=False,
187 | add_generation_prompt=True
188 | )
189 |
190 | sample_params = SamplingParams(max_tokens=1024, ignore_eos=False)
191 | outputs = model.generate(text, sampling_params=sample_params, prompt_token_ids=None)
192 |
193 | for output in outputs:
194 | generated_text = output.outputs[0].text
195 | print(generated_text)
196 | ```
197 |
198 | Note: No extra parameters required by this GPTQ Int8 quantized model for vLLM online serving.
199 |
200 | ## Finetuning
201 |
202 | We recommend you to use [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory) to finetune Ling with SFT, DPO, etc.
203 |
204 | We use [`identity`](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/identity.json) to demonstrate how to finetune our Ling models by replacing `name` with `Ling` and `author` with `inclusionAI`.
205 |
206 | ```json
207 | {
208 | "instruction": "hi",
209 | "input": "",
210 | "output": "Hello! I am Ling-Coder-Lite, an AI assistant developed by CodeFuse-AI. How can I assist you today?"
211 | }
212 | ```
213 |
214 | We provide a demo configuration of `Llama-Factory` to SFT Ling models as follows:
215 |
216 | ```bash
217 | llamafactory-cli train examples/sft/ling_full_sft.yaml
218 | ```
219 |
220 | ## License
221 |
222 | This code repository is licensed under [the MIT License](https://github.com/codefuse-ai/Ling-Coder-Lite/blob/master/LICENCE).
223 |
224 | ## Citation
225 | If you find our work is useful or helpful, please feel free to cite our paper as below.
226 |
227 | ```
228 | @misc{codefuse2025samplemattersleveragingmixtureofexperts,
229 | title={Every Sample Matters: Leveraging Mixture-of-Experts and High-Quality Data for Efficient and Accurate Code LLM},
230 | author={Codefuse and Ling Team},
231 | year={2025},
232 | eprint={2503.17793},
233 | archivePrefix={arXiv},
234 | primaryClass={cs.LG},
235 | url={https://arxiv.org/abs/2503.17793},
236 | }
237 | ```
238 |
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/atb_llm-models-base-router.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | from dataclasses import dataclass
4 | from typing import Any, Dict, List, Optional, Union
5 |
6 | import numpy as np
7 | from transformers.configuration_utils import PretrainedConfig
8 | from transformers.generation.utils import GenerationConfig
9 |
10 | from .input_builder import InputBuilder
11 | from .model_utils import safe_get_tokenizer_from_pretrained, safe_get_config_dict
12 | from .postprocessor import Postprocessor
13 | from ...utils.env import ENV
14 | from ...utils.log import logger
15 |
16 |
17 | def remove_part_of_generation_config(generation_config):
18 | """Using the transformers' GenerationConfig class, update the generation configuration with the default value."""
19 | ori_gen = GenerationConfig()
20 | for key in generation_config:
21 | if key.endswith("_id"):
22 | continue
23 | ori_value = getattr(ori_gen, key, None)
24 | if ori_value is not None:
25 | generation_config[key] = ori_value
26 | return generation_config
27 |
28 |
29 | @dataclass
30 | class BaseRouter:
31 | """The base class of router.
32 |
33 | This class should be inherited by the corresponding router subclasses of specified models. A specified model can use
34 | a subclass router to find its custom properties.
35 | """
36 | model_name_or_path: str = ""
37 |
38 | config_dict: Any = None
39 | is_flash_causal_lm: bool = True
40 | load_tokenizer: bool = True
41 | max_position_embeddings: Optional[int] = None
42 | revision: Optional[str] = None
43 | tokenizer_path: Optional[str] = None
44 | trust_remote_code: bool = False
45 | enable_atb_torch: bool = False
46 |
47 | # 初始化默认读取的autoconfig,各个模型可能会自定义,self.config会返回后续使用的config,注意不要循环依赖
48 | _config: Any = None
49 | _generation_config: Any = None
50 | _input_builder: Any = None
51 | _model_cls: Any = None
52 | _postprocessor: Any = None
53 | _tokenizer: Any = None
54 | is_inited: bool = False
55 | _tool_call_processor: Any = None
56 |
57 | def __post_init__(self):
58 | self.model_type = self.config_dict['model_type']
59 | if self.model_type == "chatglm" and "vision_config" in self.config_dict:
60 | self.model_type = "glm4v"
61 | if self.model_type == "internvl_chat":
62 | self.model_type = "internvl"
63 | if self.model_type == "llava_next_video":
64 | self.model_type = "llava_next"
65 | if self.model_type == "minicpmv" and "MiniCPM-Llama3-V-2_5" in self.model_name_or_path:
66 | self.model_type = "minicpm_llama3_v2"
67 | if self.model_type == "bunny-qwen2" or self.model_type == "bunny-minicpm":
68 | self.model_type = "bunny"
69 | if self.model_type == "bailing_moe":
70 | self.model_type = "deepseek"
71 | self.model_type_cap = self.model_type.capitalize()
72 | if self.model_type_cap == "Qwen2_moe" or self.model_type_cap == "Minicpm_llama3_v2":
73 | self.model_type_cap = self.model_type_cap.replace('_', '')
74 | if self.model_type_cap == "Qwen2_audio":
75 | self.model_type_cap = self.model_type_cap.replace('_', '')
76 | if self.model_type_cap == "Qwen2_vl":
77 | self.model_type_cap = self.model_type_cap.replace('_', '')
78 | if not self.tokenizer_path:
79 | self.tokenizer_path = self.model_name_or_path
80 |
81 | @property
82 | def config(self):
83 | """The config property, which should not be overridden.
84 |
85 | It uses generation config to update config dictionary at first, and then uses the `get_config` method to get a
86 | config object. Note that the `get_config` method should use `config_dict` to initialize the config object.
87 | """
88 | if self._config is None:
89 | self._generation_config = self.generation_config
90 | if ENV.remove_generation_config_dict:
91 | self._generation_config = remove_part_of_generation_config(self._generation_config)
92 | self.config_dict.update(self._generation_config)
93 | self._config = self.get_config()
94 | if not hasattr(self._config, 'quantize'):
95 | setattr(self._config, 'quantize', None)
96 | if self.max_position_embeddings is not None:
97 | setattr(self._config, 'max_position_embeddings', self.max_position_embeddings)
98 | return self._config
99 |
100 | @property
101 | def generation_config(self):
102 | """The generation config property, which should not be overridden."""
103 | if self._generation_config is None:
104 | self._generation_config = self.get_generation_config()
105 | return self._generation_config
106 |
107 | @property
108 | def input_builder(self):
109 | """The input builder property, which should not be overridden."""
110 | if self._input_builder is None:
111 | self._input_builder = self.get_input_builder()
112 | return self._input_builder
113 |
114 | @property
115 | def model_cls(self):
116 | """The model class property, which should not be overridden."""
117 | if self._model_cls is None:
118 | self._model_cls = self.get_model_cls()
119 | return self._model_cls
120 |
121 | @property
122 | def model_version(self):
123 | """The model version property, which should not be overridden."""
124 | return ""
125 |
126 | @property
127 | def embedding_model_name(self):
128 | """The model name property, which should not be overridden."""
129 | return ""
130 |
131 | @property
132 | def postprocessor(self):
133 | """The postprocessor property, which should not be overridden."""
134 | if self._postprocessor is None:
135 | self._postprocessor = self.get_postprocessor()
136 | return self._postprocessor
137 |
138 | @property
139 | def tokenizer(self):
140 | """The tokenizer property, which should not be overridden."""
141 | if self._tokenizer is None and self.load_tokenizer:
142 | self._tokenizer = self.get_tokenizer()
143 | return self._tokenizer
144 |
145 | @property
146 | def toolscallprocessor(self):
147 | """The tools call processor property, which should not be overridden."""
148 | if self._tool_call_processor is None:
149 | self._tool_call_processor = self.get_toolscallprocessor()
150 | return self._tool_call_processor
151 |
152 | @staticmethod
153 | def check_config(config):
154 | """The validation of values in config."""
155 | eos_token_id_field = 'eos_token_id'
156 |
157 | vocab_size = 0
158 | vocab_size_field = 'vocab_size'
159 | if hasattr(config, vocab_size_field):
160 | vocab_size = getattr(config, vocab_size_field)
161 | attribute_ranges = {
162 | vocab_size_field: (1, 2147483647),
163 | 'max_position_embeddings': (1, 2147483647),
164 | 'hidden_size': (1, 2147483647),
165 | 'intermediate_size': (1, 2147483647),
166 | 'num_hidden_layers': (1, 1000),
167 | 'num_attention_heads': (1, 10000),
168 | 'initializer_range': (0, 2147483647),
169 | 'rms_norm_eps': (0, 1),
170 | 'pad_token_id': (-1, vocab_size),
171 | 'bos_token_id': (0, vocab_size - 1),
172 | eos_token_id_field: (0, vocab_size - 1),
173 | 'temperature': (0, 2),
174 | 'top_k': (-1, vocab_size),
175 | 'top_p': (0, 1),
176 | 'repetition_penalty': (0, 2),
177 | 'frequency_penalty': (-2, 2),
178 | 'presence_penalty': (-2, 2)
179 | }
180 | if hasattr(config, "head_dim"):
181 | attribute_ranges['head_dim'] = (1, 1000)
182 | if hasattr(config, "num_key_value_heads"):
183 | attribute_ranges['num_key_value_heads'] = (1, 10000)
184 |
185 | def check_value(attr_ins, value_ins):
186 | if value_ins < min_val or value_ins > max_val:
187 | raise ValueError(f"self._config.{attr_ins} must be between {min_val} and {max_val}")
188 |
189 | def check_eos(eos_value):
190 | if isinstance(eos_value, int):
191 | check_value(eos_token_id_field, eos_value)
192 | elif isinstance(eos_value, list):
193 | for eos_v in eos_value:
194 | if isinstance(eos_v, int):
195 | check_value(eos_token_id_field, eos_v)
196 | elif isinstance(eos_v, list):
197 | for v in eos_v:
198 | check_value(eos_token_id_field, v)
199 | else:
200 | raise ValueError("eos_token_id must be Union[int, List[Union[int, List[int]]]].")
201 | else:
202 | raise ValueError("eos_token_id must be Union[int, List[Union[int, List[int]]]].")
203 |
204 | for attr, (min_val, max_val) in attribute_ranges.items():
205 | if not hasattr(config, attr) or getattr(config, attr) is None:
206 | continue
207 | value = getattr(config, attr)
208 | if attr == eos_token_id_field:
209 | check_eos(value)
210 | continue
211 | check_value(attr, value)
212 |
213 | if getattr(config, 'repetition_penalty', None) == 0:
214 | raise ValueError("repetition_penalty should not be 0.")
215 | if not isinstance(getattr(config, 'do_sample', None), bool):
216 | raise ValueError("do_sample must be bool.")
217 |
218 | def tokenize(self, inputs: List[Union[str, Dict[str, str]]], **kwargs) -> np.ndarray:
219 | """Transfer text input or multimodal input to token ids.
220 |
221 | Args:
222 | inputs: List | List[Dict], when it's List, it means the input for LLM.
223 | When it's List[Dict], it means the multimodal inputs in interleaved style,
224 | for example:
225 | [
226 | {'text': 'Let me show you two pictures'},
227 | {'image': 'image_url_or_path'},
228 | {'image': 'image_url_or_path'},
229 | {'text': 'can you show the differences?'}
230 | ]
231 |
232 | Returns:
233 | numpy.ndarray: The expanded input_ids whose dimension is 1.
234 | """
235 | return self.tokenizer([inputs[0]["text"]], return_tensors="np")["input_ids"][0]
236 |
237 | def get_config(self):
238 | """The default method to get config.
239 |
240 | A subclass router can override it to define a custom method getting config. Note that the `get_config` method
241 | should use `self.config_dict` instead of the model weight path to construct a config object.
242 | """
243 | try:
244 | config_cls = self.get_config_cls()
245 | config = config_cls.from_dict(self.config_dict)
246 | except Exception as e:
247 | logger.warning(str(e))
248 | config = PretrainedConfig.from_dict(self.config_dict)
249 | self.check_config(config)
250 | return config
251 |
252 | def get_generation_config(self):
253 | """The default method to get generation config."""
254 | generation_config_path = os.path.join(self.model_name_or_path, "generation_config.json")
255 | generation_config = {}
256 | if os.path.exists(generation_config_path):
257 | generation_config = safe_get_config_dict(generation_config_path)
258 | return generation_config
259 |
260 | def get_config_cls(self):
261 | """The default method to get config class."""
262 | model_file_dir_name = f"atb_llm.models.{self.model_type}."
263 | if self.model_version:
264 | model_file_dir_name = model_file_dir_name + \
265 | f"{self.model_version}."
266 | config_file_name = f'config_{self.model_type}'
267 | module_path = f"{model_file_dir_name}{config_file_name}"
268 | module = importlib.import_module(module_path)
269 | config_cls_name = f"{self.model_type_cap}Config"
270 | return getattr(module, config_cls_name)
271 |
272 | def get_input_builder(self):
273 | """The default method to get input builder."""
274 | if hasattr(self.config, "max_position_embeddings") and self.config.max_position_embeddings:
275 | return InputBuilder(self.tokenizer, max_length=self.config.max_position_embeddings)
276 | return InputBuilder(self.tokenizer)
277 |
278 | def get_model_cls(self):
279 | """The default method to get model class.
280 |
281 | This is a basic router method to find model class, which is usually not necessary to be overridden.
282 | """
283 | model_file_dir_name = f"atb_llm.models.{self.model_type}."
284 | if self.model_version:
285 | model_file_dir_name = model_file_dir_name + \
286 | f"{self.model_version}."
287 | model_file_name = 'flash_causal' if self.is_flash_causal_lm else 'causal'
288 | if self.embedding_model_name: # for embedding model, example: gte-qwen2
289 | module_path = f"{model_file_dir_name}{model_file_name}_{self.model_type}_{self.embedding_model_name}"
290 | else:
291 | module_path = f"{model_file_dir_name}{model_file_name}_{self.model_type}"
292 | if self.enable_atb_torch:
293 | module_path += "_atb"
294 | module = importlib.import_module(module_path)
295 | model_cls_name = f"{self.model_type_cap}ForCausalLM"
296 | if self.enable_atb_torch:
297 | model_cls_name += "ATB"
298 | if self.is_flash_causal_lm:
299 | model_cls_name = "Flash" + model_cls_name
300 | return getattr(module, model_cls_name)
301 |
302 | def get_postprocessor(self):
303 | """The default method to get postprocessor."""
304 | return Postprocessor(self.tokenizer, self.generation_config)
305 |
306 | def get_tokenizer(self):
307 | """The default method to get tokenizer."""
308 | return safe_get_tokenizer_from_pretrained(
309 | self.tokenizer_path,
310 | revision=self.revision,
311 | padding_side="left",
312 | truncation_side="left",
313 | trust_remote_code=self.trust_remote_code,
314 | use_fast=True
315 | )
316 |
317 | def get_toolscallprocessor(self):
318 | """The default method to get tools call processor."""
319 | return ToolsCallProcessor(self.model_version)
320 |
321 |
322 | class ToolsCallProcessor:
323 | """Base class for tools call processor."""
324 | def __init__(self, model_version):
325 | self.model_version = model_version
326 |
327 | @staticmethod
328 | def decode(content):
329 | """Parse model output to extract tools call output."""
330 | return content
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/modeling_deepseek.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | import torch
3 | from torch import nn
4 | from atb_llm.models.base.modeling import FlashAttention, MLP
5 | from atb_llm.utils.layers.linear import FastLinear, TensorReplicatedLinear
6 | from atb_llm.utils.layers import (
7 | TensorParallelRowLinear,
8 | TensorParallelColumnLinear,
9 | TensorEmbedding,
10 | load_column_multi,
11 | RMSNorm
12 | )
13 | from atb_llm.utils.moe_utils import assign
14 | from atb_llm.utils.quantize.pack_type import get_pack_type
15 | from atb_llm.utils.quantize.pack_type import PackType, calc_linear_pack_type
16 | from atb_llm.utils.log import logger
17 | from atb_llm.utils.log.error_code import ErrorCode
18 |
19 |
20 | class DeepseekMLP(MLP):
21 | def __init__(self, prefix, config, weights, **kwargs):
22 | super().__init__(prefix, config, weights, **kwargs)
23 | self.load_weights(**kwargs)
24 |
25 | def get_suffix(tensor_name: str) -> str:
26 | """Get the suffix of a tensor name."""
27 | return tensor_name.split(".")[-1]
28 |
29 | class FlashDeepseekAttention(FlashAttention):
30 | def __init__(self, prefix: str, config, weights, **kwargs):
31 | super().__init__(prefix, config, weights, **kwargs)
32 | self.qkv_names = [f'{self.prefix}.query_key_value']
33 | self.dense_name = f'{self.prefix}.dense'
34 | self.pack_type = get_pack_type(self.weights, self.qkv_names, self.norm_name, self.pack_name)
35 | '''
36 | super().load_qkv_weights(**kwargs)
37 |
38 | dense_linear = TensorParallelRowLinear.load_att_dense(
39 | self.config,
40 | prefix=self.dense_name,
41 | weights=self.weights,
42 | bias=self.dense_bias,
43 | gqa_size=self.head_size,
44 | bias_pre_add=self.bias_pre_add
45 | )
46 | setattr(self, get_suffix(self.dense_name), dense_linear)
47 | '''
48 | if config.model_type == "bailing_moe" and config.num_hidden_layers == 88:
49 | padding = True
50 |
51 | query_key_value_linear = TensorParallelColumnLinear.load_qkv(
52 | self.config,
53 | prefix=self.qkv_names[0],
54 | weights=self.weights,
55 | bias=self.qkv_bias,
56 | hidden_size=self.hidden_size,
57 | num_heads=self.num_heads,
58 | num_kv_heads=self.num_kv_heads,
59 | padding=padding,
60 | )
61 | setattr(self, get_suffix(self.pack_name), query_key_value_linear)
62 |
63 | dense_linear = TensorParallelRowLinear.load_att_dense(
64 | self.config,
65 | prefix=self.dense_name,
66 | weights=self.weights,
67 | bias=self.dense_bias,
68 | gqa_size=self.head_size,
69 | bias_pre_add=self.bias_pre_add
70 | )
71 | setattr(self, get_suffix(self.dense_name), dense_linear)
72 | elif config.model_type == "bailing_moe" and config.num_hidden_layers == 28:
73 | padding = False
74 | query_key_value_linear = TensorParallelColumnLinear.load_qkv(
75 | self.config,
76 | prefix=self.qkv_names[0],
77 | weights=self.weights,
78 | bias=self.qkv_bias,
79 | hidden_size=self.hidden_size,
80 | num_heads=self.num_heads,
81 | num_kv_heads=self.num_kv_heads,
82 | padding=padding,
83 | )
84 | setattr(self, get_suffix(self.pack_name), query_key_value_linear)
85 |
86 | dense_linear = TensorParallelRowLinear.load(
87 | self.config,
88 | prefix=self.dense_name,
89 | weights=self.weights,
90 | bias=self.dense_bias,
91 | gqa_size=self.head_size,
92 | bias_pre_add=self.bias_pre_add,
93 | )
94 | setattr(self, get_suffix(self.dense_name), dense_linear)
95 |
96 |
97 | class FlashDeepseekLayer(nn.Module):
98 | class ForwardInputArgs:
99 | def __init__(self,
100 | hidden_states: torch.tensor,
101 | residual: torch.tensor,
102 | cos: torch.tensor,
103 | sin: torch.tensor,
104 | cu_seqlen_prefill: torch.tensor,
105 | kv_cache: Tuple[torch.tensor, torch.tensor],
106 | block_tables: List[torch.tensor],
107 | slots: torch.tensor,
108 | input_lengths: torch.tensor,
109 | max_s: torch.tensor):
110 | self.hidden_states = hidden_states
111 | self.residual = residual
112 | self.cos = cos
113 | self.sin = sin
114 | self.cu_seqlen_prefill = cu_seqlen_prefill
115 | self.kv_cache = kv_cache
116 | self.block_tables = block_tables
117 | self.slots = slots
118 | self.input_lengths = input_lengths
119 | self.max_s = max_s
120 |
121 | def __init__(self, layer_id, config, weights):
122 | super().__init__()
123 | prefix = f"model.layers.{layer_id}"
124 | self.self_attn = FlashDeepseekAttention(
125 | prefix=f"{prefix}.attention", config=config, weights=weights
126 | )
127 | if (config.num_experts is not None and
128 | layer_id >= config.first_k_dense_replace and
129 | layer_id % config.moe_layer_freq == 0):
130 | self.mlp = DeepseekMoE(prefix=f"{prefix}.mlp", config=config, weights=weights, shared_mlp_cls=DeepseekMLP)
131 | else:
132 | self.mlp = DeepseekMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
133 | self.input_layernorm = RMSNorm(
134 | prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
135 | )
136 | self.post_attention_layernorm = RMSNorm(
137 | prefix=f"{prefix}.post_attention_layernorm",
138 | weights=weights,
139 | eps=config.rms_norm_eps,
140 | )
141 |
142 | def forward(
143 | self,
144 | input_args: ForwardInputArgs
145 | ):
146 | hidden_states = input_args.hidden_states
147 | residual = input_args.residual
148 | cos = input_args.cos
149 | sin = input_args.sin
150 | cu_seqlen_prefill = input_args.cu_seqlen_prefill
151 | kv_cache = input_args.kv_cache
152 | block_tables = input_args.block_tables
153 | slots = input_args.slots
154 | input_lengths = input_args.input_lengths
155 | max_s = input_args.max_s
156 | normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
157 |
158 | # Self Attention
159 | attn_output = self.self_attn(
160 | normed_hidden_states,
161 | cos,
162 | sin,
163 | cu_seqlen_prefill,
164 | kv_cache,
165 | block_tables,
166 | slots,
167 | input_lengths,
168 | max_s,
169 | )
170 |
171 | # faster post attention rms norm
172 | normed_attn_res_output, attn_res = self.post_attention_layernorm(
173 | attn_output, res
174 | )
175 |
176 | mlp_output = self.mlp(normed_attn_res_output)
177 |
178 | return mlp_output, attn_res
179 |
180 |
181 | class FlashDeepseekModel(torch.nn.Module):
182 | class ForwardInputArgs:
183 | def __init__(self,
184 | input_ids: torch.Tensor,
185 | position_ids: torch.Tensor,
186 | cu_seqlen_prefill: Optional[torch.Tensor],
187 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
188 | block_tables: torch.Tensor,
189 | slots: torch.Tensor,
190 | input_lengths: torch.Tensor,
191 | max_s: int,
192 | lm_head_indices: Optional[torch.Tensor] = None):
193 | self.input_ids = input_ids
194 | self.position_ids = position_ids
195 | self.cu_seqlen_prefill = cu_seqlen_prefill
196 | self.kv_cache = kv_cache
197 | self.block_tables = block_tables
198 | self.slots = slots
199 | self.input_lengths = input_lengths
200 | self.max_s = max_s
201 | self.lm_head_indices = lm_head_indices
202 |
203 | def __init__(self, config, weights):
204 | super().__init__()
205 |
206 | process_group = weights.process_group
207 | self.tp_rank = process_group.rank()
208 | self.tp_world_size = process_group.size()
209 | self.embed_tokens = TensorEmbedding(
210 | prefix="model.word_embeddings", weights=weights
211 | )
212 | self.layers = nn.ModuleList(
213 | [
214 | FlashDeepseekLayer(
215 | layer_id,
216 | config,
217 | weights,
218 | )
219 | for layer_id in range(config.num_hidden_layers)
220 | ]
221 | )
222 | self.norm = RMSNorm(
223 | prefix="model.norm", weights=weights, eps=config.rms_norm_eps
224 | )
225 |
226 |
227 | class DeepseekMoE(nn.Module):
228 | """
229 | A mixed expert module containing shared experts.
230 | """
231 |
232 | def __init__(self, prefix, config, weights, shared_mlp_cls):
233 | super().__init__()
234 | process_group = weights.process_group
235 | self.tp_rank = process_group.rank()
236 | self.tp_world_size = process_group.size()
237 | self.tp = True # defaulting the model to tensor parallel
238 | self.expert_parallel_degree = 1 if self.tp else self.tp_world_size
239 | if self.expert_parallel_degree == 0:
240 | msg = "expert parallel degree should not be 0!"
241 | logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
242 | raise ValueError(msg)
243 | self.expert_lists = []
244 | if self.tp:
245 | self.expert_lists = [[i for i in range(config.num_experts)] for j in range(self.tp_world_size)]
246 | else:
247 | self.expert_lists = assign(config.num_experts, self.tp_world_size)
248 | self.config = config
249 | self.hidden_dim = self.config.hidden_size
250 | self.num_experts_per_tok = config.num_experts_per_tok
251 | self.num_experts = config.num_experts
252 | expert_prefix = f"{prefix}.experts"
253 | self.gate = FastLinear.load(
254 | prefix=f"{prefix}.gate",
255 | weights=weights,
256 | bias=False,
257 | )
258 | linear_names = [f'{expert_prefix}.0.up_proj', f'{expert_prefix}.0.gate_proj']
259 | pack_name = f'{expert_prefix}.0.gate_up_proj'
260 | layer_prefix = '.'.join(prefix.split('.')[:-1])
261 | norm_name = f'{layer_prefix}.post_attention_layernorm'
262 | self.pack_type = calc_linear_pack_type(weights, linear_names, norm_name, pack_name)
263 | if self.tp:
264 | if self.pack_type in [
265 | PackType.ALL_FP, PackType.ALL_W8A8, PackType.ALL_W8A8_ANTI, PackType.ALL_W4A16,
266 | PackType.ALL_W4A16_ANTI, PackType.ALL_W8A16, PackType.ALL_W8A16_ANTI,
267 | PackType.MIX_W8A8_DYNAMIC, PackType.MIX_W8A8_DYNAMIC_ANTI,
268 | PackType.ALL_W8A8_DYNAMIC, PackType.ALL_W8A8_DYNAMIC_ANTI
269 | ]:
270 | self.gate_up_proj = nn.ModuleList()
271 | for i in range(self.num_experts):
272 | self.gate_up_proj.append(load_column_multi(
273 | config,
274 | prefixes=[f"{expert_prefix}.{i}.gate_proj", f"{expert_prefix}.{i}.up_proj"],
275 | weights=weights,
276 | head_size=1,
277 | ))
278 | elif self.pack_type in [PackType.ALL_W8A8SC, PackType.ALL_W8A8SC_ANTI]:
279 | self.gate_up_proj = nn.ModuleList()
280 | for i in range(self.num_experts):
281 | self.gate_up_proj.append(TensorParallelColumnLinear.load(
282 | config,
283 | prefix=f"{expert_prefix}.{i}.gate_up_proj",
284 | weights=weights,
285 | bias=False,
286 | ))
287 | else:
288 | self.gate_proj = nn.ModuleList()
289 | for i in range(self.num_experts):
290 | self.gate_proj.append(TensorParallelColumnLinear.load(
291 | config,
292 | prefix=f"{expert_prefix}.{i}.gate_proj",
293 | weights=weights,
294 | bias=False,
295 | ))
296 | self.up_proj = nn.ModuleList()
297 | for i in range(self.num_experts):
298 | self.up_proj.append(TensorParallelColumnLinear.load(
299 | config,
300 | prefix=f"{expert_prefix}.{i}.up_proj",
301 | weights=weights,
302 | bias=False,
303 | ))
304 | self.down_proj = nn.ModuleList()
305 | for i in range(self.num_experts):
306 | self.down_proj.append(TensorParallelRowLinear.load(
307 | config,
308 | prefix=f"{expert_prefix}.{i}.down_proj",
309 | weights=weights,
310 | bias=False,
311 | ))
312 | self.intermediate_size = (
313 | (config.intermediate_size + weights.process_group.size() - 1) // weights.process_group.size()
314 | )
315 | else:
316 | if self.pack_type in [
317 | PackType.ALL_FP, PackType.ALL_W8A8, PackType.ALL_W8A8_ANTI, PackType.ALL_W4A16,
318 | PackType.ALL_W4A16_ANTI, PackType.ALL_W8A16, PackType.ALL_W8A16_ANTI,
319 | PackType.MIX_W8A8_DYNAMIC, PackType.MIX_W8A8_DYNAMIC_ANTI,
320 | PackType.ALL_W8A8_DYNAMIC, PackType.ALL_W8A8_DYNAMIC_ANTI
321 | ]:
322 | self.gate_up_proj = nn.ModuleList()
323 | for i in self.expert_lists[self.tp_rank]:
324 | self.gate_up_proj.append(TensorReplicatedLinear.load(
325 | config,
326 | prefixes=[f"{expert_prefix}.{i}.gate_proj", f"{expert_prefix}.{i}.up_proj"],
327 | weights=weights,
328 | head_size=1,
329 | ))
330 | self.down_proj = nn.ModuleList()
331 | for i in self.expert_lists[self.tp_rank]:
332 | self.down_proj.append(TensorReplicatedLinear.load(
333 | config,
334 | prefix=f"{expert_prefix}.{i}.down_proj",
335 | weights=weights,
336 | bias=False,
337 | ))
338 |
339 | if config.num_shared_experts is not None:
340 | intermediate_size = config.moe_intermediate_size * config.num_shared_experts
341 | shared_expert_prefix = f"{prefix}.shared_experts"
342 | self.shared_experts = shared_mlp_cls(
343 | prefix=shared_expert_prefix,
344 | config=config,
345 | weights=weights,
346 | intermediate_size=intermediate_size
347 | )
348 |
349 | def forward(self, hidden_states):
350 | identity = hidden_states
351 | orig_shape = hidden_states.shape
352 | topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
353 | hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
354 | flat_topk_idx = topk_idx.view(-1)
355 | y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
356 | if self.config.num_shared_experts is not None:
357 | y = y + self.shared_experts(identity)
358 | return y
359 |
360 | @torch.no_grad()
361 | def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
362 | expert_cache = torch.zeros_like(x)
363 | idxs = flat_expert_indices.argsort()
364 | tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
365 | token_idxs = idxs // self.num_experts_per_tok
366 | for i, end_idx in enumerate(tokens_per_expert):
367 | start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
368 | if start_idx == end_idx:
369 | continue
370 | expert = self.experts[i]
371 | exp_token_idx = token_idxs[start_idx:end_idx]
372 | expert_tokens = x[exp_token_idx]
373 | expert_out = expert(expert_tokens)
374 | expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
375 | device = expert_cache.device
376 | expert_cache_cpu = expert_cache.cpu()
377 | expert_cache_cpu.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]).cpu(),
378 | expert_out.cpu(), reduce='sum')
379 | expert_cache = expert_cache_cpu.to(device=device)
380 | return expert_cache
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/utils-layers-linear-__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from atb_llm.utils.log import logger
5 | from torch import nn
6 |
7 | from .fast_linear import FastLinear
8 | from ...quantize.quant_type import QuantType
9 | from ...quantize.w4a16 import W4A16LinearStatic
10 | from ...quantize.w8a16 import W8A16LinearStatic
11 | from ...quantize.w8a8 import W8A8LinearStatic
12 | from ...quantize.w8a8sc import W8A8SparseCompressedLinear
13 | from ...quantize.w8a8_dynamic import W8A8DynamicLinearStatic
14 |
15 |
16 | def get_linear(weight, bias, quantize, is_norm=False, **kwargs):
17 | if quantize is None:
18 | linear = FastLinear(weight, bias, is_norm)
19 | elif quantize in [QuantType.W8A8, QuantType.W8A8S]:
20 | if isinstance(weight, torch.Tensor):
21 | linear = FastLinear(weight, bias, is_norm)
22 | else:
23 | try:
24 | qweight, deq_scale, quant_bias, input_scale, input_offset = weight
25 | except Exception as err:
26 | logger.error(
27 | "The passed weight is not `w8a8` compatible, loader needs to be updated."
28 | )
29 | raise AssertionError from err
30 | linear = W8A8LinearStatic(
31 | weight=qweight,
32 | deq_scale=deq_scale,
33 | input_scale=input_scale,
34 | quant_bias=quant_bias,
35 | input_offset=input_offset,
36 | bias=bias
37 | )
38 | elif quantize == QuantType.W4A16:
39 | if isinstance(weight, torch.Tensor):
40 | linear = FastLinear(weight, bias, is_norm)
41 | else:
42 | try:
43 | qweight, weight_scale, weight_offset = weight
44 | except Exception as err:
45 | logger.error(
46 | "The passed weight is not `w4a16` compatible, loader needs to be updated."
47 | )
48 | raise AssertionError from err
49 | linear = W4A16LinearStatic(
50 | weight=qweight,
51 | weight_scale=weight_scale,
52 | weight_offset=weight_offset,
53 | bias=bias
54 | )
55 | elif quantize == QuantType.W8A16:
56 | if isinstance(weight, torch.Tensor):
57 | linear = FastLinear(weight, bias, is_norm)
58 | else:
59 | try:
60 | qweight, weight_scale, weight_offset = weight
61 | except Exception as err:
62 | logger.error(
63 | "The passed weight is not `w8a16` compatible, loader needs to be updated."
64 | )
65 | raise AssertionError from err
66 | linear = W8A16LinearStatic(
67 | weight=qweight,
68 | weight_scale=weight_scale,
69 | weight_offset=weight_offset,
70 | bias=bias
71 | )
72 | elif quantize == QuantType.W8A8SC:
73 | if isinstance(weight, torch.Tensor):
74 | linear = FastLinear(weight, bias, is_norm)
75 | else:
76 | try:
77 | qweight, deq_scale, quant_bias, input_scale, input_offset, index = weight
78 | except Exception as err:
79 | logger.error(
80 | "The passed weight is not `w8a8sc` compatible, loader needs to be updated."
81 | )
82 | raise AssertionError from err
83 | linear = W8A8SparseCompressedLinear(
84 | weight=qweight,
85 | deq_scale=deq_scale,
86 | input_scale=input_scale,
87 | quant_bias=quant_bias,
88 | input_offset=input_offset,
89 | index=index
90 | )
91 | elif quantize == QuantType.W8A8_DYNAMIC:
92 | if isinstance(weight, torch.Tensor):
93 | linear = FastLinear(weight, bias, is_norm)
94 | else:
95 | try:
96 | qweight, weight_scale, weight_offset = weight
97 | except Exception as err:
98 | logger.error(
99 | "The passed weight is not `w8a8 dynamic` compatible, loader needs to be updated."
100 | )
101 | raise AssertionError from err
102 | linear = W8A8DynamicLinearStatic(
103 | weight=qweight,
104 | weight_scale=weight_scale,
105 | weight_offset=weight_offset,
106 | bias=bias
107 | )
108 | else:
109 | raise AssertionError(
110 | f"Quantization `{quantize}` is not implemented yet. "
111 | f"此类型从权重文件config.json中的`quantize`字段中获取。"
112 | f"若非量化权重,config.json中无需配置此字段;"
113 | f"若为量化权重,当前支持的量化类型为`w4a16`,`w8a16`,`w8a8`,`w8a8s`和`w8a8sc`。"
114 | )
115 |
116 | # 更新Linear metainfo
117 | linear.prefixes = kwargs.get("prefixes", [])
118 | linear.num_linear_before_pack = kwargs.get("num_linear_before_pack", 1)
119 | linear.tensor_parallel_dim = kwargs.get("tensor_parallel_dim", 0)
120 | linear.align_size = kwargs.get("align_size", 1)
121 | return linear
122 |
123 |
124 | class SuperLayer(nn.Module):
125 | def __init__(self, linear):
126 | super().__init__()
127 | self.linear = linear
128 |
129 | def forward(self, input_tensor):
130 | return self.linear.forward(input_tensor)
131 |
132 |
133 | class TensorHead(SuperLayer):
134 | def __init__(self, linear):
135 | super().__init__(linear)
136 |
137 | @staticmethod
138 | def load_weight(config, prefix: str, weights, is_norm=False):
139 | weight = weights.get_whole_tensor(f"{prefix}.weight", dim=0)
140 |
141 | # GPTQ doesn't quantize heads (nor embeddings)
142 | if config.quantize == "gptq":
143 | quantize = None
144 | else:
145 | quantize = config.quantize
146 | return TensorHead(
147 | get_linear(weight, bias=None, quantize=quantize, is_norm=is_norm),
148 | )
149 |
150 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
151 | # default all out of bounds values to `self.null_idx` that will then be mapped to 0
152 | # translate for [0, self.max_id - self.min_id[
153 | out = torch.mm(input_tensor, self.linear.weight.T)
154 | return out
155 |
156 |
157 | class TensorParallelHead(SuperLayer):
158 | def __init__(self, linear, process_group, should_gather: bool):
159 | super().__init__(linear)
160 | self.process_group = process_group
161 | self.should_gather = should_gather
162 |
163 | @staticmethod
164 | def load_weight(config, prefix: str, weights, is_norm=False):
165 | weight = weights.get_tensor(f"{prefix}.weight")
166 | should_gather = False
167 | # GPTQ doesn't quantize heads (nor embeddings)
168 | quantize = None if config.quantize == "gptq" else config.quantize
169 | return TensorParallelHead(
170 | get_linear(weight, bias=None, quantize=quantize, is_norm=is_norm),
171 | process_group=weights.process_group,
172 | should_gather=should_gather,
173 | )
174 |
175 | @staticmethod
176 | def load(config, prefix: str, weights, is_norm=False):
177 | should_gather = True
178 | if weights.process_group.size() > 1:
179 | try:
180 | weight = weights.get_sharded(f"{prefix}.weight", dim=0)
181 | except AssertionError:
182 | # If the vocab size is not divisible by number of shards
183 | # just load the entire thing.
184 | weight = weights.get_tensor(f"{prefix}.weight")
185 | should_gather = False
186 | else:
187 | weight = weights.get_tensor(f"{prefix}.weight")
188 | should_gather = False
189 |
190 | # GPTQ doesn't quantize heads (nor embeddings)
191 | quantize = None if config.quantize == "gptq" else config.quantize
192 | return TensorParallelHead(
193 | get_linear(weight, bias=None, quantize=quantize, is_norm=is_norm),
194 | process_group=weights.process_group,
195 | should_gather=should_gather,
196 | )
197 |
198 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
199 | if not self.should_gather:
200 | return super().forward(input_tensor)
201 |
202 | world_size = self.process_group.size()
203 | if len(input_tensor.shape) == 2 and isinstance(self.linear, FastLinear):
204 | out_dim = self.linear.weight.shape[0]
205 | if input_tensor.shape[0] == 1:
206 | world_out = input_tensor.new_empty(1, out_dim * world_size)
207 | local_out = input_tensor.new_empty(1, out_dim)
208 | gather_input = local_out
209 | else:
210 | world_out = input_tensor.new_empty(out_dim * world_size, input_tensor.shape[0])
211 | gather_input = input_tensor.new_empty(out_dim, input_tensor.shape[0])
212 | local_out = gather_input.T
213 |
214 | torch.mm(input_tensor, self.linear.weight.T, out=local_out)
215 | torch.distributed.all_gather_into_tensor(
216 | world_out, gather_input, group=self.process_group
217 | )
218 |
219 | if input_tensor.shape[0] == 1:
220 | return world_out
221 | return world_out.T
222 |
223 | output = super().forward(input_tensor)
224 | world_output = [
225 | torch.empty_like(output)
226 | for _ in range(self.process_group.size())
227 | ]
228 | torch.distributed.all_gather(world_output, output, group=self.process_group)
229 | world_output = torch.cat(world_output, dim=-1)
230 | return world_output
231 |
232 |
233 | class TensorParallelColumnLinear(SuperLayer):
234 | @classmethod
235 | def load_qkv(cls, config, prefix: str, weights, bias: bool, hidden_size, num_heads, num_kv_heads=None, dim=0, padding=False):
236 | """Specific method when the QKV was joined after the fact"""
237 | if num_kv_heads is None:
238 | num_kv_heads = num_heads
239 | weight = weights.get_weights_col_packed_qkv(
240 | prefix, quantize=config.quantize, hidden_size=hidden_size,
241 | num_heads=num_heads, num_kv_heads=num_kv_heads, dim=dim,padding=padding
242 | )
243 | if bias:
244 | bias = weights.get_tensor_col_packed_qkv(
245 | f"{prefix}.bias", hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads
246 | )
247 | else:
248 | bias = None
249 | linear = get_linear(
250 | weight, bias, config.quantize, prefixes=[prefix], num_linear_before_pack=3,
251 | tensor_parallel_dim=dim, align_size=hidden_size // num_heads,
252 | )
253 | return cls(linear)
254 |
255 | @classmethod
256 | def load_gate_up(cls, config, prefix: str, weights, bias: bool):
257 | """Specific method when the QKV was joined after the fact"""
258 | weight = weights.get_weights_col_packed_mlp(
259 | prefix, quantize=config.quantize
260 | )
261 | if bias:
262 | bias = weights.get_tensor_col_packed_mlp(f"{prefix}.bias")
263 | else:
264 | bias = None
265 | linear = get_linear(
266 | weight, bias, config.quantize, prefixes=[prefix], num_linear_before_pack=2,
267 | tensor_parallel_dim=1, align_size=1
268 | )
269 | return cls(linear)
270 |
271 | @classmethod
272 | def load(cls, config, prefix: str, weights, bias: bool, dim=0):
273 | return cls.load_multi(config, [prefix], weights, bias, dim=dim)
274 |
275 | @classmethod
276 | def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
277 | weight = weights.get_multi_weights_col(
278 | prefixes, quantize=config.quantize, dim=dim
279 | )
280 |
281 | if bias:
282 | if config.quantize == QuantType.W8A8SC:
283 | b = [weights.get_tensor(f"{p}.bias") for p in prefixes]
284 | else:
285 | b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
286 | bias = torch.cat(b, dim=0)
287 | else:
288 | bias = None
289 | linear = get_linear(
290 | weight, bias, config.quantize, prefixes=prefixes, num_linear_before_pack=len(prefixes),
291 | tensor_parallel_dim=dim, align_size=1
292 | )
293 | return cls(linear)
294 |
295 | @classmethod
296 | def load_o(cls, config, prefix: str, weights, bias: bool, hidden_size, num_heads, num_kv_heads=None):
297 | """Specific method when the QKV was joined after the fact"""
298 | if num_kv_heads is None:
299 | num_kv_heads = num_heads
300 | weight = weights.get_weights_col_packed_o(
301 | prefix, quantize=config.quantize, hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads
302 | )
303 | if bias:
304 | bias = weights.get_tensor_col_packed_o(
305 | f"{prefix}.bias", hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads
306 | )
307 | else:
308 | bias = None
309 | linear = get_linear(
310 | weight, bias, config.quantize, prefixes=[prefix], num_linear_before_pack=3,
311 | tensor_parallel_dim=0, align_size=hidden_size // num_heads,
312 | )
313 | return cls(linear)
314 |
315 | @classmethod
316 | def load_column_multi_c(cls, config, prefixes: List[str], weights, hidden_size, num_heads, num_kv_heads=None):
317 | """Specific method when the QKV was joined after the fact"""
318 | if num_kv_heads is None:
319 | num_kv_heads = num_heads
320 | weight_q = weights.get_weights_col_packed_q(
321 | prefixes[0], quantize=config.quantize, hidden_size=hidden_size, num_heads=num_heads,
322 | num_kv_heads=num_kv_heads)
323 | bias_q = weights.get_tensor_col_packed_q(
324 | f"{prefixes[0]}.bias", hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads
325 | )
326 | weight_k = weights.get_weights_col_packed_k(
327 | prefixes[1], quantize=config.quantize, hidden_size=hidden_size, num_heads=num_heads,
328 | num_kv_heads=num_kv_heads)
329 | bias_k = weights.get_tensor_col_packed_k(
330 | f"{prefixes[1]}.bias", hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads
331 | )
332 | weight_v = weights.get_weights_col_packed_v(
333 | prefixes[2], quantize=config.quantize, hidden_size=hidden_size, num_heads=num_heads,
334 | num_kv_heads=num_kv_heads)
335 | bias_v = weights.get_tensor_col_packed_v(
336 | f"{prefixes[2]}.bias", hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads
337 | )
338 | weight = torch.cat([weight_q, weight_k, weight_v], dim=0)
339 | bias = torch.cat([bias_q, bias_k, bias_v], dim=0)
340 | linear = get_linear(
341 | weight, bias, config.quantize, prefixes=prefixes, num_linear_before_pack=len(prefixes),
342 | tensor_parallel_dim=0, align_size=hidden_size // num_heads
343 | )
344 | return cls(linear)
345 |
346 |
347 | class TensorParallelRowLinear(SuperLayer):
348 | def __init__(self, linear, process_group):
349 | super().__init__(linear)
350 | self.process_group = process_group
351 |
352 | @classmethod
353 | def load(cls, config, prefix: str, weights, bias: bool, bias_pre_add=False, gqa_size=1, dim=1):
354 | weight = weights.get_multi_weights_row(prefix, quantize=config.quantize, gqa_size=gqa_size, dim=dim)
355 | if bias and bias_pre_add:
356 | bias = weights.get_tensor(f"{prefix}.bias")
357 | elif bias and weights.process_group.rank() == 0:
358 | # Rank is only on the first rank process
359 | bias = weights.get_tensor(f"{prefix}.bias")
360 | else:
361 | bias = None
362 | return cls(
363 | get_linear(
364 | weight, bias, config.quantize, prefixes=[prefix],
365 | tensor_parallel_dim=dim, align_size=gqa_size
366 | ),
367 | process_group=weights.process_group,
368 | )
369 |
370 | @classmethod
371 | def load_att_dense(cls, config, prefix: str, weights, bias: bool, bias_pre_add=False, gqa_size=1, dim=1):
372 | weight = weights.get_multi_weights_row_att_dense(prefix, quantize=config.quantize, gqa_size=gqa_size, dim=dim)
373 | if bias and bias_pre_add:
374 | bias = weights.get_tensor(f"{prefix}.bias")
375 | elif bias and weights.process_group.rank() == 0:
376 | # Rank is only on the first rank process
377 | bias = weights.get_tensor(f"{prefix}.bias")
378 | else:
379 | bias = None
380 | return cls(
381 | get_linear(
382 | weight, bias, config.quantize, prefixes=[prefix],
383 | tensor_parallel_dim=dim, align_size=gqa_size
384 | ),
385 | process_group=weights.process_group,
386 | )
387 |
388 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
389 | out = super().forward(input_tensor)
390 | if self.process_group.size() > 1:
391 | torch.distributed.all_reduce(out, group=self.process_group)
392 | return out
393 |
394 |
395 | class TensorReplicatedLinear(SuperLayer):
396 | def __init__(self, linear):
397 | super().__init__(linear)
398 |
399 | @classmethod
400 | def load(cls, config, prefix: str, weights, bias: bool):
401 | weight = weights.get_replicated_weights(prefix, quantize=config.quantize)
402 | if bias :
403 | bias = weights.get_tensor(f"{prefix}.bias")
404 | else:
405 | bias = None
406 |
407 | return cls(get_linear(weight, bias, config.quantize, prefixes=[prefix]))
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/flash_causal_deepseek.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | from typing import Optional, List, Tuple
4 |
5 | import torch
6 | import torch_npu
7 | from atb_llm.models.base.flash_causal_lm import FlashForCausalLM
8 | from atb_llm.models.deepseek.config_deepseek import DeepseekConfig
9 | from atb_llm.models.deepseek.modeling_deepseek import FlashDeepseekModel
10 | from atb_llm.utils.data.weight_wrapper import AttnWrapper
11 | from atb_llm.utils.data.moe_weight_wrapper import MoeMlpWrapper, MoeWeightWrapper
12 | from atb_llm.utils.env import ENV
13 | from atb_llm.utils.log import logger
14 | from atb_llm.utils.log.error_code import ErrorCode
15 |
16 | from atb_llm.utils.layers import (
17 | TensorEmbedding,
18 | load_column_multi,
19 | )
20 | from atb_llm.utils.layers.norm.fast_layer_norm import NormType
21 |
22 | _COMMON_EXPERTS_NUM = 64
23 | _DECODER_MODEL = "deepseek_DecoderModel"
24 | _SUPPORT_LCOC = "enableLcoc"
25 | _SUPPORT_SPECULATE = "enableSpeculate"
26 | _IS_PREFILL = "isPrefill"
27 |
28 |
29 | class FlashDeepseekForCausalLM(FlashForCausalLM):
30 | def __init__(self, config, weights, **kwargs):
31 | self.acl_encoder_operation = None
32 | self.acl_decoder_operation = None
33 | self.acl_decoder_regression_operation = None
34 | super().__init__(config, weights, **kwargs)
35 | self.model = FlashDeepseekModel(config, weights)
36 | self.config = config
37 | self.lm_head = load_column_multi(
38 | config,
39 | prefixes=["lm_head"],
40 | weights=weights,
41 | head_size=1,
42 | lm_head=True,
43 | norm_head=True,
44 | )
45 | self.config = config
46 | self.in_tensor_length = 16
47 | self.acl_encoder_operation_inputs = []
48 | self.acl_decoder_operation_inputs = []
49 | self.ascend_kcache_id = None
50 | self.ascend_vcache_id = None
51 |
52 | self.placeholder = torch.zeros(1, dtype=self.dtype, device=self.device)
53 | self.lm_head_indices_fake = torch.tensor([0], dtype=torch.int64, device=self.device)
54 |
55 | self.transdata_operation = torch.classes.OperationTorch.OperationTorch("TransdataOperation")
56 | self.transdata_param = json.dumps({})
57 | self.transdata_operation.set_param(self.transdata_param)
58 |
59 | self.padding_idx = config.pad_token_id
60 | self.embed_tokens = TensorEmbedding(
61 | prefix="model.word_embeddings", weights=weights
62 | )
63 | self.hidden_dim = config.hidden_size
64 | self.expert_array = []
65 | self.expert_group = torch.tensor([0], dtype=torch.int32).npu()
66 | self.one_hot = torch.tensor([1], dtype=torch.int32).npu()
67 | self.zero_hot = torch.tensor([0], dtype=torch.int32).npu()
68 | self.final_bias = torch.zeros([self.config.num_experts, self.config.hidden_size], dtype=self.dtype).npu()
69 | self.num_of_experts = config.num_experts
70 | self.num_of_selected_experts = [config.num_experts_per_tok]
71 | self.tp = config.tp if config.tp else True # Defaulting the model to tensor parallel
72 | self.first_k_dense_replace = config.first_k_dense_replace if config.first_k_dense_replace else 0
73 | self.n_shared_experts = config.num_shared_experts if config.num_shared_experts else 0
74 | self.norm_topk_prob = config.norm_topk_prob if config.norm_topk_prob else False
75 | if self.tp:
76 | self.expert_parallel_degree = 1
77 | else:
78 | self.expert_parallel_degree = self.tp_world_size
79 | self.ascend_weight = None
80 |
81 | if self.prefix_cache_enable:
82 | self.acl_decoder_regression_operation_inputs = []
83 | self.enable_fused_routing = False if self.soc_info.need_nz else True
84 | if config.model_type == "bailing_moe" and config.num_hidden_layers == 88:
85 | self.num_attention_heads = 7
86 |
87 | # called by super().prepare_inputs_for_ascend
88 | def init_position_rotary_embedding(self,
89 | position_ids: torch.Tensor,
90 | max_seq_len: int):
91 | self.rotary_embedding.update_cos_sin_cache_total(self.dtype, position_ids.device, max_seq_len)
92 | self.cos_embed = self.rotary_embedding.get_cos_cached_total()
93 | self.sin_embed = self.rotary_embedding.get_sin_cached_total()
94 |
95 | def init_ascend_operations(self, config: DeepseekConfig):
96 | self.acl_encoder_operation = torch.classes.ModelTorch.ModelTorch(_DECODER_MODEL)
97 | self.acl_decoder_operation = torch.classes.ModelTorch.ModelTorch(_DECODER_MODEL)
98 | if self.prefix_cache_enable:
99 | self.acl_decoder_regression_operation = torch.classes.ModelTorch.ModelTorch(_DECODER_MODEL)
100 |
101 | def get_weights(self):
102 | attn_wrapper = AttnWrapper(
103 | norm_name='input_layernorm',
104 | wrapper_name='self_attn',
105 | pack_name='query_key_value',
106 | sep_names=['q_proj', 'k_proj', 'v_proj'],
107 | o_name='dense'
108 | )
109 | moe_mlp_wrapper = MoeMlpWrapper(
110 | norm_name='post_attention_layernorm',
111 | router_name='gate',
112 | wrapper_name='mlp',
113 | pack_name='gate_up_proj',
114 | sep_names=['gate_proj', 'up_proj'],
115 | down_name='down_proj',
116 | shared_experts=(self.n_shared_experts > 0)
117 | )
118 | weight_wrapper = MoeWeightWrapper(self.soc_info, self.tp_rank,
119 | attn_wrapper, moe_mlp_wrapper,
120 | self.num_of_experts)
121 | weight_wrapper.register_embedding(self.model.embed_tokens)
122 | for i in range(self.num_layers):
123 | layer = self.model.layers[i]
124 | if i < self.first_k_dense_replace:
125 | weight_wrapper.register_moe_layer(layer, self.quantize, dense_layer=True)
126 | else:
127 | if self.tp:
128 | weight_wrapper.register_moe_layer(layer, self.quantize, dense_layer=False)
129 | del layer.mlp
130 | torch.npu.empty_cache()
131 |
132 | else:
133 | msg = "ERROR: DeepSeek does not support expert parallel!"
134 | logger.error(msg, ErrorCode.ATB_MODELS_EXECUTION_FAILURE)
135 | raise ValueError(msg)
136 | if self.soc_info.need_nz:
137 | del layer.self_attn
138 | del layer.post_attention_layernorm
139 | torch.npu.empty_cache()
140 | weight_wrapper.register_model_norm(self.model.norm)
141 | weight_wrapper.register_model_lmhead(self.lm_head)
142 | return weight_wrapper
143 |
144 | def init_ascend_weight(self):
145 | weight_wrapper = self.get_weights()
146 | self.ascend_weight = weight_wrapper.weights
147 | pack_quant_configs = weight_wrapper.pack_quant_type
148 |
149 | attn_linear_types = weight_wrapper.attn_linear_types
150 | mlp_linear_types = weight_wrapper.mlp_linear_types
151 | moe_linear_types = weight_wrapper.moe_linear_types
152 |
153 | attn_linear_transpose_types = weight_wrapper.attn_linear_transpose_types
154 | mlp_linear_transpose_types = weight_wrapper.mlp_linear_transpose_types
155 | moe_linear_transpose_types = weight_wrapper.moe_linear_transpose_types
156 |
157 | # compatible with linearQuantType
158 | for i in range(self.num_layers):
159 | attn_linear_types[i].append(attn_linear_types[i][-1])
160 | attn_linear_transpose_types[i].append(-1)
161 |
162 | coder_param = {
163 | "normEps": self.config.rms_norm_eps,
164 | "normType": NormType.RMS_NORM,
165 | "numAttentionHeadsPerRank": self.num_attention_heads,
166 | "hiddenSizePerAttentionHead": self.head_size,
167 | "numHiddenLayers": self.config.num_hidden_layers,
168 | "numKeyValueHeadsPerRank": self.num_key_value_heads,
169 | "isUnpadInputs": True,
170 | "isFA": False,
171 | "isBF16": self.dtype == torch.bfloat16,
172 | "packQuantType": pack_quant_configs,
173 | "isEmbeddingParallel": False,
174 | "isLmHeadParallel": True,
175 | "linearQuantType": attn_linear_types,
176 | "mlpLinearQuantType": mlp_linear_types,
177 | "moeLinearQuantType": moe_linear_types,
178 | "linearTransposeType": attn_linear_transpose_types,
179 | "mlpLinearTransposeType": mlp_linear_transpose_types,
180 | "moeLinearTransposeType": moe_linear_transpose_types,
181 | "lmHeadTransposeType": self.lm_head.linear.trans_flag,
182 | "enableSwiGLU": False if self.soc_info.need_nz else False,
183 | 'hasSharedExpert': True if self.n_shared_experts > 0 else False,
184 | 'hasSharedExpertGate': False,
185 | "rank": self.tp_rank,
186 | "expertParallelDegree": self.expert_parallel_degree,
187 | "numOfExperts": self.num_of_experts,
188 | "numOfGroups": 8,
189 | "routingMethod": 'softMaxTopK' if self.soc_info.need_nz else 'integratedSoftmaxTopK',
190 | "processLogits": 'normalization' if self.norm_topk_prob else 'none',
191 | "firstKDenseReplace": self.first_k_dense_replace,
192 | "numOfSharedExperts": self.n_shared_experts,
193 | "numOfSelectedExperts": self.num_of_selected_experts,
194 | "numOfSelectedGroups": 3,
195 | "worldSize": self.tp_world_size,
196 | "backend": self.soc_info.communication_backend,
197 | "rankTableFile": ENV.rank_table_file,
198 | "enableAddNorm": False,
199 | "normHasBias": False,
200 | "enableFusedRouting": self.enable_fused_routing
201 | }
202 | if coder_param["routingMethod"] not in ['softMaxTopK', 'integratedSoftmaxTopK']:
203 | msg = "The routingMethod chosen is not valid, please choose among the following:\n \
204 | 'softMaxTopK': regular routing method with softmax and topk-sort operators\n \
205 | 'integratedSoftmaxTopK': routing method with the integration of softmax and topk-sort operators\n \
206 | 'deviceLimited': device-limited routing method (e.g. deepseekv2); \
207 | invalid for Mixtral MoE and Deepseekv1"
208 | logger.error(msg, ErrorCode.ATB_MODELS_EXECUTION_FAILURE)
209 | raise ValueError(msg)
210 | encoder_param = {
211 | **coder_param, _IS_PREFILL: True, _SUPPORT_LCOC: self.lcoc_enable,
212 | _SUPPORT_SPECULATE: False, "enableSplitFuse": self.split_fuse_enable
213 | }
214 | decoder_param = {
215 | **coder_param, _IS_PREFILL: False, _SUPPORT_LCOC: False,
216 | _SUPPORT_SPECULATE: self.speculate_enable, "enablePrefixCache": self.prefix_cache_enable
217 | }
218 | self.acl_encoder_operation.set_param(json.dumps({**encoder_param}))
219 | self.acl_decoder_operation.set_param(json.dumps({**decoder_param}))
220 | self.acl_encoder_operation.set_weight(self.ascend_weight)
221 | self.acl_decoder_operation.set_weight(self.ascend_weight)
222 |
223 | if self.prefix_cache_enable:
224 | decoder_regression_param = {
225 | **coder_param, _IS_PREFILL: False, _SUPPORT_LCOC: False,
226 | _SUPPORT_SPECULATE: False
227 | }
228 | self.acl_decoder_regression_operation.set_param(json.dumps({**decoder_regression_param}))
229 | self.acl_decoder_regression_operation.set_weight(self.ascend_weight)
230 |
231 | # called by super().forward()
232 | def prepare_inputs_for_ascend(self, input_ids: torch.Tensor,
233 | position_ids: torch.Tensor,
234 | is_prefill: bool,
235 | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
236 | block_tables: torch.Tensor,
237 | slots: torch.Tensor,
238 | input_lengths: torch.Tensor,
239 | max_seq_len: int,
240 | lm_head_indices: Optional[torch.Tensor] = None,
241 | **kwargs):
242 | self.rotary_embedding.update_cos_sin_cache_total(self.dtype,
243 | self.device,
244 | self.max_position_embeddings)
245 | self.cos_embed = self.rotary_embedding.get_cos_cached_total()
246 | self.sin_embed = self.rotary_embedding.get_sin_cached_total()
247 | q_lens = kwargs.get('q_lens', [])
248 | spec_mask = kwargs.get('spec_mask', None)
249 |
250 | input_length = len(input_ids)
251 | self.expert_array = self.placeholder
252 | if not self.enable_fused_routing:
253 | self.expert_array = torch.arange(self.config.num_experts_per_tok * input_length,
254 | dtype=torch.int32, device=input_ids.device)
255 |
256 | if lm_head_indices is None:
257 | lm_head_indices = torch.tensor(range(input_ids.shape[0]),
258 | dtype=torch.int64, device=input_ids.device)
259 | if is_prefill:
260 | if self.soc_info.need_nz:
261 | pad_maxs = math.ceil(self.max_position_embeddings / 16) * 16
262 | atten_mask = self.attn_mask.get_attn_mask(pad_maxs, kv_cache[0][0].dtype,
263 | kv_cache[0][0].device)
264 | atten_mask = self.transdata_operation.execute([atten_mask])[0]
265 | else:
266 | atten_mask = self.attn_mask.get_attn_mask(max_seq_len if self.split_fuse_enable else self.max_base_len,
267 | self.dtype, self.device)
268 | if self.split_fuse_enable and self.dtype == torch.bfloat16:
269 | atten_mask = atten_mask * -10000.0
270 | self.acl_param = json.dumps({
271 | "seqLen": input_lengths.tolist(),
272 | "qLen": q_lens
273 | })
274 | self.acl_encoder_operation_inputs = [
275 | input_ids,
276 | position_ids.to(torch.int64),
277 | self.cos_embed,
278 | self.sin_embed,
279 | torch.where(atten_mask == -torch.inf, 1, atten_mask) if self.dtype == torch.bfloat16 else atten_mask,
280 | block_tables.to(torch.int32),
281 | slots.to(torch.int32),
282 | self.placeholder,
283 | self.placeholder,
284 | self.placeholder,
285 | input_lengths.to(torch.int32),
286 | lm_head_indices.to(torch.int64),
287 | self.expert_array,
288 | self.expert_group,
289 | self.one_hot,
290 | self.zero_hot
291 | ]
292 |
293 | if self.split_fuse_enable:
294 | self.acl_encoder_operation_inputs.append(torch.tensor(q_lens).to(self.device).to(torch.int32))
295 |
296 | return self.acl_encoder_operation_inputs, self.acl_param
297 | else:
298 | use_regression = False
299 | if self.prefix_cache_enable and q_lens == []:
300 | use_regression = True
301 | q_lens = []
302 |
303 | self.acl_param = json.dumps({
304 | "seqLen": input_lengths.tolist(),
305 | "qLen": q_lens
306 | })
307 | if self.prefix_cache_enable and use_regression:
308 | self.acl_decoder_regression_operation_inputs = [
309 | input_ids,
310 | position_ids.to(torch.int64),
311 | self.cos_embed,
312 | self.sin_embed,
313 | self.attn_mask_fake,
314 | block_tables.to(torch.int32),
315 | slots.to(torch.int32),
316 | self.placeholder,
317 | self.placeholder,
318 | self.placeholder,
319 | input_lengths.to(torch.int32),
320 | self.lm_head_indices_fake,
321 | self.expert_array,
322 | self.expert_group,
323 | self.one_hot,
324 | self.zero_hot,
325 | ]
326 | return self.acl_decoder_regression_operation_inputs, self.acl_param
327 |
328 | self.acl_decoder_operation_inputs = [
329 | input_ids,
330 | position_ids.to(torch.int64),
331 | self.cos_embed,
332 | self.sin_embed,
333 | spec_mask if self.speculate_enable or self.prefix_cache_enable else self.attn_mask_fake,
334 | block_tables.to(torch.int32),
335 | slots.to(torch.int32),
336 | self.placeholder,
337 | self.placeholder,
338 | self.placeholder,
339 | input_lengths.to(torch.int32),
340 | lm_head_indices.to(torch.int64) if self.prefix_cache_enable else self.lm_head_indices_fake,
341 | self.expert_array,
342 | self.expert_group,
343 | self.one_hot,
344 | self.zero_hot
345 | ]
346 |
347 | if self.split_fuse_enable or self.speculate_enable or self.prefix_cache_enable:
348 | self.acl_decoder_operation_inputs.append(torch.tensor(q_lens).to(self.device).to(torch.int32))
349 |
350 | return self.acl_decoder_operation_inputs, self.acl_param
351 |
352 | def execute_ascend_operator(self,
353 | acl_inputs,
354 | acl_param,
355 | is_prefill):
356 | if is_prefill:
357 | acl_model_out = self.acl_encoder_operation.execute(acl_inputs, acl_param)
358 | else:
359 | acl_param_dict = json.loads(acl_param)
360 | if self.prefix_cache_enable and acl_param_dict["qLen"] == []:
361 | model_operation = self.acl_decoder_regression_operation
362 | else:
363 | model_operation = self.acl_decoder_operation
364 | acl_model_out = model_operation.execute(acl_inputs, acl_param)
365 | try:
366 | acl_hidden_state = acl_model_out[0]
367 | except IndexError as e:
368 | msg = "Runtime Error, please refer to the logs for more info"
369 | logger.error(msg, ErrorCode.ATB_MODELS_EXECUTION_FAILURE)
370 | raise RuntimeError(msg) from e
371 | return acl_hidden_state
372 |
373 | def init_kvcache(self, kv_cache):
374 | kcache_id = not self.ascend_kcache_id or self.ascend_kcache_id != id(kv_cache[0][0])
375 | vcache_id = not self.ascend_vcache_id or self.ascend_vcache_id != id(kv_cache[0][1])
376 | if kcache_id or vcache_id:
377 | k_caches, v_caches = map(lambda x: list(x), zip(*kv_cache))
378 | if self.soc_info.need_nz:
379 | k_caches = [torch_npu.npu_format_cast_(k_cache, 29) for k_cache in k_caches]
380 | v_caches = [torch_npu.npu_format_cast_(v_cache, 29) for v_cache in v_caches]
381 | self.acl_encoder_operation.set_kv_cache(k_caches, v_caches)
382 | self.acl_decoder_operation.set_kv_cache(k_caches, v_caches)
383 | if self.prefix_cache_enable:
384 | self.acl_decoder_regression_operation.set_kv_cache(k_caches, v_caches)
385 | self.ascend_kcache_id = id(kv_cache[0][0])
386 | self.ascend_vcache_id = id(kv_cache[0][1])
387 |
388 | if __name__ == "__main__":
389 | test_config = DeepseekConfig()
390 | test_weights = None
391 | model = FlashDeepseekForCausalLM(test_config, test_weights)
--------------------------------------------------------------------------------
/inference/vllm/bailing_moe.patch:
--------------------------------------------------------------------------------
1 | From 2e4640391b87ad5489e3383972f703d4f2814bb3 Mon Sep 17 00:00:00 2001
2 | From: "serina.wzq@antgroup.com"
3 | Date: Thu, 27 Feb 2025 11:29:42 +0800
4 | Subject: [PATCH] support BailingMoeForCausalLM
5 |
6 | ---
7 | vllm/model_executor/models/bailing_moe.py | 535 ++++++++++++++++++
8 | vllm/model_executor/models/registry.py | 1 +
9 | vllm/transformers_utils/configs/__init__.py | 3 +
10 | .../transformers_utils/configs/bailing_moe.py | 76 +++
11 | 4 files changed, 615 insertions(+)
12 | create mode 100644 vllm/model_executor/models/bailing_moe.py
13 | create mode 100644 vllm/transformers_utils/configs/bailing_moe.py
14 |
15 | diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py
16 | new file mode 100644
17 | index 000000000..774580c94
18 | --- /dev/null
19 | +++ b/vllm/model_executor/models/bailing_moe.py
20 | @@ -0,0 +1,535 @@
21 | +# coding=utf-8
22 | +""" PyTorch Bailing model. """
23 | +
24 | +from typing import Iterable, List, Optional, Tuple, Union, Set
25 | +
26 | +import torch
27 | +from torch import nn
28 | +
29 | +from vllm.model_executor.layers.activation import get_act_fn, SiluAndMul
30 | +from vllm.attention import Attention, AttentionMetadata
31 | +from vllm.config import CacheConfig, VllmConfig
32 | +from vllm.model_executor.layers.fused_moe import fused_moe, FusedMoE
33 | +from vllm.model_executor.layers.layernorm import RMSNorm
34 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear,
35 | + MergedColumnParallelLinear,
36 | + ReplicatedLinear,
37 | + QKVParallelLinear,
38 | + RowParallelLinear)
39 | +from vllm.model_executor.layers.quantization.base_config import (
40 | + QuantizationConfig)
41 | +from vllm.model_executor.layers.rotary_embedding import get_rope
42 | +from vllm.model_executor.layers.sampler import Sampler
43 | +from vllm.model_executor.layers.vocab_parallel_embedding import (
44 | + ParallelLMHead, VocabParallelEmbedding)
45 | +from vllm.distributed import (get_pp_group,
46 | + get_tensor_model_parallel_rank,
47 | + get_tensor_model_parallel_world_size,
48 | + tensor_model_parallel_all_reduce)
49 | +from vllm.model_executor.sampling_metadata import SamplingMetadata
50 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51 | +from vllm.model_executor.utils import set_weight_attrs
52 | +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
53 | +from vllm.sequence import IntermediateTensors
54 | +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig
55 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor
56 | +from vllm.config import LoRAConfig
57 | +
58 | +from .interfaces import SupportsLoRA, SupportsPP
59 | +from .utils import (PPMissingLayer,
60 | + is_pp_missing_parameter,
61 | + make_empty_intermediate_tensors_factory,
62 | + make_layers,
63 | + maybe_prefix)
64 | +
65 | +KVCache = Tuple[torch.Tensor, torch.Tensor]
66 | +
67 | +
68 | +class BailingAttention(nn.Module):
69 | +
70 | + def __init__(
71 | + self,
72 | + config: BailingMoeConfig,
73 | + cache_config: Optional[CacheConfig] = None,
74 | + quant_config: Optional[QuantizationConfig] = None,
75 | + prefix: str = "",
76 | + ):
77 | + super().__init__()
78 | + self.hidden_size = config.hidden_size
79 | + self.total_num_heads = config.num_attention_heads
80 | + self.total_kv_heads = config.num_key_value_heads
81 | + tp_size = get_tensor_model_parallel_world_size()
82 | +
83 | + assert self.total_num_heads % tp_size == 0
84 | + assert self.total_kv_heads % tp_size == 0
85 | + assert self.total_num_heads >= self.total_kv_heads
86 | +
87 | + self.num_heads = self.total_num_heads // tp_size
88 | + self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
89 | + self.q_size_per_rank = self.head_dim * self.num_heads
90 | +
91 | + self.num_kv_heads = self.total_kv_heads // tp_size
92 | + self.kv_size_per_rank = self.num_kv_heads * self.head_dim
93 | +
94 | + self.scale = self.head_dim ** -0.5
95 | +
96 | + self.query_key_value = QKVParallelLinear(
97 | + self.hidden_size,
98 | + self.head_dim,
99 | + self.total_num_heads,
100 | + self.total_kv_heads,
101 | + bias=(config.use_bias or config.use_qkv_bias),
102 | + quant_config=quant_config,
103 | + prefix=f"{prefix}.query_key_value",
104 | + )
105 | +
106 | + self.dense = RowParallelLinear(self.total_num_heads * self.head_dim,
107 | + self.hidden_size,
108 | + bias=config.use_bias,
109 | + quant_config=quant_config,
110 | + prefix=f"{prefix}.dense",)
111 | +
112 | + self.attn = Attention(self.num_heads,
113 | + self.head_dim,
114 | + self.scale,
115 | + num_kv_heads=self.num_kv_heads,
116 | + cache_config=cache_config,
117 | + prefix=f"{prefix}.attn")
118 | +
119 | +
120 | + self.rotary_emb = get_rope(
121 | + self.head_dim,
122 | + rotary_dim=self.head_dim,
123 | + max_position=config.max_position_embeddings,
124 | + base=config.rope_theta,
125 | + is_neox_style=True,
126 | + rope_scaling=config.rope_scaling,
127 | + )
128 | +
129 | + def forward(
130 | + self,
131 | + hidden_states: torch.Tensor,
132 | + position_ids: torch.Tensor,
133 | + kv_cache: KVCache,
134 | + attn_metadata: AttentionMetadata,
135 | + ) -> torch.Tensor:
136 | +
137 | + qkv, _ = self.query_key_value(hidden_states)
138 | + q, k, v = qkv.split(
139 | + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank],
140 | + dim=-1
141 | + )
142 | +
143 | +
144 | + q, k = self.rotary_emb(position_ids, q, k)
145 | +
146 | + context_layer = self.attn(
147 | + q,
148 | + k,
149 | + v,
150 | + kv_cache,
151 | + attn_metadata,
152 | + )
153 | +
154 | + attn_output, _ = self.dense(context_layer)
155 | + return attn_output
156 | +
157 | +
158 | +class BailingMLP(nn.Module):
159 | +
160 | + def __init__(
161 | + self,
162 | + intermediate_size: int,
163 | + config: BailingMoeConfig,
164 | + quant_config: Optional[QuantizationConfig] = None,
165 | + reduce_results: Optional[bool] = True,
166 | + prefix: str = "",
167 | + ) -> None:
168 | + super().__init__()
169 | + self.gate_up_proj = MergedColumnParallelLinear(
170 | + config.hidden_size, [intermediate_size] * 2,
171 | + bias=config.use_bias,
172 | + quant_config=quant_config,
173 | + prefix=f"{prefix}.gate_up_proj",
174 | + )
175 | + self.down_proj = RowParallelLinear(
176 | + intermediate_size,
177 | + config.hidden_size,
178 | + bias=config.use_bias,
179 | + quant_config=quant_config,
180 | + reduce_results=reduce_results,
181 | + prefix=f"{prefix}.down_proj",
182 | + )
183 | + self.act_fn = SiluAndMul()
184 | +
185 | + def forward(self, x):
186 | + x, _ = self.gate_up_proj(x)
187 | + x = self.act_fn(x)
188 | + x, _ = self.down_proj(x)
189 | + return x
190 | +
191 | +class BailingMoE(nn.Module):
192 | +
193 | + def __init__(
194 | + self,
195 | + intermediate_size: int,
196 | + config: BailingMoeConfig,
197 | + quant_config: Optional[QuantizationConfig] = None,
198 | + reduce_results: Optional[bool] = True,
199 | + prefix: str = "",
200 | + ):
201 | + super().__init__()
202 | +
203 | + self.tp_size = get_tensor_model_parallel_world_size()
204 | + self.tp_rank = get_tensor_model_parallel_rank()
205 | + self.num_experts = config.num_experts
206 | + self.top_k = config.num_experts_per_tok
207 | + self.norm_expert_prob = config.norm_topk_prob
208 | + self.hidden_size = config.hidden_size
209 | + self.quant_config = quant_config
210 | + self.num_shared_experts = config.num_shared_experts
211 | + # Gate always runs at half / full precision for now.
212 | + self.gate = ReplicatedLinear(self.hidden_size,
213 | + self.num_experts,
214 | + bias=False,
215 | + quant_config=None)
216 | +
217 | + self.experts = FusedMoE(
218 | + num_experts=self.num_experts,
219 | + top_k=self.top_k,
220 | + hidden_size=self.hidden_size,
221 | + intermediate_size=config.moe_intermediate_size,
222 | + reduce_results=False,
223 | + renormalize=self.norm_expert_prob,
224 | + quant_config=quant_config,
225 | + prefix=f"{prefix}.experts"
226 | + )
227 | +
228 | + if self.num_shared_experts > 0:
229 | + intermediate_size = (config.moe_intermediate_size *
230 | + self.num_shared_experts)
231 | + self.shared_experts = BailingMLP(
232 | + intermediate_size=intermediate_size,
233 | + config=config,
234 | + quant_config=quant_config,
235 | + reduce_results=False,
236 | + prefix=f"{prefix}.shared_experts"
237 | + )
238 | +
239 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
240 | + num_tokens, hidden_size = hidden_states.shape
241 | + hidden_states = hidden_states.view(-1, hidden_size)
242 | + if self.num_shared_experts > 0:
243 | + shared_output = self.shared_experts(hidden_states)
244 | + # router_logits: (num_tokens, n_experts)
245 | + router_logits, _ = self.gate(hidden_states)
246 | + final_hidden_states = self.experts(
247 | + hidden_states=hidden_states, router_logits=router_logits
248 | + )
249 | +
250 | + if self.num_shared_experts > 0:
251 | + final_hidden_states = final_hidden_states + shared_output
252 | +
253 | + if self.tp_size > 1:
254 | + final_hidden_states = tensor_model_parallel_all_reduce(
255 | + final_hidden_states)
256 | + return final_hidden_states.view(num_tokens, hidden_size)
257 | +
258 | +class BailingMoeBlock(nn.Module):
259 | +
260 | + def __init__(
261 | + self,
262 | + config: BailingMoeConfig,
263 | + cache_config: Optional[CacheConfig] = None,
264 | + quant_config: Optional[QuantizationConfig] = None,
265 | + prefix: str = "",
266 | + ):
267 | + super().__init__()
268 | + hidden_size = config.hidden_size
269 | + intermediate_size = config.intermediate_size
270 | + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
271 | + self.attention = BailingAttention(config,
272 | + cache_config,
273 | + quant_config,
274 | + prefix=f"{prefix}.attention")
275 | + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
276 | + self.mlp = BailingMoE(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp")
277 | +
278 | + def forward(
279 | + self,
280 | + hidden_states: torch.Tensor,
281 | + position_ids: torch.Tensor,
282 | + kv_cache: KVCache,
283 | + attn_metadata: AttentionMetadata,
284 | + residual: Optional[torch.Tensor],
285 | + ) -> torch.Tensor:
286 | + if residual is None:
287 | + residual = hidden_states
288 | + hidden_states = self.input_layernorm(hidden_states)
289 | + else:
290 | + hidden_states, residual = self.input_layernorm(
291 | + hidden_states, residual)
292 | +
293 | + hidden_states = self.attention(
294 | + hidden_states=hidden_states,
295 | + position_ids=position_ids,
296 | + kv_cache=kv_cache,
297 | + attn_metadata=attn_metadata
298 | + )
299 | +
300 | + hidden_states, residual = self.post_attention_layernorm(
301 | + hidden_states, residual)
302 | + hidden_states = self.mlp(hidden_states)
303 | + return hidden_states, residual
304 | +
305 | +
306 | +class BailingMoeModel(nn.Module):
307 | +
308 | + def __init__(
309 | + self,
310 | + *,
311 | + vllm_config: VllmConfig,
312 | + prefix: str = "",
313 | + ):
314 | + super().__init__()
315 | + config = vllm_config.model_config.hf_config
316 | + cache_config = vllm_config.cache_config
317 | + quant_config = vllm_config.quant_config
318 | +
319 | + self.config = config
320 | + self.vocab_size = config.vocab_size
321 | + self.embed_dim = config.hidden_size
322 | +
323 | + if get_pp_group().is_first_rank or (config.tie_word_embeddings
324 | + and get_pp_group().is_last_rank):
325 | + self.word_embeddings = VocabParallelEmbedding(self.vocab_size, self.embed_dim)
326 | + else:
327 | + self.word_embeddings = PPMissingLayer()
328 | +
329 | + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
330 | +
331 | + self.start_layer, self.end_layer, self.layers = make_layers(
332 | + config.num_hidden_layers,
333 | + lambda prefix: BailingMoeBlock(
334 | + config=config,
335 | + cache_config=cache_config,
336 | + quant_config=quant_config,
337 | + prefix=prefix,
338 | + ),
339 | + prefix=f"{prefix}.layers"
340 | + )
341 | +
342 | + self.make_empty_intermediate_tensors = (
343 | + make_empty_intermediate_tensors_factory(
344 | + ["hidden_states", "residual"], config.hidden_size
345 | + )
346 | + )
347 | +
348 | + if get_pp_group().is_last_rank:
349 | + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
350 | + else:
351 | + self.norm = PPMissingLayer()
352 | +
353 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
354 | + return self.word_embeddings(input_ids)
355 | +
356 | + def forward(
357 | + self,
358 | + input_ids: torch.Tensor,
359 | + position_ids: torch.Tensor,
360 | + kv_caches: List[KVCache],
361 | + attn_metadata: AttentionMetadata,
362 | + intermediate_tensors: Optional[IntermediateTensors],
363 | + inputs_embeds: Optional[torch.Tensor] = None,
364 | + ) -> Union[torch.Tensor, IntermediateTensors]:
365 | + if get_pp_group().is_first_rank:
366 | + if inputs_embeds is not None:
367 | + hidden_states = inputs_embeds
368 | + else:
369 | + hidden_states = self.get_input_embeddings(input_ids)
370 | + residual = None
371 | + else:
372 | + assert intermediate_tensors is not None
373 | + hidden_states = intermediate_tensors["hidden_states"]
374 | + residual = intermediate_tensors["residual"]
375 | +
376 | + for i in range(self.start_layer, self.end_layer):
377 | + layer = self.layers[i]
378 | + hidden_states, residual = layer(
379 | + hidden_states,
380 | + position_ids,
381 | + kv_caches[i - self.start_layer],
382 | + attn_metadata,
383 | + residual
384 | + )
385 | +
386 | + if not get_pp_group().is_last_rank:
387 | + return IntermediateTensors({
388 | + "hidden_states": hidden_states,
389 | + "residual": residual
390 | + })
391 | +
392 | + hidden_states, _ = self.norm(hidden_states, residual)
393 | + return hidden_states
394 | +
395 | +
396 | +class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
397 | +
398 | + packed_modules_mapping = {
399 | + "query_key_value": ["query_key_value"],
400 | + "dense_h_to_4h": ["dense_h_to_4h"],
401 | + "gate_up_proj": [
402 | + "gate_proj",
403 | + "up_proj",
404 | + ],
405 | + }
406 | +
407 | + # LoRA specific attributes
408 | + supported_lora_modules = [
409 | + "query_key_value",
410 | + "dense",
411 | + "dense_h_to_4h",
412 | + "dense_4h_to_h",
413 | + "gate_up_proj",
414 | + "down_proj",
415 | + ]
416 | + embedding_modules = {}
417 | + embedding_padding_modules = []
418 | +
419 | + def __init__(
420 | + self,
421 | + *,
422 | + vllm_config: VllmConfig,
423 | + prefix: str = "",
424 | + ) -> None:
425 | + super().__init__()
426 | +
427 | + config = vllm_config.model_config.hf_config
428 | + quant_config = vllm_config.quant_config
429 | + lora_config = vllm_config.lora_config
430 | +
431 | + self.config = config
432 | + self.lora_config = lora_config
433 | + self.quant_config = quant_config
434 | + self.max_position_embeddings = config.max_position_embeddings
435 | + self.model = BailingMoeModel(
436 | + vllm_config=vllm_config,
437 | + prefix=maybe_prefix(prefix, "model")
438 | + )
439 | + if get_pp_group().is_last_rank:
440 | + self.lm_head = self.word_embeddings if config.tie_word_embeddings \
441 | + else ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config)
442 | + self.logits_processor = LogitsProcessor(config.vocab_size)
443 | + else:
444 | + self.lm_head = PPMissingLayer()
445 | +
446 | + self.sampler = get_sampler()
447 | + self.make_empty_intermediate_tensors = (
448 | + self.model.make_empty_intermediate_tensors
449 | + )
450 | +
451 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
452 | + return self.model.get_input_embeddings(input_ids)
453 | +
454 | + def forward(
455 | + self,
456 | + input_ids: torch.Tensor,
457 | + positions: torch.Tensor,
458 | + kv_caches: List[KVCache],
459 | + attn_metadata: AttentionMetadata,
460 | + intermediate_tensors: Optional[IntermediateTensors] = None,
461 | + inputs_embeds: Optional[torch.Tensor] = None,
462 | + ) -> Union[torch.Tensor, IntermediateTensors]:
463 | + model_output = self.model(input_ids, positions, kv_caches,
464 | + attn_metadata, intermediate_tensors,
465 | + inputs_embeds)
466 | + return model_output
467 | +
468 | + def compute_logits(
469 | + self,
470 | + hidden_states: torch.Tensor,
471 | + sampling_metadata: SamplingMetadata,
472 | + ) -> Optional[torch.Tensor]:
473 | + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
474 | + return logits
475 | +
476 | + def sample(
477 | + self,
478 | + logits: torch.Tensor,
479 | + sampling_metadata: SamplingMetadata,
480 | + ) -> Optional[SamplerOutput]:
481 | + next_tokens = self.sampler(logits, sampling_metadata)
482 | + return next_tokens
483 | +
484 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
485 | + stacked_params_mapping = [
486 | + # (param_name, shard_name, shard_id)
487 | + ("gate_up_proj", "gate_proj", 0),
488 | + ("gate_up_proj", "up_proj", 1),
489 | + ]
490 | + expert_params_mapping = FusedMoE.make_expert_params_mapping(
491 | + ckpt_gate_proj_name="gate_proj",
492 | + ckpt_down_proj_name="down_proj",
493 | + ckpt_up_proj_name="up_proj",
494 | + num_experts=self.config.num_experts)
495 | +
496 | + params_dict = dict(self.named_parameters(remove_duplicate=False))
497 | + loaded_params: Set[str] = set()
498 | + for name, loaded_weight in weights:
499 | + if (("v_head" in name) or ("inv_freq" in name) or
500 | + (self.config.tie_word_embeddings and "lm_head" in name)):
501 | + continue
502 | + if self.config.norm_head and "lm_head.weight" in name:
503 | + import torch.nn.functional as F
504 | + loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
505 | +
506 | + for (param_name, weight_name, shard_id) in stacked_params_mapping:
507 | + if weight_name not in name:
508 | + continue
509 | + if "mlp.experts" in name:
510 | + continue
511 | + name = name.replace(weight_name, param_name)
512 | + # Skip loading extra bias for GPTQ models.
513 | + if name.endswith(".bias") and name not in params_dict:
514 | + continue
515 | + if name not in params_dict:
516 | + continue
517 | +
518 | + if is_pp_missing_parameter(name, self):
519 | + continue
520 | +
521 | + param = params_dict[name]
522 | + weight_loader = param.weight_loader
523 | + weight_loader(param, loaded_weight, shard_id)
524 | + break
525 | + else:
526 | + for mapping in expert_params_mapping:
527 | + param_name, weight_name, expert_id, shard_id = mapping
528 | + if weight_name not in name:
529 | + continue
530 | + name = name.replace(weight_name, param_name)
531 | +
532 | + if is_pp_missing_parameter(name, self):
533 | + continue
534 | + param = params_dict[name]
535 | + weight_loader = param.weight_loader
536 | + weight_loader(param,
537 | + loaded_weight,
538 | + name,
539 | + shard_id=shard_id,
540 | + expert_id=expert_id)
541 | + break
542 | + else:
543 | + if name.endswith(".bias") and name not in params_dict:
544 | + continue
545 | + if name not in params_dict:
546 | + continue
547 | +
548 | + if is_pp_missing_parameter(name, self):
549 | + continue
550 | +
551 | + param = params_dict[name]
552 | + weight_loader = getattr(param, "weight_loader", default_weight_loader)
553 | + weight_loader(param, loaded_weight)
554 | + loaded_params.add(name)
555 | + return loaded_params
556 | diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
557 | index 81623defd..8e434f624 100644
558 | --- a/vllm/model_executor/models/registry.py
559 | +++ b/vllm/model_executor/models/registry.py
560 | @@ -39,6 +39,7 @@ _TEXT_GENERATION_MODELS = {
561 | "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
562 | "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
563 | "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
564 | + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
565 | "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
566 | "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
567 | "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
568 | diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
569 | index 906056559..45fab6fcf 100644
570 | --- a/vllm/transformers_utils/configs/__init__.py
571 | +++ b/vllm/transformers_utils/configs/__init__.py
572 | @@ -23,6 +23,8 @@ from vllm.transformers_utils.configs.olmo2 import Olmo2Config
573 | from vllm.transformers_utils.configs.solar import SolarConfig
574 | from vllm.transformers_utils.configs.telechat2 import Telechat2Config
575 | from vllm.transformers_utils.configs.ultravox import UltravoxConfig
576 | +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig
577 | +
578 |
579 | __all__ = [
580 | "ChatGLMConfig",
581 | @@ -45,4 +47,5 @@ __all__ = [
582 | "SolarConfig",
583 | "Telechat2Config",
584 | "UltravoxConfig",
585 | + "BailingMoeConfig",
586 | ]
587 | diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py
588 | new file mode 100644
589 | index 000000000..4379368cf
590 | --- /dev/null
591 | +++ b/vllm/transformers_utils/configs/bailing_moe.py
592 | @@ -0,0 +1,76 @@
593 | +""" Bailing MoE model configuration """
594 | +
595 | +from transformers.configuration_utils import PretrainedConfig
596 | +
597 | +
598 | +class BailingMoeConfig(PretrainedConfig):
599 | + model_type = "bailing_moe"
600 | +
601 | + def __init__(
602 | + self,
603 | + vocab_size=30592,
604 | + hidden_size=1024,
605 | + intermediate_size=None,
606 | + num_hidden_layers=24,
607 | + num_attention_heads=16,
608 | + num_key_value_heads=0,
609 | + hidden_act="silu",
610 | + use_qkv_bias=False, # bailing only
611 | + use_bias=True, # bailing only
612 | + rms_norm_eps=1e-05,
613 | + norm_head=False, # bailing only
614 | + tie_word_embeddings=False, # PretrainedConfig key, here change default value.
615 | + embedding_dropout=0.1,
616 | + attention_dropout=0.1,
617 | + output_dropout=0.1,
618 | + initializer_range=0.02,
619 | + max_position_embeddings=16384,
620 | + rope_theta=10000.0,
621 | + use_cache=True,
622 | + use_sliding_window=False,
623 | + sliding_window=4096,
624 | + max_window_layers=28,
625 | + rope_scaling=None,
626 | + pad_token_id=126081,
627 | + num_experts=16,
628 | + num_shared_experts=0,
629 | + num_experts_per_tok=2,
630 | + norm_topk_prob=True,
631 | + moe_intermediate_size=None,
632 | + first_k_dense_replace=0,
633 | + head_dim=None,
634 | + **kwargs,
635 | + ):
636 | + self.num_hidden_layers = num_hidden_layers
637 | + self.vocab_size = vocab_size
638 | + self.hidden_size = hidden_size
639 | + self.intermediate_size = intermediate_size
640 | + self.num_attention_heads = num_attention_heads
641 | + self.num_key_value_heads = num_key_value_heads
642 | + self.hidden_act = hidden_act
643 | + self.use_qkv_bias = use_qkv_bias
644 | + self.use_bias = use_bias
645 | + self.norm_head = norm_head
646 | + self.rms_norm_eps = rms_norm_eps
647 | + self.embedding_dropout = embedding_dropout
648 | + self.attention_dropout = attention_dropout
649 | + self.output_dropout = output_dropout
650 | + self.initializer_range = initializer_range
651 | + self.max_position_embeddings = max_position_embeddings
652 | + self.rope_theta = rope_theta
653 | + self.use_cache = use_cache
654 | + self.use_sliding_window = use_sliding_window
655 | + self.sliding_window = sliding_window
656 | + self.max_window_layers = max_window_layers
657 | + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
658 | + self.rope_scaling = rope_scaling
659 | +
660 | + # MoE configs
661 | + self.num_experts = num_experts
662 | + self.num_shared_experts = num_shared_experts
663 | + self.num_experts_per_tok = num_experts_per_tok
664 | + self.norm_topk_prob = norm_topk_prob
665 | + self.moe_intermediate_size = moe_intermediate_size
666 | + self.first_k_dense_replace = first_k_dense_replace
667 | +
668 | + super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
669 | --
670 | 2.39.2 (Apple Git-143)
671 |
672 |
--------------------------------------------------------------------------------
/inference/mindie/atb_llm/utils-weights.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | from typing import List, Dict, Optional, Tuple
5 |
6 | import torch
7 | from safetensors import safe_open, SafetensorError
8 |
9 | from .hub import weight_files
10 | from .log import logger, print_log
11 | from .quantize.quant_type import QuantType, LinearTypeV2, QUANTIZE_DESC_REQUIRED_LIST
12 | from . import file_utils
13 |
14 | QUANTIZE_DTYPE_LIST = [torch.int8, torch.int32, torch.int64]
15 |
16 |
17 | class Weights:
18 | def __init__(
19 | self,
20 | model_name_or_path,
21 | device,
22 | dtype,
23 | process_group,
24 | quantize=None,
25 | extension: Optional[str] = ".safetensors",
26 | aliases: Optional[Dict[str, List[str]]] = None,
27 | **kwargs
28 | ):
29 | if quantize == QuantType.W8A8SC:
30 | model_name_or_path = os.path.join(model_name_or_path,
31 | f'part{process_group.rank()}-of-{process_group.size()}'
32 | )
33 | model_name_or_path = file_utils.standardize_path(model_name_or_path, check_link=False)
34 | file_utils.check_path_permission(model_name_or_path)
35 | self.filenames = weight_files(model_name_or_path, extension=extension)
36 | self.quantize = quantize
37 | routing = self.load_routing(process_group)
38 | if aliases is None:
39 | aliases = {}
40 | self.aliases = aliases
41 | self.routing = routing
42 | self.device = device
43 | self.dtype = dtype
44 | self.process_group = process_group
45 | self._handles = {}
46 | self.gptq_bits = None
47 | self.gptq_groupsize = None
48 | self.quant_desc = None
49 |
50 | self.init_quant_params(quantize, model_name_or_path)
51 |
52 | def release_file_handler(self):
53 | del self._handles
54 | self._handles = {}
55 |
56 | def load_routing(self, process_group):
57 | routing = {}
58 | for filename in self.filenames:
59 | filename = file_utils.standardize_path(str(filename), check_link=False)
60 | file_utils.check_path_permission(filename)
61 | with safe_open(filename, framework="pytorch") as f:
62 | for k in f.keys():
63 | if k in routing:
64 | print_log(
65 | process_group.rank(),
66 | logger.error,
67 | f"Key {k} was found in multiple files: {filename} and {routing[k]}",
68 | need_filter=True
69 | )
70 | raise AssertionError
71 | routing[k] = filename
72 | return routing
73 |
74 | def get_linear_quant_type(self, key):
75 | if self.quant_desc is None:
76 | return LinearTypeV2.FLOAT16 if self.dtype == torch.float16 else LinearTypeV2.BFLOAT16
77 | if self.quant_desc.get(key, LinearTypeV2.INVALID) == "FLOAT":
78 | return LinearTypeV2.FLOAT16 if self.dtype == torch.float16 else LinearTypeV2.BFLOAT16
79 | return LinearTypeV2[self.quant_desc.get(key, LinearTypeV2.INVALID)]
80 |
81 | def correct_tensor_dtype(self, tensor, tensor_name):
82 | if tensor_name.endswith("deq_scale") and self.dtype == torch.bfloat16:
83 | # BF16场景下deq_scale字段的值为FP32
84 | return tensor
85 | if tensor.dtype not in [torch.int8, torch.int32, torch.int64]:
86 | tensor = tensor.to(dtype=self.dtype)
87 | return tensor
88 |
89 | def init_quant_params(self, quantize, model_name_or_path):
90 | if quantize in QUANTIZE_DESC_REQUIRED_LIST:
91 | self._set_quant_params(model_name_or_path)
92 |
93 | def get_filename(self, tensor_name: str) -> (str, str):
94 | filename = self.routing.get(tensor_name, None)
95 | if filename is None:
96 | aliases = self.aliases.get(tensor_name, [])
97 | for alias in aliases:
98 | filename = self.routing.get(alias, None)
99 | if filename is not None:
100 | return str(filename), alias
101 | raise AssertionError(f"weight {tensor_name} does not exist")
102 | return str(filename), tensor_name
103 |
104 | def get_shape(self, tensor_name: str):
105 | return self._get_slice(tensor_name).get_shape()
106 |
107 | def get_tensor(self, tensor_name: str):
108 | filename, tensor_name = self.get_filename(tensor_name)
109 | f = self._get_handle(filename)
110 | tensor = f.get_tensor(tensor_name)
111 | return self.correct_tensor_dtype(tensor, tensor_name)
112 |
113 | def get_whole_tensor(self, tensor_name: str, dim: int):
114 | slice_ = self._get_slice(tensor_name)
115 |
116 | start = 0
117 | stop = slice_.get_shape()[dim]
118 |
119 | if dim == 0:
120 | tensor = slice_[start:stop]
121 | elif dim == 1:
122 | tensor = slice_[:, start:stop]
123 | else:
124 | logger.error("Let's make that generic when needed")
125 | raise AssertionError
126 | return self.correct_tensor_dtype(tensor, tensor_name)
127 |
128 | def get_partial_sharded_mid_dim(self, tensor_name: str, dim: int, index: int = 1, gqa_size: int = 1):
129 |
130 | world_size = self.process_group.size()
131 | rank = self.process_group.rank()
132 |
133 | slice_ = self._get_slice(tensor_name)
134 | size = slice_.get_shape()[dim]
135 |
136 | block_size = size // 16
137 | start = block_size * index + rank * block_size // world_size
138 | stop = block_size * index + (rank + 1) * block_size // world_size
139 |
140 | if dim == 0:
141 | tensor = slice_[start:stop, :]
142 | elif dim == 1:
143 | tensor = slice_[:, start:stop]
144 | else:
145 | logger.error("Let's make that generic when needed")
146 | raise AssertionError
147 | return self.correct_tensor_dtype(tensor, tensor_name)
148 |
149 | def get_partial_sharded(self, tensor_name: str, dim: int, gqa_size: int = 1):
150 | world_size = self.process_group.size()
151 | rank = self.process_group.rank()
152 |
153 | slice_ = self._get_slice(tensor_name)
154 | size = slice_.get_shape()[dim]
155 | group_size = size // gqa_size
156 | if group_size >= world_size:
157 | block_size = size // world_size
158 | start = rank * block_size
159 | stop = (rank + 1) * block_size
160 | else:
161 | block_size = gqa_size
162 | start = (rank // (world_size // group_size)) * block_size
163 | stop = ((rank // (world_size // group_size)) + 1) * block_size
164 |
165 | if "c_attn.bias" in tensor_name:
166 | b = slice_[:]
167 | single_size = b.shape[0] // 3
168 | head_size = 128
169 | head_num = single_size // head_size
170 | rank_heads = math.ceil(head_num / world_size)
171 | if rank != world_size - 1:
172 | start = rank * (rank_heads * head_size)
173 | stop = (rank + 1) * (rank_heads * head_size)
174 | bq = slice_[start:stop]
175 | bk = slice_[start + single_size:stop + single_size]
176 | bv = slice_[start + 2 * single_size:stop + 2 * single_size]
177 | else:
178 | # last rank
179 | start = rank * (rank_heads * head_size)
180 | stop = head_num * head_size
181 | bq = slice_[start:stop]
182 | bk = slice_[start + single_size:stop + single_size]
183 | bv = slice_[start + 2 * single_size:stop + 2 * single_size]
184 | b_ = torch.cat([bq, bk, bv], dim=0)
185 | return b_
186 |
187 | if dim == 0:
188 | tensor = slice_[start:stop]
189 | elif dim == 1:
190 | tensor = slice_[:, start:stop]
191 | else:
192 | logger.error("Let's make that generic when needed")
193 | raise AssertionError
194 | return self.correct_tensor_dtype(tensor, tensor_name)
195 |
196 | def get_partial_sharded_padding(self, tensor_name: str, dim: int, gqa_size=1):
197 | world_size = self.process_group.size()
198 | rank = self.process_group.rank()
199 |
200 | slice_ = self._get_slice(tensor_name)
201 | size = slice_.get_shape()[dim]
202 |
203 | head_num = size // gqa_size
204 | block_head_num = (head_num + world_size - 1) // world_size
205 |
206 | block_size = block_head_num * gqa_size
207 |
208 | start = rank * block_size
209 | stop = (rank + 1) * block_size
210 |
211 | if rank != world_size - 1:
212 | if dim == 0:
213 | tensor = slice_[start:stop]
214 | elif dim == 1:
215 | tensor = slice_[:, start:stop]
216 | else:
217 | logger.error("Let's make that generic when needed")
218 | raise AssertionError
219 | else:
220 | if dim == 0:
221 | tensor = slice_[start:]
222 | elif dim == 1:
223 | tensor = slice_[:, start:]
224 | else:
225 | logger.error("Let's make that generic when needed")
226 | raise AssertionError
227 |
228 | if len(tensor.shape) == 1:
229 | tensor_zeros = torch.zeros(size=(block_size,), dtype=tensor.dtype, device=tensor.device)
230 | tensor_zeros[:tensor.shape[0]] = tensor
231 | tensor = tensor_zeros
232 | else:
233 | dim0, dim1 = tensor.shape
234 | if dim == 0:
235 | dim0 = block_size
236 | else:
237 | dim1 = block_size
238 | tensor_zeros = torch.zeros(size=(dim0, dim1), dtype=tensor.dtype, device=tensor.device)
239 | tensor_zeros[:tensor.shape[0], :tensor.shape[1]] = tensor
240 | tensor = tensor_zeros
241 |
242 | return self.correct_tensor_dtype(tensor, tensor_name)
243 |
244 | def get_sharded(self, tensor_name: str, dim: int, gqa_size: int = 1):
245 | slice_ = self._get_slice(tensor_name)
246 | world_size = self.process_group.size()
247 | size = slice_.get_shape()[dim]
248 | if (size // gqa_size) % world_size == 0 or world_size % (size // gqa_size) == 0:
249 | return self.get_partial_sharded(tensor_name, dim, gqa_size)
250 | else:
251 | return self.get_partial_sharded_padding(tensor_name, dim, gqa_size)
252 |
253 | def get_per_tensor_sharded(self, prefixes, dim, tensor_name):
254 | tensor = torch.cat(
255 | [self.get_whole_tensor(f"{p}.{tensor_name}", dim=0) for p in prefixes], dim=dim
256 | )
257 | if torch.allclose(tensor, tensor[0]):
258 | tensor = tensor[:1]
259 | else:
260 | raise ValueError(f"`{tensor_name}` are not equal: {tensor}")
261 | return tensor
262 |
263 | def get_tensor_col_packed_qkv_mha(self, tensor_name: str, head_size: int = None, dim=0):
264 | slice_ = self._get_slice(tensor_name)
265 | total_size = slice_.get_shape()[-1 if dim == 1 else 0]
266 | if total_size % 3 != 0:
267 | raise ValueError("Prepacked qkv is not divisible by 3")
268 | single_size = total_size // 3
269 | world_size = self.process_group.size()
270 | rank = self.process_group.rank()
271 | if dim == 1:
272 | if head_size is None:
273 | if single_size % world_size != 0:
274 | raise RuntimeError(f"Prepacked qkv cannot be sharded across {world_size} shards")
275 | try:
276 | block_size = single_size // world_size
277 | except ZeroDivisionError as e:
278 | raise ZeroDivisionError from e
279 | start = rank * block_size
280 | stop = (rank + 1) * block_size
281 | if len(slice_.get_shape()) <= 1:
282 | q = slice_[start:stop]
283 | k = slice_[start + single_size:stop + single_size]
284 | v = slice_[start + 2 * single_size:stop + 2 * single_size]
285 | tensor = torch.cat([q, k, v], dim=0)
286 | else:
287 | q = slice_[:, start:stop]
288 | k = slice_[:, start + single_size:stop + single_size]
289 | v = slice_[:, start + 2 * single_size:stop + 2 * single_size]
290 | tensor = torch.cat([q, k, v], dim=1)
291 | else:
292 | raise ValueError("qkv are not supported")
293 | else:
294 | if head_size is None:
295 | if single_size % world_size != 0:
296 | raise RuntimeError(f"Prepacked qkv cannot be sharded across {world_size} shards")
297 | try:
298 | block_size = single_size // world_size
299 | except ZeroDivisionError as e:
300 | raise ZeroDivisionError from e
301 | start = rank * block_size
302 | stop = (rank + 1) * block_size
303 | q = slice_[start:stop]
304 | k = slice_[start + single_size:stop + single_size]
305 | v = slice_[start + 2 * single_size:stop + 2 * single_size]
306 | tensor = torch.cat([q, k, v], dim=0)
307 | else:
308 | try:
309 | head_num = single_size // head_size
310 | rank_heads = math.ceil(head_num / world_size)
311 | except ZeroDivisionError as e:
312 | raise ZeroDivisionError from e
313 | if rank != world_size - 1:
314 | start = rank * (rank_heads * head_size)
315 | stop = (rank + 1) * (rank_heads * head_size)
316 | q = slice_[start:stop]
317 | k = slice_[start + single_size:stop + single_size]
318 | v = slice_[start + 2 * single_size:stop + 2 * single_size]
319 | tensor = torch.cat([q, k, v], dim=0)
320 | else:
321 | # last rank
322 | start = rank * (rank_heads * head_size)
323 | stop = head_num * head_size
324 | q = slice_[start:stop]
325 | k = slice_[start + single_size:stop + single_size]
326 | v = slice_[start + 2 * single_size:stop + 2 * single_size]
327 |
328 | # padding
329 | q_zero = torch.zeros(size=(rank_heads * head_size, slice_.get_shape()[1]))
330 | k_zero = torch.zeros(size=(rank_heads * head_size, slice_.get_shape()[1]))
331 | v_zero = torch.zeros(size=(rank_heads * head_size, slice_.get_shape()[1]))
332 | q_zero[:q.shape[0], :q.shape[1]] = q
333 | k_zero[:k.shape[0], :k.shape[1]] = k
334 | v_zero[:v.shape[0], :v.shape[1]] = v
335 | tensor = torch.cat([q_zero, k_zero, v_zero], dim=0)
336 | return self.correct_tensor_dtype(tensor, tensor_name)
337 |
338 | def get_tensor_col_packed_o_gqa(self, tensor_name: str, hidden_size, num_heads, num_kv_heads):
339 | num_o_heads = num_heads
340 | head_size = hidden_size // num_heads
341 |
342 | slice_ = self.get_tensor(tensor_name)
343 | world_size = self.process_group.size()
344 | rank = self.process_group.rank()
345 | num_heads = math.ceil(num_heads / world_size)
346 | odd_rank_hidden_size = head_size * (num_heads - 1)
347 | even_rank_hidden_size = head_size * num_heads
348 | shape = list(slice_.shape)
349 | shape[0] = head_size
350 | group_rank = world_size // (num_heads * world_size - num_o_heads)
351 | padding_zero = torch.zeros(shape, dtype=slice_.dtype, device=slice_.device)
352 | if rank % group_rank == 0:
353 | start = (rank // group_rank) * ((group_rank - 1) * even_rank_hidden_size + odd_rank_hidden_size)
354 | indices = torch.range(start, start + odd_rank_hidden_size - 1).to(torch.int32)
355 | part_tensor = torch.index_select(slice_, 1, indices)
356 | part_tensor = torch.cat((part_tensor, padding_zero.T), dim=1)
357 | else:
358 | start = (rank // group_rank) * ((group_rank - 1) * even_rank_hidden_size + odd_rank_hidden_size) + (
359 | (rank % group_rank - 1) * even_rank_hidden_size + odd_rank_hidden_size)
360 | start = int(start)
361 | indices = torch.range(start, start + even_rank_hidden_size - 1).to(torch.int32)
362 | part_tensor = torch.index_select(slice_, 1, indices)
363 | return part_tensor
364 |
365 | def get_tensor_col_packed_q_gqa(self, tensor_name: str, hidden_size, num_heads, num_kv_heads):
366 | num_q_heads = num_heads
367 | head_size = hidden_size // num_heads
368 |
369 | slice_ = self.get_tensor(tensor_name)
370 | world_size = self.process_group.size()
371 | num_heads = math.ceil(num_heads / world_size)
372 | rank = self.process_group.rank()
373 | odd_rank_hidden_size = head_size * (num_heads - 1)
374 | even_rank_hidden_size = head_size * num_heads
375 | shape = list(slice_.shape)
376 | shape[0] = head_size
377 | group_rank = world_size // (num_heads * world_size - num_q_heads)
378 | padding_zero = torch.zeros(shape, dtype=slice_.dtype, device=slice_.device)
379 | if rank % group_rank == 0:
380 | start = (rank // group_rank) * ((group_rank - 1) * even_rank_hidden_size + odd_rank_hidden_size)
381 | indices = torch.range(start, start + odd_rank_hidden_size - 1).to(torch.int32)
382 | part_tensor = torch.index_select(slice_, 0, indices)
383 | part_tensor = torch.cat((part_tensor, padding_zero), dim=0)
384 | else:
385 | start = (rank // group_rank) * ((group_rank - 1) * even_rank_hidden_size + odd_rank_hidden_size) + (
386 | (rank % group_rank - 1) * even_rank_hidden_size + odd_rank_hidden_size)
387 | start = int(start)
388 | indices = torch.range(start, start + even_rank_hidden_size - 1).to(torch.int32)
389 | part_tensor = torch.index_select(slice_, 0, indices)
390 | return part_tensor
391 |
392 | def get_tensor_col_packed_k_gqa(self, tensor_name: str, hidden_size, num_heads, num_kv_heads):
393 | slice_ = self.get_tensor(tensor_name)
394 | world_size = self.process_group.size()
395 | rank = self.process_group.rank()
396 | kv_tp_size = min(world_size, num_kv_heads)
397 | key_list = torch.chunk(slice_, kv_tp_size, dim=0)
398 | tensor = key_list[rank * kv_tp_size // world_size]
399 | return self.correct_tensor_dtype(tensor, tensor_name)
400 |
401 | def get_tensor_col_packed_v_gqa(self, tensor_name: str, hidden_size, num_heads, num_kv_heads):
402 | slice_ = self.get_tensor(tensor_name)
403 | world_size = self.process_group.size()
404 | rank = self.process_group.rank()
405 | kv_tp_size = min(world_size, num_kv_heads)
406 | value_list = torch.chunk(slice_, kv_tp_size, dim=0)
407 | tensor = value_list[rank * kv_tp_size // world_size]
408 | return self.correct_tensor_dtype(tensor, tensor_name)
409 |
410 | def get_tensor_col_packed_qkv_gqa(self, tensor_name: str, num_heads, num_kv_heads):
411 | slice_ = self.get_tensor(tensor_name)
412 | total_size = slice_.shape[0]
413 | if total_size % (num_heads + num_kv_heads * 2) != 0:
414 | raise AssertionError("Prepacked qkv is not divisible by q,k,v")
415 | q_single_size = total_size * num_heads // (num_heads + num_kv_heads * 2)
416 | kv_single_size = total_size * num_kv_heads // (num_heads + num_kv_heads * 2)
417 | world_size = self.process_group.size()
418 | rank = self.process_group.rank()
419 | if q_single_size % world_size != 0:
420 | raise AssertionError(f"Prepacked qkv cannot be sharded across {world_size} shards")
421 | query_layer, key_layer, value_layer = slice_.split((q_single_size, kv_single_size, kv_single_size), dim=0)
422 | kv_tp_size = min(world_size, num_kv_heads)
423 | query_list = torch.chunk(query_layer, world_size, dim=0)
424 | key_list = torch.chunk(key_layer, kv_tp_size, dim=0)
425 | value_list = torch.chunk(value_layer, kv_tp_size, dim=0)
426 | tensor = torch.cat([query_list[rank],
427 | key_list[rank * kv_tp_size // world_size],
428 | value_list[rank * kv_tp_size // world_size]], dim=0)
429 | return self.correct_tensor_dtype(tensor, tensor_name)
430 |
431 | def get_tensor_col_packed_qkv_gqa_padding(self, tensor_name: str, num_heads, num_kv_heads):
432 | #slice_ = self.get_tensor(tensor_name)
433 | rank = self.process_group.rank()
434 | if rank >= 6:
435 | tensor = torch.zeros(size=(1152,5376),dtype=self.dtype,device=self.device)
436 | return tensor
437 | slice_ = self.get_whole_tensor(tensor_name,dim=0)
438 | q_size = 5376
439 | k_size = 768
440 | v_size = 768
441 | q_layer, k_layer, v_layer = slice_.split((q_size, k_size, v_size), dim=0)
442 | q_layer = q_layer[rank*896:(rank+1)*896]
443 | k_layer = k_layer[rank*128:(rank+1)*128]
444 | v_layer = v_layer[rank*128:(rank+1)*128]
445 | tensor = torch.cat([q_layer,k_layer,v_layer],dim=0)
446 | del slice_
447 | del q_layer
448 | del k_layer
449 | del v_layer
450 | return tensor
451 |
452 | def get_tensor_col_packed_kv_mha(self, tensor_name: str, hiden_size, head_size: int = None):
453 | slice_ = self._get_slice(tensor_name)
454 | total_size = slice_.get_shape()[0]
455 | if total_size % 2 != 0:
456 | raise ValueError("Prepacked qkv is not divisible by 2")
457 | single_size = total_size // 2
458 | world_size = self.process_group.size()
459 | rank = self.process_group.rank()
460 |
461 | if head_size is None:
462 | raise RuntimeError("head_size is neccessary")
463 | else:
464 | try:
465 | head_num = single_size // head_size
466 | rank_heads = math.ceil(head_num / world_size)
467 | except ZeroDivisionError as e:
468 | raise ZeroDivisionError from e
469 |
470 | start = rank * (rank_heads * head_size * 2)
471 | stop = (rank + 1) * (rank_heads * head_size * 2)
472 | kv = slice_[start:stop]
473 | kv_new = kv.reshape(rank_heads, 2, head_size, -1)
474 | k, v = torch.chunk(kv_new, 2, dim=1)
475 | if len(slice_.get_shape()) == 1:
476 | k = k.reshape(head_size * rank_heads)
477 | v = v.reshape(head_size * rank_heads)
478 | else:
479 | k = k.reshape(head_size * rank_heads, hiden_size)
480 | v = v.reshape(head_size * rank_heads, hiden_size)
481 | tensor = torch.cat([k, v], dim=0)
482 |
483 | return self.correct_tensor_dtype(tensor, tensor_name)
484 |
485 |
486 | def get_tensor_col_packed_o(self, tensor_name: str, hidden_size, num_heads, num_kv_heads=None):
487 | return self.get_tensor_col_packed_o_gqa(tensor_name, hidden_size, num_heads, num_kv_heads)
488 |
489 | def get_tensor_col_packed_q(self, tensor_name: str, hidden_size, num_heads, num_kv_heads=None):
490 | return self.get_tensor_col_packed_q_gqa(tensor_name, hidden_size, num_heads, num_kv_heads)
491 |
492 | def get_tensor_col_packed_k(self, tensor_name: str, hidden_size, num_heads, num_kv_heads=None):
493 | return self.get_tensor_col_packed_k_gqa(tensor_name, hidden_size, num_heads, num_kv_heads)
494 |
495 | def get_tensor_col_packed_v(self, tensor_name: str, hidden_size, num_heads, num_kv_heads=None):
496 | return self.get_tensor_col_packed_v_gqa(tensor_name, hidden_size, num_heads, num_kv_heads)
497 |
498 | def get_w8a8sc_weight(self, prefix: str):
499 | qweight = self.get_tensor(f"{prefix}.weight")
500 | if qweight.dtype in [torch.float16, torch.bfloat16]:
501 | return qweight
502 | deq_scale = self.get_tensor(f"{prefix}.deq_scale")
503 | quant_bias = self.get_tensor(f"{prefix}.quant_bias")
504 | input_scale = self.get_tensor(f"{prefix}.input_scale")
505 | input_offset = self.get_tensor(f"{prefix}.input_offset")
506 | index = self.get_tensor(f"{prefix}.index")
507 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset, index)
508 | return weight
509 |
510 | def get_weights_col_packed_kv(self, prefix: str, quantize: str, hidden_size, head_size, num_kv_heads=None):
511 | if quantize in [QuantType.W8A8, QuantType.W8A8S]:
512 | qweight = self.get_tensor_col_packed_kv_mha(f"{prefix}.weight", hidden_size, head_size)
513 | if qweight.dtype in [torch.float16, torch.bfloat16]:
514 | return qweight
515 | deq_scale = self.get_tensor_col_packed_kv_mha(f"{prefix}.deq_scale", hidden_size, head_size)
516 | quant_bias = self.get_tensor_col_packed_kv_mha(f"{prefix}.quant_bias", hidden_size, head_size)
517 | input_scale = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_scale')
518 | input_offset = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_offset')
519 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
520 | elif quantize in [QuantType.W4A16, QuantType.W8A16, QuantType.W8A8_DYNAMIC]:
521 | qweight = self.get_tensor_col_packed_kv_mha(f"{prefix}.weight", hidden_size, head_size)
522 | if qweight.dtype in [torch.float16, torch.bfloat16]:
523 | return qweight
524 | weight_scale = self.get_tensor_col_packed_kv_mha(f"{prefix}.weight_scale", hidden_size, head_size)
525 | weight_offset = self.get_tensor_col_packed_kv_mha(f"{prefix}.weight_offset", hidden_size, head_size)
526 | weight = (qweight, weight_scale, weight_offset)
527 | elif quantize == QuantType.W8A8SC:
528 | return self.get_w8a8sc_weight(prefix)
529 | else:
530 | weight = self.get_tensor_col_packed_kv_mha(f"{prefix}.weight", hidden_size, head_size)
531 | return weight
532 |
533 |
534 | def get_weights_col_packed_o(self, prefix: str, quantize: str, hidden_size, num_heads, num_kv_heads=None):
535 | weight = self.get_tensor_col_packed_o(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads)
536 | return weight
537 |
538 | def get_weights_col_packed_q(self, prefix: str, quantize: str, hidden_size, num_heads, num_kv_heads=None):
539 | weight = self.get_tensor_col_packed_q(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads)
540 | return weight
541 |
542 | def get_weights_col_packed_k(self, prefix: str, quantize: str, hidden_size, num_heads, num_kv_heads=None):
543 | weight = self.get_tensor_col_packed_k(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads)
544 | return weight
545 |
546 | def get_weights_col_packed_v(self, prefix: str, quantize: str, hidden_size, num_heads, num_kv_heads=None):
547 | weight = self.get_tensor_col_packed_v(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads)
548 | return weight
549 |
550 |
551 | def get_tensor_col_packed_qkv(self, tensor_name: str, hidden_size, num_heads, num_kv_heads=None, dim=0,padding=False):
552 | if not num_kv_heads:
553 | num_kv_heads = num_heads
554 | if num_heads == num_kv_heads:
555 | if num_heads % self.process_group.size() == 0:
556 | return self.get_tensor_col_packed_qkv_mha(tensor_name, dim=dim)
557 | else:
558 | return self.get_tensor_col_packed_qkv_mha(tensor_name, hidden_size // num_heads, dim=dim)
559 | else:
560 | #return self.get_tensor_col_packed_qkv_gqa(tensor_name, num_heads, num_kv_heads)
561 | if padding:
562 | return self.get_tensor_col_packed_qkv_gqa_padding(tensor_name, num_heads, num_kv_heads)
563 | else:
564 | return self.get_tensor_col_packed_qkv_gqa(tensor_name, num_heads, num_kv_heads)
565 |
566 | def get_weights_col_packed_qkv(self, prefix: str, quantize: str, hidden_size, num_heads, num_kv_heads=None, dim=0, padding=False):
567 | if quantize in [QuantType.W8A8, QuantType.W8A8S]:
568 | qweight = self.get_tensor_col_packed_qkv(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads)
569 | if qweight.dtype in [torch.float16, torch.bfloat16]:
570 | return qweight
571 | deq_scale = self.get_tensor_col_packed_qkv(f"{prefix}.deq_scale", hidden_size, num_heads, num_kv_heads)
572 | quant_bias = self.get_tensor_col_packed_qkv(f"{prefix}.quant_bias", hidden_size, num_heads, num_kv_heads)
573 | input_scale = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_scale')
574 | input_offset = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_offset')
575 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
576 | elif quantize in [QuantType.W4A16, QuantType.W8A16, QuantType.W8A8_DYNAMIC]:
577 | qweight = self.get_tensor_col_packed_qkv(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads)
578 | if qweight.dtype in [torch.float16, torch.bfloat16]:
579 | return qweight
580 | weight_scale = self.get_tensor_col_packed_qkv(f"{prefix}.weight_scale", hidden_size, num_heads,
581 | num_kv_heads)
582 | weight_offset = self.get_tensor_col_packed_qkv(f"{prefix}.weight_offset", hidden_size, num_heads,
583 | num_kv_heads)
584 | weight = (qweight, weight_scale, weight_offset)
585 | elif quantize == QuantType.W8A8SC:
586 | return self.get_w8a8sc_weight(prefix)
587 | else:
588 | weight = self.get_tensor_col_packed_qkv(f"{prefix}.weight", hidden_size, num_heads, num_kv_heads, dim=dim, padding=padding)
589 | return weight
590 |
591 | def get_tensor_col_packed_mlp(self, tensor_name, head_types=2):
592 | slice_ = self.get_tensor(tensor_name)
593 | total_size = slice_.shape[0]
594 | if total_size % head_types != 0:
595 | raise AssertionError("Prepacked mlp is not divisible by up,gate")
596 | up_single_size = total_size // head_types
597 | gate_single_size = total_size // head_types
598 | world_size = self.process_group.size()
599 | rank = self.process_group.rank()
600 | if up_single_size % world_size != 0:
601 | raise AssertionError(f"Prepacked mlp cannot be sharded across {world_size} shards")
602 | gate_layer, up_layer = slice_.split((up_single_size, gate_single_size), dim=0)
603 | gate_list = torch.chunk(gate_layer, world_size, dim=0)
604 | up_list = torch.chunk(up_layer, world_size, dim=0)
605 | tensor = torch.cat([gate_list[rank], up_list[rank]], dim=0)
606 | return self.correct_tensor_dtype(tensor, tensor_name)
607 |
608 | def get_weights_col_packed_mlp(self, prefix: str, quantize: str):
609 | if quantize in [QuantType.W8A8, QuantType.W8A8S]:
610 | qweight = self.get_tensor_col_packed_mlp(f"{prefix}.weight")
611 | if qweight.dtype in [torch.float16, torch.bfloat16]:
612 | return qweight
613 | deq_scale = self.get_tensor_col_packed_mlp(f"{prefix}.deq_scale")
614 | quant_bias = self.get_tensor_col_packed_mlp(f"{prefix}.quant_bias")
615 | input_scale = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_scale')
616 | input_offset = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_offset')
617 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
618 | elif quantize in [QuantType.W4A16, QuantType.W8A16, QuantType.W8A8_DYNAMIC]:
619 | qweight = self.get_tensor_col_packed_mlp(f"{prefix}.weight")
620 | if qweight.dtype in [torch.float16, torch.bfloat16]:
621 | return qweight
622 | weight_scale = self.get_tensor_col_packed_mlp(f"{prefix}.weight_scale")
623 | weight_offset = self.get_tensor_col_packed_mlp(f"{prefix}.weight_offset")
624 | weight = (qweight, weight_scale, weight_offset)
625 | elif quantize == QuantType.W8A8SC:
626 | return self.get_w8a8sc_weight(prefix)
627 | else:
628 | weight = self.get_tensor_col_packed_mlp(f"{prefix}.weight")
629 | return weight
630 |
631 | def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int, gqa_size: int = 1, norm_head: bool = False):
632 |
633 | if quantize == "gptq":
634 | try:
635 | qweight = torch.cat(
636 | [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
637 | )
638 | except RuntimeError as err:
639 | logger.error(
640 | "Cannot load `gptq` weight, make sure the model is already quantized"
641 | )
642 | raise AssertionError from err
643 |
644 | qzeros = torch.cat(
645 | [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
646 | )
647 | scales = torch.cat(
648 | [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
649 | )
650 | w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
651 | for w2 in w[1:]:
652 | torch.testing.assert_close(w2, w[0])
653 | g_idx = w[0]
654 |
655 | bits, groupsize = self._get_gptq_params()
656 | weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
657 | elif quantize in [QuantType.W8A8, QuantType.W8A8S]:
658 | qweight = torch.cat(
659 | [self.get_sharded(f"{p}.weight", dim=0, gqa_size=gqa_size) for p in prefixes], dim=dim
660 | )
661 | if qweight.dtype in [torch.float16, torch.bfloat16]:
662 | return qweight
663 | deq_scale = torch.cat(
664 | [self.get_sharded(f"{p}.deq_scale", dim=0, gqa_size=gqa_size) for p in prefixes], dim=dim
665 | )
666 | quant_bias = torch.cat(
667 | [self.get_sharded(f"{p}.quant_bias", dim=0, gqa_size=gqa_size) for p in prefixes], dim=dim
668 | )
669 | input_scale = self.get_per_tensor_sharded(prefixes, dim, 'input_scale')
670 | input_offset = self.get_per_tensor_sharded(prefixes, dim, 'input_offset')
671 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
672 | elif quantize in [QuantType.W4A16, QuantType.W8A16, QuantType.W8A8_DYNAMIC]:
673 | qweight = torch.cat(
674 | [self.get_sharded(f"{p}.weight", dim=0, gqa_size=gqa_size) for p in prefixes], dim=dim
675 | )
676 | if qweight.dtype in [torch.float16, torch.bfloat16]:
677 | return qweight
678 | weight_scale = torch.cat(
679 | [self.get_sharded(f"{p}.weight_scale", dim=0, gqa_size=gqa_size) for p in prefixes], dim=dim
680 | )
681 | weight_offset = torch.cat(
682 | [self.get_sharded(f"{p}.weight_offset", dim=0, gqa_size=gqa_size) for p in prefixes], dim=dim
683 | )
684 | weight = (qweight, weight_scale, weight_offset)
685 | elif quantize == QuantType.W8A8SC:
686 | qweight = torch.cat([self.get_tensor(f"{p}.weight") for p in prefixes], dim=dim)
687 | if qweight.dtype in [torch.float16, torch.bfloat16]:
688 | return qweight
689 | deq_scale = torch.cat([self.get_tensor(f"{p}.deq_scale") for p in prefixes], dim=dim)
690 | quant_bias = torch.cat([self.get_tensor(f"{p}.quant_bias") for p in prefixes], dim=dim)
691 | input_scale = torch.cat([self.get_tensor(f"{p}.input_scale") for p in prefixes], dim=dim)
692 | input_offset = torch.cat([self.get_tensor(f"{p}.input_offset") for p in prefixes], dim=dim)
693 | index = torch.cat([self.get_tensor(f"{p}.index") for p in prefixes], dim=dim)
694 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset, index)
695 | else:
696 | if norm_head:
697 | w = []
698 | for p in prefixes:
699 | world_size = self.process_group.size()
700 | rank = self.process_group.rank()
701 | head_weight = self.get_whole_tensor(f"{p}.weight",dim=0).npu()
702 | dim_1 = head_weight.shape[1]
703 | input_matrix = torch.eye(dim_1, dtype=torch.float16,device=head_weight.device)
704 | unnormed_head = torch.mm(head_weight,input_matrix)
705 | head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2)
706 | normed_head = unnormed_head / (head_norm + 1e-7)
707 |
708 | size = head_weight.shape[dim]
709 | group_size = size // gqa_size
710 | if group_size >= world_size:
711 | block_size = size // world_size
712 | start = rank * block_size
713 | stop = (rank + 1) * block_size
714 | else:
715 | block_size = gqa_size
716 | start = (rank // (world_size // group_size)) * block_size
717 | stop = ((rank // (world_size // group_size)) + 1) * block_size
718 |
719 | if dim == 0:
720 | tensor = normed_head[start:stop]
721 | elif dim == 1:
722 | tensor = normed_head[:, start:stop]
723 | else:
724 | logger.error("Let's make that generic when needed")
725 | raise AssertionError
726 | w.append(tensor)
727 | weight = torch.cat(w, dim=dim)
728 | return weight
729 | w = [self.get_sharded(f"{p}.weight", dim=dim, gqa_size=gqa_size) for p in prefixes]
730 | weight = torch.cat(w, dim=dim)
731 | return weight
732 |
733 | def get_multi_weights_row(self, prefix: str, quantize: str, gqa_size=1, dim=1):
734 | if quantize in [QuantType.W8A8, QuantType.W8A8S]:
735 | qweight = self.get_sharded(f"{prefix}.weight", dim=dim, gqa_size=gqa_size)
736 | if qweight.dtype in [torch.float16, torch.bfloat16]:
737 | return qweight
738 | deq_scale = self.get_tensor(f"{prefix}.deq_scale")
739 | quant_bias = self.get_tensor(f"{prefix}.quant_bias")
740 | if self.process_group.rank() == 0:
741 | quant_bias = quant_bias
742 | else:
743 | quant_bias = torch.zeros_like(quant_bias, dtype=quant_bias.dtype, device=quant_bias.device)
744 | input_scale = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_scale')
745 | input_offset = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_offset')
746 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
747 | elif quantize in [QuantType.W4A16, QuantType.W8A16, QuantType.W8A8_DYNAMIC]:
748 | qweight = self.get_sharded(f"{prefix}.weight", dim=1, gqa_size=gqa_size)
749 | if qweight.dtype in [torch.float16, torch.bfloat16]:
750 | return qweight
751 | weight_scale = self.get_sharded(f"{prefix}.weight_scale", dim=dim, gqa_size=1)
752 | weight_offset = self.get_sharded(f"{prefix}.weight_offset", dim=dim, gqa_size=1)
753 | weight = (qweight, weight_scale, weight_offset)
754 | elif quantize == QuantType.W8A8SC:
755 | return self.get_w8a8sc_weight(prefix)
756 | else:
757 | weight = self.get_sharded(f"{prefix}.weight", dim=dim, gqa_size=gqa_size)
758 | return weight
759 |
760 | def get_replicated_weights(self, prefix: str, quantize: str):
761 | if quantize in [QuantType.W8A8, QuantType.W8A8S]:
762 | qweight = self.get_tensor(f"{prefix}.weight")
763 | if qweight.dtype in [torch.float16, torch.bfloat16]:
764 | return qweight
765 | deq_scale = self.get_tensor(f"{prefix}.deq_scale")
766 | quant_bias = self.get_tensor(f"{prefix}.quant_bias")
767 | input_scale = self.get_tensor(f"{prefix}.input_scale")
768 | input_offset = self.get_tensor(f"{prefix}.input_offset")
769 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
770 | elif quantize in [QuantType.W4A16, QuantType.W8A16, QuantType.W8A8_DYNAMIC]:
771 | qweight = self.get_tensor(f"{prefix}.weight")
772 | if qweight.dtype in [torch.float16, torch.bfloat16]:
773 | return qweight
774 | weight_scale = self.get_tensor(f"{prefix}.weight_scale")
775 | weight_offset = self.get_tensor(f"{prefix}.weight_offset")
776 | weight = (qweight, weight_scale, weight_offset)
777 | elif quantize == QuantType.W8A8SC:
778 | return self.get_w8a8sc_weight(prefix)
779 | else:
780 | weight = self.get_tensor(f"{prefix}.weight")
781 | return weight
782 |
783 | def _get_handle(self, filename):
784 | if filename not in self._handles:
785 | f = safe_open(filename, framework="pytorch")
786 | self._handles[filename] = f
787 | return self._handles[filename]
788 |
789 | def _get_slice(self, tensor_name: str):
790 | filename, tensor_name = self.get_filename(tensor_name)
791 | f = self._get_handle(filename)
792 | slice_ = f.get_slice(tensor_name)
793 | return slice_
794 |
795 | def _get_gptq_params(self) -> Tuple[int, int]:
796 | try:
797 | bits = self.get_tensor("gptq_bits").item()
798 | groupsize = self.get_tensor("gptq_groupsize").item()
799 | except (SafetensorError, RuntimeError) as _:
800 | try:
801 | bits = self.gptq_bits
802 | groupsize = self.gptq_groupsize
803 | except Exception as err:
804 | raise AssertionError from err
805 |
806 | return bits, groupsize
807 |
808 | def _set_quant_params(self, model_id):
809 | try:
810 | filename = os.path.join(model_id, f'quant_model_description_{self.quantize}.json')
811 | with file_utils.safe_open(filename, 'r') as f:
812 | data = json.load(f)
813 | self.quant_desc = data
814 | except Exception as err:
815 | raise AssertionError from err
816 |
817 | def get_multi_weights_row_att_dense(self, prefix: str, quantize: str, gqa_size=1, dim=1):
818 | if quantize in [QuantType.W8A8, QuantType.W8A8S]:
819 | qweight = self.get_sharded(f"{prefix}.weight", dim=1, gqa_size=gqa_size)
820 | if qweight.dtype in [torch.float16, torch.bfloat16]:
821 | return qweight
822 | deq_scale = self.get_tensor(f"{prefix}.deq_scale")
823 | quant_bias = self.get_tensor(f"{prefix}.quant_bias")
824 | if self.process_group.rank() == 0:
825 | quant_bias = quant_bias
826 | else:
827 | quant_bias = torch.zeros_like(quant_bias, dtype=quant_bias.dtype, device=quant_bias.device)
828 | input_scale = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_scale')
829 | input_offset = self.get_per_tensor_sharded([prefix], dim=0, tensor_name='input_offset')
830 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset)
831 | elif quantize in [QuantType.W4A16, QuantType.W8A16]:
832 | qweight = self.get_sharded(f"{prefix}.weight", dim=1, gqa_size=gqa_size)
833 | if qweight.dtype in [torch.float16, torch.bfloat16]:
834 | return qweight
835 | weight_scale = self.get_sharded(f"{prefix}.weight_scale", dim=1, gqa_size=1)
836 | weight_offset = self.get_sharded(f"{prefix}.weight_offset", dim=1, gqa_size=1)
837 | weight = (qweight, weight_scale, weight_offset)
838 | elif quantize == QuantType.W8A8SC:
839 | qweight = self.get_tensor(f"{prefix}.weight")
840 | if qweight.dtype in [torch.float16, torch.bfloat16]:
841 | return qweight
842 | deq_scale = self.get_tensor(f"{prefix}.deq_scale")
843 | quant_bias = self.get_tensor(f"{prefix}.quant_bias")
844 | input_scale = self.get_tensor(f"{prefix}.input_scale")
845 | input_offset = self.get_tensor(f"{prefix}.input_offset")
846 | index = self.get_tensor(f"{prefix}.index")
847 | weight = (qweight, deq_scale, quant_bias, input_scale, input_offset, index)
848 | else:
849 | weight = self.get_sharded_att(f"{prefix}.weight", dim=1, gqa_size=gqa_size)
850 | return weight
851 |
852 | def get_sharded_att(self, tensor_name: str, dim: int, gqa_size: int = 1):
853 | slice_ = self._get_slice(tensor_name)
854 | world_size = self.process_group.size()
855 | size = slice_.get_shape()[dim]
856 | return self.get_partial_sharded_att(tensor_name, dim, gqa_size)
857 | '''
858 | if (size // gqa_size) % world_size == 0 or world_size % (size // gqa_size) == 0:
859 | return self.get_partial_sharded_att(tensor_name, dim, gqa_size)
860 | else:
861 | return self.get_partial_sharded_padding_att(tensor_name, dim, gqa_size)
862 | '''
863 |
864 | def get_partial_sharded_att(self, tensor_name: str, dim: int, gqa_size: int = 1):
865 | world_size = self.process_group.size()
866 | rank = self.process_group.rank()
867 | if rank >= 6:
868 | tensor = torch.zeros(size=(5376,7*128),dtype=self.dtype,device=self.device)
869 | return tensor
870 | slice_ = self.get_whole_tensor(tensor_name,dim=0)
871 | slice_part = slice_[:,rank*896:(rank+1)*896]
872 | return slice_part
--------------------------------------------------------------------------------