├── VERSION
├── llmexport
├── utils
│ ├── __init__.py
│ ├── spinner.py
│ ├── onnx.py
│ ├── onnx_rebuilder.py
│ ├── custom_op.py
│ ├── torch_utils.py
│ ├── lora.py
│ ├── mnn_utils.py
│ ├── gptq.py
│ ├── hqq_quantizer.py
│ ├── talker.py
│ ├── mtp.py
│ ├── eagle.py
│ ├── audio.py
│ ├── smooth_quantizer.py
│ ├── token2wav.py
│ └── mnn_converter.py
├── __init__.py
├── gguf
│ └── gguf_reader.py
└── gguf2mnn.py
├── MANIFEST.in
├── .github
└── workflows
│ └── release.yml
├── setup.py
├── README.md
├── README_en.md
└── LICENSE
/VERSION:
--------------------------------------------------------------------------------
1 | 0.0.4
2 |
--------------------------------------------------------------------------------
/llmexport/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include VERSION
2 | recursive-exclude tests *
3 |
--------------------------------------------------------------------------------
/llmexport/__init__.py:
--------------------------------------------------------------------------------
1 | __vesion__ = '0.0.1'
2 |
3 | from .llmexport import export
4 | from .llmexport import main
5 | from .version import __version__
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | pypi-release:
12 | name: upload release to pypi
13 | runs-on: ubuntu-latest
14 | environment: release
15 | permissions:
16 | id-token: write
17 |
18 | steps:
19 | - uses: actions/checkout@v3
20 | - name: Set up Python
21 | uses: actions/setup-python@v3
22 | with:
23 | python-version: "3.x"
24 |
25 | - name: install
26 | run: |
27 | python -m pip install --upgrade pip wheel setuptools
28 | pip install build
29 |
30 | - name: build
31 | run: python -m build
32 |
33 | - name: upload
34 | uses: pypa/gh-action-pypi-publish@release/v1
--------------------------------------------------------------------------------
/llmexport/utils/spinner.py:
--------------------------------------------------------------------------------
1 | import time
2 | import functools
3 | import traceback
4 | from yaspin import yaspin
5 |
6 | RESET = "\033[0m"
7 | GREEN = "\033[32;1m"
8 | YELLOW = "\033[33;4m"
9 |
10 | def spinner_run(text='Processing...', hide=False):
11 | def decorator(func):
12 | @functools.wraps(func)
13 | def wrapper(*args, **kwargs):
14 | with yaspin(text=text, color="cyan") as spinner:
15 | start = time.time()
16 | try:
17 | if hide: spinner.hide()
18 | result = func(*args, **kwargs)
19 | if hide: spinner.show()
20 | except Exception as e:
21 | spinner.fail("💥 Failed")
22 | traceback.print_exc()
23 | exit(1)
24 | end = time.time()
25 | during = f'[{end-start:05.2f} s]'.replace('[0', '[ ')
26 | padding = ' ' * (64 - len(spinner.text) - len(result))
27 | spinner.text = f'{spinner.text}{YELLOW}{result}{RESET}{padding}{GREEN}{during}{RESET}'
28 | spinner.ok("✅ Done")
29 | return result
30 | return wrapper
31 | return decorator
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | with open("VERSION", "r") as f:
4 | version = f.read().strip()
5 |
6 | with open("llmexport/version.py", "w") as f:
7 | f.write(f'__version__ = "{version}"\n')
8 |
9 | setup(
10 | name="llmexport",
11 | version=version,
12 | description="llmexport: A toolkit to export llm to onnx or mnn.",
13 | long_description=open("README.md", "r", encoding="utf-8").read(),
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/wangzhaode/llm-export",
16 | author="wangzhaode",
17 | author_email="hi@zhaode.wang",
18 | project_urls={
19 | "Bug Tracker": "https://github.com/wangzhaode/llm-export/issues",
20 | },
21 | classifiers=[
22 | "Programming Language :: Python :: 3",
23 | "License :: OSI Approved :: MIT License",
24 | "Intended Audience :: Developers",
25 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
26 | ],
27 | license="MIT",
28 | install_requires=["yaspin", "torch", "numpy", "transformers", "sentencepiece", "onnx", "onnxslim", "onnxruntime", "MNN"],
29 | packages=find_packages(exclude=("tests", "tests.*")),
30 | entry_points={"console_scripts": ["llmexport=llmexport:main"]},
31 | zip_safe=True,
32 | python_requires=">=3.6",
33 | )
--------------------------------------------------------------------------------
/llmexport/utils/onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import onnx
3 |
4 | class OnnxRebuilder:
5 | def __init__(self, onnx_path, weight_ops):
6 | self.weight_ops = weight_ops
7 | self.onnx_model = onnx.load(onnx_path)
8 | self.dst_path = onnx_path
9 | self.onnx_weight_path = f'{onnx_path}.data'
10 | self.onnx_weight_offset = 0
11 |
12 | def make_external(self, name, data, shape):
13 | # write to external weight
14 | length = self.onnx_weight.write(data.tobytes())
15 | location = os.path.basename(self.onnx_weight_path)
16 | offset = self.onnx_weight_offset
17 | self.onnx_weight_offset += length
18 | tensor = onnx.TensorProto()
19 | tensor.name = name
20 | tensor.data_type = onnx.TensorProto.FLOAT
21 | tensor.dims.extend(shape)
22 | # external info
23 | tensor.data_location = onnx.TensorProto.EXTERNAL
24 | for k, v in { "location": location, "offset": offset, "length": length }.items():
25 | entry = tensor.external_data.add()
26 | entry.key = k
27 | entry.value = str(v)
28 | self.onnx_model.graph.initializer.append(tensor)
29 |
30 | def build_weight(self, name, has_bias, ic, oc):
31 | assert(name in self.weight_ops)
32 | linear = self.weight_ops[name]
33 | assert(linear.in_features == ic and
34 | linear.out_features == oc and
35 | (linear.bias is not None) == has_bias)
36 | weight_name, bias_name = f'{name}_weight', f'{name}_bias'
37 | weight = linear.weight.data.transpose(1, 0).flatten().float().numpy()
38 | self.make_external(weight_name, weight, [ic, oc])
39 | if has_bias:
40 | bias = linear.bias.data.flatten().float().numpy()
41 | self.make_external(bias_name, bias, [oc])
42 | return weight_name, bias_name
43 |
44 | def rebuild(self):
45 | from onnx import helper
46 | new_nodes = []
47 | self.onnx_weight = open(self.onnx_weight_path, 'wb')
48 | for node in self.onnx_model.graph.node:
49 | if node.op_type == 'FakeLinear':
50 | attributes = {a.name: a for a in node.attribute}
51 | name = attributes.get('name').s.decode('utf-8')
52 | has_bias = attributes.get('has_bias').i
53 | ic = attributes.get('in_features').i
54 | oc = attributes.get('out_features').i
55 | weight, bias = self.build_weight(name, has_bias, ic, oc)
56 | if has_bias:
57 | # fakelinear -> matmul + add
58 | middle_tensor = f'{name}_matmul'
59 | new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], [middle_tensor], name))
60 | new_nodes.append(helper.make_node('Add', [middle_tensor, bias], node.output, f'{name}/Add'))
61 | else:
62 | # fakelinear -> matmul
63 | new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], node.output, name))
64 | else:
65 | new_nodes.append(node)
66 | self.onnx_weight.close()
67 | del self.onnx_model.graph.node[:]
68 | self.onnx_model.graph.node.extend(new_nodes)
69 | onnx.save(self.onnx_model, self.dst_path)
70 | return self.onnx_weight_path
--------------------------------------------------------------------------------
/llmexport/utils/onnx_rebuilder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import onnx
3 |
4 | class OnnxRebuilder:
5 | def __init__(self, onnx_path, weight_ops):
6 | self.weight_ops = weight_ops
7 | self.onnx_model = onnx.load(onnx_path)
8 | self.dst_path = onnx_path
9 | self.onnx_weight_path = f'{onnx_path}.data'
10 | self.onnx_weight_offset = 0
11 |
12 | def make_external(self, name, data, shape):
13 | # write to external weight
14 | length = self.onnx_weight.write(data.tobytes())
15 | location = os.path.basename(self.onnx_weight_path)
16 | offset = self.onnx_weight_offset
17 | self.onnx_weight_offset += length
18 | tensor = onnx.TensorProto()
19 | tensor.name = name
20 | tensor.data_type = onnx.TensorProto.FLOAT
21 | tensor.dims.extend(shape)
22 | # external info
23 | tensor.data_location = onnx.TensorProto.EXTERNAL
24 | for k, v in { "location": location, "offset": offset, "length": length }.items():
25 | entry = tensor.external_data.add()
26 | entry.key = k
27 | entry.value = str(v)
28 | self.onnx_model.graph.initializer.append(tensor)
29 |
30 | def build_weight(self, name, has_bias, ic, oc):
31 | assert(name in self.weight_ops)
32 | linear = self.weight_ops[name]
33 | assert(linear.in_features == ic and
34 | linear.out_features == oc and
35 | (linear.bias is not None) == has_bias)
36 | weight_name, bias_name = f'{name}_weight', f'{name}_bias'
37 | weight = linear.weight.data.transpose(1, 0).flatten().float().numpy()
38 | self.make_external(weight_name, weight, [ic, oc])
39 | if has_bias:
40 | bias = linear.bias.data.flatten().float().numpy()
41 | self.make_external(bias_name, bias, [oc])
42 | return weight_name, bias_name
43 |
44 | def rebuild(self):
45 | from onnx import helper
46 | new_nodes = []
47 | self.onnx_weight = open(self.onnx_weight_path, 'wb')
48 | for node in self.onnx_model.graph.node:
49 | if node.op_type == 'FakeLinear':
50 | attributes = {a.name: a for a in node.attribute}
51 | name = attributes.get('name').s.decode('utf-8')
52 | has_bias = attributes.get('has_bias').i
53 | ic = attributes.get('in_features').i
54 | oc = attributes.get('out_features').i
55 | weight, bias = self.build_weight(name, has_bias, ic, oc)
56 | if has_bias:
57 | # fakelinear -> matmul + add
58 | middle_tensor = f'{name}_matmul'
59 | new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], [middle_tensor], name))
60 | new_nodes.append(helper.make_node('Add', [middle_tensor, bias], node.output, f'{name}/Add'))
61 | else:
62 | # fakelinear -> matmul
63 | new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], node.output, name))
64 | else:
65 | new_nodes.append(node)
66 | self.onnx_weight.close()
67 | del self.onnx_model.graph.node[:]
68 | self.onnx_model.graph.node.extend(new_nodes)
69 | onnx.save(self.onnx_model, self.dst_path)
70 | return self.onnx_weight_path
--------------------------------------------------------------------------------
/llmexport/utils/custom_op.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class FakeLinearOp(torch.autograd.Function):
4 | @staticmethod
5 | def symbolic(g, input, in_features, out_features, has_bias, name):
6 | # These become the operator attributes.
7 | kwargs = {
8 | "in_features_i": in_features,
9 | "out_features_i": out_features,
10 | "has_bias_i": has_bias,
11 | "name_s": name
12 | }
13 | from torch.onnx.symbolic_helper import _get_tensor_sizes
14 | out_sizes = _get_tensor_sizes(input)[:-1] + [out_features]
15 | output_type = input.type().with_sizes(out_sizes)
16 | return g.op("LlmExporter::FakeLinear", input, **kwargs).setType(output_type)
17 |
18 | @staticmethod
19 | def forward(ctx, input, in_features, out_features, has_bias, name):
20 | out_shape = list(input.shape)[:-1] + [out_features]
21 | return input.new_zeros(out_shape)
22 |
23 | class FakeLinear(torch.nn.Module):
24 | def __init__(self, in_features, out_features, has_bias, name):
25 | super(FakeLinear, self).__init__()
26 | self.in_features = in_features
27 | self.out_features = out_features
28 | self.has_bias = has_bias
29 | self.name = name
30 |
31 | def forward(self, x):
32 | return FakeLinearOp.apply(x, self.in_features, self.out_features, self.has_bias, self.name)
33 |
34 | class FusedAttentionOp(torch.autograd.Function):
35 | @staticmethod
36 | def symbolic(g, query, key, value, attention_mask, hidden_size, name):
37 | # These become the operator attributes.
38 | kwargs = {
39 | "hidden_size_i": hidden_size,
40 | "name_s": name
41 | }
42 | from torch.onnx.symbolic_helper import _get_tensor_sizes
43 | out_sizes = _get_tensor_sizes(query)
44 | output_type = query.type().with_sizes(out_sizes)
45 | return g.op("LlmExporter::FusedAttention", query, key, value, attention_mask, **kwargs).setType(output_type)
46 |
47 | @staticmethod
48 | def forward(ctx, query, key, value, attention_mask, hidden_size, name):
49 | out_shape = list(query.shape)[:2] + [hidden_size]
50 | return query.new_zeros(out_shape)
51 |
52 | class FusedAttention(torch.nn.Module):
53 | def __init__(self, hidden_size, name):
54 | super(FusedAttention, self).__init__()
55 | self.hidden_size = hidden_size
56 | self.name = name
57 |
58 | def forward(self, query, key, value, attention_mask):
59 | return FusedAttentionOp.apply(query, key, value, attention_mask, self.hidden_size, self.name)
60 |
61 | class MoEOp(torch.autograd.Function):
62 | @staticmethod
63 | def symbolic(g, hidden_states, routing_weights, selected_experts, num_experts, top_k, layer_id):
64 | kwargs = {
65 | "num_experts_i": num_experts,
66 | "top_k_i": top_k,
67 | "layer_id_i": layer_id
68 | }
69 | from torch.onnx.symbolic_helper import _get_tensor_sizes
70 | out_sizes = _get_tensor_sizes(hidden_states)
71 | output_type = hidden_states.type().with_sizes(out_sizes)
72 | return g.op("LlmExporter::MoE", hidden_states, routing_weights, selected_experts, **kwargs).setType(output_type)
73 |
74 | @staticmethod
75 | def forward(ctx, hidden_states, routing_weights, selected_experts, num_experts, top_k, layer_id):
76 | return hidden_states
77 |
78 | class MoE(torch.nn.Module):
79 | def __init__(self, num_experts, top_k, layer_id):
80 | super(MoE, self).__init__()
81 | self.num_experts = num_experts
82 | self.top_k = top_k
83 | self.layer_id = layer_id
84 |
85 | def forward(self, hidden_states, routing_weights, selected_experts):
86 | return MoEOp.apply(hidden_states, routing_weights, selected_experts, self.num_experts, self.top_k, self.layer_id)
--------------------------------------------------------------------------------
/llmexport/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .hqq_quantizer import HQQQuantizer
3 |
4 | def repack_low_bits(x, iNeedBits, block_size):
5 | v = []
6 | device = x.device
7 | block_number = x.shape[0]
8 | count = block_size * iNeedBits // 8
9 | for i in range(0, count):
10 | v.append(torch.zeros([block_number, 1], dtype=torch.uint8, device=device))
11 | iOffset = 0
12 | cMask = (1 << iNeedBits) - 1
13 | index = 0
14 | for i in range(0, block_size):
15 | p0 = x[:, i:i+1]
16 | uShift = 8 - iNeedBits - (iOffset % 8)
17 | if uShift < 0:
18 | v[index+iOffset // 8] |= ((p0 & cMask) >> (0 - uShift))
19 | v[index+(iOffset // 8) + 1] |= ((p0 & cMask) << (8 + uShift))
20 | else:
21 | v[index+iOffset // 8] |= ((p0 & cMask) << uShift)
22 | iOffset += iNeedBits
23 | if iOffset % 8 == 0:
24 | index += iOffset // 8
25 | iOffset = 0
26 | return torch.cat(v, axis=1)
27 |
28 | def quant(weight, quant_bit, quant_block, symmetric, awq, hqq):
29 |
30 | try:
31 | if torch.cuda.is_available():
32 | weight = weight.cuda()
33 | if torch.backends.mps.is_available():
34 | weight = weight.to('mps')
35 | except:
36 | print('Failed to move weight to GPU, fallback to CPU')
37 |
38 | oc, ic = weight.shape
39 | if quant_block == 0:
40 | block_size = ic
41 | else:
42 | block_size = quant_block
43 | while ic % block_size != 0:
44 | block_size /= 2
45 | block_size = int(block_size)
46 | block_num = ic // block_size
47 |
48 | offset = 1 << (quant_bit - 1)
49 | clip_max = offset - 1
50 |
51 | if hqq:
52 | hqq_quantizer = HQQQuantizer(weight, quant_bit, block_size, symmetric, weight.dtype, weight.device)
53 | hqq_quantizer.quant()
54 | if not symmetric:
55 | q_weight = hqq_quantizer.W_q.flatten().to(torch.uint8)
56 | scale = hqq_quantizer.meta['scale'].flatten()
57 | zeros = scale * offset - scale * hqq_quantizer.meta['zero'].flatten()
58 |
59 | alpha = torch.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
60 | else:
61 | q_weight = (hqq_quantizer.W_q.flatten() + offset).to(torch.uint8)
62 | scale = hqq_quantizer.meta['scale'].flatten()
63 | alpha = scale.flatten()
64 | else:
65 | weight = weight.reshape(oc, block_num, block_size)
66 | if symmetric:
67 | clip_min = -clip_max
68 | abs_max, _ = torch.max(torch.abs(weight), axis=-1, keepdims=True)
69 | scale = abs_max / clip_max
70 | q_weight = torch.round(weight / scale)
71 | q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
72 | alpha = scale.flatten()
73 |
74 | else:
75 | clip_min = -offset
76 | max_val, _ = torch.max(weight, axis=-1, keepdims=True)
77 | min_val, _ = torch.min(weight, axis=-1, keepdims=True)
78 | scale = (max_val - min_val) / (clip_max - clip_min)
79 |
80 | if awq:
81 | q_weight = torch.round(weight / scale) - torch.round(min_val / scale) + clip_min
82 | zeros = (torch.round(min_val / scale) - clip_min) * scale
83 | else:
84 | q_weight = torch.round((weight - min_val) / scale) + clip_min
85 | zeros = min_val - scale * clip_min
86 | q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
87 | alpha = torch.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
88 |
89 | if quant_bit < 8 and 8 % quant_bit == 0:
90 | group_size = 8 // quant_bit
91 | q_weight = q_weight.reshape(-1, group_size)
92 | multipliers = [2 ** (quant_bit * (group_size - 1 - i)) for i in range(group_size)]
93 | multipliers = torch.tensor(multipliers).to(q_weight.device)
94 | q_weight = (q_weight * multipliers).sum(axis=1).to(torch.uint8)
95 | elif quant_bit < 8:
96 | q_weight = repack_low_bits(q_weight.reshape((block_num * oc, block_size)), quant_bit, block_size)
97 |
98 | if q_weight.device is not torch.device('cpu'):
99 | return q_weight.cpu(), alpha.float().cpu()
100 | return q_weight, alpha.float()
101 |
102 | def onnx_export(model, inputs, onnx_model, input_names, output_names, dynamic_axes=None):
103 | torch.onnx.export(
104 | model, inputs,
105 | onnx_model,
106 | input_names=input_names,
107 | output_names=output_names,
108 | dynamic_axes=dynamic_axes,
109 | do_constant_folding=True,
110 | verbose=False,
111 | dynamo=False,
112 | opset_version=15)
--------------------------------------------------------------------------------
/llmexport/utils/lora.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from safetensors import safe_open
4 |
5 | class LoRA:
6 | def __init__(self, lora_path, scale = 4.0):
7 | self.lora_A = {}
8 | self.lora_B = {}
9 | self.lora_keys = set()
10 | self.scale = scale
11 | self.load(lora_path)
12 |
13 | def __str__(self):
14 | return str(self.lora_keys)
15 |
16 | def has_lora(self, op_name):
17 | if op_name[0] != '/':
18 | return False
19 | for key in self.lora_keys:
20 | if key in op_name:
21 | return True
22 | return False
23 |
24 | def get_lora(self, tag):
25 | lora_a, lora_b = self.lora_A[tag], self.lora_B[tag]
26 | return lora_a, lora_b
27 |
28 | def load(self, path):
29 | if os.path.isdir(path):
30 | base_dir = path
31 | config = json.load(open(os.path.join(base_dir, 'adapter_config.json'), 'rt'))
32 | lora_alpha = config['lora_alpha']
33 | r = config['r']
34 | self.scale = float(lora_alpha) / r
35 | path = os.path.join(base_dir, 'adapter_model.safetensors')
36 | with safe_open(path, framework="pt") as f:
37 | for k in f.keys():
38 | names = k.split('.')
39 | layer, key, name = names[4], names[6], names[7]
40 | tag = layer + key
41 | tensor = f.get_tensor(k).float()
42 | self.lora_keys.add(key)
43 | if 'lora_A' == name:
44 | self.lora_A[tag] = tensor
45 | else:
46 | self.lora_B[tag] = tensor * self.scale
47 |
48 | def build_conv(self, input_index, output_name, dims, weight):
49 | output_index = len(self.base_model['tensorName'])
50 | oc, ic = dims
51 | bias = [0.0 for i in range(oc)]
52 | op = {
53 | 'type': 'Convolution',
54 | 'name': output_name,
55 | 'inputIndexes': [input_index],
56 | 'outputIndexes': [ output_index ],
57 | 'main_type': 'Convolution2D',
58 | 'main': {
59 | 'common': {
60 | 'dilateX': 1, 'dilateY': 1, 'strideX': 1, 'strideY': 1,
61 | 'kernelX': 1, 'kernelY': 1, 'padX': 0, 'padY': 0, 'group': 1,
62 | 'outputCount': oc, 'relu': False, 'padMode': 'CAFFE',
63 | 'relu6': False, 'inputCount': ic, 'hasOutputShape': False
64 | },
65 | "weight": weight,
66 | "bias": bias
67 | },
68 | 'defaultDimentionFormat': 'NHWC'
69 | }
70 | self.new_ops.append(op)
71 | self.base_model['tensorName'].append(output_name)
72 | return output_index
73 |
74 | def build_binary(self, op_type, input_indexes, output_name):
75 | # 0: Add, 2: Mul
76 | output_index = len(self.base_model['tensorName'])
77 | op = {
78 | "type": "BinaryOp",
79 | "name": output_name,
80 | "inputIndexes": input_indexes,
81 | "outputIndexes": [ output_index ],
82 | "main_type": "BinaryOp",
83 | "main": { "opType": op_type, "T": "DT_FLOAT", "activationType": 0 },
84 | "defaultDimentionFormat": "NHWC"
85 | }
86 | self.new_ops.append(op)
87 | self.base_model['tensorName'].append(output_name)
88 | return output_index
89 |
90 | def replace_input(self, origin_idx, new_idx):
91 | for op in self.base_model['oplists']:
92 | if op['type'] == 'ConvertTensor' and origin_idx in op['inputIndexes']:
93 | op['inputIndexes'] = [new_idx]
94 |
95 | def apply_lora(self, op):
96 | names = op['name'].split('/')
97 | tag = names[1].split('.')[1] + names[3]
98 | lora_a, lora_b = self.get_lora(tag)
99 | input_index = op['inputIndexes'][0]
100 | outpt_index = op['outputIndexes'][0]
101 | # lora_B @ lora_A @ x -> lora_B @ (lora_A @ x)
102 | a_out = self.build_conv(input_index, f'{tag}_A', list(lora_a.shape), lora_a.flatten().tolist())
103 | b_out = self.build_conv(a_out, f'{tag}_B', list(lora_b.shape), lora_b.flatten().tolist())
104 | n_out = self.build_binary(0, [outpt_index, b_out], f'{tag}_add')
105 | self.replace_input(outpt_index, n_out)
106 |
107 | def apply(self, base_path, out):
108 | self.base_model = json.load(open(base_path, 'rt'))
109 | self.new_ops = []
110 | for i in range(len(self.base_model['oplists'])):
111 | op = self.base_model['oplists'][i]
112 | self.new_ops.append(op)
113 | if op['type'] == 'Convolution':
114 | if self.has_lora(op['name']):
115 | self.apply_lora(op)
116 | self.base_model['oplists'] = self.new_ops
117 | with open(out, 'w', encoding='utf-8') as file:
118 | json.dump(self.base_model, file, ensure_ascii=False, indent=4)
119 | return out
120 |
--------------------------------------------------------------------------------
/llmexport/utils/mnn_utils.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import json
3 | def write_quant_header(file, ic, oc, quant_bit):
4 | dim_num = file.write(b'\x02')
5 | shape_dtype = numpy.int16
6 | if oc > 65535 or ic > 65535:
7 | shape_dtype = numpy.int32
8 | dim_length = file.write(numpy.array([oc, ic]).astype(shape_dtype))
9 | offset = 1 << (quant_bit - 1)
10 | weight_map = [i for i in range(-offset, offset)]
11 | if len(weight_map) == 256:
12 | weight_map.insert(0, 0)
13 | else:
14 | weight_map.insert(0, len(weight_map))
15 | map_length = file.write(numpy.array(weight_map, dtype=numpy.int8))
16 | header_length = dim_num + dim_length + map_length
17 | return header_length, shape_dtype == numpy.int32
18 |
19 |
20 | def repack_low_bits(x, iNeedBits, block_size):
21 | v = []
22 | block_number = x.shape[0]
23 | count = block_size * iNeedBits // 8
24 | for i in range(0, count):
25 | v.append(numpy.zeros([block_number, 1]).astype(numpy.uint8))
26 | iOffset = 0
27 | cMask = (1 << iNeedBits) - 1
28 | index = 0
29 | for i in range(0, block_size):
30 | p0 = x[:, i:i+1]
31 | uShift = 8 - iNeedBits - (iOffset % 8)
32 | if uShift < 0:
33 | v[index+iOffset // 8] |= ((p0 & cMask) >> (0 - uShift))
34 | v[index+(iOffset // 8) + 1] |= ((p0 & cMask) << (8 + uShift))
35 | else:
36 | v[index+iOffset // 8] |= ((p0 & cMask) << uShift)
37 | iOffset += iNeedBits
38 | if iOffset % 8 == 0:
39 | index += iOffset // 8
40 | iOffset = 0
41 | return numpy.concatenate(v, axis=1)
42 |
43 | class Block:
44 | def __init__(self):
45 | self.conv = []
46 | self.layernorm = []
47 |
48 | def load_mnn(filename):
49 | mnn = {}
50 | with open(filename) as f:
51 | mnn = json.load(f)
52 | conv_indexes = []
53 | layernorm_indexes = []
54 | blockops = []
55 | for op in mnn["oplists"]:
56 | if op['type'] == 'LayerNorm':
57 | if 'external' in op['main']:
58 | del op['main']['external']
59 | if 'gamma' in op['main']:
60 | del op['main']['gamma']
61 | if 'beta' in op['main']:
62 | del op['main']['beta']
63 | layernorm_indexes.append(len(blockops))
64 | blockops.append(op)
65 | continue
66 | if op['type'] == 'Convolution':
67 | conv_indexes.append(len(blockops))
68 | blockops.append(op)
69 | block = None
70 | blockes = []
71 | conv_order = ['attn_q', 'attn_k', 'attn_v', 'attn_output', 'ffn_gate', 'ffn_up', 'ffn_down']
72 | blockNumber = len(conv_indexes) // len(conv_order)
73 | print("Layers number: ", blockNumber, ", conv number: ", len(conv_indexes), ", layernorm number:", len(layernorm_indexes))
74 | block_layernorms = len(layernorm_indexes) // blockNumber
75 | assert(len(layernorm_indexes) == block_layernorms * blockNumber + 1)
76 | for i in range(0, blockNumber):
77 | block = Block()
78 | sta_conv = len(conv_order) * i
79 | for j in range(0, len(conv_order)):
80 | index = conv_indexes[sta_conv + j]
81 | block.conv.append(blockops[index])
82 | sta_layernorm = block_layernorms * i
83 | for j in range(0, block_layernorms):
84 | index = layernorm_indexes[sta_layernorm + j]
85 | block.layernorm.append(blockops[index])
86 | blockes.append(block)
87 | # Last layernorm and lm
88 | output_norm = blockops[layernorm_indexes[len(layernorm_indexes)-1]]
89 | lm = blockops[conv_indexes[len(conv_indexes)-1]]
90 | lm['name'] = 'output'
91 | opmap = {}
92 | opmap['output_norm'] = output_norm
93 | convs = []
94 | for i in range(0, len(blockes)):
95 | _block = blockes[i]
96 | if len(_block.layernorm) == 2:
97 | opmap['blk.%d' %i + '.attn_norm']= _block.layernorm[0]
98 | opmap['blk.%d' %i + '.ffn_norm']= _block.layernorm[1]
99 | elif len(_block.layernorm) == 6:
100 | names = ['attn_norm', 'attn_q_norm', 'attn_k_norm', 'post_attention_norm', 'ffn_norm', 'post_ffw_norm']
101 | for j in range(0, len(_block.layernorm)):
102 | opmap['blk.%d' %i + '.%s' %names[j]]= _block.layernorm[j]
103 | else:
104 | assert(False)
105 | for j in range(0, 7):
106 | newname = 'blk.%d' %i + '.' + conv_order[j]
107 | _block.conv[j]['name'] = newname
108 | convs.append(_block.conv[j])
109 | convs.append(lm)
110 |
111 | return mnn, opmap, convs, blockes, block
112 |
113 |
114 | def write_quant_parameters(quant_bit, asymc, mnn_weight_file, ic, oc, weight_main, scalebias, mnn_weight_offset, need_scale_treat = True):
115 | conv = {}
116 | aMin = 0
117 | readType = 0
118 | if asymc:
119 | # Avoid aMin post treat for bias
120 | offset = -(1 << (quant_bit - 1))
121 | aMin = 1
122 | if need_scale_treat:
123 | scalebias = scalebias.reshape([-1, 2])
124 | bias = scalebias[:, 0:1]
125 | scale = scalebias[:, 1:2]
126 | bias = bias - offset * scale
127 | scalebias = numpy.concatenate([bias, scale], axis=1).astype(numpy.float32)
128 | readType = 1
129 | header_len, shape_int32 = write_quant_header(mnn_weight_file, ic, oc, quant_bit)
130 | weight_len = mnn_weight_file.write(weight_main.tobytes()) + header_len
131 | alpha_len = mnn_weight_file.write(scalebias.tobytes())
132 | conv['quanParameter'] = {
133 | "quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
134 | "useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
135 | "type": 1, "aMaxOrBits": quant_bit, "aMin": aMin, "readType": readType, "weightSize": 0
136 | }
137 | conv['external'] = [mnn_weight_offset, weight_len, alpha_len, oc * 4, 0]
138 | mnn_weight_offset += (weight_len + alpha_len)
139 | return conv, header_len, mnn_weight_offset
--------------------------------------------------------------------------------
/llmexport/utils/gptq.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import torch
4 | from safetensors import safe_open
5 |
6 | class GPTQWeight:
7 | def __init__(self, name):
8 | self.name = name
9 |
10 | def __repr__(self) -> str:
11 | if hasattr(self, 'qweight'):
12 | return f'{self.name}, {self.qweight.shape}, {self.scales.shape}'
13 | return 'None'
14 |
15 | def add(self, name, tensor):
16 | setattr(self, name, tensor)
17 |
18 | def weight(self, idx):
19 | shape = self.qweight.shape
20 | if len(shape) == 2:
21 | ic, oc = shape
22 | self.qweight = self.qweight.reshape(ic//16, 16, oc)
23 | return self.qweight[idx]
24 |
25 | def scale(self, idx):
26 | return self.scales[idx]
27 |
28 | class MNNWeight:
29 | def __init__(self, name, external, weight_elements):
30 | self.name = name
31 | self.external = external
32 | self.quant_bits = 4
33 | if round(weight_elements / external[1]) == 2:
34 | self.quant_bits = 4
35 | self.a_min = -8
36 | else:
37 | self.quant_bits = 8
38 | self.a_min = -128
39 | self.parse_name()
40 |
41 | def __repr__(self) -> str:
42 | return f'{self.layer_id}.{self.op_id}.{self.block_id}, {self.external}'
43 |
44 | def parse_name(self):
45 | parts = self.name.split('/')
46 | if len(parts) > 4:
47 | self.layer_id = parts[1].split('.')[1]
48 | self.op_id = parts[2] + '.' + parts[3]
49 | self.block_id = parts[-1].split('__')[-1]
50 | else:
51 | self.layer_id = -1
52 | self.op_id = parts[2]
53 | self.block_id = parts[-1].split('__')[-1]
54 |
55 | def key(self):
56 | if self.layer_id == -1: return self.op_id
57 | return f'{self.layer_id}.{self.op_id}'
58 | def offset(self): return self.external[0]
59 | def weight_size(self): return self.external[1]
60 | def scale_size(self): return self.external[2]
61 |
62 | class GPTQ:
63 | def __init__(self, gptq_path):
64 | self.load(gptq_path)
65 |
66 | def load(self, path):
67 | for tensor in glob.glob(f'{path}/*.safetensors'):
68 | self.load_safetensor(tensor)
69 |
70 | def prefix(self, name):
71 | splits = name.split('.')
72 | if 'lm_head' in splits[0] and len(splits) == 2:
73 | return splits[0], splits[1]
74 | if len(splits) < 5:
75 | return None, None
76 | pre = f'{splits[2]}.{splits[3]}.{splits[4]}'
77 | suf = splits[-1]
78 | return pre, suf
79 |
80 | def get(self, key : str):
81 | if key in self.weight_dict:
82 | return self.weight_dict[key]
83 | return None
84 |
85 | def load_safetensor(self, tensor):
86 | self.weight_dict = dict()
87 | with safe_open(tensor, framework="pt") as f:
88 | for k in f.keys():
89 | p, s = self.prefix(k)
90 | if p is None: continue
91 | if s not in ['qweight', 'scales']: continue
92 | if p not in self.weight_dict:
93 | self.weight_dict[p] = GPTQWeight(p)
94 | self.weight_dict[p].add(s, f.get_tensor(k))
95 |
96 | @staticmethod
97 | def weight_reorder(qweight, bits=4, group_size=128):
98 | oc = qweight.shape[-1]
99 | wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)
100 | weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
101 | torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
102 | weight = weight.reshape(-1, oc).transpose(1, 0)
103 | if bits == 8:
104 | weight = weight.to(torch.uint8)
105 | return weight
106 | weight = weight.reshape(-1, 2).to(torch.uint8)
107 | weight = weight[:, 0] * 16 + weight[:, 1]
108 | return weight
109 |
110 | def apply(self, graph_path, weight_path):
111 | # parse mnn graph
112 | mnn_weights = []
113 | mnn_graph = json.load(open(graph_path, 'rt'))
114 | for op in mnn_graph['oplists']:
115 | if op['type'] == 'Convolution':
116 | name = op['name']
117 | external = op['main']['external']
118 | weight_elements = op['main']['common']['outputCount'] * op['main']['common']['inputCount']
119 | mnn_weights.append(MNNWeight(name, external, weight_elements))
120 | # load mnn weight
121 | external_weight = open(weight_path, 'r+b')
122 | for mnn_weight in mnn_weights:
123 | gptq_weight = self.get(mnn_weight.key())
124 | if gptq_weight is None: continue
125 | # print(f'write {mnn_weight.key()} ... ', end='')
126 | weight = gptq_weight.qweight
127 | scale = gptq_weight.scales.float().transpose(1, 0)
128 | # write weight data
129 | weight = GPTQ.weight_reorder(weight, mnn_weight.quant_bits)
130 | weight_bytes = weight.numpy().tobytes()
131 | weight_size = mnn_weight.weight_size()
132 | header_len = weight_size - len(weight_bytes)
133 | assert(header_len > 0)
134 | external_weight.seek(mnn_weight.offset() + header_len)
135 | external_weight.write(weight_bytes)
136 | scale_size = mnn_weight.scale_size()
137 | is_asy = scale.numel() * scale.element_size() < scale_size
138 | # write scale data
139 | if is_asy:
140 | # zeros = mnn_weight.a_min * scale
141 | zeros = torch.zeros_like(scale)
142 | scale = torch.stack([zeros, scale], axis=-1)
143 | scale_bytes = scale.numpy().tobytes()
144 | assert(scale_size == len(scale_bytes))
145 | external_weight.write(scale_bytes)
146 | # print('Done!')
147 | external_weight.close()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LLM-Export
2 |
3 | [](https://badge.fury.io/py/llmexport)
4 | [](https://www.python.org/downloads/release/python-380/)
5 | [](https://opensource.org/licenses/MIT)
6 |
7 | [English](./README_en.md) | 中文
8 |
9 | 一个高效的大语言模型导出工具,能够将 LLM 模型导出为 ONNX 和 MNN 格式,支持量化优化和多模态模型。
10 |
11 | ## ✨ 主要特性
12 |
13 | - 🚀 **动态形状支持**:优化原始代码,支持动态输入形状
14 | - 🚀 **模型优化**:减少常量部分,提升推理性能
15 | - 🚀 **自动优化**:集成 [OnnxSlim](https://github.com/inisis/OnnxSlim) 优化 ONNX 模型,性能提升约 5% (感谢 [@inisis](https://github.com/inisis))
16 | - 🚀 **LoRA 支持**:支持 LoRA 权重的合并/分离导出
17 | - 🚀 **量化技术**:支持 AWQ、GPTQ、HQQ等多种量化方法
18 | - 🚀 **EAGLE 支持**:支持 EAGLE 推理加速技术
19 | - 🚀 **多模态支持**:支持文本、图像、音频等多模态模型
20 | - 🚀 **推理框架**:提供 [MNN](https://github.com/wangzhaode/mnn-llm) 和 [ONNX](https://github.com/wangzhaode/onnx-llm) 推理代码
21 |
22 | ## 📜 快速开始
23 |
24 | ### 安装
25 |
26 | ```bash
27 | # 从 PyPI 安装(推荐)
28 | pip install llmexport
29 |
30 | # 从 GitHub 安装最新版本
31 | pip install git+https://github.com/wangzhaode/llm-export@master
32 |
33 | # 本地开发安装
34 | git clone https://github.com/wangzhaode/llm-export
35 | cd llm-export
36 | pip install -e .
37 | ```
38 |
39 | ### 基本用法
40 |
41 | #### 1. 下载模型
42 |
43 | ```bash
44 | # 使用 Hugging Face CLI
45 | huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir Qwen2.5-1.5B-Instruct
46 |
47 | # 或使用 ModelScope(国内用户推荐)
48 | modelscope download Qwen/Qwen2.5-1.5B-Instruct --local_dir Qwen2.5-1.5B-Instruct
49 | ```
50 |
51 | #### 2. 模型测试
52 |
53 | ```bash
54 | # 文本对话测试
55 | llmexport --path Qwen2.5-1.5B-Instruct --test "你好,请介绍一下你自己"
56 |
57 | # 多模态测试(图像+文本)
58 | llmexport --path Qwen2-VL-2B-Instruct --test "
image_url描述一下这张图片"
59 | ```
60 |
61 | #### 3. 模型导出
62 |
63 | ```bash
64 | # 导出为 ONNX 格式
65 | llmexport --path Qwen2.5-1.5B-Instruct --export onnx
66 |
67 | # 导出为 MNN 格式(默认 4bit 量化)
68 | llmexport --path Qwen2.5-1.5B-Instruct --export mnn
69 |
70 | # 自定义量化参数
71 | llmexport --path Qwen2.5-1.5B-Instruct --export mnn --quant_bit 8 --quant_block 128
72 |
73 | # 导出 EAGLE 模型
74 | llmexport --path Qwen2.5-1.5B-Instruct --export mnn --eagle_path path/to/eagle
75 | ```
76 |
77 | ## 🔧 高级功能
78 |
79 | ### 模型导出选项
80 |
81 | - **ONNX 导出**:使用 `--export onnx` 导出为 ONNX 格式
82 | - **MNN 导出**:使用 `--export mnn` 导出为 MNN 格式
83 | - **模型优化**:默认启用 OnnxSlim 优化,使用 `--onnx_slim` 显式启用
84 | - **EAGLE 导出**:使用 `--eagle_path` 导出 EAGLE 加速模型
85 |
86 | ### 量化配置
87 |
88 | - **量化位数**:`--quant_bit 4/8` (默认 4bit)
89 | - **量化块大小**:`--quant_block 64/128` (默认 64)
90 | - **LM Head 量化**:`--lm_quant_bit` 单独设置输出层量化
91 | - **对称量化**:`--sym` 启用对称量化(无零点)
92 |
93 | ### 量化算法支持
94 |
95 | - **AWQ 量化**:`--awq` 启用 AWQ 量化
96 | - **HQQ 量化**:`--hqq` 启用 HQQ 量化
97 | - **GPTQ 量化**:`--gptq_path` 加载 GPTQ 量化模型
98 | - **Smooth 量化**:`--smooth` 启用 Smooth 量化
99 |
100 | ### LoRA 支持
101 |
102 | - **LoRA 合并**:`--lora_path` 指定 LoRA 权重路径
103 | - **LoRA 分离**:`--lora_split` 分离导出 LoRA 权重
104 |
105 | ### 多模态支持
106 |
107 | - **视觉量化**:`--visual_quant_bit`、`--visual_quant_block` 设置视觉模块量化
108 | - **视觉对称**:`--visual_sym` 视觉模块对称量化
109 |
110 | ### 其他选项
111 |
112 | - **详细输出**:`--verbose` 显示详细日志
113 | - **性能评估**:`--ppl` 获取所有 token 的 logits
114 | - **自定义输出**:`--dst_path` 指定输出目录(默认 `./model`)
115 | - **EAGLE 支持**:`--eagle_path` 指定 EAGLE 模型路径
116 |
117 | ## 📎 命令行参数
118 |
119 | ### 基本参数
120 |
121 | | 参数 | 类型 | 说明 |
122 | |------|------|------|
123 | | `--path` | 必需 | 模型路径,支持本地目录或 Hugging Face 模型 ID |
124 | | `--export` | 可选 | 导出格式:`onnx` 或 `mnn` |
125 | | `--test` | 可选 | 测试查询字符串 |
126 | | `--dst_path` | 可选 | 输出目录(默认 `./model`) |
127 | | `--verbose` | 开关 | 显示详细日志 |
128 |
129 | ### 量化参数
130 |
131 | | 参数 | 默认值 | 说明 |
132 | |------|--------|------|
133 | | `--quant_bit` | 4 | 量化位数(4 或 8) |
134 | | `--quant_block` | 64 | 量化块大小(0 表示通道级量化) |
135 | | `--lm_quant_bit` | 同 `quant_bit` | LM Head 层量化位数 |
136 | | `--visual_quant_bit` | 模型相关 | 视觉模块量化位数 |
137 | | `--visual_quant_block` | 模型相关 | 视觉模块量化块大小 |
138 |
139 | ### 量化算法
140 |
141 | | 参数 | 说明 |
142 | |------|------|
143 | | `--awq` | 启用 AWQ 量化 |
144 | | `--hqq` | 启用 HQQ 量化 |
145 | | `--smooth` | 启用 Smooth 量化 |
146 | | `--sym` | 启用对称量化(无零点) |
147 | | `--visual_sym` | 视觉模块对称量化 |
148 |
149 | ### LoRA 支持
150 |
151 | | 参数 | 说明 |
152 | |------|------|
153 | | `--lora_path` | LoRA 权重路径 |
154 | | `--lora_split` | 分离导出 LoRA 权重 |
155 |
156 | ### EAGLE 支持
157 |
158 | | 参数 | 说明 |
159 | |------|------|
160 | | `--eagle_path` | EAGLE 模型路径 |
161 |
162 | ### 其他选项
163 |
164 | | 参数 | 说明 |
165 | |------|------|
166 | | `--tokenizer_path` | 分词器路径(默认使用 `--path`) |
167 | | `--gptq_path` | GPTQ 量化模型路径 |
168 | | `--mnnconvert` | 本地 MNNConvert 路径 |
169 | | `--onnx_slim` | 启用 ONNX-Slim 优化 |
170 | | `--ppl` | 获取所有 token 的 logits |
171 | | `--seperate_embed` | 分离嵌入层以避免量化 |
172 | | `--calib_data` | 校准数据路径 |
173 |
174 | ## 📋 支持模型
175 |
176 | 目前支持以下模型类型:
177 |
178 | ### 文本模型
179 | - **Qwen 系列**:Qwen3、Qwen2.5、Qwen2、Qwen1.5、Qwen-VL 等
180 | - **LLaMA 系列**:Llama-3.2、Llama-3、Llama-2 等
181 | - **ChatGLM 系列**:ChatGLM4、ChatGLM3、ChatGLM2 等
182 | - **Baichuan 系列**:Baichuan2-7B-Chat 等
183 | - **Yi 系列**:Yi-6B-Chat 等
184 | - **其他**:InternLM、DeepSeek、Phi、Gemma、TinyLlama、SmolLM 等
185 |
186 | ### 多模态模型
187 | - **视觉模型**:Qwen2-VL、Qwen2.5-VL、Qwen3-VL、Llama-3.2-Vision、InternVL 等
188 | - **音频模型**:Qwen2-Audio、Qwen2.5-Omni 等
189 |
190 | ### 嵌入模型
191 | - **文本嵌入**:bge-large-zh、gte-multilingual 等
192 |
193 | ## 💾 模型下载
194 |
195 | 我们提供了已经优化的模型下载:
196 |
197 | - **Hugging Face**:[taobao-mnn](https://huggingface.co/taobao-mnn)
198 | - **ModelScope**:[MNN](https://modelscope.cn/organization/MNN)
199 |
200 | 部分热门模型:
201 |
202 | | 模型 | Hugging Face | ModelScope |
203 | |------|-------------|------------|
204 | | DeepSeek-R1-1.5B-Qwen | [Q4_1](https://huggingface.co/taobao-mnn/DeepSeek-R1-1.5B-Qwen-MNN) | [Q4_1](https://modelscope.cn/models/MNN/DeepSeek-R1-1.5B-Qwen-MNN) |
205 | | Qwen2.5-0.5B-Instruct | [Q4_1](https://huggingface.co/taobao-mnn/Qwen2.5-0.5B-Instruct-MNN) | [Q4_1](https://modelscope.cn/models/MNN/Qwen2.5-0.5B-Instruct-MNN) |
206 | | Qwen2.5-1.5B-Instruct | [Q4_1](https://huggingface.co/taobao-mnn/Qwen2.5-1.5B-Instruct-MNN) | [Q4_1](https://modelscope.cn/models/MNN/Qwen2.5-1.5B-Instruct-MNN) |
207 | | GPT-OSS-20B | [Q4_1](https://huggingface.co/taobao-mnn/gpt-oss-20b-MNN) | [Q4_1](https://modelscope.cn/models/MNN/gpt-oss-20b-MNN) |
208 | | Qwen3-4B-Instruct-2507 | [Q4_1](https://huggingface.co/taobao-mnn/Qwen3-4B-Instruct-2507-MNN) | [Q4_1](https://modelscope.cn/models/MNN/Qwen3-4B-Instruct-2507-MNN) |
209 |
210 | 更多模型请查看完整列表。
211 |
212 | ## 🔗 相关项目
213 |
214 | - **MNN 推理**:[mnn-llm](https://github.com/wangzhaode/mnn-llm) - MNN 框架的 LLM 推理库
215 | - **ONNX 推理**:[onnx-llm](https://github.com/wangzhaode/onnx-llm)、[OnnxLLM](https://github.com/inisis/OnnxLLM) - ONNX 格式推理库
216 | - **模型优化**:[OnnxSlim](https://github.com/inisis/OnnxSlim) - ONNX 模型优化工具
217 |
218 | ## 📄 许可证
219 |
220 | 本项目采用 [Apaache 2.0 许可证](./LICENSE)。
--------------------------------------------------------------------------------
/llmexport/utils/hqq_quantizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class HQQQuantizer:
6 | def __init__(self,
7 | weight,
8 | bit,
9 | group_size,
10 | sym=False,
11 | compute_dtype: torch.dtype = torch.float32,
12 | device: torch.device = torch.device("cpu"),
13 | quant_config: dict = None):
14 | self.weight = weight
15 | self.bit = bit
16 | self.group_size = group_size
17 | self.sym = sym
18 | self.compute_dtype = compute_dtype
19 | self.device = device
20 |
21 | def quant(self):
22 | self._quantize()
23 |
24 | @torch.inference_mode()
25 | def _quantize(
26 | self,
27 | channel_wise: bool = True,
28 | axis: int = 1,
29 | ) -> tuple:
30 |
31 | if self.group_size is not None:
32 | assert self.weight.numel() % self.group_size == 0, (
33 | "group_size should be divisble by the total tensor dimensions. shape: "
34 | + str(self.weight.shape)
35 | + ", group_size: "
36 | + str(self.group_size)
37 | )
38 |
39 | W = self.weight.to(self.compute_dtype).float()
40 | shape = W.shape
41 |
42 | # Reshape for grouping
43 | if (self.group_size is not None) and channel_wise:
44 | W = (
45 | W.reshape([-1, self.group_size])
46 | if (axis == 1)
47 | else W.reshape([self.group_size, -1])
48 | )
49 |
50 | # Get min/max values
51 | if not channel_wise:
52 | _min, _max = W.min(), W.max()
53 | else:
54 | _min = W.min(axis=axis, keepdim=True)[0]
55 | _max = W.max(axis=axis, keepdim=True)[0]
56 |
57 | if self.sym:
58 | max_v = 2**(self.bit-1) - 1 # 4bit: 7
59 | min_v = -2**(self.bit-1) # 4bit: -8
60 | min_max = [min_v, max_v] # [-8, 7]
61 |
62 | max_abs = torch.max(torch.abs(_min), torch.abs(_max))
63 | scale = max_v / max_abs
64 | scale = torch.where(max_abs <= 1e-4, torch.full_like(scale, 1.0), scale)
65 | scale = scale.clamp(max=2e4)
66 | zero = None
67 | else:
68 | max_v = round(2**self.bit - 1) # 4bit: 15
69 | min_v = 0 # 4bit: 0
70 | min_max = [min_v, max_v] # [0, 15]
71 |
72 | denom = (_max - _min)
73 | scale = (max_v / denom)
74 | scale = torch.where(denom.abs() <= 1e-4, torch.full_like(scale, 1.0), scale)
75 | scale = scale.clamp(max=2e4)
76 | zero = -_min * scale
77 | zero = torch.round(zero)
78 |
79 | W_q, scale, zero = self._optimize_weights(
80 | W,
81 | scale,
82 | zero,
83 | min_max=min_max,
84 | axis=axis,
85 | )
86 | #W_q = (W * scale).round_().clamp_(min_max[0], min_max[1])
87 | # cleanup
88 | del W, _min, _max
89 |
90 | # Store meta-data (we invert the scale for dequantization)
91 | scale = 1.0 / scale
92 | meta = {
93 | "nbits": self.bit,
94 | "group_size": self.group_size,
95 | "shape": shape,
96 | "scale": scale,
97 | "zero": zero,
98 | "axis": axis,
99 | }
100 |
101 | W_q = W_q.to(self.weight.dtype)
102 |
103 | if self.device == torch.device('cuda'):
104 | torch.cuda.empty_cache()
105 | elif self.device == torch.device('mps'):
106 | torch.mps.empty_cache()
107 |
108 | self.W_q = W_q
109 | self.meta = meta
110 |
111 | @torch.inference_mode()
112 | def _optimize_weights(
113 | self,
114 | W: torch.Tensor,
115 | scale: torch.Tensor,
116 | zero: torch.Tensor,
117 | min_max: list,
118 | axis: int = 0,
119 | opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20},
120 | verbose: bool = False,
121 | ) -> tuple:
122 | lp_norm, beta, kappa, iters = (
123 | opt_params["lp_norm"],
124 | opt_params["beta"],
125 | opt_params["kappa"],
126 | opt_params["iters"],
127 | )
128 |
129 | dtype = torch.float32
130 | W_f = W.to(dtype=dtype, device=self.device)
131 | scale = scale.to(dtype=dtype, device=self.device)
132 | if not self.sym:
133 | zero = zero.to(dtype=dtype, device=self.device)
134 |
135 | best_error = torch.tensor(torch.inf, dtype=torch.float32, device=self.device)
136 | W_q = torch.empty_like(W_f)
137 | W_r = torch.empty_like(W_f)
138 | W_e = torch.empty_like(W_f)
139 | W_prime = torch.empty_like(W_f) if self.sym else None
140 | for i in range(iters):
141 | if not self.sym:
142 | self._optimize_weights_proximal_legacy_step(W_f, scale, zero, min_max, beta, lp_norm, axis, W_q, W_r, W_e)
143 | else:
144 | self._optimize_weights_proximal_scale_only(W_f, scale, min_max, beta, lp_norm, axis, W_q, W_r, W_e, W_prime)
145 | current_error = torch.abs(W_f - W_r).mean().float()
146 | if verbose:
147 | print(i, current_error.cpu())
148 |
149 | if current_error < best_error:
150 | best_error = current_error
151 | else:
152 | break
153 |
154 | scale = scale.to(W.device)
155 | if not self.sym:
156 | zero = zero.to(W.device)
157 | del W_f, W_q, W_r
158 | if self.device.type == 'cuda':
159 | torch.cuda.empty_cache()
160 | elif self.device.type == 'mps':
161 | torch.mps.empty_cache()
162 | if not self.sym:
163 | W_q = torch.round(W * scale + zero).clamp_(min_max[0], min_max[1])
164 | else:
165 | W_q = torch.round(W * scale).clamp_(min_max[0], min_max[1])
166 | return W_q, scale, zero
167 |
168 | @torch.inference_mode()
169 | def _optimize_weights_proximal_legacy_step(self, W_f, scale, zero, min_max, beta, lp_norm, axis, W_q, W_r, W_e):
170 | torch.mul(W_f, scale, out=W_q)
171 | torch.add(W_q, zero, out=W_q)
172 | torch.round(W_q, out=W_q).clamp_(min_max[0], min_max[1])
173 | torch.sub(W_q, zero, out=W_r)
174 | torch.div(W_r, scale, out=W_r)
175 | torch.sub(W_f, W_r, out=W_e)
176 | self._shrink_lp_op(W_e, beta, lp_norm, out=W_e)
177 | torch.sub(W_f, W_e, out=W_r)
178 | torch.mul(W_r, scale, out=W_r)
179 | torch.sub(W_q, W_r, out=W_r)
180 | torch.mean(W_r, axis=axis, keepdim=True, out=zero)
181 |
182 | @torch.inference_mode()
183 | def _optimize_weights_proximal_scale_only(self, W_f, scale, min_max, beta, lp_norm, axis, W_q, W_r, W_e, W_prime, eps=1e-8):
184 | torch.mul(W_f, scale, out=W_q)
185 | torch.round(W_q, out=W_q).clamp_(min_max[0], min_max[1])
186 | torch.div(W_q, scale, out=W_r)
187 | torch.sub(W_f, W_r, out=W_e)
188 | self._shrink_lp_op(W_e, beta, lp_norm, out=W_e)
189 | torch.sub(W_f, W_e, out=W_prime)
190 | w_prime_dot_w_q = torch.sum(W_prime * W_q, axis=axis, keepdim=True)
191 | w_q_norm_sq = torch.sum(W_q**2, axis=axis, keepdim=True)
192 | torch.add(w_prime_dot_w_q, eps, out=w_prime_dot_w_q)
193 | torch.div(w_q_norm_sq, w_prime_dot_w_q, out=scale)
194 |
195 | # Shrinking operator
196 | @torch.inference_mode()
197 | def _shrink_lp_op(self, x: torch.Tensor, beta: float, lp_norm: float, out: torch.Tensor) -> torch.Tensor:
198 | if lp_norm == 1:
199 | #torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
200 | torch.abs(x, out=out)
201 | out.sub_(1.0 / beta).clamp_min_(0.0)
202 | out.mul_(torch.sign(x))
203 | return out
204 | else:
205 | #torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1))
206 | torch.abs(x, out=out)
207 | out.sub_((1.0 / beta) * out.pow(lp_norm - 1)).clamp_min_(0.0)
208 | out.mul_(torch.sign(x))
209 | return out
--------------------------------------------------------------------------------
/llmexport/utils/talker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | torch.set_printoptions(precision=4, sci_mode=False)
4 | from .model_mapper import ModelMapper
5 | from .transformers import Rotary, Embedding, Decoder
6 | from .token2wav import Qwen2_5OmniToken2Wav
7 | from .spinner import spinner_run
8 | from .torch_utils import onnx_export
9 |
10 | class Talker(torch.nn.Module):
11 | def __init__(self, talker, token2wav, base):
12 | super().__init__()
13 | self.model_type = base.model_type
14 | self.thinker_embed = base.embed
15 | self.args = base.args
16 | self.talker = talker.float()
17 | self.token2wav = Qwen2_5OmniToken2Wav(token2wav, base)
18 | self.config = base.config
19 | self.hidden_size = base.hidden_size
20 | self.llm_config = base.llm_config
21 | self.rope_ratio = 1.0
22 | self.quant_bit = 4
23 | if self.hidden_size <= 2048:
24 | # Qwen2.5-Omni-3B using 8 bit quantization
25 | self.quant_bit = 8
26 | self.init_config()
27 | self.load()
28 |
29 | @staticmethod
30 | def get_talker(model_type):
31 | audio_models = {
32 | 'qwen2_5_omni': Qwen2_5OmniTalker,
33 | }
34 | if model_type in audio_models:
35 | return audio_models[model_type]
36 | return None
37 |
38 | def init_config(self):
39 | self.llm_config['has_talker'] = True
40 |
41 | def load(self):
42 | raise NotImplementedError
43 |
44 | def add_token_embeds(self, thinker_embeds):
45 | raise NotImplementedError
46 |
47 | def add_hidden_states(self, thinker_hidden_states):
48 | raise NotImplementedError
49 |
50 | def add_generate_ids(self, token_id):
51 | raise NotImplementedError
52 |
53 | def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None):
54 | raise NotImplementedError
55 |
56 | def export(self, onnx_path):
57 | raise NotImplementedError
58 |
59 | def export_embed(self):
60 | import ctypes
61 | tensor_data = self.embed.weight.data.bfloat16()
62 | data_ptr = tensor_data.untyped_storage().data_ptr()
63 | buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
64 | embedding_file = f'{self.args.dst_path}/talker_embeddings_bf16.bin'
65 | with open(embedding_file, 'wb') as f:
66 | f.write(buffer)
67 | return embedding_file
68 |
69 | class OmniRotary(Rotary):
70 | def __init__(self, model):
71 | super().__init__(model)
72 | self.mrope_section = model.mrope_section
73 | self.theta_sections = self.theta.unsqueeze(0).split(self.mrope_section, dim=-1)
74 |
75 | def forward(self, position_ids):
76 | position_ids = position_ids.float().unsqueeze(-1)
77 | idx_theta = torch.concat([
78 | position_ids[0] * self.theta_sections[0],
79 | position_ids[1] * self.theta_sections[1],
80 | position_ids[2] * self.theta_sections[2]
81 | ], dim=-1)
82 | rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)])
83 | rotary_pos_emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
84 | rotary_pos_emb = rotary_pos_emb.unsqueeze(3)
85 | return rotary_pos_emb
86 |
87 | class Qwen2_5OmniTalker(Talker):
88 | def __init__(self, talker, token2wav, base):
89 | super().__init__(talker, token2wav, base)
90 | self.input_hidden_size = base.hidden_size
91 | self.seq_len = 0
92 | self.token_len = 0
93 | self.talker_embeds = []
94 |
95 | def load(self):
96 | # load talker model
97 | self.model_map = {
98 | 'config': {
99 | 'hidden_size': 'hidden_size',
100 | 'head_dim': 'head_dim',
101 | 'num_attention_heads': 'num_attention_heads',
102 | 'num_hidden_layers': 'num_hidden_layers',
103 | 'num_key_value_heads': 'num_key_value_heads',
104 | 'rope_theta': 'rope_theta',
105 | 'rope_scaling': 'rope_scaling'
106 | },
107 | 'decoder': {
108 | 'self_attn': 'self_attn',
109 | 'mlp': 'mlp',
110 | 'input_layernorm': 'input_layernorm',
111 | 'post_attention_layernorm': 'post_attention_layernorm'
112 | },
113 | 'attention': {
114 | 'q_proj': 'q_proj',
115 | 'k_proj': 'k_proj',
116 | 'v_proj': 'v_proj',
117 | 'o_proj': 'o_proj'
118 | }
119 | }
120 | ModelMapper.do_map(self, self.talker.config, self.model_map['config'])
121 | self.mrope_section = self.rope_scaling['mrope_section']
122 | self.embed = self.talker.model.embed_tokens
123 | self.rotary = OmniRotary(self)
124 | # self.rotary = Rotary(self)
125 | self.blocks = []
126 | for block in self.talker.model.layers:
127 | layer_id = len(self.blocks)
128 | self.blocks.append(Decoder(block, layer_id, self))
129 |
130 | def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None):
131 | hidden_states = self.talker.thinker_to_talker_proj(inputs_embeds)
132 | rotary_pos_emb = self.rotary(position_ids)
133 | presents = [None for i in range(self.num_hidden_layers)]
134 |
135 | for i in range(self.num_hidden_layers):
136 | hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
137 | presents[i] = kv
138 |
139 | hidden_states = hidden_states[:, -1, :]
140 | hidden_states = self.talker.model.norm(hidden_states)
141 | logits = self.talker.codec_head(hidden_states)
142 | presents = torch.stack(presents)
143 | return logits, presents
144 |
145 | def get_position_ids(self) -> torch.Tensor:
146 | if self.token_len:
147 | position_ids = torch.tensor([[self.seq_len - 1]], dtype=torch.int)
148 | else:
149 | position_ids = torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
150 | position_ids = torch.stack([position_ids] * 3)
151 | return position_ids
152 |
153 | def get_attention_mask(self) -> torch.Tensor:
154 | if self.token_len:
155 | return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
156 | return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
157 |
158 | def generate(self):
159 | talker_text_bos_token = 151872
160 | talker_inputs_embeds = torch.cat(
161 | [
162 | self.talker_embeds[0],
163 | self.thinker_embed(torch.tensor([[talker_text_bos_token]], dtype=torch.long)) + \
164 | self.embed(torch.LongTensor([self.talker.codec_pad_token])),
165 | self.talker_embeds[1] + self.embed(torch.LongTensor([self.talker.codec_bos_token])),
166 | ],
167 | dim=1,
168 | )
169 | thinker_reply_part = torch.cat(self.talker_embeds[2:], dim=1)
170 | thinker_reply_part = torch.cat(
171 | [
172 | thinker_reply_part,
173 | self.thinker_embed(
174 | torch.tensor([[self.talker.text_eos_token]], dtype=torch.long)
175 | ),
176 | self.thinker_embed(
177 | torch.tensor([[self.talker.text_pad_token]], dtype=torch.long)
178 | ),
179 | ],
180 | dim=1,
181 | )
182 |
183 | _, self.seq_len, _ = talker_inputs_embeds.shape
184 | _, reply_len, _ = thinker_reply_part.shape
185 | past_key_values = [None for i in range(self.num_hidden_layers)]
186 |
187 | inputs_embeds = talker_inputs_embeds.float()
188 | self.token_len = 0
189 | self.stop_ids = [8292, 8294]
190 | token_id = None
191 | tokens = []
192 | while self.token_len < 256:
193 | attention_mask = self.get_attention_mask()
194 | position_ids = self.get_position_ids()
195 | if self.token_len > 0:
196 | inputs_embeds = self.embed(token_id)
197 | if self.token_len <= reply_len:
198 | inputs_embeds = inputs_embeds + thinker_reply_part[:, self.token_len - 1, :]
199 | else:
200 | inputs_embeds = inputs_embeds + thinker_reply_part[:, -1, :]
201 | logits, past_key_values = self.forward(inputs_embeds=inputs_embeds,
202 | attention_mask=attention_mask,
203 | position_ids=position_ids,
204 | past_key_values=past_key_values)
205 | token_id = torch.argmax(logits)
206 | self.token_len += 1
207 | self.seq_len += 1
208 | tokens.append(int(token_id))
209 | if int(token_id) in self.stop_ids:
210 | break
211 | talker_generate_codes = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
212 | # 3. Generate wavs from code
213 | wav = self.token2wav.generate(talker_generate_codes,)
214 | import soundfile as sf
215 | sf.write(
216 | "output.wav",
217 | wav.reshape(-1).detach().cpu().numpy(),
218 | samplerate=24000,
219 | )
220 |
221 | def add_talker_embeds(self, talker_embed):
222 | self.talker_embeds.append(talker_embed)
223 |
224 | @spinner_run(f'export talker to ')
225 | def export(self, onnx_path):
226 | self.export_embed()
227 | self.seq_len = 3
228 | self.token_len = 0
229 | inputs_embeds = torch.randn([1, self.seq_len, self.input_hidden_size])
230 | posision_ids = self.get_position_ids()
231 | attention_mask = self.get_attention_mask()
232 | past_key_values = torch.zeros([self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim])
233 | talker_onnx = f'{onnx_path}/talker.onnx'
234 | onnx_export(self, (inputs_embeds, attention_mask, posision_ids, past_key_values),
235 | talker_onnx,
236 | input_names=['inputs_embeds', 'attention_mask', 'position_ids', 'past_key_values'],
237 | output_names=['logits'],
238 | dynamic_axes={
239 | "inputs_embeds": { 1: "size" },
240 | "attention_mask": { 2: "size", 3: "size" },
241 | "position_ids": { 2: "size" },
242 | "past_key_values": { 3: "size" }
243 | })
244 | return talker_onnx
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | # LLM-Export
2 |
3 | [](https://badge.fury.io/py/llmexport)
4 | [](https://www.python.org/downloads/release/python-380/)
5 | [](https://opensource.org/licenses/MIT)
6 |
7 | English | [中文](./README.md)
8 |
9 | An efficient Large Language Model export tool that converts LLM models to ONNX and MNN formats, supporting quantization optimization and multimodal models.
10 |
11 | ## ✨ Key Features
12 |
13 | - 🚀 **Dynamic Shape Support**: Optimized original code with dynamic input shape support
14 | - 🚀 **Model Optimization**: Reduced constant parts for improved inference performance
15 | - 🚀 **Automatic Optimization**: Integrated [OnnxSlim](https://github.com/inisis/OnnxSlim) for ONNX model optimization, ~5% performance improvement (Thanks [@inisis](https://github.com/inisis))
16 | - 🚀 **LoRA Support**: Support for LoRA weight merging/splitting export
17 | - 🚀 **Quantization Methods**: Support for AWQ, GPTQ, HQQ, and other quantization methods
18 | - 🚀 **Multimodal Support**: Support for text, image, audio, and other multimodal models
19 | - 🚀 **Inference Frameworks**: Provides [MNN](https://github.com/wangzhaode/mnn-llm) and [ONNX](https://github.com/wangzhaode/onnx-llm) inference code
20 |
21 | ## 📖 Quick Start
22 |
23 | ### Installation
24 |
25 | ```bash
26 | # Install from PyPI (Recommended)
27 | pip install llmexport
28 |
29 | # Install latest version from GitHub
30 | pip install git+https://github.com/wangzhaode/llm-export@master
31 |
32 | # Local development installation
33 | git clone https://github.com/wangzhaode/llm-export
34 | cd llm-export
35 | pip install -e .
36 | ```
37 |
38 | ### Basic Usage
39 |
40 | #### 1. Download Model
41 |
42 | ```bash
43 | # Using Hugging Face CLI
44 | huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir Qwen2.5-1.5B-Instruct
45 |
46 | # Or using ModelScope (Recommended for users in China)
47 | modelscope download Qwen/Qwen2.5-1.5B-Instruct --local_dir Qwen2.5-1.5B-Instruct
48 | ```
49 |
50 | #### 2. Model Testing
51 |
52 | ```bash
53 | # Text conversation testing
54 | llmexport --path Qwen2.5-1.5B-Instruct --test "Hello, please introduce yourself"
55 |
56 | # Multimodal testing (Image + Text)
57 | llmexport --path Qwen2-VL-2B-Instruct --test "
image_urlDescribe this image"
58 | ```
59 |
60 | #### 3. Model Export
61 |
62 | ```bash
63 | # Export to ONNX format
64 | llmexport --path Qwen2.5-1.5B-Instruct --export onnx
65 |
66 | # Export to MNN format (Default 4bit quantization)
67 | llmexport --path Qwen2.5-1.5B-Instruct --export mnn
68 |
69 | # Custom quantization parameters
70 | llmexport --path Qwen2.5-1.5B-Instruct --export mnn --quant_bit 8 --quant_block 128
71 | ```
72 |
73 | ## 🔧 Advanced Features
74 |
75 | ### Model Export Options
76 |
77 | - **ONNX Export**: Use `--export onnx` to export to ONNX format
78 | - **MNN Export**: Use `--export mnn` to export to MNN format
79 | - **Model Optimization**: OnnxSlim optimization enabled by default, use `--onnx_slim` to explicitly enable
80 |
81 | ### Quantization Configuration
82 |
83 | - **Quantization Bits**: `--quant_bit 4/8` (Default 4bit)
84 | - **Quantization Block Size**: `--quant_block 64/128` (Default 64)
85 | - **LM Head Quantization**: `--lm_quant_bit` separate setting for output layer quantization
86 | - **Symmetric Quantization**: `--sym` enable symmetric quantization (no zero point)
87 |
88 | ### Quantization Algorithm Support
89 |
90 | - **AWQ Quantization**: `--awq` enable AWQ quantization
91 | - **HQQ Quantization**: `--hqq` enable HQQ quantization
92 | - **GPTQ Quantization**: `--gptq_path` load GPTQ quantized model
93 | - **Smooth Quantization**: `--smooth` enable Smooth quantization
94 |
95 | ### LoRA Support
96 |
97 | - **LoRA Merging**: `--lora_path` specify LoRA weight path
98 | - **LoRA Splitting**: `--lora_split` export LoRA weights separately
99 |
100 | ### Multimodal Support
101 |
102 | - **Visual Quantization**: `--visual_quant_bit`, `--visual_quant_block` set visual module quantization
103 | - **Visual Symmetric**: `--visual_sym` visual module symmetric quantization
104 |
105 | ### Other Options
106 |
107 | - **Verbose Output**: `--verbose` show detailed logs
108 | - **Performance Evaluation**: `--ppl` get logits for all tokens
109 | - **Custom Output**: `--dst_path` specify output directory (default `./model`)
110 |
111 | ## 📎 Command Line Parameters
112 |
113 | ### Basic Parameters
114 |
115 | | Parameter | Type | Description |
116 | |-----------|------|-------------|
117 | | `--path` | Required | Model path, supports local directory or Hugging Face model ID |
118 | | `--export` | Optional | Export format: `onnx` or `mnn` |
119 | | `--test` | Optional | Test query string |
120 | | `--dst_path` | Optional | Output directory (default `./model`) |
121 | | `--verbose` | Flag | Show detailed logs |
122 |
123 | ### Quantization Parameters
124 |
125 | | Parameter | Default | Description |
126 | |-----------|---------|-------------|
127 | | `--quant_bit` | 4 | Quantization bits (4 or 8) |
128 | | `--quant_block` | 64 | Quantization block size (0 means channel-wise) |
129 | | `--lm_quant_bit` | Same as `quant_bit` | LM Head layer quantization bits |
130 | | `--visual_quant_bit` | Model dependent | Visual module quantization bits |
131 | | `--visual_quant_block` | Model dependent | Visual module quantization block size |
132 |
133 | ### Quantization Algorithms
134 |
135 | | Parameter | Description |
136 | |-----------|-------------|
137 | | `--awq` | Enable AWQ quantization |
138 | | `--hqq` | Enable HQQ quantization |
139 | | `--smooth` | Enable Smooth quantization |
140 | | `--sym` | Enable symmetric quantization (no zero point) |
141 | | `--visual_sym` | Visual module symmetric quantization |
142 |
143 | ### LoRA Support
144 |
145 | | Parameter | Description |
146 | |-----------|-------------|
147 | | `--lora_path` | LoRA weight path |
148 | | `--lora_split` | Export LoRA weights separately |
149 |
150 | ### Other Options
151 |
152 | | Parameter | Description |
153 | |-----------|-------------|
154 | | `--tokenizer_path` | Tokenizer path (default uses `--path`) |
155 | | `--gptq_path` | GPTQ quantized model path |
156 | | `--mnnconvert` | Local MNNConvert path |
157 | | `--onnx_slim` | Enable ONNX-Slim optimization |
158 | | `--ppl` | Get logits for all tokens |
159 | | `--seperate_embed` | Separate embedding layer to avoid quantization |
160 | | `--calib_data` | Calibration data path |
161 |
162 | ## Commad Args
163 | ```
164 | usage: llmexport.py [-h] --path PATH [--type TYPE] [--tokenizer_path TOKENIZER_PATH] [--lora_path LORA_PATH] [--gptq_path GPTQ_PATH] [--dst_path DST_PATH]
165 | [--verbose] [--test TEST] [--export EXPORT] [--onnx_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT]
166 | [--mnnconvert MNNCONVERT] [--ppl] [--awq] [--sym] [--tie_embed] [--lora_split]
167 |
168 | llm_exporter
169 |
170 | options:
171 | -h, --help show this help message and exit
172 | --path PATH path(`str` or `os.PathLike`):
173 | Can be either:
174 | - A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]
175 | - A path to a *directory* clone from repo like `../chatglm-6b`.
176 | --type TYPE type(`str`, *optional*):
177 | The pretrain llm model type.
178 | --tokenizer_path TOKENIZER_PATH
179 | tokenizer path, defaut is `None` mean using `--path` value.
180 | --lora_path LORA_PATH
181 | lora path, defaut is `None` mean not apply lora.
182 | --gptq_path GPTQ_PATH
183 | gptq path, defaut is `None` mean not apply gptq.
184 | --dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`.
185 | --verbose Whether or not to print verbose.
186 | --test TEST test model inference with query `TEST`.
187 | --export EXPORT export model to an onnx/mnn model.
188 | --onnx_slim Whether or not to use onnx-slim.
189 | --quant_bit QUANT_BIT
190 | mnn quant bit, 4 or 8, default is 4.
191 | --quant_block QUANT_BLOCK
192 | mnn quant block, default is 0 mean channle-wise.
193 | --lm_quant_bit LM_QUANT_BIT
194 | mnn lm_head quant bit, 4 or 8, default is `quant_bit`.
195 | --mnnconvert MNNCONVERT
196 | local mnnconvert path, if invalid, using pymnn.
197 | --ppl Whether or not to get all logits of input tokens.
198 | --awq Whether or not to use awq quant.
199 | --sym Whether or not to using symmetric quant (without zeropoint), defualt is False.
200 | --tie_embed Whether or not to using tie_embedding, defualt is False.
201 | --lora_split Whether or not export lora split, defualt is False.
202 | ```
203 |
204 | ## 📋 Supported Models
205 |
206 | Currently supports the following model types:
207 |
208 | ### Text Models
209 | - **Qwen Series**: Qwen2.5, Qwen2, Qwen1.5, Qwen-VL, etc.
210 | - **LLaMA Series**: Llama-3.2, Llama-3, Llama-2, etc.
211 | - **ChatGLM Series**: ChatGLM4, ChatGLM3, ChatGLM2, etc.
212 | - **Baichuan Series**: Baichuan2-7B-Chat, etc.
213 | - **Yi Series**: Yi-6B-Chat, etc.
214 | - **Others**: InternLM, DeepSeek, Phi, Gemma, TinyLlama, etc.
215 |
216 | ### Multimodal Models
217 | - **Vision Models**: Qwen2-VL, Qwen2.5-VL, Llama-3.2-Vision, InternVL, etc.
218 | - **Audio Models**: Qwen2-Audio, Qwen2.5-Omni, etc.
219 |
220 | ### Embedding Models
221 | - **Text Embedding**: bge-large-zh, gte-multilingual, etc.
222 |
223 | ## 💾 Model Downloads
224 |
225 | We provide optimized model downloads:
226 |
227 | - **Hugging Face**: [taobao-mnn](https://huggingface.co/taobao-mnn)
228 | - **ModelScope**: [MNN](https://modelscope.cn/organization/MNN)
229 |
230 | Popular models:
231 |
232 | | Model | Hugging Face | ModelScope |
233 | |-------|-------------|------------|
234 | | DeepSeek-R1-1.5B-Qwen | [Q4_1](https://huggingface.co/taobao-mnn/DeepSeek-R1-1.5B-Qwen-MNN) | [Q4_1](https://modelscope.cn/models/MNN/DeepSeek-R1-1.5B-Qwen-MNN) |
235 | | Qwen2.5-0.5B-Instruct | [Q4_1](https://huggingface.co/taobao-mnn/Qwen2.5-0.5B-Instruct-MNN) | [Q4_1](https://modelscope.cn/models/MNN/Qwen2.5-0.5B-Instruct-MNN) |
236 | | Qwen2.5-1.5B-Instruct | [Q4_1](https://huggingface.co/taobao-mnn/Qwen2.5-1.5B-Instruct-MNN) | [Q4_1](https://modelscope.cn/models/MNN/Qwen2.5-1.5B-Instruct-MNN) |
237 | | GPT-OSS-20B | [Q4_1](https://huggingface.co/taobao-mnn/gpt-oss-20b-MNN) | [Q4_1](https://modelscope.cn/models/MNN/gpt-oss-20b-MNN) |
238 | | Qwen3-4B-Instruct-2507 | [Q4_1](https://huggingface.co/taobao-mnn/Qwen3-4B-Instruct-2507-MNN) | [Q4_1](https://modelscope.cn/models/MNN/Qwen3-4B-Instruct-2507-MNN) |
239 |
240 | See the complete list for more models.
241 |
242 | ## 🔗 Related Projects
243 |
244 | - **MNN Inference**: [mnn-llm](https://github.com/wangzhaode/mnn-llm) - LLM inference library for MNN framework
245 | - **ONNX Inference**: [onnx-llm](https://github.com/wangzhaode/onnx-llm), [OnnxLLM](https://github.com/inisis/OnnxLLM) - ONNX format inference libraries
246 | - **Model Optimization**: [OnnxSlim](https://github.com/inisis/OnnxSlim) - ONNX model optimization tool
247 |
248 | ## 📄 License
249 |
250 | This project is licensed under the [MIT License](https://opensource.org/licenses/MIT).1.7B-Instruct-MNN) |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2024] [wangzhaode]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/llmexport/utils/mtp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from typing import Optional, List, Tuple
6 |
7 | from .transformers import Attention
8 | from .utils.custom_op import FakeLinear
9 | from .utils.spinner import spinner_run
10 | from .torch_utils import onnx_export
11 |
12 | class Mtp(torch.nn.Module):
13 | def __init__(self, mtp, base):
14 | super().__init__()
15 | self.model_type = base.model_type
16 | self.mtp = mtp
17 | self.embed_ = base.embed
18 | self.lm_ = base.lm
19 |
20 | self.config_ = base.config
21 | if not hasattr(base.config, 'head_dim'):
22 | self.config_.head_dim = base.head_dim
23 | self.config_.rotary = base.rotary
24 | self.config_.model_type = base.model_type
25 | self.config_.model_map = base.model_map
26 | self.hidden_size = base.hidden_size
27 | self.past_kv_shape = base.past_kv_shape
28 | self.num_attention_heads = base.num_attention_heads
29 | self.llm_config = base.llm_config
30 | self.load()
31 | self.unloaded_ops = {}
32 |
33 |
34 | @staticmethod
35 | def get_mtp(model_type):
36 | mtps = {
37 | 'mimo': MimoMtp,
38 | 'poi_qwen2_mtp' : PoiQwenMtp,
39 | }
40 | if model_type in mtps:
41 | return mtps[model_type]
42 | return None
43 |
44 | @spinner_run(f'export onnx model to ')
45 | def export(self, onnx_path):
46 | onnx_model = f'{onnx_path}/mtp.onnx'
47 |
48 | # unload linear weight to save export memory
49 | self.unload_param()
50 |
51 | self.seq_len = 3
52 | input_ids = torch.arange(3, dtype=torch.long)
53 | attention_mask = (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
54 | position_ids = torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
55 | hidden_states = torch.ones([self.seq_len, 1, self.hidden_size], dtype=torch.float)
56 |
57 | # For export onnx, don't need image or audio's embedding
58 | input_embed = self.embed_(input_ids)
59 | past_key_values = torch.zeros(self.past_kv_shape[1:])
60 | logits_index = torch.tensor([-1], dtype=torch.int32)
61 | # export to onnx
62 | with torch.no_grad():
63 | onnx_export(
64 | self, (input_embed, hidden_states, attention_mask, position_ids, past_key_values, logits_index),
65 | onnx_model,
66 | input_names=[
67 | 'input_embed', 'hidden_states',
68 | 'attention_mask', 'position_ids',
69 | 'past_key_values', 'logits_index'
70 | ],
71 | output_names=['logits', 'presents'],
72 | dynamic_axes={
73 | "input_embed" : { 0: "seq_len" },
74 | "hidden_states" : { 0: "seq_len" },
75 | "attention_mask" : { 2: "seq_len", 3: "seq_len" },
76 | "position_ids" : { 1: "seq_len" },
77 | "past_key_values" : { 2: "history_len" }
78 | })
79 | return onnx_model
80 |
81 | def load(self):
82 | raise NotImplementedError
83 |
84 | def forward(self, images):
85 | raise NotImplementedError
86 |
87 |
88 | class MimoMtp(Mtp):
89 | def __init__(self, mtp, base):
90 | super().__init__(mtp, base)
91 |
92 | def load(self):
93 | self.mtp.eval()
94 | self.token_layernorm = getattr(self.mtp[0], 'token_layernorm')
95 | self.hidden_layernorm = getattr(self.mtp[0], 'hidden_layernorm')
96 | self.input_proj = getattr(self.mtp[0], 'input_proj')
97 | self.input_layernorm = getattr(self.mtp[0], 'input_layernorm')
98 | self.self_attn = getattr(self.mtp[0], 'self_attn')
99 | self.post_attention_layernorm = getattr(self.mtp[0], 'post_attention_layernorm')
100 | self.mlp = getattr(self.mtp[0], 'mlp')
101 | self.final_layernorm = getattr(self.mtp[0], 'final_layernorm')
102 | self.self_attn = Attention(self.self_attn, 0, self.config_)
103 |
104 | def unload_param(self):
105 | def build_faker(real, name):
106 | faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name)
107 | self.unloaded_ops[name] = real
108 | return faker
109 | # replace linear with fakelinear to save export memory and time
110 | with torch.no_grad():
111 | # different kv cache shape in different layers
112 | if isinstance(self.num_attention_heads, list):
113 | self.self_attn.export_fused_attn = True
114 | for name, child in self.self_attn.named_children():
115 | if isinstance(child, torch.nn.Linear):
116 | setattr(self.self_attn, name, build_faker(child, f'/mtp_layers.0/self_attn/{name}/Linear'))
117 | for name, child in self.mlp.named_children():
118 | if isinstance(child, torch.nn.Linear):
119 | setattr(self.mlp, name, build_faker(child, f'/mtp_layers.0/mlp/{name}/Linear'))
120 | self.input_proj = build_faker(self.input_proj, f'/mtp/input_proj/Linear')
121 |
122 | def forward(self,
123 | input_embeds: torch.Tensor,
124 | hidden_states: torch.Tensor,
125 | attention_mask: torch.Tensor,
126 | position_ids: torch.Tensor,
127 | past_key_values: Optional[Tuple[torch.Tensor]] = None,
128 | logits_index: int = -1
129 | ):
130 | input_embeds = input_embeds.view(1, -1, self.hidden_size)
131 | hidden_states = hidden_states.view(1, -1, self.hidden_size)
132 | hidden_states = hidden_states[:, 0 : input_embeds.size(1), :]
133 |
134 | input_embeds = self.token_layernorm(input_embeds)
135 | previous_hidden_states = self.hidden_layernorm(hidden_states)
136 | hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
137 | residual = hidden_states
138 | hidden_states = self.input_layernorm(hidden_states)
139 |
140 | rotary_pos_emb = self.config_.rotary(position_ids)
141 |
142 | # Self Attention
143 | hidden_states, present_key_value = self.self_attn(
144 | hidden_states=hidden_states,
145 | rotary_pos_emb=rotary_pos_emb,
146 | attention_mask=attention_mask,
147 | past_key_value=past_key_values,
148 | cross_attention_states=None,
149 | )
150 |
151 | hidden_states = residual + hidden_states
152 | residual = hidden_states
153 | hidden_states = self.post_attention_layernorm(hidden_states)
154 | hidden_states = self.mlp(hidden_states)
155 | hidden_states = residual + hidden_states
156 |
157 | hidden_states = hidden_states[:, logits_index:, :]
158 | hidden_states = self.final_layernorm(hidden_states)
159 |
160 | logits = self.lm_(hidden_states)
161 | return logits, present_key_value
162 |
163 | class PoiQwenMtp(Mtp):
164 | def __init__(self, mtp, base):
165 | self.num_mtp_layers = 2
166 | super().__init__(mtp, base)
167 |
168 | def load(self):
169 | self.mtp[0].eval()
170 | self.mtp[1].eval()
171 | self.decode_layers = nn.ModuleList([])
172 | self.hidden_norm = nn.ModuleList([])
173 | self.last_norm = nn.ModuleList([])
174 |
175 | with torch.no_grad():
176 | for i in range(self.num_mtp_layers):
177 | self.decode_layers.append(getattr(self.mtp[i], 'layers'))
178 | self.hidden_norm.append(getattr(self.mtp[i], 'RMSorm_MTP_1'))
179 | self.last_norm.append(getattr(self.mtp[i], 'norm'))
180 |
181 | self.input_layernorm = nn.ModuleList([])
182 | self.post_attention_layernorm = nn.ModuleList([])
183 | self.mlp = nn.ModuleList([])
184 | self.self_attn = nn.ModuleList([])
185 |
186 | with torch.no_grad():
187 | for i in range(self.num_mtp_layers):
188 | self.input_layernorm.append(getattr(self.decode_layers[i], 'input_layernorm'))
189 | self.ori_attn = getattr(self.decode_layers[i], 'self_attn')
190 | self.post_attention_layernorm.append(getattr(self.decode_layers[i], 'post_attention_layernorm'))
191 | self.mlp.append(getattr(self.decode_layers[i], 'mlp'))
192 | self.self_attn.append(Attention(self.ori_attn, i, self.config_))
193 |
194 | def unload_param(self):
195 | def build_faker(real, name):
196 | faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name)
197 | self.unloaded_ops[name] = real
198 | return faker
199 | # replace linear with fakelinear to save export memory and time
200 | with torch.no_grad():
201 | for i in range(self.num_mtp_layers):
202 | # different kv cache shape in different layers
203 | if isinstance(self.num_attention_heads, list):
204 | self.self_attn[i].export_fused_attn = True
205 | for name, child in self.self_attn[i].named_children():
206 | if isinstance(child, torch.nn.Linear):
207 | setattr(self.self_attn[i], name, build_faker(child, f'/mtp_layers.{i}/self_attn/{name}/Linear'))
208 | for name, child in self.mlp[i].named_children():
209 | if isinstance(child, torch.nn.Linear):
210 | setattr(self.mlp[i], name, build_faker(child, f'/mtp_layers.{i}/mlp/{name}/Linear'))
211 |
212 | def forward(self,
213 | input_embeds: torch.Tensor,
214 | hidden_states: torch.Tensor,
215 | attention_mask: torch.Tensor,
216 | position_ids: torch.Tensor,
217 | past_key_values: Optional[Tuple[torch.Tensor]] = None,
218 | logits_index: int = -1
219 | ):
220 | present_key_value = []
221 | # [1, -1, self.hidden_size]
222 | mtp_hidden_states = []
223 |
224 | rotary_pos_emb = self.config_.rotary(position_ids)
225 | hidden_states = hidden_states.view(1, -1, self.hidden_size)
226 | hidden_states = hidden_states[:, 0 : input_embeds.size(0), :]
227 |
228 | for i in range(self.num_mtp_layers):
229 | # first norm
230 | hidden_states = self.hidden_norm[i](hidden_states)
231 |
232 | # Decoder Layer
233 | residual = hidden_states
234 | hidden_states = self.input_layernorm[i](hidden_states)
235 |
236 | # Self Attention
237 | hidden_states, kv = self.self_attn[i](
238 | hidden_states=hidden_states,
239 | rotary_pos_emb=rotary_pos_emb,
240 | attention_mask=attention_mask,
241 | past_key_value=past_key_values,
242 | cross_attention_states=None,
243 | )
244 | present_key_value.append(kv)
245 |
246 | hidden_states = residual + hidden_states
247 | residual = hidden_states
248 | hidden_states = self.post_attention_layernorm[i](hidden_states)
249 | hidden_states = self.mlp[i](hidden_states)
250 | hidden_states = residual + hidden_states
251 |
252 | # last norm
253 | hidden_states = self.last_norm[i](hidden_states)
254 |
255 | mtp_hidden_states.append(hidden_states)
256 | hidden_states = mtp_hidden_states[i]
257 |
258 | for i in range(self.num_mtp_layers):
259 | mtp_hidden_states[i] = mtp_hidden_states[i][:, logits_index:, :]
260 |
261 | mtp_logits = self.lm_(mtp_hidden_states[0])
262 | for i in range(self.num_mtp_layers-1):
263 | logits = self.lm_(mtp_hidden_states[i+1])
264 | mtp_logits = torch.cat([mtp_logits, logits], dim=0)
265 | return mtp_logits, present_key_value
--------------------------------------------------------------------------------
/llmexport/utils/eagle.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import os
6 | from typing import Optional, List, Tuple
7 | from .transformers import Attention
8 | from .transformers import RMSNorm
9 | from .transformers import Rotary
10 | from .transformers import Embedding
11 | from .utils.custom_op import FakeLinear
12 | from transformers.configuration_utils import PretrainedConfig
13 | from transformers.activations import ACT2FN
14 | from .utils.spinner import spinner_run
15 | from .torch_utils import onnx_export
16 |
17 |
18 | class Eagle(torch.nn.Module):
19 | def __init__(self, eagle_path, base):
20 | super().__init__()
21 | # load eagle config.json
22 | config_file_path = eagle_path + "/config.json"
23 | self.eagle_config = PretrainedConfig.from_json_file(config_file_path)
24 |
25 | self.model_type = base.model_type
26 | self.eagle_path = eagle_path
27 |
28 | self.config = base.config
29 | if not hasattr(base.config, 'head_dim'):
30 | self.config.head_dim = base.head_dim
31 |
32 | self.rope_theta = 10000
33 | self.rope_ratio = 1.0
34 | self.head_dim = self.config.head_dim
35 | self.config.model_type = base.model_type
36 | self.config.model_map = base.model_map
37 | self.hidden_size = base.hidden_size
38 | if self.eagle_config.hidden_size != self.hidden_size:
39 | raise RuntimeError(f'eagle_config hidden_size not equal: {self.eagle_config.hidden_size}, {self.hidden_size}!')
40 | self.past_kv_shape = base.past_kv_shape
41 | self.num_attention_heads = base.num_attention_heads
42 | self.llm_config = base.llm_config
43 |
44 | self.head_dim = self.config.head_dim
45 | self.num_key_value_heads = self.config.num_key_value_heads
46 | self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
47 | # self.config.head_dim = self.head_dim
48 | self.config.rotary = Rotary(self)
49 | # eagle config params
50 | self.padding_idx = self.eagle_config.pad_token_id
51 | self.vocab_size = self.eagle_config.vocab_size
52 | self.draft_vocab_size = self.eagle_config.draft_vocab_size
53 | # embed_tokens api
54 | self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx)
55 | if not hasattr(self.eagle_config, "target_hidden_size"):
56 | self.embed_tokens.weight = base.embed.embed.weight
57 |
58 | # fc api
59 | if hasattr(self.eagle_config, "target_hidden_size"):
60 | self.fc = nn.Linear(self.eagle_config.target_hidden_size * 3, self.hidden_size, bias=False)
61 | else:
62 | self.fc = nn.Linear(self.hidden_size * 3, self.hidden_size, bias=False)
63 |
64 | self.midlayer = nn.Module()
65 | # midlayer.hidden_norm
66 | self.midlayer.hidden_norm = RMSNorm(self.hidden_size, eps=self.eagle_config.rms_norm_eps)
67 | # midlayer.input_layernorm
68 | self.midlayer.input_layernorm = RMSNorm(self.hidden_size, eps=self.eagle_config.rms_norm_eps)
69 | # midlayer.self_attn
70 | self.midlayer.self_attn = Attention(None, 0, self.config)
71 | self.midlayer.self_attn.q_proj = nn.Linear(self.hidden_size * 2, self.num_attention_heads * self.head_dim, bias=False)
72 | self.midlayer.self_attn.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
73 | self.midlayer.self_attn.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
74 | self.midlayer.self_attn.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
75 | # midlayer.post_attention_layernorm
76 | self.midlayer.post_attention_layernorm = RMSNorm(self.hidden_size, eps=self.eagle_config.rms_norm_eps)
77 | # midlayer.mlp
78 | self.midlayer.mlp = nn.Module()
79 | self.intermediate_size = self.eagle_config.intermediate_size
80 | self.midlayer.mlp.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
81 | self.midlayer.mlp.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
82 | self.midlayer.mlp.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
83 | self.midlayer.mlp.act_fn = ACT2FN[self.eagle_config.hidden_act]
84 |
85 | # norm api
86 | self.norm = RMSNorm(self.hidden_size, eps=self.eagle_config.rms_norm_eps)
87 | # lm_head api
88 | self.lm_head = nn.Linear(self.hidden_size, self.draft_vocab_size,bias=False)
89 | # logsoftmax api
90 | self.logsoftmax = nn.LogSoftmax(dim=-1)
91 | # d2t
92 | d2t = torch.zeros((self.draft_vocab_size), dtype=torch.int64)
93 | self.register_buffer("d2t", d2t)
94 |
95 | self.load()
96 | self.unloaded_ops = {}
97 |
98 | def unload_param(self):
99 | def build_faker(real, name):
100 | faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name)
101 | self.unloaded_ops[name] = real
102 | return faker
103 | # replace linear with fakelinear to save export memory and time
104 | with torch.no_grad():
105 | # different kv cache shape in different layers
106 | if isinstance(self.num_attention_heads, list):
107 | self.midlayer.self_attn.export_fused_attn = True
108 | for name, child in self.midlayer.self_attn.named_children():
109 | if isinstance(child, torch.nn.Linear):
110 | setattr(self.midlayer.self_attn, name, build_faker(child, f'/eagle_layers.0/self_attn/{name}/Linear'))
111 | for name, child in self.midlayer.mlp.named_children():
112 | if isinstance(child, torch.nn.Linear):
113 | setattr(self.midlayer.mlp, name, build_faker(child, f'/eagle_layers.0/mlp/{name}/Linear'))
114 | self.lm_head = build_faker(self.lm_head, f'/eagle/lm_head/Linear')
115 | self.fc = build_faker(self.fc, f'/eagle/fc/Linear')
116 |
117 | @staticmethod
118 | def get_eagle(model_type):
119 | eagles = {
120 | 'llama': LlamaEagle,
121 | 'qwen3': LlamaEagle,
122 | }
123 | if model_type in eagles:
124 | return eagles[model_type]
125 | return None
126 |
127 | @spinner_run(f'export onnx model to ')
128 | def export(self, onnx_path):
129 | # save d2t to file
130 | import MNN.expr as expr
131 | torch_d2t = self.d2t.detach().to(torch.int32).contiguous().cpu()
132 | mnn_d2t = expr.const(torch_d2t.data_ptr(), torch_d2t.shape, expr.data_format.NHWC, expr.dtype.int)
133 | mnn_d2t.name = 'd2t'
134 | expr.save([mnn_d2t], f'{onnx_path}/../eagle_d2t.mnn')
135 |
136 | eagle_model = f'{onnx_path}/eagle.onnx'
137 | eagle_fc_model = f'{onnx_path}/eagle_fc.onnx'
138 | # unload linear weight to save export memory
139 | # self.unload_param()
140 |
141 | self.seq_len = 3
142 | input_ids = torch.arange(3, dtype=torch.long)
143 | attention_mask = (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
144 | position_ids = torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
145 | hidden_states = torch.ones([1, self.seq_len, self.hidden_size], dtype=torch.float)
146 | logits_index = torch.tensor([-1], dtype=torch.int32)
147 |
148 | fc_hidden = torch.ones([1, self.seq_len, self.hidden_size * 3], dtype=torch.float)
149 |
150 | # For export onnx, don't need image or audio's embedding
151 | input_embed = self.embed_tokens(input_ids)
152 | past_key_values = torch.zeros(self.past_kv_shape[1:-1] + [self.head_dim])
153 | # export to onnx
154 | with torch.no_grad():
155 | onnx_export(self.fc, (fc_hidden),
156 | eagle_fc_model,
157 | input_names=['fc_hidden'],
158 | output_names=['hidden_states'],
159 | dynamic_axes={ "fc_hidden" : { 1: "seq_len" } })
160 | onnx_export(
161 | self, (input_embed, hidden_states, attention_mask, position_ids, past_key_values, logits_index),
162 | eagle_model,
163 | input_names=[
164 | 'input_embed', 'hidden_states',
165 | 'attention_mask', 'position_ids',
166 | 'past_key_values', 'logits_index'
167 | ],
168 | output_names=['logits', 'out_hidden_states', 'presents'],
169 | dynamic_axes={
170 | "input_embed" : { 0: "seq_len" },
171 | "hidden_states" : { 1: "seq_len" },
172 | "attention_mask" : { 2: "seq_len", 3: "seq_len" },
173 | "position_ids" : { 1: "seq_len" },
174 | "past_key_values" : { 2: "history_len" }
175 | })
176 | return eagle_model, eagle_fc_model
177 |
178 | def load(self):
179 | raise NotImplementedError
180 |
181 | def forward(self, images):
182 | raise NotImplementedError
183 |
184 | class LlamaEagle(Eagle):
185 | def __init__(self, eagle_path, base):
186 | super().__init__(eagle_path, base)
187 |
188 | def load(self):
189 | safetensors_path = os.path.join(self.eagle_path, "model.safetensors")
190 | bin_path = os.path.join(self.eagle_path, "pytorch_model.bin")
191 | ea_layer_state_dict = None
192 | if os.path.exists(safetensors_path):
193 | from safetensors.torch import load_file
194 | ea_layer_state_dict = load_file(safetensors_path, device="cpu")
195 | elif os.path.exists(bin_path):
196 | ea_layer_state_dict = torch.load(bin_path, map_location="cpu")
197 | else:
198 | raise FileNotFoundError(
199 | f"Eagle path '{self.eagle_path}' not found 'model.safetensors' or 'pytorch_model.bin'."
200 | )
201 | self.load_state_dict(ea_layer_state_dict, strict=False)
202 |
203 | def forward(self,
204 | input_embeds: torch.Tensor,
205 | hidden_states: torch.Tensor,
206 | attention_mask: torch.Tensor,
207 | position_ids: torch.Tensor,
208 | past_key_values: Optional[Tuple[torch.Tensor]] = None,
209 | logits_index: int = -1
210 | ):
211 | # hidden_states = self.fc(hidden_states)
212 | hidden_states = hidden_states.view(1, -1, self.hidden_size)
213 | input_embeds = input_embeds.view(1, -1, self.hidden_size)
214 |
215 | residual = hidden_states
216 |
217 | input_embeds = self.midlayer.input_layernorm(input_embeds)
218 | previous_hidden_states = self.midlayer.hidden_norm(hidden_states)
219 | hidden_states = torch.cat([input_embeds, previous_hidden_states], dim=-1)
220 |
221 | rotary_pos_emb = self.config.rotary(position_ids)
222 |
223 | # Self Attention
224 | hidden_states, present_key_value = self.midlayer.self_attn(
225 | hidden_states=hidden_states,
226 | rotary_pos_emb=rotary_pos_emb,
227 | attention_mask=attention_mask,
228 | past_key_value=past_key_values,
229 | cross_attention_states=None,
230 | )
231 |
232 | hidden_states = residual + hidden_states
233 | residual = hidden_states
234 | hidden_states = self.midlayer.post_attention_layernorm(hidden_states)
235 | hidden_states = self.midlayer.mlp.down_proj(self.midlayer.mlp.act_fn(self.midlayer.mlp.gate_proj(hidden_states)) * self.midlayer.mlp.up_proj(hidden_states))
236 | hidden_states = residual + hidden_states
237 |
238 | hidden_states = hidden_states[:, logits_index:, :]
239 | last_hidden = self.norm(hidden_states)
240 |
241 | logits = self.lm_head(last_hidden)
242 | logits = self.logsoftmax(logits)
243 | return logits, hidden_states, present_key_value
244 |
--------------------------------------------------------------------------------
/llmexport/utils/audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .transformers import Decoder
3 | from .spinner import spinner_run
4 | from .torch_utils import onnx_export
5 |
6 | class Audio(torch.nn.Module):
7 | def __init__(self, audio, base):
8 | super().__init__()
9 | self.model_type = base.model_type
10 | self.audio = audio
11 | self.embed_ = base.embed
12 | self.tokenizer = base.tokenizer
13 | self.config = base.config
14 | self.hidden_size = base.hidden_size
15 | self.llm_config = base.llm_config
16 | self.rope_ratio = 1.0
17 | self.quant_bit = 16
18 | self.init_config()
19 | self.load()
20 |
21 | @staticmethod
22 | def get_audio(model_type):
23 | audio_models = {
24 | 'qwen2_audio': Qwen2Audio,
25 | 'qwen2_5_omni_audio_encoder': Qwen2_5OmniAudio,
26 | }
27 | if model_type in audio_models:
28 | return audio_models[model_type]
29 | return None
30 |
31 | def init_config(self):
32 | self.llm_config['is_audio'] = True
33 |
34 | def load(self):
35 | raise NotImplementedError
36 |
37 | def str_to_ids(self, prompt):
38 | input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
39 | return input_ids
40 |
41 | def forward(self, images):
42 | raise NotImplementedError
43 |
44 | def embed(self, input_ids, images = None, videos = None):
45 | raise NotImplementedError
46 |
47 | def export(self, onnx_path):
48 | raise NotImplementedError
49 |
50 | class Qwen2Audio(Audio):
51 | def __init__(self, audio, base):
52 | super().__init__(audio, base)
53 | self.audio_embeds = None
54 | self.audio_pad_id = 151646
55 | self.n_fft = 400
56 | self.sampling_rate = 16000
57 | self.hop_length = 160
58 | self.chunk_length = 30
59 | self.feature_size = 128
60 | self.n_samples = self.chunk_length * self.sampling_rate
61 | self.max_length = self.n_samples // self.hop_length
62 | from transformers.audio_utils import mel_filter_bank
63 | self.mel_filters = mel_filter_bank(
64 | num_frequency_bins=1 + self.n_fft // 2,
65 | num_mel_filters=self.feature_size,
66 | min_frequency=0.0,
67 | max_frequency=8000.0,
68 | sampling_rate=self.sampling_rate,
69 | norm="slaney",
70 | mel_scale="slaney",
71 | )
72 |
73 | def load(self):
74 | # model
75 | self.audio_tower = self.audio.audio_tower
76 | self.multi_modal_projector = self.audio.multi_modal_projector
77 | # config
78 | self.llm_config['is_audio'] = True
79 |
80 | def str_to_ids(self, prompt):
81 | if '' in prompt:
82 | import re
83 | from io import BytesIO
84 | from urllib.request import urlopen
85 | import librosa
86 | pattern = r'()'
87 | parts = re.split(pattern, prompt)
88 | txt_prompt = ''
89 | for part in parts:
90 | if re.match(pattern, part):
91 | audio_content = re.search(r'', part).group(1)
92 | if audio_content.startswith('http://') or audio_content.startswith('https://'):
93 | audio_obj = librosa.load(BytesIO(urlopen(audio_content).read()), sr=self.sampling_rate)[0]
94 | else:
95 | # local file
96 | audio_obj = librosa.load(audio_content, sr=self.sampling_rate)[0]
97 | audio_embed_len = self.audio_process(audio_obj)
98 | audio_pad_str = '<|AUDIO|>' * audio_embed_len
99 | txt_prompt += audio_pad_str
100 | else:
101 | txt_prompt += part
102 | else:
103 | txt_prompt = prompt
104 | input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
105 | return input_ids
106 |
107 | def forward(self, input_features):
108 | input_features = input_features.to(dtype=self.audio_tower.conv1.weight.dtype, device=self.audio_tower.conv1.weight.device)
109 | inputs_embeds = torch.nn.functional.gelu(self.audio_tower.conv1(input_features))
110 | inputs_embeds = torch.nn.functional.gelu(self.audio_tower.conv2(inputs_embeds))
111 | inputs_embeds = inputs_embeds.permute(0, 2, 1)
112 | _, seq_len, _ = inputs_embeds.shape
113 | embed_pos = self.audio_tower.embed_positions.weight[:seq_len, :]
114 | hidden_states = inputs_embeds + embed_pos
115 | for encoder_layer in self.audio_tower.layers:
116 | hidden_states = encoder_layer(hidden_states, None, None)[0]
117 | hidden_states = hidden_states.permute(0, 2, 1)
118 | hidden_states = self.audio_tower.avg_pooler(hidden_states)
119 | hidden_states = hidden_states.permute(0, 2, 1)
120 | hidden_states = self.audio_tower.layer_norm(hidden_states)
121 | audio_features = self.multi_modal_projector(hidden_states)
122 | return audio_features
123 |
124 | def _torch_extract_fbank_features(self, waveform):
125 | window = torch.hann_window(self.n_fft)
126 | stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
127 | magnitudes = stft[..., :-1].abs() ** 2
128 | mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
129 | mel_spec = mel_filters.T @ magnitudes
130 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
131 | if waveform.dim() == 2:
132 | max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
133 | log_spec = torch.maximum(log_spec, max_val - 8.0)
134 | else:
135 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
136 | log_spec = (log_spec + 4.0) / 4.0
137 | return log_spec
138 |
139 | def audio_process(self, audio_obj):
140 | # audio_obj = np.pad(audio_obj, (0, self.n_samples - audio_obj.shape[0]))
141 | waveform = torch.from_numpy(audio_obj).type(torch.float32)
142 | input_features = self._torch_extract_fbank_features(waveform).unsqueeze(0)
143 | audio_embeds = self.forward(input_features)
144 | self.audio_embeds = audio_embeds.permute([1, 0, 2])
145 | return self.audio_embeds.shape[0]
146 |
147 | def embed(self, input_ids, images = None, videos = None):
148 | input_embeds = self.embed_(input_ids)
149 | if self.audio_embeds is not None:
150 | audio_mask = (input_ids == self.audio_pad_id).squeeze()
151 | input_embeds[audio_mask] = self.audio_embeds.type(input_embeds.dtype)
152 | return input_embeds
153 |
154 | @spinner_run(f'export audio to ')
155 | def export(self, onnx_path):
156 | input_features = torch.randn((1, self.feature_size, self.max_length))
157 |
158 | model = self.float()
159 | onnx_model = f'{onnx_path}/audio.onnx'
160 | onnx_export(model, (input_features),
161 | onnx_model,
162 | input_names=['input_features'],
163 | output_names=['audio_embeds'],
164 | dynamic_axes={"input_features": {
165 | 2: "size"
166 | }})
167 | return onnx_model
168 |
169 | class AudioMlp(torch.nn.Module):
170 | def __init__(self, fc1, fc2, act):
171 | super().__init__()
172 | self.fc1 = fc1
173 | self.fc2 = fc2
174 | self.act = act
175 |
176 | def forward(self, hidden_states):
177 | hidden_states = self.fc1(hidden_states)
178 | hidden_states = self.act(hidden_states)
179 | hidden_states = self.fc2(hidden_states)
180 | return hidden_states
181 |
182 | class Qwen2_5OmniAudio(Qwen2Audio):
183 | def __init__(self, audio, base):
184 | super().__init__(audio, base)
185 | self.quant_bit = 4
186 |
187 | def load(self):
188 | # config
189 | config = self.audio.config
190 | self.n_window = config.n_window
191 | self.llm_config['is_audio'] = True
192 | self.llm_config['n_window'] = self.n_window
193 | self.hidden_size = config.d_model
194 | self.num_attention_heads = config.encoder_attention_heads
195 | self.num_key_value_heads = self.num_attention_heads
196 | self.head_dim = self.hidden_size // self.num_attention_heads
197 | self.rotary = None
198 | self.model_map = {
199 | 'decoder': {
200 | 'self_attn': 'self_attn',
201 | 'input_layernorm': 'self_attn_layer_norm',
202 | 'post_attention_layernorm': 'final_layer_norm'
203 | },
204 | 'attention': {
205 | 'q_proj': 'q_proj',
206 | 'k_proj': 'k_proj',
207 | 'v_proj': 'v_proj',
208 | 'o_proj': 'out_proj'
209 | }
210 | }
211 | self.blocks = []
212 | for layer in self.audio.layers:
213 | layer_id = len(self.blocks)
214 | block = Decoder(layer, layer_id, self)
215 | block.mlp = AudioMlp(layer.fc1, layer.fc2, layer.activation_fn)
216 | self.blocks.append(block)
217 |
218 | def forward(self, input_features, attention_mask = None):
219 | input_features = input_features.to(dtype=self.audio.conv1.weight.dtype, device=self.audio.conv1.weight.device)
220 | inputs_embeds = torch.nn.functional.gelu(self.audio.conv1(input_features))
221 | inputs_embeds = torch.nn.functional.gelu(self.audio.conv2(inputs_embeds))
222 | inputs_embeds = inputs_embeds.permute(0, 2, 1)
223 | _, seq_len, _ = inputs_embeds.shape
224 | embed_pos = self.audio.positional_embedding.positional_embedding[:seq_len, :]
225 | hidden_states = inputs_embeds + embed_pos
226 | for block in self.blocks:
227 | hidden_states = block(hidden_states, attention_mask=attention_mask)[0]
228 | hidden_states = hidden_states.permute(0, 2, 1)
229 | hidden_states = self.audio.avg_pooler(hidden_states)
230 | hidden_states = hidden_states.permute(0, 2, 1)
231 | hidden_states = self.audio.ln_post(hidden_states)
232 | audio_features = self.audio.proj(hidden_states)
233 | return audio_features
234 |
235 | def audio_process(self, audio_obj):
236 | # audio_obj = np.pad(audio_obj, (0, self.n_samples - audio_obj.shape[0]))
237 | waveform = torch.from_numpy(audio_obj).type(torch.float32)
238 | input_features = self._torch_extract_fbank_features(waveform).unsqueeze(0)
239 | _, _, seq_len = input_features.shape
240 | seq_len = int(seq_len // 2)
241 | cu_seqlens = [i for i in range(0, seq_len, self.n_window)]
242 | if seq_len % self.n_window != 0:
243 | cu_seqlens.append(seq_len)
244 | cu_seqlens = torch.tensor(cu_seqlens)
245 | attention_mask = torch.full(
246 | [1, seq_len, seq_len], torch.finfo(torch.float32).min
247 | )
248 | for i in range(1, len(cu_seqlens)):
249 | attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
250 | audio_embeds = self.forward(input_features, attention_mask)
251 | self.audio_embeds = audio_embeds.permute([1, 0, 2])
252 | return self.audio_embeds.shape[0]
253 |
254 | @spinner_run(f'export audio to ')
255 | def export(self, onnx_path):
256 | input_features = torch.randn((1, self.feature_size, self.max_length))
257 | seq_len = self.max_length // 2
258 | attention_mask = torch.randn([1, seq_len, seq_len])
259 | model = self.float()
260 | onnx_model = f'{onnx_path}/audio.onnx'
261 | onnx_export(model, (input_features, attention_mask),
262 | onnx_model,
263 | input_names=['input_features', 'attention_mask'],
264 | output_names=['audio_embeds'],
265 | dynamic_axes={"input_features": {
266 | 0: "size"
267 | }, "attention_mask": {
268 | 1: "size", 2: "size"
269 | }})
270 | return onnx_model
--------------------------------------------------------------------------------
/llmexport/gguf/gguf_reader.py:
--------------------------------------------------------------------------------
1 | #
2 | # GGUF file reading/modification support. For API usage information,
3 | # please see the files scripts/ for some fairly simple examples.
4 | #
5 | from __future__ import annotations
6 |
7 | import logging
8 | import os
9 | from collections import OrderedDict
10 | from typing import Any, Literal, NamedTuple, TypeVar, Union
11 |
12 | import numpy as np
13 | import numpy.typing as npt
14 |
15 | def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
16 | block_size, type_size = GGML_QUANT_SIZES[quant_type]
17 | if shape[-1] % block_size != 0:
18 | raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
19 | return (*shape[:-1], shape[-1] // block_size * type_size)
20 |
21 | if __name__ == "__main__":
22 | import sys
23 | from pathlib import Path
24 |
25 | # Allow running file in package as a script.
26 | sys.path.insert(0, str(Path(__file__).parent.parent))
27 |
28 | from gguf.constants import (
29 | GGML_QUANT_SIZES,
30 | GGUF_DEFAULT_ALIGNMENT,
31 | GGUF_MAGIC,
32 | GGUF_VERSION,
33 | GGMLQuantizationType,
34 | GGUFValueType,
35 | )
36 |
37 | logger = logging.getLogger(__name__)
38 |
39 | READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
40 |
41 |
42 | class ReaderField(NamedTuple):
43 | # Offset to start of this field.
44 | offset: int
45 |
46 | # Name of the field (not necessarily from file data).
47 | name: str
48 |
49 | # Data parts. Some types have multiple components, such as strings
50 | # that consist of a length followed by the string data.
51 | parts: list[npt.NDArray[Any]] = []
52 |
53 | # Indexes into parts that we can call the actual data. For example
54 | # an array of strings will be populated with indexes to the actual
55 | # string data.
56 | data: list[int] = [-1]
57 |
58 | types: list[GGUFValueType] = []
59 |
60 |
61 | class ReaderTensor(NamedTuple):
62 | name: str
63 | tensor_type: GGMLQuantizationType
64 | shape: npt.NDArray[np.uint32]
65 | n_elements: int
66 | n_bytes: int
67 | data_offset: int
68 | data: npt.NDArray[Any]
69 | field: ReaderField
70 |
71 |
72 | class GGUFReader:
73 | # I - same as host, S - swapped
74 | byte_order: Literal['I', 'S'] = 'I'
75 | alignment: int = GGUF_DEFAULT_ALIGNMENT
76 | data_offset: int
77 |
78 | # Note: Internal helper, API may change.
79 | gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
80 | GGUFValueType.UINT8: np.uint8,
81 | GGUFValueType.INT8: np.int8,
82 | GGUFValueType.UINT16: np.uint16,
83 | GGUFValueType.INT16: np.int16,
84 | GGUFValueType.UINT32: np.uint32,
85 | GGUFValueType.INT32: np.int32,
86 | GGUFValueType.FLOAT32: np.float32,
87 | GGUFValueType.UINT64: np.uint64,
88 | GGUFValueType.INT64: np.int64,
89 | GGUFValueType.FLOAT64: np.float64,
90 | GGUFValueType.BOOL: np.bool_,
91 | }
92 |
93 | def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
94 | self.data = np.memmap(path, mode = mode)
95 | offs = 0
96 |
97 | # Check for GGUF magic
98 | if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
99 | raise ValueError('GGUF magic invalid')
100 | offs += 4
101 |
102 | # Check GGUF version
103 | temp_version = self._get(offs, np.uint32)
104 | if temp_version[0] & 65535 == 0:
105 | # If we get 0 here that means it's (probably) a GGUF file created for
106 | # the opposite byte order of the machine this script is running on.
107 | self.byte_order = 'S'
108 | temp_version = temp_version.newbyteorder(self.byte_order)
109 | version = temp_version[0]
110 | if version not in READER_SUPPORTED_VERSIONS:
111 | raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
112 | self.fields: OrderedDict[str, ReaderField] = OrderedDict()
113 | self.tensors: list[ReaderTensor] = []
114 | offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
115 |
116 | # Check tensor count and kv count
117 | temp_counts = self._get(offs, np.uint64, 2)
118 | offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
119 | offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
120 | tensor_count, kv_count = temp_counts
121 | offs = self._build_fields(offs, kv_count)
122 |
123 | # Build Tensor Info Fields
124 | offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
125 | new_align = self.fields.get('general.alignment')
126 | if new_align is not None:
127 | if new_align.types != [GGUFValueType.UINT32]:
128 | raise ValueError('Bad type for general.alignment field')
129 | self.alignment = new_align.parts[-1][0]
130 | padding = offs % self.alignment
131 | if padding != 0:
132 | offs += self.alignment - padding
133 | self.data_offset = offs
134 | self._build_tensors(offs, tensors_fields)
135 |
136 | _DT = TypeVar('_DT', bound = npt.DTypeLike)
137 |
138 | # Fetch a key/value metadata field by key.
139 | def get_field(self, key: str) -> Union[ReaderField, None]:
140 | return self.fields.get(key, None)
141 |
142 | # Fetch a tensor from the list by index.
143 | def get_tensor(self, idx: int) -> ReaderTensor:
144 | return self.tensors[idx]
145 |
146 | def _get(
147 | self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
148 | ) -> npt.NDArray[Any]:
149 | count = int(count)
150 | itemsize = int(np.empty([], dtype = dtype).itemsize)
151 | end_offs = offset + itemsize * count
152 | arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
153 | if override_order is None:
154 | return arr
155 | return arr.view(arr.dtype.newbyteorder(override_order))
156 |
157 | def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
158 | if field.name in self.fields:
159 | # TODO: add option to generate error on duplicate keys
160 | # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
161 |
162 | logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
163 | self.fields[field.name + '_{}'.format(field.offset)] = field
164 | else:
165 | self.fields[field.name] = field
166 | return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
167 |
168 | def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
169 | slen = self._get(offset, np.uint64)
170 | return slen, self._get(offset + 8, np.uint8, slen[0])
171 |
172 | def _get_field_parts(
173 | self, orig_offs: int, raw_type: int,
174 | ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
175 | offs = orig_offs
176 | types: list[GGUFValueType] = []
177 | gtype = GGUFValueType(raw_type)
178 | types.append(gtype)
179 | # Handle strings.
180 | if gtype == GGUFValueType.STRING:
181 | sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
182 | size = sum(int(part.nbytes) for part in sparts)
183 | return size, sparts, [1], types
184 | # Check if it's a simple scalar type.
185 | nptype = self.gguf_scalar_to_np.get(gtype)
186 | if nptype is not None:
187 | val = self._get(offs, nptype)
188 | return int(val.nbytes), [val], [0], types
189 | # Handle arrays.
190 | if gtype == GGUFValueType.ARRAY:
191 | raw_itype = self._get(offs, np.uint32)
192 | offs += int(raw_itype.nbytes)
193 | alen = self._get(offs, np.uint64)
194 | offs += int(alen.nbytes)
195 | aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
196 | data_idxs: list[int] = []
197 | for idx in range(alen[0]):
198 | curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
199 | if idx == 0:
200 | types += curr_types
201 | idxs_offs = len(aparts)
202 | aparts += curr_parts
203 | data_idxs += (idx + idxs_offs for idx in curr_idxs)
204 | offs += curr_size
205 | return offs - orig_offs, aparts, data_idxs, types
206 | # We can't deal with this one.
207 | raise ValueError('Unknown/unhandled field type {gtype}')
208 |
209 | def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
210 | offs = orig_offs
211 |
212 | # Get Tensor Name
213 | name_len, name_data = self._get_str(offs)
214 | offs += int(name_len.nbytes + name_data.nbytes)
215 |
216 | # Get Tensor Dimensions Count
217 | n_dims = self._get(offs, np.uint32)
218 | offs += int(n_dims.nbytes)
219 |
220 | # Get Tensor Dimension Array
221 | dims = self._get(offs, np.uint64, n_dims[0])
222 | offs += int(dims.nbytes)
223 |
224 | # Get Tensor Encoding Scheme Type
225 | raw_dtype = self._get(offs, np.uint32)
226 | offs += int(raw_dtype.nbytes)
227 |
228 | # Get Tensor Offset
229 | offset_tensor = self._get(offs, np.uint64)
230 | offs += int(offset_tensor.nbytes)
231 |
232 | return ReaderField(
233 | orig_offs,
234 | str(bytes(name_data), encoding = 'utf-8'),
235 | [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
236 | [1, 3, 4, 5],
237 | )
238 |
239 | def _build_fields(self, offs: int, count: int) -> int:
240 | for _ in range(count):
241 | orig_offs = offs
242 | kv_klen, kv_kdata = self._get_str(offs)
243 | offs += int(kv_klen.nbytes + kv_kdata.nbytes)
244 | raw_kv_type = self._get(offs, np.uint32)
245 | offs += int(raw_kv_type.nbytes)
246 | parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
247 | idxs_offs = len(parts)
248 | field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
249 | parts += field_parts
250 | self._push_field(ReaderField(
251 | orig_offs,
252 | str(bytes(kv_kdata), encoding = 'utf-8'),
253 | parts,
254 | [idx + idxs_offs for idx in field_idxs],
255 | field_types,
256 | ), skip_sum = True)
257 | offs += field_size
258 | return offs
259 |
260 | def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
261 | tensor_fields = []
262 | for _ in range(count):
263 | field = self._get_tensor_info_field(offs)
264 | offs += sum(int(part.nbytes) for part in field.parts)
265 | tensor_fields.append(field)
266 | return offs, tensor_fields
267 |
268 | def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
269 | tensors = []
270 | tensor_names = set() # keep track of name to prevent duplicated tensors
271 | for field in fields:
272 | _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
273 | # check if there's any tensor having same name already in the list
274 | tensor_name = str(bytes(name_data), encoding = 'utf-8')
275 | if tensor_name in tensor_names:
276 | raise ValueError(f'Found duplicated tensor with name {tensor_name}')
277 | tensor_names.add(tensor_name)
278 | ggml_type = GGMLQuantizationType(raw_dtype[0])
279 | n_elems = int(np.prod(dims))
280 | np_dims = tuple(reversed(dims.tolist()))
281 | block_size, type_size = GGML_QUANT_SIZES[ggml_type]
282 | n_bytes = n_elems * type_size // block_size
283 | data_offs = int(start_offs + offset_tensor[0])
284 | item_type: npt.DTypeLike
285 | if ggml_type == GGMLQuantizationType.F16:
286 | item_count = n_elems
287 | item_type = np.float16
288 | elif ggml_type == GGMLQuantizationType.F32:
289 | item_count = n_elems
290 | item_type = np.float32
291 | elif ggml_type == GGMLQuantizationType.F64:
292 | item_count = n_elems
293 | item_type = np.float64
294 | elif ggml_type == GGMLQuantizationType.I8:
295 | item_count = n_elems
296 | item_type = np.int8
297 | elif ggml_type == GGMLQuantizationType.I16:
298 | item_count = n_elems
299 | item_type = np.int16
300 | elif ggml_type == GGMLQuantizationType.I32:
301 | item_count = n_elems
302 | item_type = np.int32
303 | elif ggml_type == GGMLQuantizationType.I64:
304 | item_count = n_elems
305 | item_type = np.int64
306 | else:
307 | item_count = n_bytes
308 | item_type = np.uint8
309 | np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
310 | tensors.append(ReaderTensor(
311 | name = tensor_name,
312 | tensor_type = ggml_type,
313 | shape = dims,
314 | n_elements = n_elems,
315 | n_bytes = n_bytes,
316 | data_offset = data_offs,
317 | data = self._get(data_offs, item_type, item_count).reshape(np_dims),
318 | field = field,
319 | ))
320 | self.tensors = tensors
321 |
--------------------------------------------------------------------------------
/llmexport/utils/smooth_quantizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import gc
4 | import functools
5 | import json
6 | import inspect
7 | from typing import Dict
8 | from tqdm import tqdm
9 | from collections import defaultdict
10 | #from datasets import load_from_disk
11 |
12 |
13 | logging.basicConfig(level=logging.ERROR)
14 |
15 | class SmoothQuantizer:
16 | def __init__(
17 | self,
18 | model,
19 | n_parallel_calib_samples=None,
20 | max_calib_samples=128,
21 | max_calib_seq_len=512,
22 | alpha=0.5,
23 | act_bit=8
24 | ) -> None:
25 |
26 | self.model = model
27 | self.tokenizer = model.tokenizer
28 | #self.w_bit = model.args.quant_bit
29 | self.act_bit = act_bit
30 | self.group_size = model.args.quant_block
31 | self.alpha = alpha
32 |
33 | self.max_calib_samples = max_calib_samples
34 | self.max_calib_seq_len = max_calib_seq_len
35 | self.split = 'train'
36 | self.calib_data = 'wikitext' if model.args.calib_data is None else model.args.calib_data
37 | self.best_device = SmoothQuantizer.get_best_device()
38 |
39 | self.modules = self.model.blocks
40 | if "cpu" != self.best_device:
41 | for idx in range(len(self.modules)):
42 | SmoothQuantizer.to_device(self.modules[idx], "cpu")
43 |
44 | self.act_scales = [{} for _ in range(len(self.modules))]
45 | self.act_dict = [defaultdict(dict) for _ in range(len(self.modules))]
46 |
47 | self.n_parallel_calib_samples = n_parallel_calib_samples
48 |
49 | self.samples = self.init_quant(
50 | n_samples=self.max_calib_samples,
51 | max_seq_len=self.max_calib_seq_len,
52 | )
53 |
54 | @staticmethod
55 | def get_calib_dataset(
56 | data,
57 | tokenizer=None,
58 | n_samples=128,
59 | max_seq_len=512,
60 | split="train",
61 | ):
62 | if isinstance(data, str):
63 | from datasets import load_dataset
64 | if data == "pileval":
65 | dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
66 | elif data == "wikitext":
67 | dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split=split)
68 | #dataset = load_from_disk("./wikitest-2-raw-v1")
69 | else:
70 | dataset = load_dataset(data, split=split)
71 | # dataset = dataset.shuffle(seed=42)
72 | else:
73 | raise NotImplementedError(
74 | "Either pass a string to a huggingface dataset"
75 | "that is preprocessed with one sample of text per element"
76 | )
77 |
78 | samples = []
79 | dataset = dataset.shuffle(seed=42)
80 |
81 | for i in range(n_samples):
82 | input_ids = tokenizer(
83 | dataset[i]["text"], return_tensors="pt", max_length=max_seq_len, truncation=True
84 | ).input_ids
85 | samples.append(input_ids)
86 |
87 | return samples
88 |
89 | @staticmethod
90 | def get_best_device():
91 | if torch.backends.mps.is_available():
92 | return "mps"
93 | elif torch.cuda.is_available():
94 | return "cuda:0"
95 | else:
96 | return "cpu"
97 |
98 | @staticmethod
99 | def clear_memory(weight=None):
100 | if weight is not None:
101 | del weight
102 | gc.collect()
103 | torch.cuda.empty_cache()
104 |
105 |
106 | def init_quant(self, n_samples=128, max_seq_len=512):
107 | samples = SmoothQuantizer.get_calib_dataset(
108 | data=self.calib_data,
109 | tokenizer=self.tokenizer,
110 | n_samples=n_samples,
111 | max_seq_len=max_seq_len,
112 | split=self.split
113 | )
114 | return samples
115 |
116 | def _get_first_input(self, sample):
117 | layer_kwargs = {}
118 | self.model.seq_len = sample.numel()
119 | self.model.context_len = sample.numel() - 2
120 | self.model.token_len = 0
121 | inps = self.model.embedding(sample).to(self.best_device)
122 | position_ids = self.model.get_position_ids()
123 | rotary_pos_emb = self.model.rotary(position_ids)
124 | attention_mask = self.model.get_attention_mask()
125 | layer_kwargs["rotary_pos_emb"] = rotary_pos_emb.to(self.best_device)
126 | layer_kwargs["attention_mask"] = attention_mask.to(self.best_device)
127 | del sample
128 | SmoothQuantizer.clear_memory()
129 | return layer_kwargs, inps
130 |
131 | def _get_max_input(self, idx, layer, named_linears):
132 |
133 | def stat_tensor(name, tensor):
134 | hidden_dim = tensor.shape[-1]
135 | tensor = tensor.view(-1, hidden_dim).abs().detach()
136 | comming_max = torch.max(tensor, dim=0)[0].float().cpu()
137 | if name in self.act_scales[idx]:
138 | self.act_scales[idx][name] = torch.max(self.act_scales[idx][name], comming_max)
139 | else:
140 | self.act_scales[idx][name] = comming_max
141 |
142 | def stat_input_hook(m, x, y, name):
143 | if isinstance(x, tuple):
144 | x = x[0]
145 | stat_tensor(name, x)
146 | handles = []
147 | for name in named_linears:
148 | handles.append(
149 | named_linears[name].register_forward_hook(
150 | functools.partial(stat_input_hook, name=name)
151 | )
152 | )
153 | module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
154 |
155 | self.inps = self._module_forward(self.inps, layer, module_kwargs)
156 | for h in handles:
157 | h.remove()
158 |
159 | def _sanitize_kwargs(self, inputs_kwargs, module):
160 | """
161 | Remove the arguments that are not supported in the module's
162 | forward pass to avoid breaking behaviour between different versions
163 | of transformers.
164 |
165 | Args:
166 | inputs_kwargs (`dict`):
167 | The input dictionary to pass to the model layer
168 | module (`torch.nn.Module`):
169 | Target module to quantize.
170 | """
171 | module_signature = inspect.signature(module.forward).parameters
172 | sanitized_kwargs = {}
173 | for k, v in inputs_kwargs.items():
174 | if k in module_signature:
175 | sanitized_kwargs[k] = v
176 | return sanitized_kwargs
177 |
178 | @torch.no_grad()
179 | def _module_forward(
180 | self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
181 | ) -> torch.Tensor:
182 |
183 | if self.n_parallel_calib_samples is None:
184 | # runs through all samples at once
185 | # print(module, x, module_kwargs); exit(0)
186 | module_output = module(x, **module_kwargs)
187 | if isinstance(module_output, tuple):
188 | module_output = module_output[0]
189 | else:
190 | # memory efficiently runs through all calibration samples
191 | # but only n_parallel_calib_samples at a time
192 | module_output = []
193 | partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
194 | for x_partial in partitioned_inputs:
195 | partial_output = module(x_partial, **module_kwargs)
196 |
197 | if isinstance(partial_output, tuple):
198 | partial_output = partial_output[0]
199 |
200 | module_output.append(partial_output.cpu())
201 |
202 | module_output = torch.cat(module_output, dim=0)
203 |
204 | return module_output
205 |
206 | @staticmethod
207 | def to_device(module, device):
208 | for child_name, child_module in module.named_children():
209 | if child_name == 'self_attn':
210 | for sub_name, sub_child in child_module.named_children():
211 | if sub_name != 'config':
212 | sub_child.to(device)
213 | else:
214 | child_module.to(device)
215 |
216 | @staticmethod
217 | def get_named_linears(module):
218 | linears = {}
219 | for child_name, child_module in module.named_children():
220 | if child_name == 'self_attn':
221 | for name, mod in child_module.named_children():
222 | if name != 'config':
223 | if isinstance(mod, torch.nn.Linear):
224 | linears[f"{child_name}.{name}"] = mod
225 | else:
226 | for name, mod in child_module.named_modules():
227 | if isinstance(mod, torch.nn.Linear):
228 | full_name = f"{child_name}.{name}" if name else child_name
229 | linears[full_name] = mod
230 |
231 | return linears
232 |
233 | @staticmethod
234 | @torch.no_grad()
235 | def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
236 | if not isinstance(fcs, list):
237 | fcs = [fcs]
238 | if not SmoothQuantizer.is_allowed_norms(ln):
239 | raise NotImplementedError(
240 | f"LayerNorm {ln} is not supported for smooth quantization."
241 | )
242 | for fc in fcs:
243 | assert isinstance(fc, torch.nn.Linear)
244 | assert ln.weight.numel() == fc.in_features == act_scales.numel()
245 | device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
246 | act_scales = act_scales.to(device=device, dtype=dtype)
247 | weight_scales = torch.cat(
248 | [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0
249 | )
250 | weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
251 | scales = (
252 | (act_scales.pow(alpha) / weight_scales.pow(1 - alpha))
253 | .clamp(min=1e-5)
254 | .to(device)
255 | .to(dtype)
256 | )
257 |
258 | if 'GemmaRMSNorm' in str(type(ln)):
259 | ln.weight += 1
260 | ln.weight.div_(scales)
261 | ln.weight -= 1
262 | else:
263 | ln.weight.div_(scales)
264 |
265 | if hasattr(ln, "bias") and ln.bias is not None:
266 | ln.bias.div_(scales)
267 |
268 | for fc in fcs:
269 | fc.weight.mul_(scales.view(1, -1))
270 |
271 | @staticmethod
272 | def is_allowed_norms(op):
273 | if isinstance(op, torch.nn.LayerNorm):
274 | return True
275 | if any(t in str(type(op)) for t in ['LlamaRMSNorm', 'GemmaRMSNorm', 'CohereLayerNorm']):
276 | return True
277 | if "rmsnorm" in str(op.__class__).lower():
278 | return True
279 | return False
280 |
281 | def _apply_scale(self, idx, module):
282 | attn_ln = module.input_layernorm
283 | qkv = [
284 | module.self_attn.q_proj,
285 | module.self_attn.k_proj,
286 | module.self_attn.v_proj,
287 | ]
288 |
289 | qkv_input_scales = self.act_scales[idx]["self_attn.q_proj"]
290 | SmoothQuantizer.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, self.alpha)
291 |
292 | ffn_ln = module.post_attention_layernorm # feed forward norm
293 | fcs = [module.mlp.gate_proj, module.mlp.up_proj]
294 | ffn_input_scales = self.act_scales[idx]["mlp.gate_proj"]
295 | SmoothQuantizer.smooth_ln_fcs(ffn_ln, fcs, ffn_input_scales, self.alpha)
296 |
297 | @torch.no_grad()
298 | def _get_all_static_scales(self, idx, layer, named_linears):
299 | def stat_io_hook(m, x, y, name):
300 | if isinstance(x, tuple):
301 | x = x[0]
302 | if name not in self.act_dict[idx] or "input" not in self.act_dict[idx][name]:
303 | self.act_dict[idx][name]["input"] = x.detach().abs().max().item()
304 | else:
305 | self.act_dict[idx][name]["input"] = max(
306 | self.act_dict[idx][name]["input"], x.detach().abs().max().item()
307 | )
308 | if isinstance(y, tuple):
309 | y = y[0]
310 | if name not in self.act_dict[idx] or "output" not in self.act_dict[idx][name]:
311 | self.act_dict[idx][name]["output"] = y.detach().abs().max().item()
312 | else:
313 | self.act_dict[idx][name]["output"] = max(
314 | self.act_dict[idx][name]["output"], y.detach().abs().max().item()
315 | )
316 | handles = []
317 | for name in named_linears:
318 | handles.append(
319 | named_linears[name].register_forward_hook(
320 | functools.partial(stat_io_hook, name=name)
321 | )
322 | )
323 | module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
324 |
325 | self.inps = self._module_forward(self.inps, layer, module_kwargs)
326 | for h in handles:
327 | h.remove()
328 |
329 | @torch.no_grad()
330 | def _extract_static_scales(self):
331 |
332 | print("Extracting static scales...")
333 |
334 | scale = 2 ** (self.act_bit - 1) - 1
335 |
336 | for idx in range(len(self.modules)):
337 | for name, input_output in self.act_dict[idx].items():
338 | self.act_dict[idx][name]['input'] = input_output['input'] / scale
339 | self.act_dict[idx][name]['output'] = input_output['output'] / scale
340 |
341 | def quantize(self):
342 |
343 | for i in tqdm(range(len(self.samples)), desc="collecting data and computing scales..."):
344 | sample = self.samples[i]
345 | if sample.numel() == 0:
346 | continue
347 | self.module_kwargs, self.inps = self._get_first_input(sample)
348 |
349 | for idx in range(len(self.modules)):
350 | SmoothQuantizer.to_device(self.modules[idx], self.best_device)
351 |
352 | if self.module_kwargs.get("position_ids", None) is not None:
353 | self.module_kwargs["position_ids"] = self.module_kwargs["position_ids"].to(self.best_device)
354 |
355 | if self.module_kwargs.get("attention_mask", None) is not None:
356 | self.module_kwargs["attention_mask"] = self.module_kwargs["attention_mask"].to(self.best_device)
357 |
358 | named_linears = SmoothQuantizer.get_named_linears(self.modules[idx])
359 |
360 | self._get_max_input(idx, self.modules[idx], named_linears)
361 | if "cpu" != self.best_device:
362 | SmoothQuantizer.to_device(self.modules[idx], "cpu")
363 |
364 | for idx in tqdm(range(len(self.modules)), desc="applying scales..."):
365 | self._apply_scale(idx, self.modules[idx])
366 |
367 | for i in tqdm(range(len(self.samples)), desc="collecting static activation scales..."):
368 | sample = self.samples[i]
369 | if sample.numel() == 0:
370 | continue
371 | self.module_kwargs, self.inps = self._get_first_input(sample)
372 |
373 | for idx in range(len(self.modules)):
374 | SmoothQuantizer.to_device(self.modules[idx], self.best_device)
375 |
376 | if self.module_kwargs.get("position_ids", None) is not None:
377 | self.module_kwargs["position_ids"] = self.module_kwargs["position_ids"].to(self.best_device)
378 |
379 | if self.module_kwargs.get("attention_mask", None) is not None:
380 | self.module_kwargs["attention_mask"] = self.module_kwargs["attention_mask"].to(self.best_device)
381 |
382 | named_linears = SmoothQuantizer.get_named_linears(self.modules[idx])
383 |
384 | self._get_all_static_scales(idx, self.modules[idx], named_linears)
385 | if "cpu" != self.best_device:
386 | SmoothQuantizer.to_device(self.modules[idx], "cpu")
387 | self._extract_static_scales()
388 |
389 | SmoothQuantizer.clear_memory()
390 | for idx in range(len(self.modules)):
391 | SmoothQuantizer.to_device(self.modules[idx], "cpu")
392 |
393 |
394 |
395 | def apply(self, base_path):
396 | mnn = json.load(open(base_path, 'rt'))
397 | mnn['extraTensorDescribe'] = []
398 |
399 | max_val = 2 ** (self.act_bit - 1) - 1
400 | min_val = -max_val
401 | data_type = 'DT_INT16'
402 | if self.act_bit <= 8:
403 | data_type = 'DT_INT8'
404 | if self.act_bit > 8 and self.act_bit <= 16:
405 | data_type = 'DT_INT16'
406 |
407 | quant_info_dict = {}
408 |
409 | for op in mnn['oplists']:
410 | if op['type'] == 'Convolution' and 'lm_head' not in op['name']:
411 | name_vec = op['name'].split('/')
412 | layer_idx = int(name_vec[1].split('.')[-1])
413 | layer_name = name_vec[2] + '.' + name_vec[3]
414 |
415 | tensor_input_index = op['inputIndexes'][0]
416 | tensor_output_index = op['outputIndexes'][0]
417 |
418 | if tensor_input_index not in quant_info_dict:
419 | quant_info_dict[tensor_input_index] = {
420 | 'index': tensor_input_index,
421 | 'quantInfo': {
422 | 'scale': self.act_dict[layer_idx][layer_name]['input'],
423 | 'min': min_val,
424 | 'max': max_val,
425 | "type":data_type
426 | }
427 | }
428 |
429 | if tensor_output_index not in quant_info_dict:
430 | quant_info_dict[tensor_output_index] = {
431 | 'index': tensor_output_index,
432 | 'quantInfo': {
433 | 'scale': self.act_dict[layer_idx][layer_name]['output'],
434 | 'min': min_val,
435 | 'max': max_val,
436 | "type":data_type
437 | }
438 | }
439 | mnn['extraTensorDescribe'] = list(quant_info_dict.values())
440 |
441 | with open(base_path, 'w', encoding='utf-8') as f:
442 | json.dump(mnn, f, ensure_ascii=False, indent=4)
443 |
444 | return base_path
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
--------------------------------------------------------------------------------
/llmexport/utils/token2wav.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | torch.set_printoptions(precision=4, sci_mode=False)
5 | from .model_mapper import ModelMapper
6 | from .transformers import Rotary, Embedding, Decoder, Attention
7 | from .spinner import spinner_run
8 | from .torch_utils import onnx_export
9 |
10 | class Token2Wav(torch.nn.Module):
11 | def __init__(self,token2wav, base):
12 | super().__init__()
13 | self.args = base.args
14 | self.token2wav = token2wav.float()
15 | self.config = base.config
16 | self.llm_config = base.llm_config
17 | self.rope_ratio = 1.0
18 | self.quant_bit = 8
19 | self.load()
20 |
21 | def load(self):
22 | raise NotImplementedError
23 |
24 | def add_token_embeds(self, thinker_embeds):
25 | raise NotImplementedError
26 |
27 | def add_hidden_states(self, thinker_hidden_states):
28 | raise NotImplementedError
29 |
30 | def add_generate_ids(self, token_id):
31 | raise NotImplementedError
32 |
33 | def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None):
34 | raise NotImplementedError
35 |
36 | def export(self, onnx_path):
37 | raise NotImplementedError
38 |
39 | class UpSample1d(torch.nn.Module):
40 | def __init__(self, upsample, channel):
41 | super().__init__()
42 | self.ratio = upsample.ratio
43 | self.stride = upsample.stride
44 | self.pad = upsample.pad
45 | self.pad_left = upsample.pad_left
46 | self.pad_right = upsample.pad_right
47 | self.filter = upsample.filter.expand(channel, -1, -1).clone()
48 | self.channel = channel
49 |
50 | def forward(self, x):
51 | x = F.pad(x, (self.pad, self.pad), mode="replicate")
52 | x = self.ratio * F.conv_transpose1d(x, self.filter, stride=self.stride, groups=self.channel)
53 | x = x[..., self.pad_left : -self.pad_right]
54 | return x
55 |
56 | class DownSample1d(torch.nn.Module):
57 | def __init__(self, downsample, channel):
58 | super().__init__()
59 | self.pad_left = downsample.pad_left
60 | self.pad_right = downsample.pad_right
61 | self.stride = downsample.stride
62 | self.filter = downsample.filter.expand(channel, -1, -1).clone()
63 | self.channel = channel
64 |
65 | def forward(self, x):
66 | x = F.pad(x, (self.pad_left, self.pad_right), mode="replicate")
67 | out = F.conv1d(x, self.filter, stride=self.stride, groups=self.channel)
68 | return out
69 |
70 | class TorchActivation1d(torch.nn.Module):
71 | def __init__(
72 | self,
73 | activation
74 | ):
75 | super().__init__()
76 | self.act = activation.act
77 | channel = self.act.in_features
78 | self.upsample = UpSample1d(activation.upsample, channel)
79 | self.downsample = DownSample1d(activation.downsample, channel)
80 |
81 | def forward(self, x):
82 | x = self.upsample(x)
83 | x = self.act(x)
84 | x = self.downsample(x)
85 | return x
86 |
87 | # DiT model code
88 | class ECAPA_TDNN(torch.nn.Module):
89 | def __init__(self, spk_encoder):
90 | super().__init__()
91 | self.blocks = spk_encoder.blocks
92 | self.mfa = spk_encoder.mfa
93 | self.asp = spk_encoder.asp
94 | self.fc = spk_encoder.fc
95 |
96 | def forward(self, x):
97 | # Minimize transpose for efficiency
98 | x = x.transpose(1, 2)
99 | xl = []
100 | for layer in self.blocks:
101 | x = layer(x)
102 | xl.append(x)
103 | # Multi-layer feature aggregation
104 | x = torch.cat(xl[1:], dim=1)
105 | x = self.mfa(x)
106 | # Attentive Statistical Pooling
107 | x = self.asp(x)
108 | # Final linear transformation
109 | x = self.fc(x)
110 | # x = x.squeeze(-1) # avoid If when export to onnx
111 | x = x.permute(0, 2, 1)
112 | return x
113 |
114 | class DitRotary(Rotary):
115 | def __init__(self):
116 | super().__init__(None)
117 | self.model_type = 'dit'
118 | self.rope_theta = 10000
119 | self.rotary_dim = 64
120 | self.theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim))
121 |
122 | def forward(self, position_ids):
123 | position_ids = position_ids.float().reshape(-1, 1)
124 | idx_theta = position_ids * self.theta
125 | rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)])
126 | rotary_pos_emb = torch.stack((rotary_pos_emb, rotary_pos_emb), dim=-1)
127 | rotary_pos_emb = rotary_pos_emb.reshape(*rotary_pos_emb.shape[:-2], -1)
128 | rotary_pos_emb = rotary_pos_emb.unsqueeze(2).unsqueeze(1)
129 | return rotary_pos_emb
130 |
131 | @staticmethod
132 | def apply_rotary_pos(x, cos, sin):
133 | def rotate_half(x):
134 | x = x.reshape(*x.shape[:-1], -1, 2)
135 | x1, x2 = x.unbind(dim=-1)
136 | x = torch.stack((-x2, x1), dim=-1)
137 | return x.reshape(*x.shape[:-2], -1)
138 |
139 | x = (x * cos) + (rotate_half(x) * sin)
140 | return x
141 |
142 | import math
143 | class DiTAttention(torch.nn.Module):
144 | def __init__(self, attn):
145 | super().__init__()
146 | self.dim = attn.dim
147 | self.heads = attn.heads
148 | self.inner_dim = attn.inner_dim
149 | self.to_q = attn.to_q
150 | self.to_k = attn.to_k
151 | self.to_v = attn.to_v
152 | self.to_out = attn.to_out
153 |
154 | def forward(
155 | self,
156 | x,
157 | rope=None,
158 | mask=None,
159 | ) -> torch.Tensor:
160 | batch_size = x.shape[0]
161 |
162 | # `sample` projections.
163 | query = self.to_q(x)
164 | key = self.to_k(x)
165 | value = self.to_v(x)
166 |
167 | # attention
168 | inner_dim = key.shape[-1]
169 | head_dim = inner_dim // self.heads
170 | query = query.view(batch_size, -1, self.heads, head_dim)
171 | key = key.view(batch_size, -1, self.heads, head_dim)
172 | value = value.view(batch_size, -1, self.heads, head_dim)
173 | # apply rotary position embedding
174 | # Due to training process, only first head is applied with RoPE, will be fixed at next release
175 | cos, sin = rope[0], rope[1]
176 | first_query = query[:, :, :1, :]
177 | first_key = key[:, :, :1, :]
178 | other_query = query[:, :, 1:, :]
179 | other_key = key[:, :, 1:, :]
180 | first_query = DitRotary.apply_rotary_pos(first_query, cos, sin)
181 | first_key = DitRotary.apply_rotary_pos(first_key, cos, sin)
182 | query = torch.concat([first_query, other_query], dim=2)
183 | key = torch.concat([first_key, other_key], dim=2)
184 |
185 | attention_mask = (~mask) * torch.finfo(torch.float32).min
186 |
187 | query = query.transpose(1, 2)
188 | key = key.permute([0, 2, 3, 1])
189 | value = value.transpose(1, 2)
190 | attn_weights = torch.matmul(query, key) / math.sqrt(head_dim)
191 | attn_weights = attn_weights + attention_mask
192 | attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
193 | attn_output = torch.matmul(attn_weights, value)
194 | x = attn_output.transpose(1, 2)
195 |
196 | # mask. e.g. inference got a batch with different target durations, mask out the padding
197 | x = x.reshape(batch_size, -1, self.heads * head_dim)
198 | x = x.to(query.dtype)
199 |
200 | # linear proj
201 | x = self.to_out[0](x)
202 | # dropout
203 | x = self.to_out[1](x)
204 |
205 | return x
206 |
207 | class DiTBlock(torch.nn.Module):
208 | def __init__(self, block):
209 | super().__init__()
210 | self.attn_norm = block.attn_norm
211 | self.attn = DiTAttention(block.attn)
212 | self.attn_ = block.attn
213 | self.look_ahead_block = block.look_ahead_block
214 | self.look_backward_block = block.look_backward_block
215 | self.ff_norm = block.ff_norm
216 | self.ff = block.ff
217 |
218 | def forward(self, x, t, rope=None, block_diff=None):
219 | norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
220 |
221 | attn_output = self.attn(
222 | x=norm,
223 | rope=rope,
224 | mask=(block_diff >= -float(self.look_backward_block)) & (block_diff <= float(self.look_ahead_block)),
225 | )
226 |
227 | # process attention output for input x
228 | x = x + gate_msa.unsqueeze(1) * attn_output
229 |
230 | norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
231 | ff_output = self.ff(norm)
232 | x = x + gate_mlp.unsqueeze(1) * ff_output
233 |
234 | return x
235 |
236 | class DitPreprocess(torch.nn.Module):
237 | def __init__(self, dit):
238 | super().__init__()
239 | self.code_embed = dit.code_embed
240 | self.input_proj = dit.proj_in_other
241 | self.rotary_embed = DitRotary()
242 | self.block_size = 24
243 |
244 | def forward(self, cond, spk, code):
245 | max_duration = code.shape[1] * 2
246 | spk = spk.repeat(1, max_duration, 1)
247 | cond = cond.repeat(1, max_duration, 1)
248 | code_embed = self.code_embed(code)
249 | input_embeds = torch.cat((cond, code_embed, spk), dim=-1)
250 | code_embeds = self.input_proj(input_embeds)
251 | position_ids = torch.arange(max_duration)
252 | rope = self.rotary_embed(position_ids)
253 |
254 | block_indices = position_ids // self.block_size
255 | block_i = block_indices.unsqueeze(1)
256 | block_j = block_indices.unsqueeze(0)
257 | block_diff = block_j - block_i
258 | mask = block_diff.reshape(1, 1, max_duration, max_duration)
259 | return code_embeds, rope, mask
260 |
261 | class DitWrapper(torch.nn.Module):
262 | def __init__(self, dit):
263 | super().__init__()
264 | self.dit = dit
265 | self.cfg = False
266 | self.time_embed = dit.time_embed
267 | self.code_embed = dit.text_embed
268 | self.rotary_embed = DitRotary()
269 | self.transformer_blocks = torch.nn.ModuleList()
270 | for i in range(len(dit.transformer_blocks)):
271 | self.transformer_blocks.append(DiTBlock(dit.transformer_blocks[i]))
272 | self._create_block_diff = dit._create_block_diff
273 | self.norm_out = dit.norm_out
274 | self.proj_out = dit.proj_out
275 | proj_in = dit.input_embed.proj
276 | oc, ic = proj_in.weight.shape
277 | x_ic = 80
278 | other_ic = ic - x_ic
279 | self.proj_in_x = torch.nn.Linear(x_ic, oc)
280 | self.proj_in_x.weight.data = proj_in.weight[:, :x_ic]
281 | self.proj_in_x.bias = None
282 | self.proj_in_other = torch.nn.Linear(other_ic, oc)
283 | self.proj_in_other.weight.data = proj_in.weight[:, x_ic:]
284 | self.proj_in_other.bias = proj_in.bias
285 | self.spk_encoder = ECAPA_TDNN(dit.input_embed.spk_encoder)
286 | self.preprocess = DitPreprocess(self)
287 |
288 | def spk_encode(self, spk):
289 | return self.spk_encoder(spk)
290 |
291 | def forward(self, x, code_embeds, rope, mask, time):
292 | t = self.time_embed(time)
293 | hidden = self.proj_in_x(x) + code_embeds
294 | for block in self.transformer_blocks:
295 | hidden = block(hidden, t, rope=rope, block_diff=mask)
296 | hidden = self.norm_out(hidden, t)
297 | output = self.proj_out(hidden)
298 | return output
299 |
300 | # end
301 |
302 | class Qwen2_5OmniToken2Wav(Token2Wav):
303 | def __init__(self, token2wav, base):
304 | super().__init__(token2wav, base)
305 |
306 | def load(self):
307 | self.dit = self.token2wav.code2wav_dit_model
308 | self.bigvgan = self.token2wav.code2wav_bigvgan_model
309 | # some code change for export
310 | self.dit = DitWrapper(self.dit)
311 | # bigvgan.resblocks.activations.up/downsample contain conv weight channel by input
312 | for i in range(len(self.bigvgan.resblocks)):
313 | for j in range(len(self.bigvgan.resblocks[i].activations)):
314 | old_act = self.bigvgan.resblocks[i].activations[j]
315 | self.bigvgan.resblocks[i].activations[j] = TorchActivation1d(old_act)
316 | self.bigvgan.activation_post = TorchActivation1d(self.bigvgan.activation_post)
317 | # spk
318 | path = os.path.join(self.args.path, 'spk_dict.pt')
319 | self.speaker_map = {}
320 | for key, value in torch.load(path).items():
321 | spk = value["cond"].float()
322 | cond = value['ref_mel'].float()
323 | value.pop("ref_mel", None)
324 | value['spk'] = spk.unsqueeze(1)
325 | value['cond'] =self.dit.spk_encode(cond)
326 | self.speaker_map[key] = value
327 | spk = "Chelsie"
328 | self.speaker_params = self.speaker_map[spk]
329 |
330 | def dit_forward(self, code, initial_noise = None):
331 | spk = self.speaker_params["spk"].float()
332 | cond = self.speaker_params["cond"].float()
333 | max_duration = code.shape[1] * 2
334 | code_embeds, rope, mask = self.dit.preprocess(cond, spk, code)
335 | def func(t, x):
336 | pred = self.dit(x=x, code_embeds=code_embeds, rope=rope, mask=mask, time=torch.tensor([t]))
337 | return pred
338 |
339 | steps = 5
340 | t = torch.linspace(0, 1, steps, dtype=cond.dtype)
341 | t = 1 - torch.cos(torch.pi / 2 * t)
342 |
343 | if initial_noise is None:
344 | torch.manual_seed(42)
345 | y0 = torch.randn([1, max_duration, 80], dtype=cond.dtype)
346 | else:
347 | y0 = initial_noise.clone()
348 |
349 | for t0, t1 in zip(t[:-1], t[1:]):
350 | dt = t1 - t0
351 | k1 = func(t0, y0)
352 | k2 = func(t0 + dt * 1/3, y0 + dt * k1 * 1/3)
353 | k3 = func(t0 + dt * 2/3, y0 + dt * (k2 - k1 * 2/3))
354 | k4 = func(t1, y0 + dt * (k1 - k2 + k3))
355 | dy = (k1 + 3 * (k2 + k3) + k4) * dt * 0.125
356 | y0 += dy
357 |
358 | generated_mel = y0.permute(0, 2, 1)
359 | # print('generated_mel = ', generated_mel, generated_mel.shape)
360 | # print('generated_mel.shape = ', generated_mel.shape)
361 | return generated_mel
362 |
363 | @torch.no_grad()
364 | def generate(self, code):
365 | generated_mel = self.dit_forward(code)
366 | waveform = self.bigvgan(generated_mel)
367 | return waveform
368 |
369 | @torch.no_grad()
370 | def generate_stream(self, code):
371 | # Defeine dit streaming parameters
372 | dit_chunk_size = 48
373 | dit_left_context = 24
374 | dit_right_context = 12
375 | dit_left_padding = 0
376 | dit_right_padding = dit_right_context
377 | dit_start_index = 0
378 | dit_mel_len = 0
379 |
380 | # Define vocoder streaming parameters
381 | vocoder_left_context = 10
382 | vocoder_right_context = 10
383 | vocoder_left_pad = 0
384 | vocoder_right_pad = vocoder_right_context
385 | vocoder_upsample_rate = 240
386 |
387 | torch.manual_seed(42)
388 | initial_noise = torch.randn([1, 30000, 80], dtype=torch.float32)
389 | code_buffer = torch.full((1, 0), 0, dtype=torch.long, device=code.device)
390 | mel_buffer = torch.full((1, 80, 0), 0, dtype=torch.float32, device=code.device)
391 | waveform_buffer = torch.full((0,), 0, dtype=torch.float32)
392 | for next_code in code[0]:
393 | code_buffer = torch.cat([code_buffer, next_code.reshape(1, 1)], dim=1)
394 | if code_buffer.size(1) == dit_left_padding + dit_chunk_size + dit_right_padding:
395 | # dit
396 | generated_mel = self.dit_forward(code_buffer, initial_noise[:, dit_start_index: dit_start_index + code_buffer.size(1) * 2])
397 | generated_mel = generated_mel[:, :, dit_left_padding * 2: -dit_right_padding * 2]
398 | dit_left_padding = dit_left_context
399 | code_buffer = code_buffer[:, -(dit_left_padding + dit_right_padding):]
400 | dit_mel_len += generated_mel.size(-1)
401 | dit_start_index = dit_mel_len - dit_left_context * 2
402 | # bigvgan
403 | mel_buffer = torch.cat([mel_buffer, generated_mel], dim=-1)
404 | waveform = self.bigvgan(mel_buffer)
405 | waveform = waveform[vocoder_left_pad * vocoder_upsample_rate: -vocoder_right_pad * vocoder_upsample_rate]
406 | waveform_buffer = torch.cat([waveform_buffer, waveform], dim=-1)
407 | vocoder_left_pad = vocoder_left_context
408 | mel_buffer = mel_buffer[:, :, -(vocoder_left_pad + vocoder_right_pad):]
409 |
410 | if code_buffer.size(1) > 0:
411 | generated_mel = self.dit_forward(code_buffer, initial_noise[:, dit_start_index: dit_start_index + code_buffer.size(1) * 2])
412 | generated_mel = generated_mel[:, :, dit_left_padding * 2:]
413 | mel_buffer = torch.cat([mel_buffer, generated_mel], dim=-1)
414 | waveform = self.bigvgan(mel_buffer)
415 | waveform = waveform[vocoder_left_pad * vocoder_upsample_rate:]
416 | waveform_buffer = torch.cat([waveform_buffer, waveform], dim=-1)
417 |
418 | return waveform_buffer
419 |
420 | def export_spk(self):
421 | import MNN.expr as expr
422 | def torch_to_mnn(x):
423 | return expr.const(x.data_ptr(), x.shape)
424 | var_list = []
425 | for key, value in self.speaker_map.items():
426 | for k, v in value.items():
427 | if type(v) is not torch.Tensor:
428 | v = torch.tensor(v)
429 | mnn_var = torch_to_mnn(v.contiguous().float())
430 | mnn_var.name = f'{key}_{k}'
431 | var_list.append(mnn_var)
432 | expr.save(var_list, f'{self.args.dst_path}/spk_dict.mnn')
433 |
434 | @spinner_run(f'export token2wav.predit to ')
435 | def export_predit(self, onnx_path):
436 | cond = torch.randn([1, 1, 128], dtype=torch.float32)
437 | spk = torch.randn([1, 1, 192], dtype=torch.float32)
438 | code = torch.ones([1, 256], dtype=torch.int32)
439 | onnx_model = f'{onnx_path}/predit.onnx'
440 | onnx_export(self.dit.preprocess, (cond, spk, code),
441 | onnx_model,
442 | input_names=['cond', 'spk', 'code'],
443 | output_names=['code_embeds', 'rope', 'mask'],
444 | dynamic_axes={
445 | "code": { 1: "size" },
446 | })
447 | return onnx_model
448 |
449 | @spinner_run(f'export token2wav.dit to ')
450 | def export_dit(self, onnx_path):
451 | x = torch.randn([1, 512, 80], dtype=torch.float32)
452 | code_embeds = torch.randn([1, 512, 1024], dtype=torch.float32)
453 | rope = torch.randn([2, 1, 512, 1, 64], dtype=torch.float32)
454 | mask = torch.ones([1, 1, 512, 512], dtype=torch.int32)
455 | time = torch.tensor([0.0])
456 | onnx_model = f'{onnx_path}/dit.onnx'
457 | onnx_export(self.dit, (x, code_embeds, rope, mask, time),
458 | onnx_model,
459 | input_names=['x', 'code_embeds', 'rope', 'mask', 'time'],
460 | output_names=['mel'],
461 | dynamic_axes={
462 | "x": { 1: "size" },
463 | "code_embeds": { 1: "size" },
464 | "rope": { 2: "size" },
465 | "mask": { 2: "size", 3: "size" },
466 | })
467 | return onnx_model
468 |
469 | @spinner_run(f'export token2wav.bigvgan to ')
470 | def export_bigvgan(self, onnx_path):
471 | generated_mel = torch.randn([1, 80, 512], dtype=torch.float32)
472 | onnx_model = f'{onnx_path}/bigvgan.onnx'
473 | onnx_export(self.bigvgan, (generated_mel),
474 | onnx_model,
475 | input_names=['generated_mel'],
476 | output_names=['waveform'],
477 | dynamic_axes={
478 | "generated_mel": { 2: "size" },
479 | })
480 | return onnx_model
481 |
482 | def export(self, onnx_path):
483 | self.export_spk()
484 | predit = self.export_predit(onnx_path)
485 | dit = self.export_dit(onnx_path)
486 | bigvgan = self.export_bigvgan(onnx_path)
487 | return predit, dit, bigvgan
--------------------------------------------------------------------------------
/llmexport/utils/mnn_converter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import torch
6 | import numpy as np
7 |
8 | from .torch_utils import quant as torch_quant
9 | from .torch_utils import onnx_export
10 | from tqdm import tqdm
11 | from .spinner import spinner_run
12 | from .gptq import GPTQ
13 | from .lora import LoRA
14 |
15 | EXPORT_LOG = '.export.log'
16 |
17 | class MNNConveter:
18 | def __init__(self, config, weight_ops = None):
19 | self.weight_ops = weight_ops
20 | self.config = config
21 | self.quant_block = config.args.quant_block
22 | self.quant_bit = config.args.quant_bit
23 | self.lm_quant_bit = config.args.lm_quant_bit
24 | self.symmetric = config.args.sym
25 | self.hqq = config.args.hqq
26 | self.mnn_weight_offset = 0
27 | if os.path.exists(config.args.mnnconvert):
28 | self.mnnconvert = config.args.mnnconvert
29 | else:
30 | self.mnnconvert = None
31 | self.lm_weight = None
32 |
33 | def convert(self, convert_args):
34 | sfd = os.dup(1)
35 | log_fp = open(EXPORT_LOG, "a")
36 | log_fd = log_fp.fileno()
37 | # mnnconvert ... > .export.log
38 | os.dup2(log_fd, 1)
39 | try:
40 | sys.argv = convert_args
41 | sys.argc = len(convert_args)
42 | if self.mnnconvert is None:
43 | from MNN.tools import mnnconvert
44 | mnnconvert.main()
45 | else:
46 | convert_args[0] = self.mnnconvert
47 | cmd = ' '.join(convert_args)
48 | message = os.popen(cmd).read()
49 | print(message)
50 | sys.argv = []
51 | finally:
52 | os.dup2(sfd, 1)
53 | log_fp.close()
54 |
55 | @spinner_run(f'convert onnx model to ')
56 | def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, weight_sym = False, save_external_data = True):
57 | convert_args = [
58 | '',
59 | '-f',
60 | 'ONNX',
61 | '--modelFile',
62 | str(onnx_path),
63 | '--MNNModel',
64 | str(mnn_path),
65 | '--allowCustomOp'
66 | ]
67 | if transformer_fuse:
68 | convert_args += ['--transformerFuse']
69 | if group_conv_native:
70 | convert_args += ['--groupConvNative']
71 | if weight_sym:
72 | convert_args += ['--weightQuantAsymmetric=0']
73 | if save_external_data:
74 | convert_args += ['--saveExternalData']
75 | if self.hqq:
76 | convert_args += ['--hqq']
77 | convert_args += args
78 | self.convert(convert_args)
79 | return mnn_path
80 |
81 | def mnn2json(self, mnn_path, json_path):
82 | convert_args = [
83 | '',
84 | '-f',
85 | 'MNN',
86 | '--modelFile',
87 | str(mnn_path),
88 | '--JsonFile',
89 | str(json_path)
90 | ]
91 | self.convert(convert_args)
92 | return json_path
93 |
94 | def json2mnn(self, json_path, mnn_path):
95 | convert_args = [
96 | '',
97 | '-f',
98 | 'JSON',
99 | '--modelFile',
100 | str(json_path),
101 | '--MNNModel',
102 | str(mnn_path)
103 | ]
104 | self.convert(convert_args)
105 | return mnn_path
106 |
107 | def removeDupOps(self, mnn_path):
108 | convert_args = [
109 | '',
110 | '-f',
111 | 'MNN',
112 | '--modelFile',
113 | str(mnn_path),
114 | '--MNNModel',
115 | str(mnn_path),
116 | '--optimizeLevel=1'
117 | ]
118 | self.convert(convert_args)
119 | return mnn_path
120 |
121 | def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False, weight_sym = None):
122 | self.onnx_model_path = onnx_path
123 | self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn')
124 | self.mnn_model_path = os.path.join(self.config.args.dst_path, self.mnn_name)
125 | self.mnn_weight_path = f'{self.mnn_model_path}.weight'
126 | if self.weight_ops is None:
127 | if quant_bit is None:
128 | quant_bit = self.quant_bit
129 | if quant_block is None:
130 | quant_block = self.quant_block
131 | if weight_sym is None:
132 | weight_sym = self.symmetric
133 | if quant_bit == 16:
134 | quant_args = ['--fp16']
135 | else:
136 | quant_args = [
137 | '--weightQuantBits',
138 | str(quant_bit),
139 | '--weightQuantBlock',
140 | str(quant_block)
141 | ]
142 | if quant_bit == 32:
143 | quant_args = []
144 | self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym)
145 | else:
146 | mnn_json = f'{self.mnn_model_path}.json'
147 | self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym)
148 | self.mnn2json(self.mnn_model_path, mnn_json)
149 | self.rebuild(mnn_json)
150 | self.json2mnn(mnn_json, self.mnn_model_path)
151 | self.removeDupOps(self.mnn_model_path)
152 | self.mnn2json(self.mnn_model_path, mnn_json)
153 | if self.config.args.gptq_path is not None:
154 | self.apply_gptq(mnn_json)
155 | if self.config.args.lora_path is not None and self.config.args.lora_split:
156 | self.export_lora(mnn_json)
157 | if self.config.args.smooth:
158 | self.export_smooth_quant(mnn_json)
159 |
160 | def get_experts_graphs(self, experts):
161 | hidden_states = torch.randn((1, self.config.hidden_size))
162 | layers_num = len(experts)
163 | expert_num = len(experts[0])
164 | dummy_expert = experts[0][0]
165 | onnx_model = f'{self.config.onnx_path}/expert.onnx'
166 | onnx_export(
167 | dummy_expert, (hidden_states),
168 | onnx_model,
169 | input_names=['hidden_states'],
170 | output_names=['hidden_states'])
171 | mnn_model = f'{onnx_model}.mnn'
172 | mnn_json = f'{mnn_model}.json'
173 | self.onnx2mnn(onnx_model, mnn_model)
174 | self.mnn2json(mnn_model, mnn_json)
175 | expert_graph = json.load(open(mnn_json, 'rt'))
176 | tensors = expert_graph['tensorName']
177 | nodes = expert_graph['oplists']
178 | # get input and output
179 | inputs = []
180 | outputs = []
181 | for node in nodes:
182 | if node['type'] == 'Input':
183 | inputs.append(node['outputIndexes'][0])
184 | for output_name in expert_graph['outputName']:
185 | outputs.append(tensors.index(output_name))
186 | subgraphs = []
187 | for i in range(layers_num):
188 | for j in range(expert_num):
189 | ijnodes = copy.deepcopy(nodes)
190 | for op in ijnodes:
191 | if op['type'] == 'Extra':
192 | for attr in op['main']['attr']:
193 | if attr['key'] == 'name':
194 | names = attr['s'].split('/')
195 | names[2] = f'{i}_{j}'
196 | attr['s'] = '/'.join(names)
197 | subgraph = {
198 | 'name': f'/expert/{i}_{j}',
199 | 'inputs': inputs,
200 | 'outputs': outputs,
201 | 'tensors': copy.deepcopy(tensors),
202 | 'nodes': ijnodes
203 | }
204 | subgraphs.append(subgraph)
205 | return subgraphs
206 |
207 |
208 | @spinner_run(f'apply gptq to ')
209 | def apply_gptq(self, mnn_json):
210 | GPTQ(self.config.args.gptq_path).apply(mnn_json, self.mnn_weight_path)
211 | return self.mnn_weight_path
212 |
213 | @spinner_run(f'export split lora to ')
214 | def export_lora(self, mnn_json):
215 | lora_model = os.path.join(self.config.args.dst_path, 'lora.mnn')
216 | lora_json = f'{lora_model}.json'
217 | LoRA(self.config.args.lora_path).apply(mnn_json, lora_json)
218 | self.json2mnn(lora_json, lora_model)
219 | if os.path.exists(lora_json):
220 | os.remove(lora_json)
221 | return lora_model
222 |
223 | @spinner_run(f'export smooth quant scale to ')
224 | def export_smooth_quant(self, mnn_json):
225 | self.config.smooth_quantizer.apply(mnn_json)
226 | self.json2mnn(mnn_json, self.mnn_model_path)
227 | return self.mnn_model_path
228 |
229 | @spinner_run(f'quant model weight to ', True)
230 | def rebuild(self, json_path):
231 | mnn_graph = json.load(open(json_path, 'rt'))
232 | has_experts = len(self.config.experts) > 0
233 | if has_experts:
234 | subgraphs = self.get_experts_graphs(self.config.experts)
235 | mnn_graph['subgraphs'] = subgraphs
236 | new_ops = []
237 | # Load layernorm weight from external
238 | with open(self.mnn_weight_path, 'rb') as f:
239 | for op in tqdm(mnn_graph['oplists'], 'Load LayerNorm data'):
240 | if op['type'] == 'LayerNorm' and 'external' in op['main']:
241 | external = op['main']['external']
242 | f.seek(external[0])
243 | op['main']['gamma'] = np.frombuffer(f.read(external[1]), np.float32).tolist()
244 | op['main']['beta'] = np.frombuffer(f.read(external[2]), np.float32).tolist()
245 | del op['main']['external']
246 | # Rebuild ops
247 | with open(self.mnn_weight_path, 'wb') as self.mnn_weight:
248 | for op in tqdm(mnn_graph['oplists'], 'Quant weights'):
249 | if op['type'] == 'Extra' or op['type'] == 'LayerNorm':
250 | new_ops += self.rebuild_op(op, mnn_graph)
251 | else:
252 | new_ops.append(op)
253 | mnn_graph['oplists'] = new_ops
254 | if has_experts and 'subgraphs' in mnn_graph:
255 | for subgraph in tqdm(mnn_graph['subgraphs'], 'Quant subgraphs weights'):
256 | new_subops = []
257 | for op in subgraph['nodes']:
258 | if op['type'] == 'Extra' or op['type'] == 'LayerNorm':
259 | new_subops += self.rebuild_op(op, subgraph)
260 | else:
261 | new_subops.append(op)
262 | subgraph['nodes'] = new_subops
263 | with open(json_path, 'w', encoding='utf-8') as file:
264 | json.dump(mnn_graph, file, ensure_ascii=False, indent=4)
265 | return self.mnn_weight_path
266 |
267 | def quant(self, weight, quant_bit, quant_block, symmetric):
268 | q_weight, alpha = torch_quant(weight, quant_bit, quant_block, symmetric, self.config.args.awq, self.config.args.hqq)
269 | return q_weight, alpha
270 |
271 | def write_weight(self, data):
272 | if isinstance(data, torch.Tensor):
273 | data = data.numpy()
274 | if isinstance(data, list):
275 | data = np.array(data).astype(np.float32)
276 | return self.mnn_weight.write(data.tobytes())
277 |
278 | def write_header(self, ic, oc, quant_bit):
279 | dim_num = self.mnn_weight.write(b'\x02')
280 | shape_dtype = np.int16
281 | if oc > 65535 or ic > 65535:
282 | shape_dtype = np.int32
283 | dim_length = self.write_weight(np.array([oc, ic]).astype(shape_dtype))
284 | offset = 1 << (quant_bit - 1)
285 | weight_map = [i for i in range(-offset, offset)]
286 | if len(weight_map) == 256:
287 | weight_map.insert(0, 0)
288 | else:
289 | weight_map.insert(0, len(weight_map))
290 | map_length = self.write_weight(np.array(weight_map, dtype=np.int8))
291 | header_length = dim_num + dim_length + map_length
292 | return header_length, shape_dtype == np.int32
293 |
294 | def build_weight(self, linear, quant_bit, quant_block, symmetric):
295 | ic, oc = linear.in_features, linear.out_features
296 | if quant_bit == 16:
297 | half_weight = linear.weight.data.flatten().half()
298 | weight_len = self.write_weight(half_weight)
299 | alpha_len, q_min, shape_int32, header_len = 0, 0, False, 0
300 | else:
301 | q_min = 1
302 | assert(quant_bit in (1, 2, 4, 8))
303 | q_weight, alpha = self.quant(linear.weight.data, quant_bit, quant_block, symmetric)
304 | header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
305 | weight_len = self.write_weight(q_weight) + header_len
306 | alpha_len = self.write_weight(alpha)
307 | if linear.bias is not None:
308 | bias = linear.bias.data.flatten().float()
309 | bias_length = self.write_weight(bias)
310 | else:
311 | bias_length = 0
312 | # bias = np.zeros([oc], dtype=np.float32)
313 | # bias_length = self.write_weight(bias)
314 | external = [self.mnn_weight_offset, weight_len, alpha_len, bias_length, 0]
315 | self.mnn_weight_offset += (weight_len + alpha_len + bias_length)
316 | return external, q_min, shape_int32, header_len
317 |
318 | def build_tensor(self, graph, tensor_name):
319 | tensor_key = 'tensorName'
320 | if tensor_key not in graph and 'tensors' in graph:
321 | tensor_key = 'tensors'
322 | tensor_idx = [len(graph[tensor_key])]
323 | graph[tensor_key].append(tensor_name)
324 | return tensor_idx
325 |
326 | def rebuild_op(self, op, graph):
327 | if "type" in op['main']:
328 | op_type = op['main']['type']
329 | else:
330 | op_type = op['type']
331 | if op_type == 'FakeLinear':
332 | return self.rebuild_linear(op, graph)
333 | if op_type == 'FusedAttention':
334 | return self.rebuild_attnention(op, graph)
335 | if op_type == "LayerNorm":
336 | return self.rebuild_layernorm(op, graph)
337 | if op_type == 'MoE':
338 | return self.rebuild_moe(op, graph)
339 | return None
340 |
341 | def rebuild_moe(self, op, graph):
342 | moe = copy.deepcopy(op)
343 | moe['main'] = { 'attr': moe['main']['attr'][:3] }
344 | moe['type'] = 'MoE'
345 | return [moe]
346 |
347 | def rebuild_layernorm(self, op, graph):
348 | if "gamma" not in op['main'] or "beta" not in op['main']:
349 | return [op]
350 | attr = op['main']
351 | gamma = attr['gamma']
352 | beta = attr['beta']
353 | gamma_len = self.write_weight(gamma)
354 | beta_len = self.write_weight(beta)
355 | del attr['gamma']
356 | del attr['beta']
357 | external = [self.mnn_weight_offset, gamma_len, beta_len]
358 | self.mnn_weight_offset += (gamma_len + beta_len)
359 | attr['external'] = external
360 | layernorm_op = {
361 | "name": op['name'],
362 | "inputIndexes": op['inputIndexes'],
363 | "outputIndexes": op['outputIndexes'],
364 | "type": "LayerNorm",
365 | "main_type": "LayerNorm",
366 | "main": attr,
367 | "defaultDimentionFormat": op['defaultDimentionFormat']
368 | }
369 | return [layernorm_op]
370 |
371 | def rebuild_attnention(self, op, graph):
372 | attrs = op['main']['attr']
373 | for attr in attrs:
374 | if attr['key'] == 'name':
375 | name = attr['s']
376 | origin_input = op['inputIndexes']
377 | origin_output = op['outputIndexes']
378 | fused_attention = {
379 | "inputIndexes": origin_input,
380 | "main_type": "AttentionParam",
381 | "main": { "kv_cache": True },
382 | "name": name,
383 | "outputIndexes": origin_output,
384 | "type": "Attention",
385 | "defaultDimentionFormat": "NHWC"
386 | }
387 | return [fused_attention]
388 |
389 | def rebuild_linear(self, op, graph):
390 | attrs = op['main']['attr']
391 | for attr in attrs:
392 | if attr['key'] == 'name':
393 | name = attr['s']
394 | elif attr['key'] == "in_features":
395 | ic = attr["i"]
396 | elif attr['key'] == "out_features":
397 | oc = attr["i"]
398 | elif attr['key'] == "has_bias":
399 | has_bias = attr["i"]
400 | linear = self.weight_ops[name]
401 | assert(linear.in_features == ic and
402 | linear.out_features == oc and
403 | (linear.bias is not None) == has_bias)
404 |
405 | is_lm = 'lm_head' in name
406 | quant_bit = self.lm_quant_bit if is_lm else self.quant_bit
407 | block_size = ic if self.quant_block == 0 else self.quant_block
408 | if is_lm and self.lm_weight is not None:
409 | external, q_min, shape_int32, header_len = self.lm_weight
410 | else:
411 | external, q_min, shape_int32, header_len = self.build_weight(linear, quant_bit, self.quant_block, self.symmetric)
412 | if is_lm and self.lm_weight is None:
413 | self.lm_weight = [external, q_min, shape_int32, header_len]
414 | if is_lm and self.config.tie_word_embeddings:
415 | weight_offset = external[0] + header_len
416 | alpha_offset = external[0] + external[1]
417 | alpha_size = external[2]
418 | self.config.llm_config['tie_embeddings'] = [weight_offset, alpha_offset, alpha_size, quant_bit, self.quant_block]
419 |
420 | origin_input = op['inputIndexes']
421 | origin_output = op['outputIndexes']
422 | # build new tensor
423 | pre_reshape_name = f'{name}/pre_reshape'
424 | pre_convert_name = f'{name}/pre_convert'
425 | conv_name = name
426 | post_convert_name = f'{name}/post_convert'
427 | post_reshape_name = f'{name}/post_reshape'
428 | pre_reshape_output = self.build_tensor(graph, pre_reshape_name)
429 | pre_convert_output = self.build_tensor(graph, pre_convert_name)
430 | conv_output = self.build_tensor(graph, conv_name)
431 | post_convert_output = self.build_tensor(graph, post_convert_name)
432 | # [batch, seq, hidden_size_i] -[Linear] -> [batch, seq, hidden_size_o]
433 | # [1, seq, hidden_size_i] ->[Reshape]-> [seq, hidden_size_i, 1, 1]
434 | # -[Convert]-[Convolution]-[Convert]-> [Reshape] -> [1, seq, hidden_size_o]
435 | pre_reshape = {
436 | "name": pre_reshape_name,
437 | "type": "Reshape",
438 | "inputIndexes": origin_input,
439 | "outputIndexes": pre_reshape_output,
440 | "main_type": "Reshape",
441 | "main": {
442 | "dims": [-1, ic, 1, 1],
443 | "dimType": "NCHW"
444 | },
445 | "defaultDimentionFormat": "NHWC"
446 | }
447 | pre_convert = {
448 | "name": pre_convert_name,
449 | "inputIndexes": pre_reshape_output,
450 | "outputIndexes": pre_convert_output,
451 | "type": "ConvertTensor",
452 | "main_type": "TensorConvertInfo",
453 | "main": {
454 | "source": "NCHW",
455 | "dest": "NC4HW4"
456 | },
457 | "defaultDimentionFormat": "NHWC"
458 | }
459 |
460 | if quant_bit == 16:
461 | quanParameter = { "type": 3 }
462 | else:
463 | if self.symmetric:
464 | aMin = 0
465 | readType = 0
466 | else:
467 | aMin = q_min
468 | readType = oc * (ic // block_size)
469 |
470 | quanParameter = {
471 | "quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
472 | "useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
473 | "type": 1, "aMaxOrBits": quant_bit, "aMin": aMin, "readType": readType, "weightSize": 0
474 | }
475 | conv_op = {
476 | "name": conv_name,
477 | "inputIndexes": pre_convert_output,
478 | "outputIndexes": conv_output,
479 | "type": "Convolution",
480 | "main_type": "Convolution2D",
481 | "main": {
482 | 'common': {
483 | 'dilateX': 1, 'dilateY': 1, 'strideX': 1, 'strideY': 1,
484 | 'kernelX': 1, 'kernelY': 1, 'padX': 0, 'padY': 0, 'group': 1,
485 | 'outputCount': oc, 'relu': False, 'padMode': 'CAFFE',
486 | 'relu6': False, 'inputCount': ic, 'hasOutputShape': False
487 | },
488 | "quanParameter": quanParameter,
489 | "external": external
490 | },
491 | "defaultDimentionFormat": "NHWC"
492 | }
493 | post_convert = {
494 | "name": post_convert_name,
495 | "inputIndexes": conv_output,
496 | "outputIndexes": post_convert_output,
497 | "type": "ConvertTensor",
498 | "main_type": "TensorConvertInfo",
499 | "main": {
500 | "source": "NC4HW4",
501 | "dest": "NCHW"
502 | },
503 | "defaultDimentionFormat": "NHWC"
504 | }
505 | post_reshape = {
506 | "name": post_reshape_name,
507 | "type": "Reshape",
508 | "inputIndexes": post_convert_output,
509 | "outputIndexes": origin_output,
510 | "main_type": "Reshape",
511 | "main": {
512 | "dims": [1, -1, oc],
513 | "dimType": "NCHW"
514 | },
515 | "defaultDimentionFormat": "NHWC"
516 | }
517 | if name.startswith('/expert/'):
518 | post_reshape['main']['dims'] = [-1, oc]
519 | return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape]
520 |
--------------------------------------------------------------------------------
/llmexport/gguf2mnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | from gguf import gguf_reader
3 | from gguf import constants
4 | import numpy
5 | import json
6 | import argparse
7 |
8 | from .utils.mnn_utils import *
9 |
10 | class TokenContent:
11 | def __init__(self):
12 | self.token_type = -1
13 | self.spec_ids = []
14 | self.names = []
15 | self.stop_ids = []
16 | self.pre_ids = []
17 | self.token_num = 0
18 |
19 | def load_token(reader):
20 | content = TokenContent()
21 | model = reader.fields['tokenizer.ggml.model'].parts[4].tobytes().decode('utf-8')
22 | field = reader.fields['tokenizer.ggml.token_type']
23 | valids = []
24 | for i in range(0, len(field.data)):
25 | p = field.data[i]
26 | if field.parts[p] == 1:
27 | #normal
28 | valids.append(i)
29 | elif field.parts[p] == 3 or field.parts[p] == 4:
30 | valids.append(i)
31 | content.spec_ids.append(i)
32 | tokens = reader.fields['tokenizer.ggml.tokens']
33 | stopes = ["<|eot_id|>", "<|im_end|>", "<|end|>", "", "<|endoftext|>", "<|eom_id|>", ""]
34 |
35 | for i in valids:
36 | p = tokens.data[i]
37 | tok = tokens.parts[p].tobytes().decode('utf-8')
38 | if tok in stopes:
39 | content.stop_ids.append(i)
40 | content.names.append(tok)
41 | content.token_num = len(content.names)
42 | if model == "gpt2":
43 | # bpe -> HUGGINGFACE
44 | content.token_type = 3
45 | # load merge
46 | merges = reader.fields['tokenizer.ggml.merges']
47 | for i in range(0, len(merges.data)):
48 | p = merges.data[i]
49 | tok = merges.parts[p].tobytes().decode('utf-8')
50 | content.names.append(tok)
51 | elif model == 'llama':
52 | content.token_type = 1
53 | else:
54 | print("[Error] Not support token type: , you can try download tokenizer.txt from old MNN LLM model", model)
55 | return content
56 |
57 | def write_token_file(filename, token):
58 | with open(filename, 'w') as f:
59 | f.write("430 %d\n" %token.token_type)
60 | f.write("%d " %(len(token.spec_ids)) + '%d 0\n' %(len(token.stop_ids)))
61 | l = ""
62 | for i in token.spec_ids:
63 | l += "%d " %i
64 | for i in token.stop_ids:
65 | l += "%d " %i
66 | l+='\n'
67 | f.write(l)
68 | if token.token_type == 3:
69 | merge_num = len(token.names) - token.token_num
70 | f.write("%d " %token.token_num + "%d\n" %merge_num)
71 | else:
72 | f.write("%d\n" %token.token_num)
73 | for name in token.names:
74 | f.write(name + '\n')
75 | return
76 |
77 | def shuffle_weight_int4(weight_main):
78 | # shuffle weight
79 | block_number = weight_main.shape[0]
80 | half_block_size = weight_main.shape[1]
81 | weight_main_low = weight_main % 16
82 | weight_main_high = weight_main // 16
83 | weight_main = numpy.concatenate([weight_main_low, weight_main_high], axis = 1).reshape([block_number, half_block_size, 2])
84 | weight_main_low = weight_main[:, :, 1]
85 | weight_main_high = weight_main[:, :, 0]
86 | weight_main = weight_main_low + weight_main_high * 16
87 | return weight_main
88 |
89 | # const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
90 | # const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
91 |
92 | # const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
93 | # const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
94 |
95 | # y[i*qk + j + 0 ] = x0*d;
96 | # y[i*qk + j + qk/2] = x1*d;
97 |
98 | def shuffle_weight_int5(weight, repack = True):
99 | block_number = weight.shape[0]
100 | qh = weight[:, 0:4]
101 | qs = weight[:, 4:20]
102 | x0 = qs & 0x0F
103 | x1 = qs >> 4
104 | qh = numpy.frombuffer(qh.tobytes(), numpy.uint32).reshape([block_number, 1])
105 | mask_0 = []
106 | mask_1 = []
107 | for i in range(0, 16):
108 | mask_0.append(((qh >> i)<< 4) & 0x10)
109 | mask_1.append(((qh >> (i+12))) & 0x10)
110 | mask_0 = numpy.concatenate(mask_0, axis=1)
111 | mask_1 = numpy.concatenate(mask_1, axis=1)
112 | x0 = x0 + mask_0
113 | x1 = x1 + mask_1
114 | x = numpy.concatenate([x0, x1], axis=1)
115 | if repack:
116 | return repack_low_bits(x, 5, 32)
117 | return x
118 |
119 | def extract_tensor_as_int8(weight):
120 | ic = int(weight.shape[0])
121 | oc = int(weight.shape[1])
122 | if weight.tensor_type == constants.GGMLQuantizationType.Q6_K:
123 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
124 | block_number = oc * ic // block_size
125 | weight = weight.data.reshape([oc * ic // block_size, type_size])
126 | scale_int8 = weight[:, 192:208]
127 | scale_half = weight[:, 208:210]
128 | scale_int8 = numpy.frombuffer(scale_int8.tobytes(), numpy.int8).astype(numpy.float32).reshape([block_number, 16, 1])
129 | scale_half = numpy.frombuffer(scale_half.tobytes(), numpy.float16).astype(numpy.float32).reshape([block_number, 1, 1])
130 | weight_scale = scale_half * scale_int8
131 |
132 | # Extract to int8
133 | ql = weight[:, 0:128]
134 | qh = weight[:, 128:192]
135 |
136 | qall = []
137 | for i in range(256):
138 | qall.append(None)
139 | for nnp in range(0, 2):
140 | for l in range(0, 32):
141 | q1 = ((ql[:, l + 0 + 64 * nnp] & 0xF) | (((qh[:, l + 32*nnp] >> 0) & 3) << 4))
142 | q2 = ((ql[:, l + 32 + 64 * nnp] & 0xF) | (((qh[:, l + 32*nnp] >> 2) & 3) << 4))
143 | q3 = ((ql[:, l + 0 + 64 * nnp] >> 4) | (((qh[:, l + 32*nnp] >> 4) & 3) << 4))
144 | q4 = ((ql[:, l + 32 + 64 * nnp] >> 4) | (((qh[:, l + 32*nnp] >> 6) & 3) << 4))
145 | qall[l + 0 + 128 * nnp] = q1.reshape([block_number, 1])
146 | qall[l + 32 + 128 * nnp] = q2.reshape([block_number, 1])
147 | qall[l + 64 + 128 * nnp] = q3.reshape([block_number, 1])
148 | qall[l + 96 + 128 * nnp] = q4.reshape([block_number, 1])
149 | q_raw = numpy.concatenate(qall, axis = 1)
150 | return q_raw, weight_scale, 16, 6
151 | elif weight.tensor_type == constants.GGMLQuantizationType.Q5_0:
152 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
153 | weight = weight.data.reshape([oc * ic // block_size, type_size])
154 | # Seperate Scale and Bias
155 | weight_main = weight[:, 2:type_size]
156 | weight_main = shuffle_weight_int5(weight_main, False)
157 | weight_scale = weight[:, 0:2]
158 | weight_scale = numpy.frombuffer(weight_scale.tobytes(), numpy.float16).astype(numpy.float32)
159 | return weight_main, weight_scale, 32, 5
160 | return None
161 |
162 | def write_external_weight(weight, mnn_weight_file, mnn_weight_offset):
163 | ic = int(weight.shape[0])
164 | oc = int(weight.shape[1])
165 | bias_length = oc * 4
166 | conv = {}
167 | block_size = 0
168 | block_number = 0
169 | quant_bit = 0
170 | tie_embedding = False
171 | header_len = 0
172 | if weight.tensor_type == constants.GGMLQuantizationType.F16:
173 | # FP16
174 | quan = {}
175 | quan['type'] = 3
176 | conv['quanParameter'] = quan
177 | rawbytes = weight.data.tobytes()
178 | weightlen = mnn_weight_file.write(rawbytes)
179 | external = [mnn_weight_offset, weightlen, 0, bias_length, 0]
180 | conv['external'] = external
181 | mnn_weight_offset += weightlen
182 | tie_embedding = True
183 | quant_bit = 16
184 | elif weight.tensor_type == constants.GGMLQuantizationType.F32:
185 | # FP16
186 | quan = {}
187 | quan['type'] = 3
188 | conv['quanParameter'] = quan
189 | rawbytes = weight.data.astype(numpy.float16).tobytes()
190 | weightlen = mnn_weight_file.write(rawbytes)
191 | external = [mnn_weight_offset, weightlen, 0, bias_length, 0]
192 | conv['external'] = external
193 | mnn_weight_offset += weightlen
194 | elif weight.tensor_type == constants.GGMLQuantizationType.Q4_0:
195 | tie_embedding = True
196 | quant_bit = 4
197 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
198 | block_number = oc * ic // block_size
199 | weight = weight.data.reshape([block_number, type_size])
200 | # Seperate Scale and Bias
201 | weight_main = weight[:, 2:type_size]
202 | weight_scale = weight[:, 0:2]
203 | weight_scale = numpy.frombuffer(weight_scale.tobytes(), numpy.float16).astype(numpy.float32)
204 |
205 | # shuffle weight
206 | weight_main = shuffle_weight_int4(weight_main)
207 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, False, mnn_weight_file, ic, oc, weight_main, weight_scale, mnn_weight_offset)
208 | elif weight.tensor_type == constants.GGMLQuantizationType.Q4_1:
209 | quant_bit = 4
210 | tie_embedding = True
211 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
212 | block_number = oc * ic // block_size
213 | weight = weight.data.reshape([oc * ic // block_size, type_size])
214 | # Seperate Scale and Bias
215 | weight_main = weight[:, 4:type_size]
216 |
217 | # shuffle weight
218 | weight_main = shuffle_weight_int4(weight_main);
219 |
220 | weight_scale = weight[:, 0:2]
221 | weight_bias = weight[:, 2:4]
222 | weight_scale = numpy.frombuffer(weight_scale.tobytes(), numpy.float16).reshape((block_number, 1))
223 | weight_bias = numpy.frombuffer(weight_bias.tobytes(), numpy.float16).reshape((block_number, 1))
224 | scalebias = numpy.concatenate((weight_bias, weight_scale), axis=1).astype(numpy.float32)
225 |
226 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, True, mnn_weight_file, ic, oc, weight_main, scalebias, mnn_weight_offset)
227 | elif weight.tensor_type == constants.GGMLQuantizationType.Q4_K:
228 | quant_bit = 4
229 | tie_embedding = True
230 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
231 | block_number = oc * ic // block_size
232 | weight = weight.data.reshape([oc * ic // block_size, type_size])
233 | # Seperate Scale and Bias
234 | d = weight[:, 0:2]
235 | dmin = weight[:, 2:4]
236 | scales = weight[:, 4:16]
237 | weight_main = weight[:, 16:type_size]
238 |
239 | # shuffle weight
240 | weight_main = weight_main.reshape((block_number * 4, 32))
241 | weight_main = shuffle_weight_int4(weight_main)
242 |
243 | # Compute Scale
244 | d = numpy.frombuffer(d.tobytes(), numpy.float16).reshape((block_number, 1)).astype(numpy.float32)
245 | dmin = numpy.frombuffer(dmin.tobytes(), numpy.float16).reshape((block_number, 1)).astype(numpy.float32)
246 |
247 | def get_scale_min_k4(j, q):
248 | if j < 4:
249 | d = q[:, j] & 63
250 | m = q[:, j + 4] & 63
251 | else:
252 | d = (q[:, j+4] & 0xF) | ((q[:, j-4] >> 6) << 4)
253 | m = (q[:, j+4] >> 4) | ((q[:, j-0] >> 6) << 4)
254 | return d, m
255 | dgroup=[]
256 | mgroup=[]
257 | for j in range(0, 8):
258 | dgroup.append(None)
259 | mgroup.append(None)
260 | for j in range(0, 8):
261 | vd, vm = get_scale_min_k4(j, scales)
262 | vd = vd.reshape((block_number, 1))
263 | vm = vm.reshape((block_number, 1))
264 | vd = vd.astype(numpy.float32) * d
265 | vm = vm.astype(numpy.float32) * dmin
266 | dgroup[j] = vd
267 | mgroup[j] = -vm
268 | weight_scale = numpy.concatenate(dgroup, -1).reshape((block_number, 8, 1))
269 | weight_bias = numpy.concatenate(mgroup, -1).reshape((block_number, 8, 1))
270 | scalebias = numpy.concatenate((weight_bias, weight_scale), axis=-1).astype(numpy.float32)
271 |
272 |
273 | block_size = 32
274 |
275 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, True, mnn_weight_file, ic, oc, weight_main, scalebias, mnn_weight_offset)
276 |
277 | elif weight.tensor_type == constants.GGMLQuantizationType.Q8_0:
278 | quant_bit = 8
279 | tie_embedding = True
280 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
281 | weight = weight.data.reshape([oc * ic // block_size, type_size])
282 | # Seperate Scale and Bias
283 | weight_main = weight[:, 2:type_size]
284 | weight_scale = weight[:, 0:2]
285 | weight_scale = numpy.frombuffer(weight_scale.tobytes(), numpy.float16).astype(numpy.float32)
286 | weight_main = numpy.frombuffer(weight_main.tobytes(), numpy.int8).astype(numpy.int16) + 128
287 | weight_main = weight_main.astype(numpy.uint8)
288 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, False, mnn_weight_file, ic, oc, weight_main, weight_scale, mnn_weight_offset)
289 |
290 | elif weight.tensor_type == constants.GGMLQuantizationType.Q5_0:
291 | tie_embedding = False
292 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
293 | weight = weight.data.reshape([oc * ic // block_size, type_size])
294 | # Seperate Scale and Bias
295 | weight_main = weight[:, 2:type_size]
296 | weight_main = shuffle_weight_int5(weight_main)
297 | weight_scale = weight[:, 0:2]
298 | weight_scale = numpy.frombuffer(weight_scale.tobytes(), numpy.float16).astype(numpy.float32)
299 | quant_bit = 5
300 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, False, mnn_weight_file, ic, oc, weight_main, weight_scale, mnn_weight_offset)
301 |
302 | elif weight.tensor_type == constants.GGMLQuantizationType.Q5_1:
303 | tie_embedding = False
304 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
305 | block_number = oc * ic // block_size
306 | weight = weight.data.reshape([oc * ic // block_size, type_size])
307 | # Seperate Scale and Bias
308 | weight_main = weight[:, 4:type_size]
309 | weight_main = shuffle_weight_int5(weight_main)
310 | weight_scale = weight[:, 0:2]
311 | weight_bias = weight[:, 2:4]
312 | weight_scale = numpy.frombuffer(weight_scale.tobytes(), numpy.float16).reshape((block_number, 1))
313 | weight_bias = numpy.frombuffer(weight_bias.tobytes(), numpy.float16).reshape((block_number, 1))
314 | weight_scale = numpy.concatenate((weight_bias, weight_scale), axis=1).astype(numpy.float32)
315 | quant_bit = 5
316 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, True, mnn_weight_file, ic, oc, weight_main, weight_scale, mnn_weight_offset)
317 | elif weight.tensor_type == constants.GGMLQuantizationType.Q6_K:
318 | block_size, type_size = constants.GGML_QUANT_SIZES[weight.tensor_type]
319 | block_number = oc * ic // block_size
320 | q_raw, weight_scale, block_size, bits = extract_tensor_as_int8(weight)
321 | weight_main = repack_low_bits(q_raw, 6, 256)
322 | quant_bit = 6
323 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, False, mnn_weight_file, ic, oc, weight_main, weight_scale, mnn_weight_offset)
324 |
325 | else:
326 | print('Not support type: ', weight.tensor_type)
327 | print(weight.data.shape, ic, oc)
328 | assert(False)
329 | return mnn_weight_offset, conv, tie_embedding, block_size, quant_bit, header_len
330 |
331 | def convert(args):
332 | gguf = args.gguf
333 | mnn_dir = args.mnn_dir
334 | src_json = os.path.join(mnn_dir, "llm.mnn.json")
335 | dst_json = os.path.join(mnn_dir, "llm.mnn_new.json")
336 |
337 | mnn, opmap, convs, _, __ = load_mnn(src_json)
338 | llm_config = {}
339 | with open(os.path.join(mnn_dir, "llm_config.json")) as f:
340 | llm_config = json.load(f)
341 |
342 | reader = gguf_reader.GGUFReader(gguf)
343 | if args.load_token:
344 | write_token_file(os.path.join(mnn_dir, "tokenizer.txt"), load_token(reader))
345 | arch = reader.fields['general.architecture'].parts[4].tobytes().decode('utf-8')
346 | print("Arch:", arch)
347 | tensormap = {}
348 | for t in reader.tensors:
349 | tensormap[t.name] = t
350 |
351 | mnn_weight_file = open(os.path.join(mnn_dir, "llm.mnn.weight"), "wb")
352 | mnn_weight_offset = 0
353 | if 'tie_embeddings' in llm_config:
354 | del llm_config['tie_embeddings']
355 | for name in opmap:
356 | op = opmap[name]
357 | print('Load layernorm: ', name)
358 | if op['type'] == 'LayerNorm':
359 | weight_tensor = tensormap[name+'.weight']
360 | layernorm = op['main']
361 | layernorm['gamma'] = weight_tensor.data.tolist()
362 | if name+'.bias' in tensormap:
363 | layernorm['beta'] = tensormap[name+'.bias'].data.tolist()
364 | else:
365 | layernorm['beta'] = [0.0] * len(layernorm['gamma'])
366 | continue
367 | for op in convs:
368 | conv = op['main']
369 | name = op['name']
370 | if 'quanParameter' in conv:
371 | del conv['quanParameter']
372 | weight_name = name+'.weight'
373 | weight = None
374 | tie_embedding = False
375 | ichannel = conv['common']['inputCount']
376 | ochannel = conv['common']['outputCount']
377 | if name == 'output':
378 | print('hidden size: ', ichannel)
379 | llm_config['hidden_size'] = ichannel
380 | if weight_name in tensormap:
381 | weight = tensormap[weight_name]
382 | elif name == 'output':
383 | weight = tensormap['token_embd.weight']
384 | tie_embedding = True
385 | else:
386 | print("Error: Can't find weight for " + name)
387 | assert(False)
388 | print('Load Convolution: ', name, ", weight type: ", weight.tensor_type)
389 | if weight.shape[0] != ichannel or weight.shape[1] != ochannel:
390 | print(name, ", weight not match: ", ichannel, ", ", ochannel, " : ", weight.shape, ", reset to ", weight.shape)
391 | ichannel = int(weight.shape[0])
392 | ochannel = int(weight.shape[1])
393 | conv['common']['inputCount'] = ichannel
394 | conv['common']['outputCount'] = ochannel
395 | # Change post reshape for convolution
396 | outputIndex = op['outputIndexes'][0]
397 | for subop in mnn["oplists"]:
398 | if 'inputIndexes' not in subop:
399 | continue
400 | if subop['inputIndexes'][0] == outputIndex and subop['type'] == 'ConvertTensor':
401 | outputIndex = subop['outputIndexes'][0]
402 | break
403 | for subop in mnn["oplists"]:
404 | if 'inputIndexes' not in subop:
405 | continue
406 | if subop['inputIndexes'][0] == outputIndex and subop['type'] == 'Reshape':
407 | subop['main']['dims'][2] = ochannel
408 | break
409 | mnn_weight_offset, conv_new, can_tie_embedding, block_size, quant_bit, header_len = write_external_weight(weight, mnn_weight_file, mnn_weight_offset)
410 | if not can_tie_embedding:
411 | tie_embedding = False
412 | conv['quanParameter'] = conv_new['quanParameter']
413 | conv['external'] = conv_new['external']
414 |
415 | bias = None
416 | bias_name = name + '.bias'
417 | if bias_name in tensormap:
418 | if tensormap[bias_name].tensor_type > 1:
419 | print('Error: Bias is quant: ', tensormap[bias_name].tensor_type)
420 | assert(False)
421 | bias = tensormap[bias_name].data.astype(numpy.float32)
422 | else:
423 | bias = numpy.zeros(ochannel).astype(numpy.float32)
424 | mnn_weight_offset += mnn_weight_file.write(bias.tobytes())
425 | if tie_embedding:
426 | external = conv['external']
427 | weight_offset = external[0] + header_len
428 | alpha_offset = external[0] + external[1]
429 | alpha_size = external[2]
430 | llm_config['tie_embeddings'] = [weight_offset, alpha_offset, alpha_size, quant_bit, 32]
431 | embedding_file = os.path.join(mnn_dir, "embeddings_bf16.bin")
432 |
433 | embeding_in_weight = True
434 | if 'tie_embeddings' not in llm_config:
435 | # Need write embedding
436 | weight = tensormap['token_embd.weight']
437 | print("Embedding type: ", weight.tensor_type)
438 | if weight.tensor_type <= 1:
439 | embeding_in_weight = False
440 | print("Write ", embedding_file)
441 | weight = weight.data.astype(numpy.float32)
442 | weight = numpy.frombuffer(weight.tobytes(), numpy.uint32) >> 16
443 | weight = weight.astype(numpy.uint16)
444 | with open(embedding_file, 'wb') as f:
445 | f.write(weight.tobytes())
446 | elif weight.tensor_type == constants.GGMLQuantizationType.Q8_0 or weight.tensor_type == constants.GGMLQuantizationType.Q4_0 or weight.tensor_type == constants.GGMLQuantizationType.Q4_1:
447 | mnn_weight_offset, conv, can_tie_embedding, block_size, quant_bit, header_len = write_external_weight(weight, mnn_weight_file, mnn_weight_offset)
448 | external = conv['external']
449 | weight_offset = external[0] + header_len
450 | alpha_offset = external[0] + external[1]
451 | alpha_size = external[2]
452 | llm_config['tie_embeddings'] = [weight_offset, alpha_offset, alpha_size, quant_bit, block_size]
453 | elif weight.tensor_type == constants.GGMLQuantizationType.Q6_K or weight.tensor_type == constants.GGMLQuantizationType.Q5_0:
454 | q_raw, weight_scale, block_size, bits = extract_tensor_as_int8(weight)
455 | # embeding_in_weight = False
456 | ic = int(weight.shape[0])
457 | oc = int(weight.shape[1])
458 | offset = (1 << (bits - 1))
459 | q_raw = repack_low_bits(q_raw, 8, q_raw.shape[1])
460 | q_raw = q_raw + (128-offset)
461 | quant_bit = 8
462 | conv, header_len, mnn_weight_offset = write_quant_parameters(quant_bit, False, mnn_weight_file, ic, oc, q_raw, weight_scale, mnn_weight_offset)
463 | external = conv['external']
464 | weight_offset = external[0] + header_len
465 | alpha_offset = external[0] + external[1]
466 | alpha_size = external[2]
467 | llm_config['tie_embeddings'] = [weight_offset, alpha_offset, alpha_size, quant_bit, block_size]
468 | else:
469 | assert(False)
470 |
471 | if embeding_in_weight:
472 | if os.path.exists(embedding_file):
473 | os.remove(embedding_file)
474 |
475 | mnn_weight_file.close()
476 | with open(dst_json, 'w') as f:
477 | f.write(json.dumps(mnn, indent=4))
478 | with open(os.path.join(mnn_dir, "llm_config.json"), 'w') as f:
479 | f.write(json.dumps(llm_config, indent=4))
480 |
481 | convert_args = [
482 | '',
483 | '-f',
484 | 'JSON',
485 | '--modelFile',
486 | dst_json,
487 | '--MNNModel',
488 | os.path.join(mnn_dir, 'llm.mnn'),
489 | ]
490 |
491 | print(convert_args)
492 | from MNN.tools import mnnconvert
493 | mnnconvert.convert(convert_args)
494 | os.remove(dst_json)
495 |
496 | if __name__ == '__main__':
497 | parser = argparse.ArgumentParser(description='gguf2mnn', formatter_class=argparse.RawTextHelpFormatter)
498 | parser.add_argument('--gguf', type=str, required=True,help='src gguf model')
499 | parser.add_argument('--mnn_dir', type=str, required=True,help='mnn llm dir')
500 | parser.add_argument('--load_token', type=bool, default = False, help='Override tokenizer.txt from gguf')
501 | args = parser.parse_args()
502 | import time
503 | sta = time.time()
504 | convert(args)
505 | fin = time.time()
506 | print("Cost time ", fin - sta, " s")
507 |
--------------------------------------------------------------------------------