├── QIGen.py ├── README.md ├── demo.ipynb ├── generate.py ├── intrin.py ├── mmm.cpp ├── requirements.txt ├── setup.py ├── template.py └── utils.py /QIGen.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | from tqdm import tqdm 4 | import gc 5 | 6 | import cQIGen as qinfer 7 | import math 8 | import numpy as np 9 | from gekko import GEKKO 10 | from utils import mem_model 11 | 12 | params = {} 13 | 14 | def compute_reductions(x, gs=-1, cpp=True): 15 | if cpp: 16 | if len(x.shape) != 1: 17 | rows, cols = x.shape 18 | else: 19 | rows = 1 20 | cols = x.shape[0] 21 | if gs == -1: 22 | out = torch.zeros(rows).float().contiguous() 23 | mygs = cols 24 | else: 25 | out = torch.zeros(rows, cols // gs).float().contiguous() 26 | mygs = gs 27 | 28 | qinfer.compute_reduction_cpp(x, out, rows, cols, mygs) 29 | return out 30 | if gs == -1: 31 | if len(x.shape) != 1: 32 | return torch.sum(x,1) 33 | else: 34 | return torch.sum(x) 35 | else: 36 | if len(x.shape) != 1: 37 | rows, cols = x.shape 38 | out = torch.zeros(rows, cols // gs).float().contiguous() 39 | for i in range(cols // gs): 40 | out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1) 41 | return out 42 | else: 43 | cols = x.shape[0] 44 | out = torch.zeros(cols // gs).float().contiguous() 45 | for i in range(cols // gs): 46 | out[i] = torch.sum(x[i*gs:(i+1)*gs]) 47 | return out 48 | 49 | def process_zeros_scales(zeros, scales, bits, M): 50 | if zeros.dtype != torch.float32: 51 | new_zeros = torch.zeros_like(scales).float().contiguous() 52 | if bits == 4: 53 | qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1]) 54 | elif bits == 2: 55 | qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1]) 56 | elif bits == 3: 57 | print("Unpacking zeros for 3 bits") 58 | new_scales = scales.contiguous() 59 | else: 60 | if scales.shape[1] != M: 61 | new_scales = scales.transpose(0,1).contiguous() 62 | else: 63 | new_scales = scales.contiguous() 64 | if zeros.shape[1] != M: 65 | new_zeros = zeros.transpose(0,1).contiguous() 66 | else: 67 | new_zeros = zeros.contiguous() 68 | 69 | return new_zeros, new_scales 70 | 71 | class qLinear(torch.nn.Module): 72 | def __str__(self): 73 | return self.name 74 | 75 | def __init__(self, mode, p, l1, name="", other=None, N=0, M=0, qweights=None, zeros=None, scales=None, bias=None, bits=4, hint=1, verbose=False, gs=-1): 76 | super().__init__() 77 | self.bits = bits 78 | pack = 32 // bits 79 | 80 | if mode == 'linear': 81 | self.N, self.M = other.in_features, other.out_features 82 | else: 83 | self.N, self.M = N, M 84 | 85 | n = hint 86 | m = self.N 87 | t = self.M 88 | 89 | 90 | #registers for now are fixed 91 | if bits == 3: 92 | packed = 32 93 | unroll = 3 94 | nu = 1 #args.n 95 | mu = 32 96 | tu = 32 97 | else: 98 | packed = 32 // bits 99 | unroll = 2 100 | nu = 1 #args.n 101 | mu = 16 102 | tu = 32 103 | 104 | nb = n # it's always small for transformers 105 | 106 | global params 107 | if (m,t) in params: 108 | mb = params[(m,t)][0] 109 | tb = params[(m,t)][1] 110 | else: 111 | if verbose: 112 | print("Computing memory model for {}x{}x{} with {} bits".format(n,m,t,bits)) 113 | mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, gs, verbose=False) 114 | params[(m,t)] = (mb,tb) 115 | 116 | split = np.ones(p) 117 | split = split * tb 118 | while np.sum(split) < t: 119 | split = split + tb 120 | 121 | idx = p - 1 122 | while np.sum(split) > t: 123 | split[idx] = split[idx] - tb 124 | idx = idx - 1 125 | 126 | assert(np.sum(split) == t) 127 | 128 | split = split.astype(int) 129 | self.tt = int(split[0]) 130 | 131 | if split[0] == split[-1]: 132 | self.cutoff = int(p+1) 133 | else: 134 | self.cutoff = int(idx + 1) 135 | 136 | self.mb = mb #// packed 137 | self.tb = tb 138 | 139 | self.gs = gs 140 | 141 | if verbose: 142 | print("Chose parameters {}x{}x{} with {} bits and tt {}".format(nb,mb,tb,bits,self.tt)) 143 | 144 | 145 | self.name=name 146 | if bias is None: 147 | self.bias = torch.zeros(self.M) 148 | else: 149 | self.bias = bias 150 | 151 | self.zeros, self.scales = process_zeros_scales(zeros, scales, bits, self.M) 152 | 153 | 154 | if bits == 4: 155 | if verbose: 156 | print(self.N // packed, self.M, self.mb // packed, self.tb, self.cutoff) 157 | self.weight = torch.zeros(int(self.N // packed * self.M)).int().contiguous() 158 | qinfer.pack4(qweights.int().contiguous(),self.weight, self.N // packed, self.M, self.mb, self.tb, self.cutoff)# * (self.tt//tb)) 159 | elif bits == 3: 160 | self.weight = torch.zeros(int(self.N // packed * 3 * self.M)).int().contiguous() 161 | if verbose: 162 | print(self.N // packed * 3, self.M, self.mb // packed * 3, self.tb, self.cutoff) 163 | qinfer.pack3(qweights.int().contiguous(),self.weight, self.N // packed * 3, self.M, self.mb // packed * 3, self.tb, self.cutoff) 164 | elif bits == 2: 165 | self.weight = torch.zeros(int(self.N // packed * self.M)).int().contiguous() 166 | qinfer.pack2(qweights.int().contiguous(),self.weight, self.N // packed, self.M, self.mb, self.tb, self.cutoff)# * (self.tt//tb)) 167 | 168 | 169 | 170 | def forward(self, x): 171 | x = x.reshape((-1, x.shape[-1])) 172 | B = x.shape[0] 173 | new_x = x.T.contiguous() 174 | out = torch.zeros((B, self.M), dtype=torch.float32).contiguous() 175 | sums = compute_reductions(x,gs=self.gs,cpp=True) 176 | sums = sums.contiguous() 177 | if self.gs == -1: 178 | if self.bits == 4: 179 | qinfer.forward4(new_x.contiguous(), self.weight.contiguous(), out.contiguous(), self.bias.contiguous(), 180 | self.scales.contiguous(), self.zeros.contiguous(), sums.contiguous(), B, self.N, self.M, B, self.mb, self.tb, self.tt, self.cutoff) 181 | elif self.bits == 2: 182 | qinfer.forward2(new_x.contiguous(), self.weight.contiguous(), out.contiguous(), self.bias.contiguous(), 183 | self.scales.contiguous(), self.zeros.contiguous(), sums.contiguous(), B, self.N, self.M, B, self.mb, self.tb, self.tt, self.cutoff) 184 | elif self.bits == 3: 185 | qinfer.forward3(new_x.contiguous(), self.weight.contiguous(), out.contiguous(), self.bias.contiguous(), 186 | self.scales.contiguous(), self.zeros.contiguous(), sums.contiguous(), B, self.N, self.M, B, self.mb, self.tb, self.tt, self.cutoff) 187 | else: 188 | if self.bits == 4: 189 | qinfer.forward_gs4(new_x.contiguous(), self.weight.contiguous(), out.contiguous(), self.bias.contiguous(), 190 | self.scales.contiguous(), self.zeros.contiguous(), sums.contiguous(), B, self.N, self.M, B, self.mb, self.tb, self.tt, self.gs, self.cutoff) 191 | elif self.bits == 2: 192 | qinfer.forward_gs2(new_x.contiguous(), self.weight.contiguous(), out.contiguous(), self.bias.contiguous(), 193 | self.scales.contiguous(), self.zeros.contiguous(), sums.contiguous(), B, self.N, self.M, B, self.mb, self.tb, self.tt, self.gs, self.cutoff) 194 | elif self.bits == 3: 195 | qinfer.forward_gs3(new_x.contiguous(), self.weight.contiguous(), out.contiguous(), self.bias.contiguous(), 196 | self.scales.contiguous(), self.zeros.contiguous(), sums.contiguous(), B, self.N, self.M, B, self.mb, self.tb, self.tt, self.gs, self.cutoff) 197 | 198 | return out 199 | 200 | def swap_module(network, module_name, new_module): 201 | name_parts = module_name.split('.') 202 | parent = network 203 | for part in name_parts[:-1]: 204 | if part.isdigit(): 205 | parent = parent[int(part)] 206 | else: 207 | parent = getattr(parent, part) 208 | 209 | last_part = name_parts[-1] 210 | if last_part.isdigit(): 211 | parent[int(last_part)] = new_module 212 | else: 213 | setattr(parent, last_part, new_module) 214 | 215 | def swap_modules(version, in_network, checkpoint, bits, p, l1, inplace=False, verbose=False, hint=1, qzeros=True, gs=-1, simulate_gs=-1): 216 | global params 217 | params = {} 218 | 219 | if version == 'llama': 220 | preamble = "model" 221 | elif version == 'opt': 222 | preamble = "model.decoder" 223 | else: 224 | print(f'unknown version {version}') 225 | return 226 | 227 | if not inplace: 228 | network = deepcopy(in_network) 229 | else: 230 | network = in_network 231 | 232 | if not qzeros: 233 | zeros = 'zeros' 234 | else: 235 | zeros = 'qzeros' 236 | 237 | for name, module in network.named_modules(): 238 | is_linear = isinstance(module, torch.nn.Linear) 239 | 240 | if not is_linear: 241 | if verbose: 242 | print(f'module {name} not replaced') 243 | continue 244 | 245 | try: 246 | if version == 'llama': 247 | layer_type = name.split('.')[4] 248 | module_name = name.split('.')[3] 249 | index_number= name.split('.')[2] 250 | bias = None 251 | start = f"{preamble}.layers.{index_number}.{module_name}.{layer_type}" 252 | elif version == 'opt': 253 | layer_type = name.split('.')[-1] 254 | module_name = name.split('.')[-2] 255 | index_number= name.split('.')[-3] 256 | if 'fc' in layer_type: 257 | start = f"{preamble}.layers.{module_name}.{layer_type}" 258 | else: 259 | start = f"{preamble}.layers.{index_number}.{module_name}.{layer_type}" 260 | bias = checkpoint[f"{start}.bias"].float(), 261 | 262 | 263 | 264 | if simulate_gs == -1: 265 | new_module = qLinear(mode='llama', p=p, l1=l1, name=f"{start}", 266 | zeros=checkpoint[f"{start}.{zeros}"], 267 | scales = checkpoint[f"{start}.scales"].float(), 268 | bias = bias, 269 | qweights = checkpoint[f"{start}.qweight"].contiguous(), 270 | N=module.in_features, M=module.out_features, bits=bits, hint=hint,verbose=verbose,gs=gs) 271 | else: 272 | tmp_zeros = checkpoint[f"{start}.{zeros}"] 273 | tmp_scales = checkpoint[f"{start}.scales"].float() 274 | if gs != -1: 275 | tmp_zeros = tmp_zeros[:,0] 276 | tmp_scales = tmp_scales[:,0] 277 | zeros_tensor = tmp_zeros.repeat(module.in_features//simulate_gs,1) 278 | scales_tensor = tmp_scales.repeat(module.in_features//simulate_gs,1) 279 | new_module = qLinear(mode='llama', p=p, l1=l1, name=f"{start}", zeros=zeros_tensor, 280 | scales = scales_tensor, 281 | bias = bias, 282 | qweights = checkpoint[f"{start}.qweight"].contiguous(), 283 | N=module.in_features, M=module.out_features, bits=bits, hint=hint,verbose=verbose,gs=simulate_gs) 284 | 285 | swap_module(network, name, new_module) 286 | 287 | if verbose: 288 | print(f'module {name} replaced with {preamble}.layers.{index_number}.{module_name}.{layer_type}') 289 | except Exception as e: 290 | if verbose: 291 | print(e) 292 | print(f'module {name} not replaced') 293 | 294 | 295 | return network 296 | 297 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantized Inference on Generative LLMs (QIGen) 2 | 3 | Code generator for inference on Quantized Large Language Models. Quantization is done using [GPTQ](https://github.com/IST-DASLab/gptq). 4 | 5 | ## Current features 6 | 7 | * Support for LlaMA and OPT 8 | * 4,3, and 2 bit inference 9 | * x86 with AVX2 support 10 | * Support for `pyTorch` and `transformers` 11 | * Support for generic quantization group size 12 | 13 | ## TODOs 14 | 15 | * Support for ARM Neon 16 | * Support for AVX512 17 | * Including quantization error analysis in code generation 18 | 19 | ## Usage 20 | 21 | ### Installation 22 | 23 | 1. Install dependencies via `pip install -r requirements.txt` 24 | 2. Install transformers from source `pip install git+https://github.com/huggingface/transformers` 25 | 3. Install the python module `python setup.py install`. This will run a search to find the best parameters for register usage. 26 | 27 | ### Example 28 | 29 | We give an example notebook in `demo.ipynb`. The basic workflow is 30 | 31 | * load floating point model, 32 | * load quantized checkpoint from GPTQ, 33 | * call the `infergen.swap_modules_llama(model, quantized_checkpoint, bits=4, p=64, l1=l1, inplace=False)` function, where `model` is the full-size model, `quantized_checkpoint` is the quantized model, `bits` is the number of bits used for the quantization,`l1` is the size of the l1 data cache in bits, `p` is the number of cores to use, and `inplace` is a flag to swap in place or creating a copy. 34 | * Use the quantized model as a normal transformer. 35 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bfb07323-aea6-4eb9-b45a-1f2c5761f9b0", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import os\n", 13 | "\n", 14 | "import QIGen\n", 15 | "import torch\n", 16 | "import time\n", 17 | "import numpy as np" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "2a856859-71f1-4fb3-a95e-f4511b7a3051", 24 | "metadata": { 25 | "tags": [] 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "/local/home/tommaso/env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 33 | " from .autonotebook import tqdm as notebook_tqdm\n", 34 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:36<00:00, 1.12it/s]\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n", 40 | "import transformers\n", 41 | "\n", 42 | "path = \"../models/path\" # PATH TO YOUR MODELS \n", 43 | "\n", 44 | "tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)\n", 45 | "model = AutoModelForCausalLM.from_pretrained(path)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "id": "c83d1203-131f-41bc-af81-7772847e091c", 52 | "metadata": { 53 | "tags": [] 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "checkpoint = torch.load(\"../checkpoint/path.pt\",map_location=torch.device('cpu'))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "id": "032e5a4f-6599-45dc-9105-c7f8925239ce", 64 | "metadata": { 65 | "tags": [] 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "l1 = 2**18 # Level 1 cache in bits\n", 70 | "p = 64 # number of cores\n", 71 | "bits = 3 # bits used for the model\n", 72 | "gs = 128 # Group size (-1 means full column)\n", 73 | "arch = 'llama' #or opt\n", 74 | "\n", 75 | "qzeros = not (bits == 3)\n", 76 | "qmodel = QIGen.swap_modules(arch, model, checkpoint, bits=bits, p=p, gs=gs,\n", 77 | " l1=l1, inplace=True, hint=1,\n", 78 | " verbose=False, qzeros=qzeros)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "id": "4ab68b96-26c2-4346-9926-66f4a7ac3984", 85 | "metadata": { 86 | "tags": [] 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "prompt = \"My favorite vacation was \"\n", 91 | "result_length = 128\n", 92 | "inputs = tokenizer(prompt, return_tensors=\"pt\")\n", 93 | "generation_config = transformers.GenerationConfig(\n", 94 | " temperature=0.8, #0.8\n", 95 | " top_p=0.95, #0.95\n", 96 | " top_k=40, #40\n", 97 | " num_beams=1,\n", 98 | " min_new_tokens=result_length,\n", 99 | " max_new_tokens=result_length,\n", 100 | " do_sample=False,\n", 101 | " repetition_penalty=1.1, #1.1\n", 102 | ")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "id": "e935c6dd-f2b1-405c-959d-b7f0b2785f82", 109 | "metadata": { 110 | "tags": [] 111 | }, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "29.422704935073853 ⁇ My favorite vacation was 10 years ago. I went to the beach with my family and friends. We stayed in a hotel for two days. It was very hot, but we had fun. We swam in the ocean and played volleyball on the beach. We also went shopping at the mall. We bought some souvenirs and gifts for our families. We were so tired after the trip, but it was worth it.\n", 118 | "My favorite vacation was 10 years ago. I went to the beach with my family and friends. We stayed in a hotel for two days. It was very hot, but we had fun\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "torch.manual_seed(0)\n", 124 | "start = time.time()\n", 125 | "output = tokenizer.decode(qmodel.generate(inputs[\"input_ids\"],generation_config=generation_config)[0])\n", 126 | "end = time.time() - start\n", 127 | "print(end, output)" 128 | ] 129 | } 130 | ], 131 | "metadata": { 132 | "kernelspec": { 133 | "display_name": "Python 3 (ipykernel)", 134 | "language": "python", 135 | "name": "python3" 136 | }, 137 | "language_info": { 138 | "codemirror_mode": { 139 | "name": "ipython", 140 | "version": 3 141 | }, 142 | "file_extension": ".py", 143 | "mimetype": "text/x-python", 144 | "name": "python", 145 | "nbconvert_exporter": "python", 146 | "pygments_lexer": "ipython3", 147 | "version": "3.8.10" 148 | } 149 | }, 150 | "nbformat": 4, 151 | "nbformat_minor": 5 152 | } 153 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import intrin 2 | import argparse 3 | import subprocess 4 | import time 5 | import template 6 | import numpy as np 7 | from gekko import GEKKO 8 | from utils import mem_model 9 | import pandas as pd 10 | 11 | def macros(): 12 | return "#include\n#include\n#include\n\n#define mymin(a,b) ((a)<(b)?(a):(b))\n#define mymax(a,b) ((a)>(b)?(a):(b))\n" 13 | 14 | def print_parameters(bits, n, m, t, nb, mb, tb, mu, nu, tu, unroll, p, gs=-1): 15 | res = "" 16 | res += "void print_parameters(){\n" 17 | res += f" std::cout << {bits} << \"bits,\" << {n} << \",\" << {m} << \",\" << {t} << \",\" << {nb} << \",\" << {mb} << \",\" << {tb} << \",\" << {nu} << \",\" << {mu} << \",\" << {tu} << \",\" << {unroll} << \",\" << {p} << \",\" << {gs} << \",\";\n" 18 | res += "}\n" 19 | return res 20 | 21 | def print_parameters_module(bits, mu, nu, tu, unroll, p, gs=-1): 22 | res = "" 23 | res += "void print_parameters(){\n" 24 | res += "std::ofstream outfile;\n" 25 | res += "outfile.open(\"tmp.csv\", std::ios_base::app);\n" 26 | res += f"outfile << {bits} << \",\" << {nu} << \",\" << {mu} << \",\" << {tu} << \",\" << {unroll} << \",\" << {p} << \",\" << {gs} << \",\";\n" 27 | res += "}\n" 28 | return res 29 | 30 | def pack_in(n, m, nb, mb): 31 | res = "" 32 | res += "inline void pack_input(float* A, float* B){\n" 33 | res += " // copy the full matrix A in blocked format into B\n" 34 | res += " uint64_t idx = 0;\n" 35 | res += f" const int N = {n};\n" 36 | res += f" const int M = {m};\n" 37 | res += f" const int nb = {nb};\n" 38 | res += f" const int mb = {mb};\n" 39 | res += " for(int i = 0; i < N; i+=nb){ \n \ 40 | for(int j = 0; j < M; j+=mb){\n \ 41 | for(int jj = j; jj < mymin(j+mb, M); jj++){\n \ 42 | for(int ii = i; ii < mymin(i+nb, N); ii++){\n \ 43 | B[idx] = A[ii*M+jj];\n \ 44 | idx++;\n \ 45 | }\n \ 46 | }\n \ 47 | }\n \ 48 | }\n \ 49 | }\n" 50 | return res 51 | 52 | def pack_out(n, t, nb, tb): 53 | res = "" 54 | res += "inline void pack_output(float* A, float* B){\n" 55 | res += " // copy the full matrix A in blocked format into B\n" 56 | res += " uint64_t idx = 0;\n" 57 | res += f" const int N = {n};\n" 58 | res += f" const int M = {t};\n" 59 | res += f" const int nb = {nb};\n" 60 | res += f" const int mb = {tb};\n" 61 | res += " for(int i = 0; i < N; i+=nb){ \n \ 62 | for(int j = 0; j < M; j+=mb){\n \ 63 | for(int ii = i; ii < mymin(i+nb, N); ii++){\n \ 64 | for(int jj = j; jj < mymin(j+mb, M); jj++){\n \ 65 | B[idx] = A[ii*M+jj];\n \ 66 | idx++;\n \ 67 | }\n \ 68 | }\n \ 69 | }\n \ 70 | }\n \ 71 | }\n" 72 | return res 73 | 74 | def pack_qw(m, t, mb, tb, tb1, bits=4, cutoff=-1): 75 | packed = 32 // bits 76 | res = "" 77 | if cutoff == -1: 78 | cutoff = 65 79 | if bits == 3: 80 | res += "inline void pack_qw_inner(int* A, int* B, int cutoff){\n" 81 | res += " // copy the full matrix A in blocked format into B\n" 82 | res += " uint64_t idx = 0;\n" 83 | res += f" const int N = {m // 32 * 3};\n" 84 | res += f" const int M = {t};\n" 85 | res += f" const int nb = {mb // 32 * 3};\n" 86 | res += f"int mb = {int(tb)};\n" 87 | res += " for(int j = 0, tid = 0; j < M; j+=mb, tid++){\n" 88 | # res += "if(tid==cutoff){\n " 89 | # res += f" mb = {tb1};\n" 90 | # res += "}\n" 91 | res += " for(int i = 0; i < N; i+=nb){\n \ 92 | for(int ii = i; ii < mymin(i+nb, N); ii+=3){\n \ 93 | for(int jj = j; jj < mymin(j+mb, M); jj+=8){\n \ 94 | for(int iii = ii; iii < ii + 3; iii++){\n \ 95 | for(int jjj = jj; jjj < jj + 8; jjj++){\n \ 96 | B[idx] = A[iii*M+jjj];\n \ 97 | idx++;\n \ 98 | }\n \ 99 | }\n \ 100 | }\n \ 101 | }\n \ 102 | }\n \ 103 | }\n \ 104 | }\n" 105 | res += "inline void pack_qw(int* A, int* B){\n" 106 | res += f" pack_qw_inner(A, B, {cutoff});\n" 107 | res += "}\n" 108 | return res 109 | else: 110 | # in case i do this for python i can just add the n,m,nb,mb as function parameters 111 | res += "inline void pack_qw_inner(int* A, int* B, int cutoff){\n" 112 | res += " // copy the full matrix A in blocked format into B\n" 113 | res += " uint64_t idx = 0;\n" 114 | res += f" const int N = {m // packed};\n" 115 | res += f" const int M = {t};\n" 116 | res += f" const int nb = {mb // packed};\n" 117 | res += f"int mb = {int(tb)};\n" 118 | res += " for(int j = 0, tid = 0; j < M; j+=mb, tid++){\n" 119 | # res += "if(tid==cutoff){\n " 120 | # res += f" mb = {tb1};\n" 121 | # res += "}\n" 122 | res += " for(int i = 0; i < N; i+=nb){\n \ 123 | for(int ii = i; ii < mymin(i+nb, N); ii++){\n \ 124 | for(int jj = j; jj < mymin(j+mb, M); jj++){\n \ 125 | B[idx] = A[ii*M+jj];\n \ 126 | idx++;\n \ 127 | }\n \ 128 | }\n \ 129 | }\n" 130 | res += "}\n" 131 | res += "}\n" 132 | res += "inline void pack_qw(int* A, int* B){\n" 133 | res += f" pack_qw_inner(A, B, {cutoff});\n" 134 | res += "}\n" 135 | return res 136 | 137 | def block_gs(nu_iter, mu, tu, rho, packed, unroll, bits): 138 | res = "" 139 | i = 0 140 | # unroll = 4 # number of bcasts and unpacks 141 | if bits == 3: 142 | for j in range(0,tu,8): 143 | res += f"__m256i w0_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k3*tb/{packed}*3 + jw+{j*3}]);\n" 144 | res += f"__m256i w1_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k3*tb/{packed}*3 + jw+{j*3}+8]);\n" 145 | res += f"__m256i w2_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k3*tb/{packed}*3 + jw+{j*3}+16]);\n" 146 | 147 | u = 0 148 | first_off = 3 149 | second_off = 2 150 | wid = 0 151 | shift = 0 152 | while u < 32: 153 | 154 | if u == 10: 155 | 156 | res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{u})*nb + i1+{i}]);\n" 157 | 158 | for j in range(0,tu,8): 159 | res += f"__m256i ws{j}_10 = _mm256_srli_epi32(w0_{j}, {bits*10});\n" 160 | res += f"__m256i temp0_{j} = _mm256_slli_epi32(w1_{j}, 2);\n" 161 | res += f"temp0_{j} = _mm256_and_si256(temp0_{j}, mask);\n" 162 | res += f"ws{j}_10 = _mm256_or_si256(ws{j}_10, temp0_{j});\n" 163 | 164 | res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n" 165 | 166 | res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n" 167 | 168 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n" 169 | 170 | wid = wid + 1 171 | u = u + 1 172 | 173 | elif u == 21: 174 | res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{u})*nb + i1+{i}]);\n" 175 | 176 | for j in range(0,tu,8): 177 | res += f"__m256i ws{j}_{u} = _mm256_srli_epi32(w1_{j}, 31);\n" 178 | res += f"__m256i temp1_{j} = _mm256_slli_epi32(w2_{j}, 1);\n" 179 | res += f"temp1_{j} = _mm256_and_si256(temp1_{j}, mask);\n" 180 | res += f"ws{j}_{u} = _mm256_or_si256(ws{j}_{u}, temp1_{j});\n" 181 | 182 | res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n" 183 | 184 | res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n" 185 | 186 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n" 187 | 188 | wid = wid + 1 189 | u = u + 1 190 | 191 | for k in range(u,u + second_off): 192 | res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{k})*nb + i1+{i}]);\n" 193 | 194 | for k in range(u,u + second_off): 195 | for j in range(0,tu,8): 196 | res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{wid}_{j}, {bits*k-wid*32-shift});\n" 197 | 198 | for j in range(0,tu,8): 199 | res += f"__m256i wsa{j}_{k} = _mm256_and_si256(ws{j}_{k}, mask);\n" 200 | 201 | for j in range(0,tu,8): 202 | res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n" 203 | 204 | for j in range(0,tu,8): 205 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n" 206 | 207 | u = u + 2 208 | 209 | 210 | return res 211 | 212 | else: 213 | for j in range(0,tu,8): 214 | res += f"__m256i w{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed} + k*mb*tb/{packed} + k3*tb/{packed} + j1+{j}]);\n" 215 | 216 | for u in range(packed-unroll, -1, -unroll): 217 | for k in range(u+unroll-1,u-1,-1): 218 | res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{k})*nb + i1+{i}]);\n" 219 | 220 | for k in range(u,u+unroll): 221 | for j in range(0,tu,8): 222 | res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{j}, {bits*k});\n" 223 | 224 | for j in range(0,tu,8): 225 | res += f"__m256i wsa{j}_{k}= _mm256_and_si256(ws{j}_{k}, mask);\n" 226 | 227 | for j in range(0,tu,8): 228 | res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n" 229 | 230 | for j in range(0,tu,8): 231 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n" 232 | 233 | return res 234 | 235 | def block(nu_iter, mu, tu, rho, packed, unroll, bits): 236 | res = "" 237 | i = 0 238 | # unroll = 4 # number of bcasts and unpacks 239 | if bits == 3: 240 | for j in range(0,tu,8): 241 | res += f"__m256i w0_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k2*tb/{packed}*3 + jw+{j*3}]);\n" 242 | res += f"__m256i w1_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k2*tb/{packed}*3 + jw+{j*3}+8]);\n" 243 | res += f"__m256i w2_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k2*tb/{packed}*3 + jw+{j*3}+16]);\n" 244 | 245 | u = 0 246 | first_off = 3 247 | second_off = 2 248 | wid = 0 249 | shift = 0 250 | while u < 32: 251 | 252 | if u == 10: 253 | 254 | res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{u})*nb + i1+{i}]);\n" 255 | 256 | for j in range(0,tu,8): 257 | res += f"__m256i ws{j}_10 = _mm256_srli_epi32(w0_{j}, {bits*10});\n" 258 | res += f"__m256i temp0_{j} = _mm256_slli_epi32(w1_{j}, 2);\n" 259 | res += f"temp0_{j} = _mm256_and_si256(temp0_{j}, mask);\n" 260 | res += f"ws{j}_10 = _mm256_or_si256(ws{j}_10, temp0_{j});\n" 261 | 262 | res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n" 263 | 264 | res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n" 265 | 266 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n" 267 | 268 | wid = wid + 1 269 | u = u + 1 270 | 271 | elif u == 21: 272 | res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{u})*nb + i1+{i}]);\n" 273 | 274 | for j in range(0,tu,8): 275 | res += f"__m256i ws{j}_{u} = _mm256_srli_epi32(w1_{j}, 31);\n" 276 | res += f"__m256i temp1_{j} = _mm256_slli_epi32(w2_{j}, 1);\n" 277 | res += f"temp1_{j} = _mm256_and_si256(temp1_{j}, mask);\n" 278 | res += f"ws{j}_{u} = _mm256_or_si256(ws{j}_{u}, temp1_{j});\n" 279 | 280 | res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n" 281 | 282 | res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n" 283 | 284 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n" 285 | 286 | wid = wid + 1 287 | u = u + 1 288 | 289 | for k in range(u,u + second_off): 290 | res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{k})*nb + i1+{i}]);\n" 291 | 292 | for k in range(u,u + second_off): 293 | for j in range(0,tu,8): 294 | res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{wid}_{j}, {bits*k-wid*32-shift});\n" 295 | 296 | for j in range(0,tu,8): 297 | res += f"__m256i wsa{j}_{k} = _mm256_and_si256(ws{j}_{k}, mask);\n" 298 | 299 | for j in range(0,tu,8): 300 | res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n" 301 | 302 | for j in range(0,tu,8): 303 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n" 304 | 305 | u = u + 2 306 | 307 | 308 | return res 309 | 310 | else: 311 | for j in range(0,tu,8): 312 | res += f"__m256i w{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed} + k*mb*tb/{packed} + k2*tb/{packed} + j1+{j}]);\n" 313 | 314 | for u in range(packed-unroll, -1, -unroll): 315 | for k in range(u+unroll-1,u-1,-1): 316 | res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{k})*nb + i1+{i}]);\n" 317 | 318 | for k in range(u,u+unroll): 319 | for j in range(0,tu,8): 320 | res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{j}, {bits*k});\n" 321 | 322 | for j in range(0,tu,8): 323 | res += f"__m256i wsa{j}_{k}= _mm256_and_si256(ws{j}_{k}, mask);\n" 324 | 325 | for j in range(0,tu,8): 326 | res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n" 327 | 328 | for j in range(0,tu,8): 329 | res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n" 330 | 331 | return res 332 | 333 | 334 | def accumulators_f(nu, tu, gs=False): 335 | accumulators = "" 336 | for i in range(nu): 337 | for j in range(0,tu,8): 338 | if gs: 339 | accumulators += f"__m256 acc{i}_{j} = _mm256_setzero_ps();\n" 340 | else: 341 | accumulators += f"__m256 acc{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n" 342 | return accumulators 343 | 344 | def stores_f(nu, tu, gs=False): 345 | store = "" 346 | if gs: 347 | for i in range(nu): 348 | for j in range(0,tu,8): 349 | store += f"__m256 o{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n" 350 | 351 | for i in range(nu): 352 | for j in range(0,tu,8): 353 | store += f"__m256 s{i}_{j} = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+{j}]);\n" 354 | 355 | for i in range(nu): 356 | for j in range(0,tu,8): 357 | store += f"__m256 f{i}_{j} = _mm256_fmadd_ps(acc{i}_{j}, s{i}_{j}, o{i}_{j});\n" 358 | 359 | for i in range(nu): 360 | for j in range(0,tu,8): 361 | store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], f{i}_{j});\n" 362 | else: 363 | for i in range(nu): 364 | for j in range(0,tu,8): 365 | store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], acc{i}_{j});\n" 366 | return store 367 | 368 | def qforward(nu, mu, tu, p, unroll, bits, n=0, m=0, t =0, nb=0, mb=0, tb=0, tt=0, cutoff=-1, gs=False, gs_val=-1, module=True): 369 | assert(module or (gs and gs_val != -1) or (not gs and gs_val == -1)) 370 | if cutoff == -1: 371 | cutoff = p+1 372 | # packed = 32 // bits 373 | if bits == 3: 374 | packed = 32 375 | loopguard = packed 376 | else: 377 | packed = 32 // bits 378 | loopguard = packed 379 | #compute the parameters from the model 380 | 381 | accumulators = accumulators_f(nu, tu, gs) 382 | store = stores_f(nu, tu, gs) 383 | 384 | ugemm = "" 385 | if gs: 386 | ugemm += "int j1 = 0;\n" 387 | if bits == 3: 388 | ugemm += "int jw = 0;\n" 389 | ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})" 390 | ugemm += "{\n" 391 | else: 392 | ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n" 393 | ugemm += "for(int k1 = 0; k1 < mb; k1+=gs) {\n" 394 | ugemm += accumulators 395 | ugemm += f"for(int k2 = k1; k2 < k1+gs; k2+={loopguard})\n" 396 | ugemm += "{\n" 397 | ugemm += block(nu, mu, tu, 16, packed, unroll, bits) 398 | ugemm += "}\n" 399 | ugemm += store 400 | ugemm += "}\n" 401 | ugemm += "}\n" 402 | else: 403 | ugemm += "int j1 = 0;\n" 404 | if bits == 3: 405 | ugemm += "int jw = 0;\n" 406 | ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})" 407 | ugemm += "{\n" 408 | else: 409 | ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n" 410 | ugemm += accumulators 411 | ugemm += "for(int k1 = 0; k1 < mb; k1+=mu) {\n" 412 | ugemm += f"for(int k2 = k1; k2 < k1+mu; k2+={loopguard})" 413 | ugemm += "{\n" 414 | ugemm += block(nu, mu, tu, 16, packed, unroll, bits) 415 | ugemm += "}\n" 416 | ugemm += "}\n" 417 | ugemm += store 418 | ugemm += "}\n" 419 | 420 | 421 | res = "" 422 | res += "inline\n" 423 | if gs: 424 | res += f"void q{bits}gemm_gs(const float* __restrict__ input, \n" 425 | else: 426 | res += f"void q{bits}gemm(const float* __restrict__ input, \n" 427 | res += "const int* __restrict__ W, \n" 428 | res += "const float* __restrict__ scales, \n" 429 | res += "const float* __restrict__ zeros, \n" 430 | res +="const float* __restrict__ bias, \n " 431 | res +="const float* __restrict__ sums, \n " 432 | res +="float* __restrict__ output,\n\ 433 | const int n,\n\ 434 | const int m,\n\ 435 | const int t,\n\ 436 | const int nb,\n\ 437 | const int mb,\n\ 438 | const int tb,\n\ 439 | int ogtt,\n" 440 | if gs: 441 | res += "const int gs,\n" 442 | res += "const int cutoff){\n" 443 | 444 | res += f"#pragma omp parallel num_threads({p})\n" 445 | res += "{\n" 446 | res += "int tid;\n" 447 | res += f"const int mu = {mu};\n" 448 | res += f"const int nu = {nu};\n" 449 | res += f"const int tu = {tu};\n" 450 | res += f"const int on = n / nb;\n" 451 | res += f"const int om = m / mb;\n" 452 | 453 | mask = (2**bits)-1 454 | res += f"const __m256i mask = _mm256_set1_epi32({mask});\n" 455 | if bits == 3: 456 | res += f"const __m256i mask4 = _mm256_set1_epi32(4);\n" 457 | res += f"const __m256i mask6 = _mm256_set1_epi32(6);\n" 458 | res += "tid = omp_get_thread_num();\n" 459 | 460 | res += "int tt = ogtt;\n" 461 | res += "if(tid >= cutoff){\n" 462 | res += f"tt -= tb;\n" 463 | res += "}\n" 464 | res += f"const int base_output = tid >= cutoff ?\n \ 465 | (tid-cutoff)*tt + (tt+tb)*cutoff: \n \ 466 | tid*tt;\n" #is this >= cutoff or > cutoff? 467 | if bits != 3: 468 | res += f"const int base_W = tid >= cutoff ?\n \ 469 | ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}: \n \ 470 | tid*tt*m/{packed};\n" 471 | else: 472 | res += f"const int base_W = tid >= cutoff ?\n \ 473 | ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}*3: \n \ 474 | tid*tt*m/{packed}*3;\n" 475 | 476 | res += "for(int j = 0; j < tt; j+=tb){\n" 477 | res += "for(int i = 0; i < on; i++) {\n" 478 | res += "for(int k = 0; k < om; k++) {\n" 479 | res += "for(int i1 = 0; i1 < nb; i1+=nu) {\n" 480 | res += ugemm 481 | res += "}\n" 482 | res += "}\n" 483 | res += "}\n" 484 | res += "}\n" 485 | res += "#pragma omp barrier\n" 486 | # res += "#pragma omp for\n" 487 | if gs: 488 | res += "const int ngs = m/gs;\n" 489 | res += "for (int i = 0; i < n; i++) {\n" 490 | res += f"for (int j = 0; j < tt; j+={tu})" 491 | res += "{\n" 492 | for i in range(0,tu,8): 493 | res += f"__m256 acc{i} = _mm256_setzero_ps();\n" 494 | res += "for (int i1 = 0; i1 < ngs; i1++){\n" 495 | res += "__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);\n" 496 | for i in range(0,tu,8): 497 | res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + i1* t + j + {i}]);\n" 498 | # if not module: 499 | if bits != 3 or not module: 500 | for i in range(0,tu,8): 501 | res += f"__m256 s{i} = _mm256_loadu_ps(&scales[base_output + i1 * t + j + {i}]);\n" 502 | for i in range(0,tu,8): 503 | res += f"__m256 zs{i} = _mm256_mul_ps(z{i}, s{i});\n" 504 | for i in range(0,tu,8): 505 | # if module: 506 | if bits == 3 and module: 507 | res += f"acc{i} = _mm256_fmadd_ps(z{i}, r, acc{i});\n" 508 | else: 509 | res += f"acc{i} = _mm256_fmadd_ps(zs{i}, r, acc{i});\n" 510 | res += "}\n" 511 | for i in range(0,tu,8): 512 | res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n" 513 | for i in range(0,tu,8): 514 | res += f"__m256 b{i} = _mm256_loadu_ps(&bias[base_output + j + {i}]);\n" 515 | for i in range(0,tu,8): 516 | if module: 517 | res += f"__m256 o1{i} = _mm256_sub_ps(o{i}, acc{i});\n" 518 | else: 519 | res += f"__m256 o1{i} = _mm256_add_ps(o{i}, acc{i});\n" 520 | for i in range(0,tu,8): 521 | res += f"__m256 o2{i} = _mm256_add_ps(o1{i}, b{i});\n" 522 | for i in range(0,tu,8): 523 | res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o2{i});\n" 524 | res += "}\n" 525 | res += "}\n" 526 | res += "}\n" 527 | res += "}\n" 528 | else: 529 | res += "for (int i = 0; i < n; i++) {\n" 530 | res += "__m256 r = _mm256_set1_ps(sums[i]);\n" 531 | res += f"for (int j = 0; j < tt; j+={tu})" 532 | res += "{\n" 533 | for i in range(0,tu,8): 534 | res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n" 535 | for i in range(0,tu,8): 536 | res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + j + {i}]);\n" 537 | for i in range(0,tu,8): 538 | res += f"__m256 b{i} = _mm256_loadu_ps(&bias[base_output + j + {i}]);\n" 539 | for i in range(0,tu,8): 540 | res += f"__m256 s{i} = _mm256_loadu_ps(&scales[base_output + j + {i}]);\n" 541 | if bits == 3 and module: 542 | for i in range(0,tu,8): 543 | res += f"__m256 os{i} = _mm256_mul_ps(o{i}, s{i});\n" 544 | for i in range(0,tu,8): 545 | if module: 546 | if bits == 3: 547 | res += f"__m256 zr{i} = _mm256_fnmadd_ps(z{i}, r, os{i});\n" 548 | else: 549 | res += f"__m256 zr{i} = _mm256_fnmadd_ps(z{i}, r, o{i});\n" 550 | else: 551 | res += f"__m256 zr{i} = _mm256_fmadd_ps(z{i}, r, o{i});\n" 552 | for i in range(0,tu,8): 553 | # j res += f"__m256 o2{i} = _mm256_mul_ps(zr{i}, s{i});\n" 554 | if bits == 3 and module: 555 | res += f"__m256 o2{i} = _mm256_add_ps(zr{i}, b{i});\n" 556 | else: 557 | res += f"__m256 o2{i} = _mm256_fmadd_ps(zr{i}, s{i}, b{i});\n" 558 | for i in range(0,tu,8): 559 | res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o2{i});\n" 560 | res += "}\n" 561 | res += "}\n" 562 | res += "}\n" 563 | res += "}\n" 564 | 565 | # wrapper for qgemm if we call from cpp 566 | if module: 567 | if gs: 568 | res += f"inline void forward{bits}_gs_cpu(\n" 569 | else: 570 | res += f"inline void forward{bits}_cpu(\n" 571 | res += "torch::Tensor in, torch::Tensor weight, torch::Tensor out,\n" 572 | res += "torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,\n" 573 | if gs: 574 | res += "int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){\n" 575 | else: 576 | res += "int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){\n" 577 | res += "int* W = weight.data_ptr();\n" 578 | res += "float* input = in.data_ptr();\n" 579 | res += "float* b = bias.data_ptr();\n" 580 | res += "float* s = scales.data_ptr();\n" 581 | res += "float* z = zeros.data_ptr();\n" 582 | res += "float* r = sums.data_ptr();\n" 583 | res += "float* O = out.data_ptr();\n" 584 | res += "\n" 585 | if gs: 586 | res += f"q{bits}gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff);\n" 587 | else: 588 | res += f"q{bits}gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff);\n" 589 | res += "}\n" 590 | else: 591 | res += "inline void qforward(const float* __restrict__ input, \n \ 592 | const int* __restrict__ W, \n\ 593 | const float* __restrict__ scales, \n\ 594 | const float* __restrict__ zeros, \n\ 595 | const float* __restrict__ bias, \n\ 596 | const float* __restrict__ sums, \n\ 597 | float* __restrict__ output, \n\ 598 | int n, \n \ 599 | int m, \n \ 600 | int t) {\n" 601 | if gs: 602 | res += f"q{bits}gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, {nb}, {mb}, {tb}, {tt}, {gs_val}, {cutoff});\n" 603 | else: 604 | res += f"q{bits}gemm(input, W, scales, zeros, bias, sums, output, n, m, t, {nb}, {mb}, {tb}, {tt}, {cutoff});\n" 605 | res += "}\n" 606 | return res 607 | 608 | 609 | def gen_model(n, m, t, bits, p, gs): 610 | 611 | # get parameters 612 | if bits == 3: 613 | packed = 32 614 | unroll = 3 615 | nu = 1 #args.n 616 | mu = 32 617 | tu = 32 618 | else: 619 | packed = 32 // bits 620 | unroll = 2 621 | nu = 1 #args.n 622 | mu = 16 623 | tu = 32 624 | 625 | #compute the parameters from the model 626 | 627 | nb = n # it's always small for transformers 628 | 629 | 630 | mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, gs) 631 | 632 | split = np.ones(p) 633 | split = split * tb 634 | while np.sum(split) < t: 635 | split = split + tb 636 | 637 | idx = p - 1 638 | while np.sum(split) > t: 639 | split[idx] = split[idx] - tb 640 | idx = idx - 1 641 | 642 | assert(np.sum(split) == t) 643 | 644 | split = split.astype(int) 645 | tt = int(split[0]) 646 | 647 | if split[0] == split[-1]: 648 | cutoff = int(p+1) 649 | else: 650 | cutoff = int(idx + 1) 651 | 652 | 653 | if gs == -1: 654 | code = qforward(nu, mu, tu, p, unroll, n=n, m=m, t=t, nb=nb, mb=mb, tb=tb, tt=tt, bits=bits, cutoff=cutoff, module=False) 655 | else: 656 | code = qforward(nu, mu, tu, p, unroll, n=n, m=m, t=t, nb=nb, mb=mb, tb=tb, tt=tt, bits=bits, cutoff=cutoff, gs=True, gs_val=gs, module=False) 657 | code += pack_in(n, m, nb, mb) 658 | # code += pack_qw(m, t, mb, tb, tb, bits=bits)#, cutoff=cutoff) 659 | code += pack_qw(m, t, mb, tb, tu,bits=bits) 660 | code += pack_out(n, t, nb, tb) 661 | code += print_parameters(bits, n, m, t, nb, mb, tb, mu, nu, tu, unroll, p) 662 | 663 | 664 | with open("forward.h", "w") as f: 665 | f.write(macros()) 666 | f.write(code) 667 | 668 | def gen_and_compile(n, m, t, nb, mb, tb, nu, mu, tu, p, unroll, bits=4, gs=-1, module=False): 669 | 670 | split = np.ones(p) 671 | split = split * tb 672 | while np.sum(split) < t: 673 | split = split + tb 674 | 675 | idx = p - 1 676 | while np.sum(split) > t: 677 | split[idx] = split[idx] - tb 678 | idx = idx - 1 679 | 680 | assert(np.sum(split) == t) 681 | 682 | split = split.astype(int) 683 | tt = int(split[0]) 684 | 685 | if split[0] == split[-1]: 686 | cutoff = int(p+1) 687 | else: 688 | cutoff = int(idx + 1) 689 | 690 | if gs == -1: 691 | code = qforward(nu, mu, tu, p, unroll, n=n, m=m, t=t, nb=nb, mb=mb, tb=tb, tt=tt, bits=bits, cutoff=cutoff, module=False) 692 | else: 693 | code = qforward(nu, mu, tu, p, unroll, n=n, m=m, t=t, nb=nb, mb=mb, tb=tb, tt=tt, bits=bits, cutoff=cutoff, gs=True, gs_val=gs, module=False) 694 | code += pack_in(n, m, nb, mb) 695 | code += pack_qw(m, t, mb, tb, tu,bits=bits) 696 | code += pack_out(n, t, nb, tb) 697 | if module: 698 | code += print_parameters_module(bits, mu, nu, tu, unroll, p, gs=gs) 699 | else: 700 | code += print_parameters(bits, n, m, t, nb, mb, tb, mu, nu, tu, unroll, p, gs=gs) 701 | 702 | # write the code to a file called forward.h 703 | with open("forward.h", "w") as f: 704 | f.write(macros()) 705 | f.write(code) 706 | 707 | 708 | # g++ mmm_test.cpp -O3 -ftree-vectorize -mfma -mavx -mavx2 -fno-signaling-nans -fno-trapping-math -fopenmp -o mmm_test 709 | start = time.time() 710 | if not module: 711 | subprocess.call(["g++", "-O3", "-o", "mmm_test", "mmm_test.cpp", "-mavx", "-mfma", "-mavx2", "-ftree-vectorize", "-fno-signaling-nans", "-fno-trapping-math", "-march=native", "-fopenmp"]) 712 | subprocess.call(["./mmm_test", f"{n}", f"{m}", f"{t}", f"{bits}", f"{gs}"]) 713 | else: 714 | subprocess.call(["g++", "-O3", "-o", "mmm", "mmm.cpp", "-mavx", "-mfma", "-mavx2", "-ftree-vectorize", "-fno-signaling-nans", "-fno-trapping-math", "-march=native", "-fopenmp"]) 715 | subprocess.call(["./mmm", f"{n}", f"{m}", f"{t}", f"{bits}", f"{gs}"]) 716 | # subprocess.call(["./mmm", f"{n}", f"{m}", f"{t}", f"{bits}", f"{gs}", ">>", "tmp.csv"]) 717 | end = time.time() - start 718 | return end 719 | 720 | def grid(): 721 | tt = 64 722 | for p in [32]: 723 | # for n in [1, 10]: 724 | for n in [1]: 725 | for m in [4096]: 726 | for t in [4096]: 727 | # for mb in range(1,m): 728 | # for mb in range(32,512,32): 729 | # for mb in [64, 128, 256, 512, 1024, 2048]: 730 | for mb in [512, 1024, 2048]: 731 | if m % mb == 0: 732 | # for tb in range(8,t,8): 733 | # for tb in range(32,512,32): 734 | # for tb in [16, 32, 64]:#, 128, 192, 256]: 735 | # for tb in [32]:#, 128, 192, 256]: 736 | for tb in [128, 256]: 737 | if t % tb == 0: 738 | # for mu in range(32,mb,32): 739 | for mu in [16, 32]: 740 | if mb % mu == 0: 741 | # for tu in range(8,tb,8): 742 | # for tu in [16, 32]: 743 | for tu in [16, 32, 64, 128]: 744 | if tb % tu == 0: 745 | for gs in [-1, 128, 64, 32, 16]: 746 | # for bits in [2, 3, 4]: 747 | for bits in [4, 3, 2]: 748 | if bits == 3: 749 | for u in [5]: 750 | gen_and_compile(n,m,t,n,mb,tb,1,mu,tu,p,u,bits=bits, gs=gs) 751 | else: 752 | for u in [1, 2, 4, 8]: 753 | gen_and_compile(n,m,t,n,mb,tb,1,mu,tu,p,u,bits=bits, gs=gs) 754 | 755 | 756 | def forward_module_gs(nu, mu, tu, p, unroll, bits): 757 | # packed = 32 // bits 758 | if bits == 3: 759 | packed = 32 760 | loopguard = packed 761 | else: 762 | packed = 32 // bits 763 | loopguard = packed 764 | #compute the parameters from the model 765 | 766 | accumulators = "" 767 | for i in range(nu): 768 | for j in range(0,tu,8): 769 | accumulators += f"__m256 acc{i}_{j} = _mm256_setzero_ps();\n" 770 | 771 | store = "" 772 | for i in range(nu): 773 | for j in range(0,tu,8): 774 | store += f"__m256 o{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n" 775 | 776 | for i in range(nu): 777 | for j in range(0,tu,8): 778 | store += f"__m256 s{i}_{j} = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+{j}]);\n" 779 | 780 | for i in range(nu): 781 | for j in range(0,tu,8): 782 | store += f"__m256 f{i}_{j} = _mm256_fmadd_ps(acc{i}_{j}, s{i}_{j}, o{i}_{j});\n" 783 | 784 | for i in range(nu): 785 | for j in range(0,tu,8): 786 | store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], f{i}_{j});\n" 787 | 788 | ugemm = "" 789 | if bits == 3: 790 | ugemm += "int j1 = 0;\n" 791 | ugemm += "int jw = 0;\n" 792 | ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})" 793 | ugemm += "{\n" 794 | else: 795 | ugemm += "int j1 = 0;\n" 796 | ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n" 797 | ugemm += "for(int k1 = 0; k1 < mb; k1+=gs) {\n" 798 | ugemm += accumulators 799 | ugemm += f"for(int k2 = k1; k2 < k1+gs; k2+={loopguard})\n" 800 | ugemm += "{\n" 801 | ugemm += block(nu, mu, tu, 16, packed, unroll, bits) 802 | ugemm += "}\n" 803 | ugemm += store 804 | ugemm += "}\n" 805 | ugemm += "}\n" 806 | 807 | 808 | res = "" 809 | res += "inline\n" 810 | res += f"void q{bits}gemm_gs(const float* __restrict__ input, \n" 811 | res += " const int* __restrict__ W, \n \ 812 | const float* __restrict__ scales, \n" 813 | res += "const float* __restrict__ zeros, \n" 814 | res +=" const float* __restrict__ bias, \n " 815 | res +=" const float* __restrict__ sums,\n" 816 | res +=" float* __restrict__ output,\n \ 817 | const int n,\n \ 818 | const int m,\n \ 819 | const int t,\n \ 820 | const int nb,\n \ 821 | const int mb,\n \ 822 | const int tb,\n \ 823 | int ogtt,\n \ 824 | const int gs,\n\ 825 | const int cutoff){\n" 826 | 827 | res += f"#pragma omp parallel num_threads({p})\n" 828 | res += "{\n" 829 | res += " int tid;\n" 830 | res += f" const int mu = {mu};\n" 831 | res += f" const int nu = {nu};\n" 832 | res += f" const int tu = {tu};\n" 833 | res += f" const int on = n / nb;\n" 834 | res += f" const int om = m / mb;\n" 835 | 836 | mask = (2**bits)-1 837 | res += f"const __m256i mask = _mm256_set1_epi32({mask});\n" 838 | if bits == 3: 839 | res += f"const __m256i mask4 = _mm256_set1_epi32(4);\n" 840 | res += f"const __m256i mask6 = _mm256_set1_epi32(6);\n" 841 | res += "tid = omp_get_thread_num();\n" 842 | 843 | res += "int tt = ogtt;\n" 844 | res += "if(tid >= cutoff){\n" 845 | res += f"tt -= tb;\n" 846 | res += "}\n" 847 | res += f"const int base_output = tid >= cutoff ?\n \ 848 | (tid-cutoff)*tt + (tt+tb)*cutoff: \n \ 849 | tid*tt;\n" #is this >= cutoff or > cutoff? 850 | if bits != 3: 851 | res += f"const int base_W = tid >= cutoff ?\n \ 852 | ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}: \n \ 853 | tid*tt*m/{packed};\n" 854 | else: 855 | res += f"const int base_W = tid >= cutoff ?\n \ 856 | ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}*3: \n \ 857 | tid*tt*m/{packed}*3;\n" 858 | 859 | res += "for(int j = 0; j < tt; j+=tb){\n" 860 | res += "for(int i = 0; i < on; i++) {\n" 861 | res += "for(int k = 0; k < om; k++) {\n" 862 | res += "for(int i1 = 0; i1 < nb; i1+=nu) {\n" 863 | res += ugemm 864 | res += "}\n" 865 | res += "}\n" 866 | res += "}\n" 867 | res += "}\n" 868 | res += "const int ngs = m/gs;\n" 869 | res += "#pragma omp barrier\n" 870 | # res += "#pragma omp for collapse(2)\n" 871 | res += "for (int i = 0; i < n; i++) {\n" 872 | # res += f" for (int j = 0; j < t; j+={tu})" 873 | res += f"for (int j = 0; j < tt; j+={tu})" 874 | res += "{\n" 875 | # for i in range(0,tu,8): 876 | # res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n" 877 | for i in range(0,tu,8): 878 | res += f"__m256 acc{i} = _mm256_setzero_ps();\n" 879 | res += "for (int i1 = 0; i1 < ngs; i1++){\n" 880 | res += "__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);\n" 881 | for i in range(0,tu,8): 882 | # res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[i1 * t + j + {i}]);\n" 883 | res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + i1* t + j + {i}]);\n" 884 | # for i in range(0,tu,8): 885 | # res += f"__m256 s{i} = _mm256_loadu_ps(&scales[i1 * t + j + {i}]);\n" 886 | # for i in range(0,tu,8): 887 | # res += f"__m256 zr{i} = _mm256_mul_ps(z{i}, r);\n" 888 | # for i in range(0,tu,8): 889 | # res += f"acc{i} = _mm256_fmadd_ps(zr{i}, s{i}, acc{i});\n" 890 | for i in range(0,tu,8): 891 | res += f"acc{i} = _mm256_fmadd_ps(z{i}, r, acc{i});\n" 892 | # for i in range(0,tu,8): 893 | # res += f"__m256 zr{i} = _mm256_mul_ps(z{i}, r);\n" 894 | # for i in range(0,tu,8): 895 | # res += f"o{i} = _mm256_fnmadd_ps(zr{i}, s{i}, o{i});\n" 896 | res += "}\n" 897 | for i in range(0,tu,8): 898 | # res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n" 899 | res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n" 900 | for i in range(0,tu,8): 901 | res += f"__m256 o1{i} = _mm256_sub_ps(o{i}, acc{i});\n" 902 | for i in range(0,tu,8): 903 | # res += f"_mm256_storeu_ps(&output[i*t + j + {i}], o1{i});\n" 904 | res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o1{i});\n" 905 | res += " }\n" 906 | res += "}\n" 907 | res += "}\n" 908 | res += "}\n" 909 | 910 | 911 | # wrapper for qgemm if we call from cpp 912 | res += f"inline void forward{bits}_gs_cpu(\n" 913 | res += "torch::Tensor in, torch::Tensor weight, torch::Tensor out,\n" 914 | res += "torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,\n" 915 | res += "int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){\n" 916 | res += "int* W = weight.data_ptr();\n" 917 | res += "float* input = in.data_ptr();\n" 918 | res += "float* b = bias.data_ptr();\n" 919 | res += "float* s = scales.data_ptr();\n" 920 | # res += "int* z = zeros.data_ptr();\n" 921 | res += "float* z = zeros.data_ptr();\n" 922 | res += "float* r = sums.data_ptr();\n" 923 | res += "float* O = out.data_ptr();\n" 924 | res += "\n" 925 | res += f"q{bits}gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff);\n" 926 | res += "}\n" 927 | return res 928 | 929 | def forward_module(nu, mu, tu, p, unroll, bits): 930 | # packed = 32 // bits 931 | if bits == 3: 932 | packed = 32 933 | loopguard = packed 934 | else: 935 | packed = 32 // bits 936 | loopguard = packed 937 | #compute the parameters from the model 938 | 939 | accumulators = "" 940 | for i in range(nu): 941 | for j in range(0,tu,8): 942 | accumulators += f"__m256 acc{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n" 943 | 944 | 945 | store = "" 946 | for i in range(nu): 947 | for j in range(0,tu,8): 948 | store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], acc{i}_{j});\n" 949 | 950 | ugemm = "" 951 | if bits == 3: 952 | ugemm += "int jw = 0;\n" 953 | ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})" 954 | ugemm += "{\n" 955 | else: 956 | ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n" 957 | ugemm += accumulators 958 | ugemm += "for(int k1 = 0; k1 < mb; k1+=mu) {\n" 959 | ugemm += f"for(int k2 = k1; k2 < k1+mu; k2+={loopguard})" 960 | ugemm += "{\n" 961 | ugemm += block(nu, mu, tu, 16, packed, unroll, bits) 962 | ugemm += "}\n" 963 | ugemm += "}\n" 964 | ugemm += store 965 | ugemm += "}\n" 966 | 967 | res = "" 968 | res += "inline\n" 969 | res += f"void q{bits}gemm(const float* __restrict__ input, \n" 970 | res += "const int* __restrict__ W, \n" 971 | res += "const float* __restrict__ scales, \n" 972 | # res += "const int* __restrict__ zeros, \n" 973 | res += "const float* __restrict__ zeros, \n" 974 | res +="const float* __restrict__ bias, \n " 975 | res +="const float* __restrict__ sums," 976 | res +="float* __restrict__ output,\n \ 977 | const int n,\n \ 978 | const int m,\n \ 979 | const int t,\n \ 980 | const int nb,\n \ 981 | const int mb,\n \ 982 | const int tb,\n \ 983 | int ogtt,\n \ 984 | const int cutoff){\n" 985 | 986 | res += f"#pragma omp parallel num_threads({p})\n" 987 | res += "{\n" 988 | res += "int tid, nthreads;\n" 989 | res += f"const int mu = {mu};\n" 990 | res += f"const int nu = {nu};\n" 991 | res += f"const int tu = {tu};\n" 992 | res += f"const int on = n / nb;\n" 993 | res += f"const int om = m / mb;\n" 994 | 995 | mask = (2**bits)-1 996 | res += f"const __m256i mask = _mm256_set1_epi32({mask});\n" 997 | if bits == 3: 998 | res += f"const __m256i mask4 = _mm256_set1_epi32(4);\n" 999 | res += f"const __m256i mask6 = _mm256_set1_epi32(6);\n" 1000 | res += "tid = omp_get_thread_num();\n" 1001 | # res += " std::cout << \"thread \" << tid << \" started\" << std::endl;\n" 1002 | res += "nthreads = omp_get_num_threads();\n" 1003 | 1004 | res += "int tt = ogtt;\n" 1005 | res += "if(tid >= cutoff){\n" 1006 | res += f"tt -= tb;\n" 1007 | res += "}\n" 1008 | res += f"const int base_output = tid >= cutoff ?\n \ 1009 | (tid-cutoff)*tt + (tt+tb)*cutoff: \n \ 1010 | tid*tt;\n" #is this >= cutoff or > cutoff? 1011 | if bits != 3: 1012 | res += f"const int base_W = tid >= cutoff ?\n \ 1013 | ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}: \n \ 1014 | tid*tt*m/{packed};\n" 1015 | else: 1016 | res += f"const int base_W = tid >= cutoff ?\n \ 1017 | ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}*3: \n \ 1018 | tid*tt*m/{packed}*3;\n" 1019 | 1020 | 1021 | res += "for(int j = 0; j < tt; j+=tb){\n" 1022 | res += "for(int i = 0; i < on; i++) {\n" 1023 | res += "for(int k = 0; k < om; k++) {\n" 1024 | res += "for(int i1 = 0; i1 < nb; i1+=nu) {\n" 1025 | res += "int j1 = 0;\n" 1026 | res += ugemm 1027 | res += "}\n" 1028 | res += "}\n" 1029 | res += "}\n" 1030 | res += "}\n" 1031 | # res += "#pragma omp barrier\n" 1032 | # res += "#pragma omp for\n" 1033 | res += "for (int i = 0; i < n; i++) {\n" 1034 | res += "__m256 r = _mm256_set1_ps(sums[i]);\n" 1035 | # res += f"for (int j = 0; j < t; j+={tu})" 1036 | res += f"for (int j = 0; j < tt; j+={tu})" 1037 | res += "{\n" 1038 | for i in range(0,tu,8): 1039 | # res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n" 1040 | res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n" 1041 | for i in range(0,tu,8): 1042 | res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + j + {i}]);\n" 1043 | for i in range(0,tu,8): 1044 | res += f"__m256 s{i} = _mm256_loadu_ps(&scales[base_output + j + {i}]);\n" 1045 | for i in range(0,tu,8): 1046 | res += f"__m256 zr{i} = _mm256_fnmadd_ps(z{i}, r, o{i});\n" 1047 | for i in range(0,tu,8): 1048 | res += f"__m256 o2{i} = _mm256_mul_ps(zr{i}, s{i});\n" 1049 | for i in range(0,tu,8): 1050 | res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o2{i});\n" 1051 | res += "}\n" 1052 | res += "}\n" 1053 | res += "}\n" 1054 | res += "}\n" 1055 | 1056 | 1057 | # wrapper for qgemm if we call from cpp 1058 | res += f"inline void forward{bits}_cpu(\n" 1059 | res += "torch::Tensor in, torch::Tensor weight, torch::Tensor out,\n" 1060 | res += "torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,\n" 1061 | res += "int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){\n" 1062 | res += "int* W = weight.data_ptr();\n" 1063 | res += "float* input = in.data_ptr();\n" 1064 | res += "float* b = bias.data_ptr();\n" 1065 | res += "float* s = scales.data_ptr();\n" 1066 | # res += "int* z = zeros.data_ptr();\n" 1067 | res += "float* z = zeros.data_ptr();\n" 1068 | res += "float* r = sums.data_ptr();\n" 1069 | res += "float* O = out.data_ptr();\n" 1070 | res += "\n" 1071 | res += f"q{bits}gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff);\n" 1072 | res += "}\n" 1073 | return res 1074 | 1075 | def unpack_zeros(bits): 1076 | res = "" 1077 | res += f"void unpack_zeros{bits}_cpu(const int* zv, float* ov, int n, int m)" 1078 | packed = 32//bits 1079 | mask = (2**bits)-1 1080 | res += "{\nconst __m256i ones = _mm256_set1_epi32(1);\n" 1081 | res += f"const __m256i mask = _mm256_set1_epi32({mask});\n" 1082 | if bits == 4: 1083 | res += "const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);\n" 1084 | elif bits == 3: 1085 | pass 1086 | elif bits == 2: 1087 | res += "const __m256i shift0 = _mm256_set_epi32(30,28,26,24,22,20,18,16);\n" 1088 | res += "const __m256i shift1 = _mm256_set_epi32(14,12,10,8,6,4,2,0);\n" 1089 | else: 1090 | print("ERROR") 1091 | res += "for(int i = 0; i < n; i++){\n" 1092 | if bits == 4: 1093 | res += "for(int j = 0; j < m; j+=8){\n" 1094 | res += "__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);\n" 1095 | res += "__m256i z0 = _mm256_srlv_epi32(z, shift);\n" 1096 | res += "__m256i z1 = _mm256_and_si256(z0, mask);\n" 1097 | res += "__m256i z2 = _mm256_add_epi32(z1, ones);\n" 1098 | res += "__m256 z3 = _mm256_cvtepi32_ps(z2);\n" 1099 | res += "_mm256_storeu_ps(&ov[i*m +j], z3);\n" 1100 | elif bits == 2: 1101 | res += f"for (int j = 0; j < m; j+={packed})" 1102 | res += "{\n" 1103 | res += f"for (int k = 0; k < {packed}; k++)" 1104 | res += "{\n" 1105 | res += f"ov[i*m + j+k] = (((zv[j/{packed}] >> ({bits}*k)) & {mask})+1);\n" 1106 | res += "}\n" 1107 | # res += "for(int j = 0; j < m; j+=16){\n" 1108 | # res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n" 1109 | # res += "__m256i z00 = _mm256_srlv_epi32(z, shift0);\n" 1110 | # res += "__m256i z01 = _mm256_srlv_epi32(z, shift1);\n" 1111 | # res += "__m256i z10 = _mm256_and_si256(z00, mask);\n" 1112 | # res += "__m256i z11 = _mm256_and_si256(z01, mask);\n" 1113 | # res += "__m256i z20 = _mm256_add_epi32(z10, ones);\n" 1114 | # res += "__m256i z21 = _mm256_add_epi32(z11, ones);\n" 1115 | # res += "__m256 z30 = _mm256_cvtepi32_ps(z20);\n" 1116 | # res += "__m256 z31 = _mm256_cvtepi32_ps(z21);\n" 1117 | # res += "_mm256_storeu_ps(&ov[i*m +j], z30);\n" 1118 | # res += "_mm256_storeu_ps(&ov[i*m +j+8], z31);\n" 1119 | elif bits == 3: 1120 | # pass 1121 | res += "for(int j = 0; j < m; j+=32){\n" 1122 | res += "std::cout<<\"not yet implemented\"<> {29 - i*3}) & 7) + 1;\n" 1128 | # for i in range(10): 1129 | # res += f"ov[i*m + j + {i}] = z0{i} * sv[i*m + j + {i}];\n" 1130 | # res += "unsigned int t0 = ((z0<<1 & 6) | (z1>>31)) + 1;\n" 1131 | # res += "ov[i*m + j + 10] = t0 * sv[i*m + j + 10];\n" 1132 | # for i in range(10): 1133 | # res += f"unsigned int z1{i} = ((z1 >> {28 - i*3}) & 7) + 1;\n" 1134 | # for i in range(10): 1135 | # res += f"ov[i*m + j + {11 + i}] = z1{i} * sv[i*m + j + {11 + i}];\n" 1136 | # res += "unsigned int t1 = ((z1<<2 & 6) | (z2>>30)) + 1;\n" 1137 | # res += "ov[i*m + j + 21] = t1 * sv[i*m + j + 21];\n" 1138 | # for i in range(10): 1139 | # res += f"unsigned int z2{i} = ((z2 >> {27 - i*3}) & 7) + 1;\n" 1140 | # for i in range(10): 1141 | # res += f"ov[i*m + j + {22 + i}] = z2{i} * sv[i*m + j + {22 + i}];\n" 1142 | 1143 | res += "}\n" 1144 | res += "}\n" 1145 | res += "}\n" 1146 | 1147 | # write the pybind interface 1148 | res += f"void unpack_zeros{bits}(torch::Tensor zeros, torch::Tensor out, int N, int M)" 1149 | res += "{\nint* Z = zeros.data_ptr();\n" 1150 | res += "float* O = out.data_ptr();\n" 1151 | res += f"unpack_zeros{bits}_cpu(Z, O, N, M);\n" 1152 | res += "}\n" 1153 | 1154 | return res 1155 | 1156 | def gen_module(r, p, bits_list=[2,3,4]): 1157 | code = "" 1158 | for bits in bits_list: 1159 | if bits == 3: 1160 | unroll = 3 1161 | nu = 1 #args.n 1162 | mu = 32 1163 | tu = 32 1164 | else: 1165 | unroll = 2 1166 | nu = 1 #args.n 1167 | mu = 16 1168 | # mu = 32 1169 | tu = 32 1170 | 1171 | code += qforward(nu, mu, tu, p, unroll, bits=bits, module=True, gs=False) 1172 | code += qforward(nu, mu, tu, p, unroll, bits=bits, module=True, gs=True) 1173 | code += pack_qw_module(bits) 1174 | code += unpack_zeros(bits) 1175 | 1176 | with open("backend.cpp", "w") as f: 1177 | f.write(template.includes()) 1178 | f.write(template.quant_scalar()) 1179 | f.write(compute_reduction(p)) 1180 | f.write(unquantize_sim(p)) 1181 | f.write(code) 1182 | f.write(template.module(bits_list)) 1183 | 1184 | def compute_reduction(p): 1185 | res = "" 1186 | res += "void compute_reduction_cpu(const float* in, float* out, int n, int m, int gs){\n" 1187 | res += f"#pragma omp parallel num_threads({p})\n" 1188 | res += "{\n" 1189 | res += "#pragma omp for collapse(2)\n" 1190 | res += "for(int i = 0; i < n; i++){\n" 1191 | res += "for(int j0 = 0; j0 < m; j0+=gs){\n" 1192 | res += "__m256 acc = _mm256_setzero_ps();\n" 1193 | res += "for(int j1 = j0; j1 < j0+gs; j1+=8){\n" 1194 | res += "__m256 x = _mm256_loadu_ps(&in[i*m + j1]);\n" 1195 | res += "acc = _mm256_add_ps(acc, x);\n" 1196 | res += "}\n" 1197 | #compute simd add reduction 1198 | res += "const __m128 hiQuad = _mm256_extractf128_ps(acc, 1);\n" 1199 | res += "const __m128 loQuad = _mm256_castps256_ps128(acc);\n" 1200 | res += "const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);\n" 1201 | res += "const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);\n" 1202 | res += "const __m128 sumDual = _mm_add_ps(sumQuad, hiDual);\n" 1203 | res += "const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);\n" 1204 | res += "const __m128 sum = _mm_add_ss(hi, sumDual);\n" 1205 | res += "out[(i*m + j0)/gs] = _mm_cvtss_f32(sum);\n" 1206 | res += "}\n" 1207 | res += "}\n" 1208 | res += "}\n" 1209 | res += "}\n" 1210 | 1211 | # write the pybind interface 1212 | res += f"void compute_reduction(torch::Tensor in, torch::Tensor out, int N, int M, int gs)" 1213 | res += "{\nfloat* I = in.data_ptr();\n" 1214 | res += "float* O = out.data_ptr();\n" 1215 | res += f"compute_reduction_cpu(I, O, N, M, gs);\n" 1216 | res += "}\n" 1217 | 1218 | return res 1219 | 1220 | def unquantize_sim(p): 1221 | res = "" 1222 | res += "void unquantize_sim_cpu(const int* in, float* out, float* s, float* z, int n, int m, int bits, int gs){\n" 1223 | res += f"#pragma omp parallel num_threads({p})\n" 1224 | res += "{\n" 1225 | res += "int packed = 32/bits;\n" 1226 | res += "int mask = (1< 2 | #include "forward.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define mymin(a,b) ((a)<(b)?(a):(b)) 10 | #define mymax(a,b) ((a)>(b)?(a):(b)) 11 | 12 | void print_matrix(std::string name, float* A, int N, int M){ 13 | std::cout<> 2; 123 | int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1; 124 | int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4; 125 | int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7; 126 | int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10; 127 | int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13; 128 | int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16; 129 | int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19; 130 | int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22; 131 | int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25; 132 | int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28; 133 | int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31; 134 | int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1; 135 | int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2; 136 | int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5; 137 | int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8; 138 | int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11; 139 | int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14; 140 | int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17; 141 | int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20; 142 | int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23; 143 | int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26; 144 | int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29; 145 | 146 | int acc0 = 0, acc1 = 0, acc2 = 0; 147 | 148 | acc0 |= temp0; 149 | acc0 |= temp1; 150 | acc0 |= temp2; 151 | acc0 |= temp3; 152 | acc0 |= temp4; 153 | acc0 |= temp5; 154 | acc0 |= temp6; 155 | acc0 |= temp7; 156 | acc0 |= temp8; 157 | acc0 |= temp9; 158 | acc0 |= temp10_0; 159 | 160 | acc1 |= temp10_1; 161 | acc1 |= temp11; 162 | acc1 |= temp12; 163 | acc1 |= temp13; 164 | acc1 |= temp14; 165 | acc1 |= temp15; 166 | acc1 |= temp16; 167 | acc1 |= temp17; 168 | acc1 |= temp18; 169 | acc1 |= temp19; 170 | acc1 |= temp20; 171 | acc1 |= temp21_0; 172 | 173 | acc2 |= temp21_1; 174 | acc2 |= temp22; 175 | acc2 |= temp23; 176 | acc2 |= temp24; 177 | acc2 |= temp25; 178 | acc2 |= temp26; 179 | acc2 |= temp27; 180 | acc2 |= temp28; 181 | acc2 |= temp29; 182 | acc2 |= temp30; 183 | acc2 |= temp31; 184 | 185 | BQ[(3*i1/32)*m+j] = acc0; 186 | BQ[(3*i1/32+1)*m+j] = acc1; 187 | BQ[(3*i1/32+2)*m+j] = acc2; 188 | } 189 | 190 | }else{ 191 | for (int i1 = i0; i1 < i0+gs; i1+=packed){ 192 | uint32_t acc = 0; 193 | for (int i2 = i1; i2 < i1+packed; i2++){ 194 | int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j]; 195 | acc = acc | (temp << (bits*(i2-i1))); 196 | } 197 | BQ[(i1/packed)*m+j] = acc; 198 | } 199 | } 200 | } 201 | } 202 | 203 | } 204 | 205 | int main(int argc, char *argv[]){ 206 | // read n m t from args 207 | if(argc == 0){std::cout << "Parameters not given\n"; return 0;} 208 | int n = atoi(argv[1]); 209 | int m = atoi(argv[2]); 210 | int t = atoi(argv[3]); 211 | int bits = atoi(argv[4]); 212 | int gs = atoi(argv[5]); 213 | int ng; 214 | if(gs == -1){ 215 | ng = 1; 216 | }else{ 217 | ng = m/gs; 218 | } 219 | float* A = new float[n*m]; 220 | float* AB = new float[n*m]; 221 | float* B = new float[m*t]; 222 | float* BQS = new float[m*t]; 223 | float* scales = new float[t*ng]; 224 | float* zeros = new float[t*ng]; 225 | int* BQ = new int[m*t/8]; 226 | int* BQB = new int[m*t/8]; 227 | float* sums = new float[n*ng]; 228 | float* bias = new float[t]; 229 | float* C = new float[n*t]; 230 | float* CB = new float[n*t]; 231 | float* C2 = new float[n*t]; 232 | srand(1); 233 | for (int i = 0; i < n*m; i++){ 234 | A[i] = (float)rand() / RAND_MAX; 235 | } 236 | for (int i = 0; i < t*m; i++){ 237 | B[i] = (float)rand() / RAND_MAX; 238 | } 239 | for (int i = 0; i < t; i++){ 240 | bias[i] = (float)rand() / RAND_MAX; 241 | } 242 | for (int i = 0; i < n*t; i++){ 243 | C[i] = 0.0; 244 | C2[i] = 0.0; 245 | } 246 | quantize_sim(B,BQS,scales,zeros,m,t,bits,gs); 247 | quantize(B,BQ,scales,zeros,m,t,bits,gs); 248 | 249 | quantize_sim(B,BQS,scales,zeros,m,t,bits,gs); 250 | quantize(B,BQ,scales,zeros,m,t,bits,gs); 251 | oracle_mmadd(A, BQS, bias, C, n, m, t); 252 | pack_input(A,AB); 253 | pack_qw(BQ,BQB); 254 | pack_output(C,CB); 255 | 256 | compute_reduction(A,sums,n,m,gs); 257 | qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t); 258 | 259 | float norm = 0.0; 260 | for (int i = 0; i < n*t; i++){ 261 | norm += (C[i] - C2[i]) * (C[i] - C2[i]); 262 | } 263 | if(norm / (n*t) < 0.0001){ 264 | int iter = 30; 265 | for(int _ = 0; _ < iter; _++){ 266 | qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t); 267 | } 268 | 269 | int num_runs = 15; 270 | std::vector runs(num_runs); 271 | for(int r = 0; r < num_runs; r++){ 272 | auto start = std::chrono::high_resolution_clock::now(); 273 | for(int _ = 0; _ < iter; _++){ 274 | qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t); 275 | } 276 | auto end = std::chrono::high_resolution_clock::now(); 277 | runs[r] = std::chrono::duration_cast(end - start).count(); 278 | 279 | } 280 | 281 | std::sort(runs.begin(), runs.end()); 282 | 283 | float cycles_final = runs[num_runs/2 + 1] / iter; 284 | 285 | std::ofstream outfile; 286 | outfile.open("tmp.csv", std::ios_base::app); 287 | 288 | print_parameters(); 289 | outfile << cycles_final << std::endl; 290 | }else{ 291 | float cycles_final = int(10e12); 292 | 293 | std::ofstream outfile; 294 | outfile.open("tmp.csv", std::ios_base::app); 295 | 296 | print_parameters(); 297 | outfile << cycles_final << std::endl; 298 | } 299 | 300 | return 0; 301 | } 302 | 303 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gekko==1.0.6 2 | numpy==1.24.3 3 | setuptools==44.0.0 4 | torch==2.0.0 5 | tqdm==4.65.0 6 | datasets==2.12.0 7 | sentencepiece==0.1.99 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from torch.utils import cpp_extension 3 | import os 4 | import subprocess 5 | import math 6 | 7 | os.environ["CC"] = "g++" 8 | os.environ["CXX"] = "g++" 9 | 10 | 11 | p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2]) 12 | 13 | subprocess.call(["python", "generate.py", "--module", "--search", "--p", str(p)]) 14 | 15 | setup( 16 | name='cQIGen', 17 | ext_modules=[cpp_extension.CppExtension( 18 | 'cQIGen', ['backend.cpp'], 19 | extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"] 20 | )], 21 | cmdclass={'build_ext': cpp_extension.BuildExtension} 22 | ) 23 | -------------------------------------------------------------------------------- /template.py: -------------------------------------------------------------------------------- 1 | 2 | def includes(): 3 | out = " \ 4 | #include \n \ 5 | #include \n \ 6 | #include \n \ 7 | #include \n \ 8 | #include \n \ 9 | \n \ 10 | #define mymin(a,b) ((a)<(b)?(a):(b))\n \ 11 | #define mymax(a,b) ((a)>(b)?(a):(b))\n \ 12 | " 13 | return out 14 | 15 | 16 | def module(bits_list=[4, 2]): 17 | out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n' 18 | for bits in bits_list: 19 | out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits) 20 | 21 | for bits in bits_list: 22 | out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits) 23 | 24 | for bits in bits_list: 25 | out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits) 26 | 27 | for bits in bits_list: 28 | out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits) 29 | 30 | out += 'm.def("compute_reduction_cpp", &compute_reduction);\n' 31 | out += 'm.def("unquantize_sim", &unquantize_sim);\n' 32 | 33 | # if oracle: 34 | # out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n' 35 | 36 | 37 | out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n' 38 | 39 | out += '}\n' 40 | return out 41 | 42 | def quant_scalar(): 43 | out = " \ 44 | void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \ 45 | //find scales and zeros arrays \n \ 46 | //quantize \n \ 47 | int pack = 32/bits;\n \ 48 | for (int j = 0; j < m; j++){\n \ 49 | for (int i = 0; i < n; i+=pack){\n \ 50 | uint32_t acc = 0;\n \ 51 | for (int ii = i; ii < i+pack; ii++){\n \ 52 | float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \ 53 | int temp = (int)ftemp;\n \ 54 | acc = acc | (temp << (bits*(ii-i)));\n \ 55 | }\n \ 56 | BQ[(i/pack)*m+j] = acc;\n \ 57 | //BQ[0] = acc;\n \ 58 | }\n \ 59 | }\n \ 60 | }\n \ 61 | \n \ 62 | void quant_scalar_cpu(\n \ 63 | torch::Tensor in, torch::Tensor out, \n \ 64 | torch::Tensor scales, torch::Tensor zeros, int bits\n \ 65 | ) {\n \ 66 | \n \ 67 | int N = in.size(0);\n \ 68 | int M = in.size(1);\n \ 69 | \n \ 70 | float* input = in.data_ptr(); \n \ 71 | float* s = scales.data_ptr();\n \ 72 | float* z = zeros.data_ptr();\n \ 73 | int* O = out.data_ptr();\n \ 74 | \n \ 75 | quantize_scalar(input, O, s, z, N, M, bits);\n \ 76 | \n \ 77 | }\n" 78 | 79 | return out 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from gekko import GEKKO 4 | 5 | def mem_model(N, M, T, mu, tu, bits, l1, p, gs, verbose=False): 6 | m = GEKKO() # create GEKKO model 7 | #cinfergen if bits==3: 8 | # tu = tu*3 9 | B = m.Const(value=bits) 10 | TP = m.Const(value=T//p) 11 | k = m.Var(1,integer=True,lb=1) 12 | z = m.Var(1,integer=True,lb=1) 13 | w = m.Var(1,integer=True,lb=1) 14 | y = m.Var(1,integer=True,lb=1) 15 | v = m.Var(1,integer=True,lb=1) 16 | mb = m.Var(mu,integer=True,lb=1) 17 | if gs != -1: 18 | gg = m.Var(1,integer=True,lb=1) 19 | tb = m.Var(tu,integer=True,lb=1,ub=int(T/p)) 20 | L = m.Var(integer=True,lb=0,ub=l1) 21 | m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N) 22 | m.Equation(mb * k == M) 23 | if gs != -1: 24 | m.Equation(gs * gg == mb) 25 | # m.Equation(tb * z == T) 26 | m.Equation(tb * z == TP) 27 | m.Equation(mu * w == mb) 28 | m.Equation(tu * y == tb) 29 | # m.Equation(tb * v == tt) 30 | m.Maximize(L) 31 | m.options.SOLVER = 1 32 | m.solver_options = ['minlp_maximum_iterations 1000', \ 33 | # minlp iterations with integer solution 34 | 'minlp_max_iter_with_int_sol 10', \ 35 | # treat minlp as nlp 36 | 'minlp_as_nlp 0', \ 37 | # nlp sub-problem max iterations 38 | 'nlp_maximum_iterations 100', \ 39 | # 1 = depth first, 2 = breadth first 40 | 'minlp_branch_method 2', \ 41 | # maximum deviation from whole number 42 | 'minlp_integer_tol 0.00', \ 43 | # covergence tolerance 44 | 'minlp_gap_tol 0.01'] 45 | try: 46 | m.solve(disp=False) 47 | except: 48 | try: 49 | m.solver_options = ['minlp_maximum_iterations 1000', \ 50 | # minlp iterations with integer solution 51 | 'minlp_max_iter_with_int_sol 10', \ 52 | # treat minlp as nlp 53 | 'minlp_as_nlp 0', \ 54 | # nlp sub-problem max iterations 55 | 'nlp_maximum_iterations 100', \ 56 | # 1 = depth first, 2 = breadth first 57 | 'minlp_branch_method 1', \ 58 | # maximum deviation from whole number 59 | 'minlp_integer_tol 0.00', \ 60 | # covergence tolerance 61 | 'minlp_gap_tol 0.01'] 62 | m.solve(disp=False) 63 | except: 64 | # mytb = T//p 65 | mytb = tu 66 | if gs != -1: 67 | mymb = gs 68 | while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1: 69 | mymb += gs 70 | while M % mymb != 0: 71 | mymb -= gs 72 | if verbose: 73 | print("Failed to solve, using heuristic. mb = ", mymb, "tb = ", mytb) 74 | return (int(mymb), int(mytb)) 75 | else: 76 | mymb = mu 77 | while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1: 78 | mymb += mu 79 | while M % mymb != 0: 80 | mymb -= mu 81 | if verbose: 82 | print("Failed to solve, using heuristic. mb = ", mymb, "tb = ", mytb) 83 | return (int(mymb), int(mytb)) 84 | 85 | if verbose: 86 | print("mb = ", int(mb.value[0]), "tb = ", int(tb.value[0])) 87 | return (int(mb.value[0]), int(tb.value[0])) 88 | --------------------------------------------------------------------------------