├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── cfgs ├── 2-norm-cfgs │ └── 2-norm-SmolLM1-135M.yml ├── Llama2-7B-4GPU.yml ├── Llama2-7B-8GPU.yml ├── Qwen1_5-0_5B-2GPU.yml ├── Qwen1_5-0_5B-4GPU.yml ├── Qwen1_5-0_5B-8GPU.yml ├── Qwen3-0_6B-4GPU.yml ├── Qwen3-30B-4GPU.yml ├── SmolLM1-135M-2GPU.yml ├── SmolLM1-135M-4GPU-high.yml ├── SmolLM1-135M-4GPU-low.yml ├── SmolLM1-135M-4GPU-uniform.yml ├── SmolLM1-135M-4GPU.yml ├── SmolLM1-135M-8GPU.yml ├── SmolLM1-1B7-4GPU.yml ├── SmolLM1-1B7-8GPU.yml ├── SmolLM1-360M-2GPU.yml ├── SmolLM1-360M-4GPU.yml ├── ds_zero_1.json ├── ds_zero_2.json └── ds_zero_2_offload.json ├── eval ├── __init__.py ├── benchmark_latency_memory.py ├── eval.py ├── lb_table.py ├── longbench.py ├── math_utils.py ├── smollm1_base.txt ├── tasks.py └── test_import.py ├── img └── overview.png ├── requirements.txt ├── src └── mha2mla │ ├── 2_norm.py │ ├── __init__.py │ ├── arguments.py │ ├── helpers.py │ ├── mla_triton_kernel.py │ ├── patch_func.py │ ├── patching_llama.py │ ├── patching_model_load.py │ ├── patching_qwen2.py │ ├── patching_qwen3.py │ ├── patching_qwen3_moe.py │ └── run_train.py └── utils ├── 135M_2norm_rank.pth ├── llama2_13B-2_norm_rank.pth ├── llama2_7B-2_norm_rank.pth ├── qwen1.5_0.5B-2_norm_rank.pth ├── qwen2_0.5B-2_norm_rank.pth ├── smollm1_135M-2_norm_rank.pth ├── smollm1_135M-2_norm_rank_q.pth ├── smollm1_1B7-2_norm_rank.pth ├── smollm1_360M-2_norm_rank.pth └── test_load_tensor.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/.gitignore -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/README.md -------------------------------------------------------------------------------- /cfgs/2-norm-cfgs/2-norm-SmolLM1-135M.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/2-norm-cfgs/2-norm-SmolLM1-135M.yml -------------------------------------------------------------------------------- /cfgs/Llama2-7B-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Llama2-7B-4GPU.yml -------------------------------------------------------------------------------- /cfgs/Llama2-7B-8GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Llama2-7B-8GPU.yml -------------------------------------------------------------------------------- /cfgs/Qwen1_5-0_5B-2GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Qwen1_5-0_5B-2GPU.yml -------------------------------------------------------------------------------- /cfgs/Qwen1_5-0_5B-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Qwen1_5-0_5B-4GPU.yml -------------------------------------------------------------------------------- /cfgs/Qwen1_5-0_5B-8GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Qwen1_5-0_5B-8GPU.yml -------------------------------------------------------------------------------- /cfgs/Qwen3-0_6B-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Qwen3-0_6B-4GPU.yml -------------------------------------------------------------------------------- /cfgs/Qwen3-30B-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/Qwen3-30B-4GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-135M-2GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-135M-2GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-135M-4GPU-high.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-135M-4GPU-high.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-135M-4GPU-low.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-135M-4GPU-low.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-135M-4GPU-uniform.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-135M-4GPU-uniform.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-135M-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-135M-4GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-135M-8GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-135M-8GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-1B7-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-1B7-4GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-1B7-8GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-1B7-8GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-360M-2GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-360M-2GPU.yml -------------------------------------------------------------------------------- /cfgs/SmolLM1-360M-4GPU.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/SmolLM1-360M-4GPU.yml -------------------------------------------------------------------------------- /cfgs/ds_zero_1.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/ds_zero_1.json -------------------------------------------------------------------------------- /cfgs/ds_zero_2.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/ds_zero_2.json -------------------------------------------------------------------------------- /cfgs/ds_zero_2_offload.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/cfgs/ds_zero_2_offload.json -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval/benchmark_latency_memory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/benchmark_latency_memory.py -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/eval.py -------------------------------------------------------------------------------- /eval/lb_table.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/lb_table.py -------------------------------------------------------------------------------- /eval/longbench.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/longbench.py -------------------------------------------------------------------------------- /eval/math_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/math_utils.py -------------------------------------------------------------------------------- /eval/smollm1_base.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/smollm1_base.txt -------------------------------------------------------------------------------- /eval/tasks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/eval/tasks.py -------------------------------------------------------------------------------- /eval/test_import.py: -------------------------------------------------------------------------------- 1 | from src.mha2mla.run_train import main -------------------------------------------------------------------------------- /img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/img/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/requirements.txt -------------------------------------------------------------------------------- /src/mha2mla/2_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/2_norm.py -------------------------------------------------------------------------------- /src/mha2mla/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/mha2mla/arguments.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/arguments.py -------------------------------------------------------------------------------- /src/mha2mla/helpers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/helpers.py -------------------------------------------------------------------------------- /src/mha2mla/mla_triton_kernel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/mla_triton_kernel.py -------------------------------------------------------------------------------- /src/mha2mla/patch_func.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/patch_func.py -------------------------------------------------------------------------------- /src/mha2mla/patching_llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/patching_llama.py -------------------------------------------------------------------------------- /src/mha2mla/patching_model_load.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/patching_model_load.py -------------------------------------------------------------------------------- /src/mha2mla/patching_qwen2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/patching_qwen2.py -------------------------------------------------------------------------------- /src/mha2mla/patching_qwen3.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/patching_qwen3.py -------------------------------------------------------------------------------- /src/mha2mla/patching_qwen3_moe.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/patching_qwen3_moe.py -------------------------------------------------------------------------------- /src/mha2mla/run_train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/src/mha2mla/run_train.py -------------------------------------------------------------------------------- /utils/135M_2norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/135M_2norm_rank.pth -------------------------------------------------------------------------------- /utils/llama2_13B-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/llama2_13B-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/llama2_7B-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/llama2_7B-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/qwen1.5_0.5B-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/qwen1.5_0.5B-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/qwen2_0.5B-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/qwen2_0.5B-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/smollm1_135M-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/smollm1_135M-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/smollm1_135M-2_norm_rank_q.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/smollm1_135M-2_norm_rank_q.pth -------------------------------------------------------------------------------- /utils/smollm1_1B7-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/smollm1_1B7-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/smollm1_360M-2_norm_rank.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/smollm1_360M-2_norm_rank.pth -------------------------------------------------------------------------------- /utils/test_load_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/HEAD/utils/test_load_tensor.py --------------------------------------------------------------------------------