├── .gitattributes ├── .clang-format ├── .gitignore ├── tools ├── cudaprof.sh ├── requirements.txt ├── download.py ├── chart.py ├── cudaprof.cu └── convert.py ├── src ├── sampler.h ├── tokenizer.h ├── tensors.h ├── sampler.c ├── model.h ├── tokenizer.c ├── tensors.c ├── helpers.cuh ├── infer.m ├── infer.c ├── infer.metal ├── run.c └── infer.cu ├── .github └── workflows │ └── build.yml ├── LICENSE.md ├── Makefile └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | UseTab: ForIndentation 2 | TabWidth: 4 3 | IndentWidth: 4 4 | ColumnLimit: 0 5 | PointerAlignment: Left -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build intermediates 2 | .vscode/ 3 | build/ 4 | 5 | # model files 6 | *.calm 7 | 8 | # profiling tools 9 | *.sqlite 10 | *.nsys-rep 11 | *.ncu-rep 12 | perf.data* 13 | *.gputrace/ -------------------------------------------------------------------------------- /tools/cudaprof.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | make -q build/cudaprof || make build/cudaprof 4 | if [ "$1" == "-s" ]; then 5 | shift 6 | export PROF_SYNC=1 7 | fi 8 | CUDA_INJECTION64_PATH=build/cudaprof "$@" 9 | -------------------------------------------------------------------------------- /tools/requirements.txt: -------------------------------------------------------------------------------- 1 | # convert.py 2 | safetensors 3 | torch 4 | numpy 5 | 6 | # convert.py, only if using models without HF tokenizer.json 7 | sentencepiece 8 | 9 | # download.py 10 | huggingface_hub 11 | -------------------------------------------------------------------------------- /src/sampler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | struct Sampler { 4 | int vocab_size; 5 | unsigned long long rng_state; 6 | 7 | float temperature; 8 | float minp; 9 | }; 10 | 11 | float sample_prob(int idx, float* logits, int size); 12 | 13 | int sample(struct Sampler* sampler, float* logits); 14 | -------------------------------------------------------------------------------- /tools/download.py: -------------------------------------------------------------------------------- 1 | # Download model folder from HuggingFace 2 | # python download.py model_folder repo_id 3 | 4 | import argparse 5 | import huggingface_hub 6 | 7 | argp = argparse.ArgumentParser() 8 | argp.add_argument("output", type=str) 9 | argp.add_argument("repo", type=str) 10 | argp.add_argument("--all", action="store_true") 11 | args = argp.parse_args() 12 | 13 | # download model folder from HuggingFace, excluding .bin files (assume the model contains safetensors) 14 | ignore_patterns = ["*.bin", "*.pth", "*.pt", "*.gguf", "consolidated.safetensors"] if not args.all else [] 15 | huggingface_hub.snapshot_download(repo_id=args.repo, local_dir=args.output, ignore_patterns=ignore_patterns) 16 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | linux: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: cuda install 11 | run: | 12 | # sudo apt install -y nvidia-cuda-toolkit 13 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb 14 | sudo dpkg -i cuda-keyring_1.1-1_all.deb 15 | sudo apt-get update 16 | sudo apt-get install -y cuda-compiler-12-4 # cuda-libraries-dev-12-4 17 | - name: make 18 | run: | 19 | export NVCC=/usr/local/cuda/bin/nvcc 20 | make -j2 21 | 22 | macos: 23 | runs-on: macos-latest 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: make 27 | run: make -j2 28 | -------------------------------------------------------------------------------- /src/tokenizer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | struct TokenIndex { 4 | char* str; 5 | int id; 6 | }; 7 | 8 | struct Tokenizer { 9 | char** vocab; 10 | float* vocab_scores; 11 | struct TokenIndex* sorted_vocab; 12 | 13 | int vocab_size; 14 | int bos_id; 15 | int eos_id; 16 | int eot_id; 17 | int byte_fallbacks; 18 | 19 | char byte_pieces[256][2]; 20 | }; 21 | 22 | enum TokenizerFlags { 23 | TF_ENCODE_BOS = 1 << 0, 24 | TF_ENCODE_EOS = 1 << 1, 25 | }; 26 | 27 | void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size, int total_length); 28 | void tokenizer_free(struct Tokenizer* tokenizer); 29 | 30 | int tokenizer_bound(int bytes); 31 | 32 | char* tokenizer_decode(struct Tokenizer* tokenizer, int prev_token, int token); 33 | int tokenizer_encode(struct Tokenizer* tokenizer, char* text, unsigned flags, int* tokens); 34 | 35 | int tokenizer_find(struct Tokenizer* tokenizer, char* token); 36 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2024 Arseny Kapoulkine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/tensors.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | enum DType { 6 | dt_f32, 7 | dt_f16, 8 | dt_bf16, 9 | dt_f8e5m2, 10 | dt_f8e4m3, 11 | dt_i32, 12 | dt_i16, 13 | dt_i8, 14 | dt_u8, 15 | }; 16 | 17 | struct Tensor { 18 | char* name; 19 | enum DType dtype; 20 | int shape[4]; 21 | void* data; 22 | size_t size; 23 | }; 24 | 25 | struct Metadata { 26 | char* key; 27 | char* value; 28 | }; 29 | 30 | struct Tensors { 31 | void* data; 32 | size_t size; 33 | 34 | struct Metadata metadata[128]; 35 | int n_metadata; 36 | 37 | struct Tensor tensors[1024]; 38 | int n_tensors; 39 | }; 40 | 41 | int tensors_parse(struct Tensors* tensors, void* data, size_t size); 42 | 43 | int tensors_open(struct Tensors* tensors, const char* filename); 44 | void tensors_close(struct Tensors* tensors); 45 | 46 | struct Tensor* tensors_find(struct Tensors* tensors, const char* name, int layer); 47 | void* tensors_get(struct Tensors* tensors, const char* name, int layer, enum DType dtype, int shape[4]); 48 | 49 | const char* tensors_metadata_find(struct Tensors* tensors, const char* name); 50 | const char* tensors_metadata(struct Tensors* tensors, const char* name); 51 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | MAKEFLAGS+=-r -j 2 | 3 | UNAME=$(shell uname) 4 | 5 | NVCC?=nvcc 6 | 7 | BUILD=build 8 | 9 | SOURCES=$(wildcard src/*.c) 10 | 11 | ifeq ($(UNAME),Darwin) 12 | SOURCES+=$(wildcard src/*.m) 13 | SOURCES+=$(wildcard src/*.metal) 14 | endif 15 | 16 | ifneq ($(UNAME),Darwin) 17 | SOURCES+=$(wildcard src/*.cu) 18 | endif 19 | 20 | OBJECTS=$(SOURCES:%=$(BUILD)/%.o) 21 | BINARY=$(BUILD)/run 22 | 23 | CFLAGS=-g -Wall -Wpointer-arith -Werror -O3 -ffast-math 24 | LDFLAGS=-lm 25 | 26 | ifeq ($(UNAME),Darwin) 27 | ifneq (,$(wildcard /opt/homebrew/opt/libomp)) 28 | CFLAGS+=-Xclang -fopenmp -I/opt/homebrew/opt/libomp/include 29 | LDFLAGS+=-L/opt/homebrew/opt/libomp/lib -lomp 30 | endif 31 | LDFLAGS+=-framework Metal -framework Foundation 32 | METALFLAGS=-std=metal3.0 -O2 33 | else 34 | CFLAGS+=-fopenmp -mf16c -mavx2 -mfma 35 | LDFLAGS+=-fopenmp 36 | endif 37 | 38 | ifneq ($(UNAME),Darwin) 39 | LDFLAGS+=-lcudart 40 | endif 41 | 42 | ifneq (,$(wildcard /usr/local/cuda)) 43 | LDFLAGS+=-L/usr/local/cuda/lib64 44 | endif 45 | 46 | CUFLAGS+=-g -O2 -lineinfo 47 | CUFLAGS+=-allow-unsupported-compiler # for recent CUDA versions 48 | 49 | ifeq ($(CUARCH),) 50 | CUFLAGS+=-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=compute_90 --threads 2 51 | else 52 | CUFLAGS+=-arch=$(CUARCH) 53 | endif 54 | 55 | all: $(BINARY) 56 | 57 | format: 58 | clang-format -i src/* tools/*.cu 59 | 60 | $(BUILD)/fuzz-tensors: src/tensors.c 61 | clang $(CFLAGS) -DFUZZING -O1 -fsanitize=address,fuzzer -o $@ $^ 62 | 63 | $(BUILD)/cudaprof: tools/cudaprof.cu 64 | $(NVCC) $< $(CUFLAGS) -Xcompiler -fPIC -shared -lcupti -MMD -MP -o $@ 65 | 66 | $(BINARY): $(OBJECTS) 67 | $(CC) $^ $(LDFLAGS) -o $@ 68 | 69 | $(BUILD)/%.c.o: %.c 70 | @mkdir -p $(dir $@) 71 | $(CC) $< $(CFLAGS) -c -MMD -MP -o $@ 72 | 73 | $(BUILD)/%.m.o: %.m 74 | @mkdir -p $(dir $@) 75 | $(CC) $< $(CFLAGS) -c -MMD -MP -o $@ 76 | 77 | $(BUILD)/%.metal.o: %.metal 78 | @mkdir -p $(dir $@) 79 | xcrun metal $< $(METALFLAGS) -c -MMD -MP -o $@.ir 80 | xcrun metallib -o $@.metallib $@.ir 81 | xxd -i -n $(basename $(notdir $<))_metallib $@.metallib > $@.c 82 | $(CC) $@.c -c -o $@ 83 | 84 | $(BUILD)/%.cu.o: %.cu 85 | @mkdir -p $(dir $@) 86 | $(NVCC) $< $(CUFLAGS) -c -MMD -MP -o $@ 87 | 88 | -include $(OBJECTS:.o=.d) 89 | -include $(BUILD)/cudaprof.d 90 | 91 | clean: 92 | rm -rf $(BUILD) 93 | 94 | .PHONY: all clean format 95 | -------------------------------------------------------------------------------- /tools/chart.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import re 6 | import os.path 7 | import time 8 | 9 | argp = argparse.ArgumentParser() 10 | argp.add_argument("input", type=str) 11 | argp.add_argument("--output", type=str, default="chart.png") 12 | args = argp.parse_args() 13 | 14 | # Function to parse the markdown table and return a DataFrame 15 | def parse_markdown_table(file_path): 16 | with open(file_path, 'r') as file: 17 | lines = file.readlines() 18 | 19 | # Finding the start and end of the table 20 | start, end = None, None 21 | for i, line in enumerate(lines): 22 | if line.startswith('|'): 23 | start = i if start is None else start 24 | end = i 25 | elif start is not None: 26 | break 27 | 28 | # Extracting the table content 29 | table_lines = lines[start:end+1] 30 | 31 | headers = [header.strip() for header in table_lines[0].split("|")[1:-1]] 32 | data = [] 33 | for line in table_lines[1:]: 34 | if "----" in line: 35 | continue 36 | row = [value.strip() for value in line.split("|")[1:-1]] 37 | if len(row) == len(headers): 38 | data.append(row) 39 | 40 | return pd.DataFrame(data, columns=headers) 41 | 42 | # Function to extract numerical values from the table cells 43 | def extract_values(cell): 44 | tokens = re.search(r"(\d+)\s*tok/s", cell) 45 | gb_s = re.search(r"(\d+)\s*GB/s", cell) 46 | return int(tokens.group(1)) if tokens else None, int(gb_s.group(1)) if gb_s else None 47 | 48 | # Parsing the Markdown table 49 | df = parse_markdown_table(args.input) 50 | 51 | # Extracting and processing the data 52 | df['Tokens/s'], df['GB/s'] = zip(*df['Performance (first 32 tokens)'].apply(extract_values)) 53 | 54 | # Creating the scatter plot 55 | plt.figure(figsize=(10, 6)) 56 | colors = plt.cm.tab20(np.linspace(0, 1, len(df))) 57 | for index, row in df.iterrows(): 58 | plt.scatter(row['Tokens/s'], row['GB/s'], color=colors[index], label=row['Model (context)'], s=100, marker='o', linewidths=2, zorder=3) 59 | 60 | mtime = os.path.getmtime(args.input) 61 | 62 | plt.xscale('log') 63 | 64 | # Customizing the ticks on the Tokens/s axis 65 | # Generating approximately 10 evenly spaced ticks 66 | min_tok_s = min(df['Tokens/s']) 67 | max_tok_s = max(df['Tokens/s']) 68 | ticks = np.logspace(np.log10(min_tok_s), np.log10(max_tok_s), num=10) 69 | plt.xticks(ticks, [f"{int(tick)}" for tick in ticks]) 70 | 71 | plt.title('calm performance (RTX 4090), ' + time.strftime("%b %Y", time.localtime(mtime))) 72 | plt.xlabel('Performance (Tokens/s)') 73 | plt.ylabel('Performance (GB/s)') 74 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') 75 | plt.grid(True, zorder=1) 76 | plt.savefig(args.output, format='png', bbox_inches='tight') 77 | -------------------------------------------------------------------------------- /src/sampler.c: -------------------------------------------------------------------------------- 1 | #include "sampler.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | static unsigned int random_u32(unsigned long long* state) { 8 | // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A 9 | *state ^= *state >> 12; 10 | *state ^= *state << 25; 11 | *state ^= *state >> 27; 12 | return (*state * 0x2545F4914F6CDD1Dull) >> 32; 13 | } 14 | 15 | static float random_f32(unsigned long long* state) { // random float32 in [0,1) 16 | return (random_u32(state) >> 8) / 16777216.0f; 17 | } 18 | 19 | float sample_prob(int idx, float* logits, int size) { 20 | // find max value (for numerical stability) 21 | float max_val = -FLT_MAX; 22 | for (int i = 0; i < size; i++) { 23 | max_val = logits[i] > max_val ? logits[i] : max_val; 24 | } 25 | // exp and sum 26 | float sum = 0.0f; 27 | for (int i = 0; i < size; i++) { 28 | sum += expf(logits[i] - max_val); 29 | } 30 | // return probability of the given index 31 | return expf(logits[idx] - max_val) / sum; 32 | } 33 | 34 | static int sample_argmax(float* logits, int n) { 35 | int max_i = -1; 36 | float max_p = -FLT_MAX; 37 | for (int i = 0; i < n; i++) { 38 | max_i = logits[i] > max_p ? i : max_i; 39 | max_p = logits[i] > max_p ? logits[i] : max_p; 40 | } 41 | return max_i; 42 | } 43 | 44 | static int sample_minp(float* logits, int n, float minp, float temperature, float coin) { 45 | // find max logit; we will use this to derive minp cutoff (in log space), since minp is scale-invariant (wrt softmax) 46 | float max_logit = -FLT_MAX; 47 | for (int i = 0; i < n; i++) { 48 | max_logit = logits[i] > max_logit ? logits[i] : max_logit; 49 | } 50 | 51 | // exp(logit / temp) <= exp(max_logit / temp) * minp -> logit <= max_logit + log(minp) * temp 52 | float logit_cutoff = max_logit + logf(minp) * temperature; 53 | 54 | // convert from logits to probabilities in-place while simultaneously doing (unscaled) softmax; we'll rescale later 55 | float* probs = logits; 56 | int fallback = 0; 57 | float cumulative_prob = 0.0f; 58 | for (int i = 0; i < n; i++) { 59 | if (logits[i] >= logit_cutoff) { 60 | probs[i] = expf((logits[i] - max_logit) / temperature); 61 | cumulative_prob += probs[i]; 62 | fallback = i; // for fallback due to rounding errors 63 | } else { 64 | probs[i] = 0.0f; 65 | } 66 | } 67 | 68 | // sample from the truncated list 69 | float r = coin * cumulative_prob; 70 | float cdf = 0.0f; 71 | for (int i = 0; i < n; i++) { 72 | cdf += probs[i]; 73 | if (r < cdf) { 74 | return i; 75 | } 76 | } 77 | return fallback; // in case of rounding errors 78 | } 79 | 80 | int sample(struct Sampler* sampler, float* logits) { 81 | if (sampler->temperature == 0.0f || sampler->minp >= 1.0f) { 82 | // greedy argmax sampling: take the token with the highest probability 83 | return sample_argmax(logits, sampler->vocab_size); 84 | } else { 85 | float coin = random_f32(&sampler->rng_state); 86 | // min-p (cutoff) sampling, clamping the least likely tokens to zero 87 | return sample_minp(logits, sampler->vocab_size, sampler->minp, sampler->temperature, coin); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define MAX_LAYERS 128 7 | #define MAX_EXPERTS 64 8 | 9 | // How many attention sinks to use for rolling buffer 10 | #define KV_SINKS 2 11 | 12 | struct Config { 13 | int dim; // transformer dimension 14 | int hidden_dim; // for ffn layers 15 | int head_dim; // for attention heads; usually dim / n_heads 16 | int n_layers; // number of layers 17 | int n_heads; // number of query heads 18 | int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) 19 | int vocab_size; // vocabulary size, usually 256 (byte-level) 20 | int seq_len; // max sequence length 21 | float rope_theta; // RoPE theta 22 | int rotary_dim; // RoPE rotary dimension (elements after that don't get rotated) 23 | int n_experts; // number of experts for MoE models 24 | int n_experts_ac; // number of active experts for MoE models 25 | float norm_eps; // epsilon for layer normalization 26 | bool act_gelu; // use GELU activation function 27 | bool norm_ln; // use full LN normalization 28 | bool norm_par; // use parallel MLP/attention by omitting intermediate normalization 29 | float qkv_clip; // clip qkv values to [-clip, clip] 30 | }; 31 | 32 | struct Weights { 33 | int dbits; // 4 for gf4, 8 for fp8, 16 for fp16; determines type of void* below 34 | 35 | // token embedding table 36 | void* token_embedding_table; // (vocab_size, dim) 37 | // weights for norms 38 | float* rms_att_weight[MAX_LAYERS]; // (dim) rmsnorm weights 39 | float* rms_ffn_weight[MAX_LAYERS]; // (dim) 40 | // weights for matmuls 41 | void* wq[MAX_LAYERS]; // (n_heads * head_dim, dim) 42 | void* wk[MAX_LAYERS]; // (n_kv_heads * head_dim, dim) 43 | void* wv[MAX_LAYERS]; // (n_kv_heads * head_dim, dim) 44 | void* wo[MAX_LAYERS]; // (dim, n_heads * head_dim) 45 | // weights for ffn 46 | void* w1[MAX_LAYERS]; // (n_experts?, hidden_dim, dim) 47 | void* w2[MAX_LAYERS]; // (n_experts?, dim, hidden_dim) 48 | void* w3[MAX_LAYERS]; // (n_experts?, hidden_dim, dim) 49 | // final norm 50 | float* rms_final_weight; // (dim,) 51 | // classifier weights for the logits, on the last layer 52 | void* wcls; 53 | // biases for qkv (qwen) 54 | float* bqkv[MAX_LAYERS]; // ((n_heads + n_kv_heads * 2) * head_dim) 55 | // moe gate weights (mixtral) 56 | void* moegate[MAX_LAYERS]; // (n_experts, dim) 57 | }; 58 | 59 | struct RunState { 60 | // current wave of activations 61 | float* x; // activation at current time stamp (dim,) 62 | float* xb; // same, but inside a residual branch (dim,) 63 | float* xb2; // an additional buffer just for convenience (dim,) 64 | float* hb; // buffer for hidden dimension in the ffn (hidden_dim,) 65 | float* hb2; // buffer for hidden dimension in the ffn (hidden_dim,) 66 | float* he; // buffer for hidden dimension in the ffn (n_experts_ac,hidden_dim,) 67 | float* q; // query (dim,) 68 | float* k; // key (dim,) 69 | float* v; // value (dim,) 70 | float* att; // buffer for scores/attention values (n_heads, seq_len) 71 | float* exp; // buffer for MoE computations (n_experts + n_experts_ac * 2) 72 | float* logits; // output logits 73 | // kv cache 74 | int kvbits; // 8 for fp8, 16 for fp16; determines type of void* below 75 | void* key_cache; // (layer, seq_len, dim) 76 | void* value_cache; // (layer, seq_len, dim) 77 | }; 78 | 79 | struct Transformer { 80 | struct Config config; // the hyperparameters of the architecture (the blueprint) 81 | struct Weights weights; // the weights of the model 82 | struct RunState state; // buffers for the "wave" of activations in the forward pass 83 | size_t n_params, n_bytes, n_bandwidth; 84 | float* (*forward)(struct Transformer* transformer, int token, int pos, unsigned flags); 85 | }; 86 | 87 | enum ForwardFlags { 88 | FF_UPDATE_KV_ONLY = 1 << 0, // only update kv cache and don't output logits 89 | }; 90 | -------------------------------------------------------------------------------- /tools/cudaprof.cu: -------------------------------------------------------------------------------- 1 | // Based on NVIDIA's cupti_trace_injection sample 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #define CUPTI_CHECK(call) \ 9 | do { \ 10 | CUptiResult _status = call; \ 11 | if (_status != CUPTI_SUCCESS) { \ 12 | const char* err = "?"; \ 13 | cuptiGetResultString(_status, &err); \ 14 | fprintf(stderr, "CUPTI error in %s at %s:%d: %s (%d)\n", \ 15 | __FUNCTION__, __FILE__, __LINE__, err, _status); \ 16 | abort(); \ 17 | } \ 18 | } while (0) 19 | 20 | #define BUFFER_SIZE (8 * 1024 * 1024) 21 | #define MAX_TOKENS (1024 * 1024) 22 | #define MAX_KERNELS 1024 23 | 24 | struct KernelInfo { 25 | const char* name; 26 | 27 | float total_time; 28 | int calls; 29 | float call_avg; 30 | float call_m2; 31 | float peak_bw; 32 | float peak_util; 33 | int limit_occ; 34 | }; 35 | 36 | static CUpti_ActivityDevice3 device; 37 | 38 | static uint64_t tokens[MAX_TOKENS]; 39 | 40 | static KernelInfo kernels[MAX_KERNELS]; 41 | static int n_kernels; 42 | 43 | static KernelInfo* get_kernel(const char* name) { 44 | for (int i = 0; i < n_kernels; i++) { 45 | if (strcmp(kernels[i].name, name) == 0) { 46 | return &kernels[i]; 47 | } 48 | } 49 | 50 | assert(n_kernels < MAX_KERNELS); 51 | KernelInfo* kernel = &kernels[n_kernels++]; 52 | kernel->name = name; 53 | 54 | return kernel; 55 | } 56 | 57 | static void CUPTIAPI buffer_requested(uint8_t** buffer, size_t* size, size_t* maxNumRecords) { 58 | *size = BUFFER_SIZE; 59 | *buffer = (uint8_t*)malloc(BUFFER_SIZE); 60 | *maxNumRecords = 0; 61 | } 62 | 63 | static void CUPTIAPI buffer_completed(CUcontext ctx, uint32_t streamId, uint8_t* buffer, size_t size, size_t validSize) { 64 | CUpti_Activity* record = NULL; 65 | 66 | for (;;) { 67 | CUptiResult status = cuptiActivityGetNextRecord(buffer, validSize, &record); 68 | if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { 69 | break; 70 | } 71 | CUPTI_CHECK(status); 72 | 73 | switch (record->kind) { 74 | case CUPTI_ACTIVITY_KIND_DEVICE: { 75 | device = *(CUpti_ActivityDevice3*)record; 76 | break; 77 | } 78 | 79 | case CUPTI_ACTIVITY_KIND_KERNEL: 80 | case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { 81 | CUpti_ActivityKernel5* activity = (CUpti_ActivityKernel5*)record; 82 | KernelInfo* info = get_kernel(activity->name); 83 | 84 | float time = (float)(activity->end - activity->start) / 1e6; 85 | uint64_t token = tokens[activity->correlationId % MAX_TOKENS]; 86 | 87 | info->total_time += time; 88 | 89 | // Welford's algorithm 90 | float delta = time - info->call_avg; 91 | info->calls++; 92 | info->call_avg += delta / info->calls; 93 | info->call_m2 += delta * (time - info->call_avg); 94 | 95 | // update peak bandwidth for kernel calls that specify profiling token as the first argument 96 | if ((token >> 48) == 0xCDAF) { 97 | uint64_t bytes = token & ((1ull << 48) - 1); 98 | float bw = ((double)bytes / 1e9) / (time / 1e3); 99 | info->peak_bw = fmaxf(info->peak_bw, bw); 100 | } 101 | 102 | int blocks = activity->gridX * activity->gridY * activity->gridZ; 103 | int blocks_rounded = (blocks + device.numMultiprocessors - 1) / device.numMultiprocessors * device.numMultiprocessors; 104 | info->peak_util = fmaxf(info->peak_util, (float)blocks / blocks_rounded); 105 | 106 | int block_size = activity->blockX * activity->blockY * activity->blockZ; 107 | int block_size_warps = (block_size + device.numThreadsPerWarp - 1) / device.numThreadsPerWarp; 108 | 109 | int occ_limit_blocks = device.maxBlocksPerMultiprocessor; 110 | int occ_limit_warps = device.maxWarpsPerMultiprocessor / block_size_warps; 111 | int occ_limit_smem = (activity->sharedMemoryExecuted) / (activity->staticSharedMemory + activity->dynamicSharedMemory + 1024); 112 | int occ_limit_regs = device.maxRegistersPerMultiprocessor / (((activity->registersPerThread * device.numThreadsPerWarp + 255) & ~255) * block_size_warps); 113 | int occ_limit = min(occ_limit_blocks, min(occ_limit_warps, min(occ_limit_smem, occ_limit_regs))); 114 | info->limit_occ = max(info->limit_occ, occ_limit * block_size_warps); 115 | break; 116 | } 117 | default: 118 | break; 119 | } 120 | } 121 | 122 | free(buffer); 123 | 124 | size_t dropped = 0; 125 | CUPTI_CHECK(cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped)); 126 | 127 | if (dropped != 0) { 128 | printf("WARNING: dropped %u CUPTI activity records.\n", (unsigned int)dropped); 129 | } 130 | } 131 | 132 | static void CUPTIAPI callback_handler(void* userdata, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const void* cbdata) { 133 | const CUpti_CallbackData* cbinfo = (CUpti_CallbackData*)cbdata; 134 | 135 | switch (domain) { 136 | case CUPTI_CB_DOMAIN_RUNTIME_API: { 137 | switch (cbid) { 138 | case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000: { 139 | if (cbinfo->callbackSite == CUPTI_API_ENTER) { 140 | cudaLaunchKernel_v7000_params* params = (cudaLaunchKernel_v7000_params*)cbinfo->functionParams; 141 | tokens[cbinfo->correlationId % MAX_TOKENS] = *(uint64_t*)*params->args; 142 | } 143 | break; 144 | } 145 | case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000: { 146 | if (cbinfo->callbackSite == CUPTI_API_ENTER) { 147 | cudaLaunchCooperativeKernel_v9000_params* params = (cudaLaunchCooperativeKernel_v9000_params*)cbinfo->functionParams; 148 | tokens[cbinfo->correlationId % MAX_TOKENS] = *(uint64_t*)*params->args; 149 | } 150 | break; 151 | } 152 | default: 153 | break; 154 | } 155 | break; 156 | } 157 | default: 158 | break; 159 | } 160 | } 161 | 162 | static void atexit_handler(void) { 163 | CUPTI_CHECK(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED)); 164 | 165 | if (n_kernels) { 166 | printf("\n"); 167 | printf("%20s%10s%23s%12s%15s%25s\n", "Kernel", "Time", "Avg Time (us)", "Calls", "BW (GB/s)", "Utilization (limit)"); 168 | printf("%20s%10s%23s%12s%15s%25s\n", "---", "---", "---", "---", "---", "---"); 169 | 170 | float total_time = 0; 171 | for (int i = 0; i < n_kernels; i++) { 172 | total_time += kernels[i].total_time; 173 | } 174 | 175 | for (int i = 0; i < n_kernels; i++) { 176 | KernelInfo* kernel = &kernels[i]; 177 | 178 | const char* name = kernel->name; 179 | size_t length = strlen(name); 180 | 181 | if (strncmp(name, "_Z", 2) == 0 && length >= 2) { 182 | name += 2; 183 | char* end; 184 | length = strtoul(name, &end, 10); 185 | name = end; 186 | length = length > strlen(name) ? strlen(name) : length; 187 | } 188 | 189 | if (strncmp(name, "kernel_", 7) == 0 && length >= 7) { 190 | name += 7; 191 | length -= 7; 192 | } 193 | 194 | const char* namecont = ""; 195 | 196 | if (length > 20) { 197 | length = 19; 198 | namecont = "…"; 199 | } 200 | 201 | char avgtime[64]; 202 | snprintf(avgtime, sizeof(avgtime), "%.2f ± %.2f", 203 | kernel->call_avg * 1e3, 204 | sqrtf(kernel->call_m2 / kernel->calls) * 1e3); 205 | 206 | char util[64]; 207 | snprintf(util, sizeof(util), "%.0f%% SMs, %d wrp/SM", kernel->peak_util * 100, kernel->limit_occ); 208 | 209 | printf("%20.*s%s%9.1f%%%24s%12d%15.1f%25s\n", (int)length, name, namecont, 210 | kernel->total_time / total_time * 100, avgtime, kernel->calls, kernel->peak_bw, util); 211 | } 212 | } 213 | } 214 | 215 | extern "C" int InitializeInjection(void) { 216 | atexit(&atexit_handler); 217 | 218 | CUpti_SubscriberHandle subscriber; 219 | CUPTI_CHECK(cuptiSubscribe(&subscriber, callback_handler, NULL)); 220 | CUPTI_CHECK(cuptiEnableCallback(1, subscriber, CUPTI_CB_DOMAIN_RUNTIME_API, CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000)); 221 | CUPTI_CHECK(cuptiEnableCallback(1, subscriber, CUPTI_CB_DOMAIN_RUNTIME_API, CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000)); 222 | 223 | const char* sync = getenv("PROF_SYNC"); 224 | 225 | // note: KIND_KERNEL serializes kernel launches; KIND_CONCURRENT_KERNEL does not but it results in less stable timings 226 | if (sync && atoi(sync)) { 227 | CUPTI_CHECK(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL)); 228 | } else { 229 | CUPTI_CHECK(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); 230 | } 231 | 232 | CUPTI_CHECK(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE)); 233 | 234 | CUPTI_CHECK(cuptiActivityRegisterCallbacks(buffer_requested, buffer_completed)); 235 | return 1; 236 | } 237 | -------------------------------------------------------------------------------- /src/tokenizer.c: -------------------------------------------------------------------------------- 1 | #include "tokenizer.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define MAX_TOKEN_LENGTH 512 9 | 10 | static int compare_tokens(const void* a, const void* b) { 11 | return strcmp(((struct TokenIndex*)a)->str, ((struct TokenIndex*)b)->str); 12 | } 13 | 14 | static int str_lookup(char* str, struct TokenIndex* sorted_vocab, int vocab_size) { 15 | // efficiently find the perfect match for str in vocab, return its index or -1 if not found 16 | struct TokenIndex tok = {str, -1}; // acts as the key to search for 17 | struct TokenIndex* res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(struct TokenIndex), compare_tokens); 18 | return res != NULL ? res->id : -1; 19 | } 20 | 21 | void tokenizer_init(struct Tokenizer* tokenizer, char* tokens, float* scores, int bos_id, int eos_id, int vocab_size, int total_length) { 22 | tokenizer->vocab_size = vocab_size; 23 | tokenizer->bos_id = bos_id; 24 | tokenizer->eos_id = eos_id; 25 | tokenizer->eot_id = -1; 26 | 27 | tokenizer->vocab = (char**)malloc(vocab_size * sizeof(char*)); 28 | tokenizer->sorted_vocab = (struct TokenIndex*)malloc(vocab_size * sizeof(struct TokenIndex)); 29 | tokenizer->vocab_scores = scores; 30 | 31 | assert(tokens[total_length - 1] == '\0'); 32 | int token_offset = 0; 33 | 34 | for (int i = 0; i < vocab_size; ++i) { 35 | tokenizer->vocab[i] = tokens + token_offset; 36 | tokenizer->sorted_vocab[i].str = tokens + token_offset; 37 | tokenizer->sorted_vocab[i].id = i; 38 | 39 | int token_length = strlen(tokens + token_offset); 40 | assert(token_length <= MAX_TOKEN_LENGTH && token_offset + token_length + 1 <= total_length); 41 | token_offset += token_length + 1; 42 | } 43 | 44 | assert(token_offset == total_length); 45 | 46 | qsort(tokenizer->sorted_vocab, vocab_size, sizeof(struct TokenIndex), compare_tokens); 47 | 48 | tokenizer->byte_fallbacks = str_lookup("<0x00>", tokenizer->sorted_vocab, vocab_size); 49 | 50 | if (tokenizer->byte_fallbacks >= 0) { 51 | for (int i = 0; i < 256; i++) { 52 | tokenizer->byte_pieces[i][0] = (char)i; 53 | tokenizer->byte_pieces[i][1] = '\0'; 54 | } 55 | } 56 | 57 | if (tokenizer->eot_id < 0) { 58 | tokenizer->eot_id = str_lookup("<|eot_id|>", tokenizer->sorted_vocab, vocab_size); 59 | } 60 | if (tokenizer->eot_id < 0) { 61 | tokenizer->eot_id = str_lookup("<|end|>", tokenizer->sorted_vocab, vocab_size); 62 | } 63 | if (tokenizer->eot_id < 0) { 64 | tokenizer->eot_id = str_lookup("<|im_end|>", tokenizer->sorted_vocab, vocab_size); 65 | } 66 | } 67 | 68 | void tokenizer_free(struct Tokenizer* tokenizer) { 69 | free(tokenizer->vocab); 70 | free(tokenizer->sorted_vocab); 71 | } 72 | 73 | int tokenizer_bound(int bytes) { 74 | return bytes + 3; // +3 for prefix space, ?BOS, ?EOS 75 | } 76 | 77 | char* tokenizer_decode(struct Tokenizer* tokenizer, int prev_token, int token) { 78 | char* piece = tokenizer->vocab[token]; 79 | // following BOS token, sentencepiece decoder strips any leading whitespace (see PR #89) 80 | if (prev_token == tokenizer->bos_id && piece[0] == ' ') { 81 | piece++; 82 | } 83 | // return byte piece for byte fallback tokens (<0x00>, <0x01>, etc.) 84 | if (tokenizer->byte_fallbacks >= 0 && (unsigned)(token - tokenizer->byte_fallbacks) < 256) { 85 | piece = tokenizer->byte_pieces[token - tokenizer->byte_fallbacks]; 86 | } 87 | return piece; 88 | } 89 | 90 | struct Merge { 91 | int lpos, lid; 92 | int rpos, rid; 93 | int resid; 94 | float score; 95 | }; 96 | 97 | static void heap_swap(struct Merge* heap, int i, int j) { 98 | struct Merge tmp = heap[i]; 99 | heap[i] = heap[j]; 100 | heap[j] = tmp; 101 | } 102 | 103 | static void heap_insert(struct Merge* heap, int n_heap, struct Merge merge) { 104 | // insert a new element at the end (breaks heap invariant) 105 | heap[n_heap] = merge; 106 | n_heap++; 107 | 108 | // bubble up the new element to its correct position 109 | int i = n_heap - 1; 110 | while (i > 0 && heap[i].score > heap[(i - 1) / 2].score) { 111 | heap_swap(heap, i, (i - 1) / 2); 112 | i = (i - 1) / 2; 113 | } 114 | } 115 | 116 | static void heap_poptop(struct Merge* heap, int n_heap) { 117 | // move the last element to the top (breaks heap invariant) 118 | n_heap--; 119 | heap[0] = heap[n_heap]; 120 | 121 | // bubble down the new top element to its correct position 122 | int i = 0; 123 | while (i * 2 + 1 < n_heap) { 124 | // find the largest child 125 | int j = i * 2 + 1; 126 | if (j + 1 < n_heap && heap[j + 1].score > heap[j].score) { 127 | j++; 128 | } 129 | // if the largest child is smaller than the parent, we're done 130 | if (heap[j].score <= heap[i].score) { 131 | break; 132 | } 133 | // otherwise, swap the parent and child 134 | heap_swap(heap, i, j); 135 | i = j; 136 | } 137 | } 138 | 139 | static int merge_tokens_tryadd(struct Tokenizer* tokenizer, struct Merge* heap, int n_heap, int lpos, int lid, int rpos, int rid) { 140 | char str_buffer[MAX_TOKEN_LENGTH * 2 + 1]; 141 | strcpy(str_buffer, tokenizer->vocab[lid]); 142 | strcat(str_buffer, tokenizer->vocab[rid]); 143 | int id = str_lookup(str_buffer, tokenizer->sorted_vocab, tokenizer->vocab_size); 144 | if (id != -1) { 145 | struct Merge merge = {lpos, lid, rpos, rid, id, tokenizer->vocab_scores[id]}; 146 | heap_insert(heap, n_heap++, merge); 147 | } 148 | return n_heap; 149 | } 150 | 151 | static int merge_tokens(struct Tokenizer* tokenizer, int* tokens, int n_tokens) { 152 | // create heap for all token merge pairs 153 | struct Merge* heap = malloc(2 * n_tokens * sizeof(struct Merge)); 154 | int n_heap = 0; 155 | 156 | // insert all initial pairs 157 | for (int i = 0; i < n_tokens - 1; i++) { 158 | n_heap = merge_tokens_tryadd(tokenizer, heap, n_heap, i, tokens[i], i + 1, tokens[i + 1]); 159 | } 160 | 161 | // merge all pairs 162 | while (n_heap > 0) { 163 | struct Merge merge = heap[0]; 164 | heap_poptop(heap, n_heap--); 165 | 166 | if (tokens[merge.lpos] != merge.lid || tokens[merge.rpos] != merge.rid) { 167 | continue; // this pair was already merged, skip it 168 | } 169 | 170 | // merge 171 | tokens[merge.lpos] = merge.resid; 172 | tokens[merge.rpos] = -1; 173 | 174 | // we might have new pairs to merge 175 | for (int i = merge.lpos - 1; i >= 0; i--) { 176 | if (tokens[i] != -1) { 177 | n_heap = merge_tokens_tryadd(tokenizer, heap, n_heap, i, tokens[i], merge.lpos, merge.resid); 178 | break; 179 | } 180 | } 181 | 182 | for (int i = merge.rpos + 1; i < n_tokens; i++) { 183 | if (tokens[i] != -1) { 184 | n_heap = merge_tokens_tryadd(tokenizer, heap, n_heap, merge.lpos, merge.resid, i, tokens[i]); 185 | break; 186 | } 187 | } 188 | } 189 | 190 | free(heap); 191 | 192 | // compact tokens 193 | int nm_tokens = 0; 194 | for (int i = 0; i < n_tokens; i++) { 195 | if (tokens[i] != -1) { 196 | tokens[nm_tokens++] = tokens[i]; 197 | } 198 | } 199 | 200 | return nm_tokens; 201 | } 202 | 203 | int tokenizer_encode(struct Tokenizer* tokenizer, char* text, unsigned flags, int* tokens) { 204 | int n_tokens = 0; 205 | 206 | // add optional BOS token, if desired 207 | if ((flags & TF_ENCODE_BOS) && tokenizer->bos_id >= 0) { 208 | tokens[n_tokens++] = tokenizer->bos_id; 209 | } 210 | 211 | // process the raw (UTF-8) byte sequence of the input string 212 | for (char* c = text; *c != '\0';) { 213 | char codepoint[5] = {}; 214 | 215 | codepoint[0] = *c++; 216 | 217 | if (codepoint[0] == '<' && *c == '|') { 218 | // special token, skip until '|>' 219 | char* e = c + 1; 220 | while (*e && !(e[0] == '|' && e[1] == '>')) { 221 | e++; 222 | } 223 | if (e[0] == '|' && e[1] == '>' && e - c + 3 <= MAX_TOKEN_LENGTH) { 224 | // we found the end of the special token, try to encode it as is 225 | char special[MAX_TOKEN_LENGTH + 1]; 226 | memcpy(special, c - 1, e - c + 3); 227 | special[e - c + 3] = '\0'; 228 | 229 | int sid = str_lookup(special, tokenizer->sorted_vocab, tokenizer->vocab_size); 230 | if (sid != -1) { 231 | // we found special codepoint in vocab, add it as a token 232 | tokens[n_tokens++] = sid; 233 | c = e + 2; 234 | continue; 235 | } 236 | } 237 | } 238 | 239 | // this byte is a leading byte (11...), so it's a multi-byte UTF8 codepoint 240 | if ((codepoint[0] & 0xC0) == 0xC0) { 241 | for (int i = 1; i < 4 && (*c & 0xC0) == 0x80; ++i) { 242 | codepoint[i] = *c++; 243 | } 244 | } 245 | 246 | int id = str_lookup(codepoint, tokenizer->sorted_vocab, tokenizer->vocab_size); 247 | 248 | if (id != -1) { 249 | // we found this codepoint in vocab, add it as a token 250 | tokens[n_tokens++] = id; 251 | } else if (tokenizer->byte_fallbacks >= 0) { 252 | // byte_fallback encoding: just encode each byte as a token 253 | for (char* fb = codepoint; *fb != '\0'; ++fb) { 254 | tokens[n_tokens++] = (unsigned char)*fb + tokenizer->byte_fallbacks; 255 | } 256 | } 257 | } 258 | 259 | // optimized heap-based merge 260 | n_tokens = merge_tokens(tokenizer, tokens, n_tokens); 261 | 262 | // add optional EOS token, if desired 263 | if (flags & TF_ENCODE_EOS) { 264 | tokens[n_tokens++] = tokenizer->eos_id; 265 | } 266 | 267 | assert(n_tokens <= tokenizer_bound(strlen(text))); 268 | return n_tokens; 269 | } 270 | 271 | int tokenizer_find(struct Tokenizer* tokenizer, char* token) { 272 | return str_lookup(token, tokenizer->sorted_vocab, tokenizer->vocab_size); 273 | } 274 | -------------------------------------------------------------------------------- /src/tensors.c: -------------------------------------------------------------------------------- 1 | #include "tensors.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | static char* json_skipws(char* json) { 17 | while (*json == ' ' || *json == '\t' || *json == '\n' || *json == '\r') { 18 | json++; 19 | } 20 | return json; 21 | } 22 | 23 | static char* json_string(char* json, char** res) { 24 | if (*json != '"') { 25 | return NULL; 26 | } 27 | json++; 28 | 29 | *res = json; 30 | while (*json != '"') { 31 | if (*json == 0 || *json == '\\') { 32 | return NULL; 33 | } 34 | json++; 35 | } 36 | 37 | *json = 0; 38 | return json_skipws(json + 1); 39 | } 40 | 41 | static char* json_array(char* json, long long* res, int size) { 42 | if (*json != '[') { 43 | return NULL; 44 | } 45 | json = json_skipws(json + 1); 46 | 47 | for (int i = 0; i < size; ++i) { 48 | char* end; 49 | res[i] = strtoll(json, &end, 10); 50 | if (end == json) { 51 | return NULL; 52 | } 53 | json = json_skipws(end); 54 | if (*json == ']') { 55 | return json_skipws(json + 1); 56 | } 57 | if (*json != ',') { 58 | return NULL; 59 | } 60 | json = json_skipws(json + 1); 61 | } 62 | 63 | if (*json != ']') { 64 | return NULL; 65 | } 66 | return json_skipws(json + 1); 67 | } 68 | 69 | static int json_dtype(const char* str, enum DType* dtype, int* dsize) { 70 | static const struct { 71 | const char* str; 72 | enum DType dtype; 73 | int dsize; 74 | } dtypes[] = { 75 | {"F32", dt_f32, 4}, 76 | {"F16", dt_f16, 2}, 77 | {"BF16", dt_bf16, 2}, 78 | {"F8_E5M2", dt_f8e5m2, 1}, 79 | {"F8_E4M3", dt_f8e4m3, 1}, 80 | {"I32", dt_i32, 4}, 81 | {"I16", dt_i16, 2}, 82 | {"I8", dt_i8, 1}, 83 | {"U8", dt_u8, 1}, 84 | }; 85 | 86 | for (size_t i = 0; i < sizeof(dtypes) / sizeof(dtypes[0]); ++i) { 87 | if (strcmp(str, dtypes[i].str) == 0) { 88 | *dtype = dtypes[i].dtype; 89 | *dsize = dtypes[i].dsize; 90 | return 0; 91 | } 92 | } 93 | 94 | return -1; 95 | } 96 | 97 | static bool validate_shape(int dsize, int shape[4], size_t length) { 98 | size_t expected_length = 1; 99 | int max_elements = INT_MAX; 100 | 101 | for (int i = 0; i < 4; ++i) { 102 | int dim = shape[i] == 0 ? 1 : shape[i]; 103 | if (dim < 0 || dim > max_elements) { 104 | return false; 105 | } 106 | 107 | expected_length *= dim; 108 | max_elements /= dim; 109 | } 110 | 111 | return expected_length * dsize == length; 112 | } 113 | 114 | static char* parse_tensor(struct Tensor* tensor, void* bytes, size_t bytes_size, char* name, char* json) { 115 | tensor->name = name; 116 | 117 | if (*json != '{') { 118 | return NULL; 119 | } 120 | json = json_skipws(json + 1); 121 | 122 | int dsize = 0; 123 | 124 | while (*json != '}') { 125 | char* key; 126 | json = json_string(json, &key); 127 | if (!json || *json != ':') { 128 | return NULL; 129 | } 130 | json = json_skipws(json + 1); 131 | 132 | if (strcmp(key, "dtype") == 0) { 133 | char* val; 134 | json = json_string(json, &val); 135 | if (!json) { 136 | return NULL; 137 | } 138 | if (json_dtype(val, &tensor->dtype, &dsize) != 0) { 139 | return NULL; 140 | } 141 | } else if (strcmp(key, "shape") == 0) { 142 | long long shape[4] = {}; 143 | json = json_array(json, shape, 4); 144 | if (!json) { 145 | return NULL; 146 | } 147 | 148 | for (int j = 0; j < 4; ++j) { 149 | if (shape[j] < 0 || shape[j] > INT_MAX) { 150 | return NULL; 151 | } 152 | tensor->shape[j] = (int)shape[j]; 153 | } 154 | } else if (strcmp(key, "data_offsets") == 0) { 155 | long long offsets[2] = {}; 156 | json = json_array(json, offsets, 2); 157 | if (!json) { 158 | return NULL; 159 | } 160 | 161 | if (offsets[0] < 0 || offsets[1] <= offsets[0] || offsets[1] > bytes_size) { 162 | return NULL; 163 | } 164 | 165 | tensor->data = (char*)bytes + offsets[0]; 166 | tensor->size = offsets[1] - offsets[0]; 167 | } else { 168 | return NULL; 169 | } 170 | 171 | if (*json != '}' && *json != ',') { 172 | return NULL; 173 | } 174 | json = (*json == ',') ? json_skipws(json + 1) : json; 175 | } 176 | 177 | if (!validate_shape(dsize, tensor->shape, tensor->size)) { 178 | return NULL; 179 | } 180 | 181 | return json_skipws(json + 1); 182 | } 183 | 184 | static char* parse_metadata(struct Tensors* tensors, char* json) { 185 | if (*json != '{') { 186 | return NULL; 187 | } 188 | json = json_skipws(json + 1); 189 | 190 | while (*json != '}') { 191 | struct Metadata metadata = {}; 192 | json = json_string(json, &metadata.key); 193 | if (!json || *json != ':') { 194 | return NULL; 195 | } 196 | json = json_skipws(json + 1); 197 | json = json_string(json, &metadata.value); 198 | if (!json) { 199 | return NULL; 200 | } 201 | 202 | if (tensors->n_metadata >= sizeof(tensors->metadata) / sizeof(tensors->metadata[0])) { 203 | return NULL; 204 | } 205 | tensors->metadata[tensors->n_metadata++] = metadata; 206 | 207 | if (*json != '}' && *json != ',') { 208 | return NULL; 209 | } 210 | json = (*json == ',') ? json_skipws(json + 1) : json; 211 | } 212 | 213 | return json_skipws(json + 1); 214 | } 215 | 216 | int tensors_parse(struct Tensors* tensors, void* data, size_t size) { 217 | if (size < sizeof(uint64_t)) { 218 | return -1; 219 | } 220 | 221 | uint64_t json_size = *(uint64_t*)data; 222 | if (json_size == 0 || json_size > size - sizeof(uint64_t)) { 223 | return -1; 224 | } 225 | 226 | char* json = (char*)data + sizeof(uint64_t); 227 | void* bytes = (char*)data + sizeof(uint64_t) + json_size; 228 | size_t bytes_size = size - sizeof(uint64_t) - json_size; 229 | 230 | json[json_size - 1] = 0; 231 | 232 | if (*json != '{') { 233 | return -1; 234 | } 235 | json = json_skipws(json + 1); 236 | 237 | while (*json && *json != '}') { 238 | char* key; 239 | json = json_string(json, &key); 240 | if (!json || *json != ':') { 241 | return -1; 242 | } 243 | json = json_skipws(json + 1); 244 | 245 | if (strcmp(key, "__metadata__") == 0) { 246 | json = parse_metadata(tensors, json); 247 | if (!json) { 248 | return -1; 249 | } 250 | } else { 251 | struct Tensor tensor = {}; 252 | json = parse_tensor(&tensor, bytes, bytes_size, key, json); 253 | if (!json) { 254 | return -1; 255 | } 256 | 257 | if (tensors->n_tensors >= sizeof(tensors->tensors) / sizeof(tensors->tensors[0])) { 258 | return -1; 259 | } 260 | tensors->tensors[tensors->n_tensors++] = tensor; 261 | } 262 | 263 | if (*json != '}' && *json != ',' && *json != '\0') { 264 | return -1; 265 | } 266 | json = (*json == ',') ? json_skipws(json + 1) : json; 267 | } 268 | 269 | return 0; 270 | } 271 | 272 | int tensors_open(struct Tensors* tensors, const char* filename) { 273 | int fd = open(filename, O_RDONLY); 274 | if (fd == -1) { 275 | return -1; 276 | } 277 | 278 | struct stat st; 279 | if (fstat(fd, &st) != 0) { 280 | close(fd); 281 | return -1; 282 | } 283 | 284 | size_t size = st.st_size; 285 | void* data = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, 0); 286 | if (data == MAP_FAILED) { 287 | close(fd); 288 | return -1; 289 | } 290 | 291 | #ifdef __linux__ 292 | // increases readahead buffer size, resulting in faster cold loads 293 | posix_fadvise(fd, 0, size, POSIX_FADV_SEQUENTIAL); 294 | #endif 295 | 296 | close(fd); // fd can be closed after mmap returns without invalidating the mapping 297 | 298 | if (tensors_parse(tensors, data, size) != 0) { 299 | munmap(data, size); 300 | return -2; 301 | } 302 | 303 | tensors->data = data; 304 | tensors->size = size; 305 | 306 | return 0; 307 | } 308 | 309 | void tensors_close(struct Tensors* tensors) { 310 | munmap(tensors->data, tensors->size); 311 | } 312 | 313 | struct Tensor* tensors_find(struct Tensors* tensors, const char* name, int layer) { 314 | char key[128]; 315 | snprintf(key, sizeof(key), name, layer); 316 | 317 | for (int i = 0; i < tensors->n_tensors; ++i) { 318 | if (strcmp(tensors->tensors[i].name, key) == 0) { 319 | return &tensors->tensors[i]; 320 | } 321 | } 322 | return NULL; 323 | } 324 | 325 | void* tensors_get(struct Tensors* tensors, const char* name, int layer, enum DType dtype, int shape[4]) { 326 | struct Tensor* tensor = tensors_find(tensors, name, layer); 327 | if (tensor == NULL) { 328 | fprintf(stderr, "FATAL: Tensor not found: %s\n", name); 329 | assert(false); 330 | return NULL; 331 | } 332 | 333 | if (tensor->dtype != dtype || memcmp(tensor->shape, shape, sizeof(tensor->shape)) != 0) { 334 | fprintf(stderr, "FATAL: Tensor mismatch: %s\n", name); 335 | fprintf(stderr, " Expected: dtype=%d shape=[%d,%d,%d,%d]\n", dtype, shape[0], shape[1], shape[2], shape[3]); 336 | fprintf(stderr, " Actual: dtype=%d shape=[%d,%d,%d,%d]\n", tensor->dtype, tensor->shape[0], tensor->shape[1], tensor->shape[2], tensor->shape[3]); 337 | assert(false); 338 | return NULL; 339 | } 340 | 341 | return tensor->data; 342 | } 343 | 344 | const char* tensors_metadata_find(struct Tensors* tensors, const char* name) { 345 | for (int i = 0; i < tensors->n_metadata; ++i) { 346 | if (strcmp(tensors->metadata[i].key, name) == 0) { 347 | return tensors->metadata[i].value; 348 | } 349 | } 350 | return NULL; 351 | } 352 | 353 | const char* tensors_metadata(struct Tensors* tensors, const char* name) { 354 | const char* res = tensors_metadata_find(tensors, name); 355 | if (res == NULL) { 356 | fprintf(stderr, "FATAL: Metadata not found: %s\n", name); 357 | assert(false); 358 | } 359 | return res; 360 | } 361 | 362 | #ifdef FUZZING 363 | int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { 364 | struct Tensors tensors = {}; 365 | void* copy = malloc(size); 366 | memcpy(copy, data, size); 367 | tensors_parse(&tensors, copy, size); 368 | free(copy); 369 | return 0; 370 | } 371 | #endif 372 | -------------------------------------------------------------------------------- /src/helpers.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // note: we expect loads to be broken into units of up to 16b due to specified alignment 10 | template 11 | union _ALIGNAS(sizeof(T) * N) ablock { 12 | T v[N]; 13 | }; 14 | 15 | __device__ inline float warpreduce_sum(float v) { 16 | #pragma unroll 17 | for (int mask = warpSize / 2; mask > 0; mask >>= 1) { 18 | v += __shfl_xor_sync(0xffffffff, v, mask); 19 | } 20 | return v; 21 | } 22 | 23 | __device__ inline float warpreduce_max(float v) { 24 | #pragma unroll 25 | for (int mask = warpSize / 2; mask > 0; mask >>= 1) { 26 | v = max(v, __shfl_xor_sync(0xffffffff, v, mask)); 27 | } 28 | return v; 29 | } 30 | 31 | __device__ inline int warpreduce_maxi(int v) { 32 | #pragma unroll 33 | for (int mask = warpSize / 2; mask > 0; mask >>= 1) { 34 | v = max(v, __shfl_xor_sync(0xffffffff, v, mask)); 35 | } 36 | return v; 37 | } 38 | 39 | __device__ inline float blocktranspose(float v, float def) { 40 | int lane = threadIdx.x % warpSize; 41 | int warp = threadIdx.x / warpSize; 42 | 43 | __shared__ float sm[32]; 44 | sm[warp] = v; 45 | __syncthreads(); 46 | 47 | return lane < blockDim.x / warpSize ? sm[lane] : def; 48 | } 49 | 50 | __device__ inline float blockreduce_sum(float v) { 51 | v = warpreduce_sum(v); 52 | v = blocktranspose(v, 0.f); 53 | v = warpreduce_sum(v); 54 | return v; 55 | } 56 | 57 | __device__ inline float blockreduce_max(float v) { 58 | v = warpreduce_max(v); 59 | v = blocktranspose(v, -FLT_MAX); 60 | v = warpreduce_max(v); 61 | return v; 62 | } 63 | 64 | // fast fp8x4 => float4 conversion; drops unnecessary NaN handling from __nv_cvt_fp8_to_halfraw 65 | __device__ inline float4 fp8x4_e5m2_ff(__nv_fp8x4_e5m2 v) { 66 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 67 | return float4(v); 68 | #else 69 | unsigned int vlo = v.__x, vhi = v.__x >> 16; 70 | __half2_raw hlo = {(unsigned short)(vlo << 8), (unsigned short)(vlo & 0xff00)}; 71 | __half2_raw hhi = {(unsigned short)(vhi << 8), (unsigned short)(vhi & 0xff00)}; 72 | float2 rlo = __internal_halfraw2_to_float2(hlo); 73 | float2 rhi = __internal_halfraw2_to_float2(hhi); 74 | float4 res = {rlo.x, rlo.y, rhi.x, rhi.y}; 75 | return res; 76 | #endif 77 | } 78 | 79 | // fast fp8x2 => half2 conversion; drops unnecessary NaN handling from __nv_cvt_fp8_to_halfraw 80 | __device__ inline half2 fp8x2_e5m2_ff(unsigned int v) { 81 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 82 | __nv_fp8x2_e5m2 p; 83 | p.__x = v; 84 | return half2(p); 85 | #else 86 | __half2_raw h = {(unsigned short)(v << 8), (unsigned short)(v & 0xff00)}; 87 | return h; 88 | #endif 89 | } 90 | 91 | __device__ inline half fp8_e5m2_ff(uint8_t v) { 92 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 93 | __half_raw h = __nv_cvt_fp8_to_halfraw(v, __NV_E5M2); 94 | #else 95 | __half_raw h = {(unsigned short)(v << 8)}; 96 | #endif 97 | return h; 98 | } 99 | 100 | // gf4 decoding: 8 3-bit values + 1 fp8 scale are packed in a 32-bit word 101 | __device__ inline half gf4_ff(uint32_t v, int k) { 102 | half s = fp8_e5m2_ff(v & 0xff) * half(-0.25f); // we expect compiler to reuse this across multiple calls 103 | return half(int((v >> (8 + k * 3)) & 7) - 4) * s; 104 | } 105 | 106 | // gf4 decoding (2 values): 8 3-bit values + 1 fp8 scale are packed in a 32-bit word 107 | __device__ inline half2 gf4x2_ff(uint32_t v, int k) { 108 | half us = fp8_e5m2_ff(v & 0xff); // we expect compiler to reuse this across multiple calls 109 | half s = us * half(-0.25f); // we expect compiler to reuse this across multiple calls 110 | uint32_t p = v >> (8 + k * 3); 111 | half2 q = half2(int(p & 7), int((p >> 3) & 7)); 112 | return __hfma2(q, half2(s, s), half2(us, us)); 113 | } 114 | 115 | // regular mat*vec; naive and unoptimized (won't reach peak bw or flops) 116 | template 117 | __device__ inline float matmul(float* x, T* w, int i, int n) { 118 | float val = 0.0f; 119 | for (int j = 0; j < n; j++) { 120 | val += float(w[i * n + j]) * x[j]; 121 | } 122 | return val; 123 | } 124 | 125 | // warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row 126 | // specialized for half weights and ensures that we maximize transaction sizes by reading 4 bytes per thread 127 | __device__ inline float matmul_warppar(float* x, half* w, int i, int n) { 128 | int lane = threadIdx.x % warpSize; 129 | float val = 0.0f; 130 | for (int j = lane * 2; j < n; j += warpSize * 2) { 131 | float2 ww = __half22float2(*(half2*)&w[i * n + j]); 132 | float2 xx = *(float2*)&x[j]; 133 | val += ww.x * xx.x; 134 | val += ww.y * xx.y; 135 | } 136 | return warpreduce_sum(val); 137 | } 138 | 139 | // warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row 140 | // specialized for half weights and ensures that we maximize transaction sizes by reading 4 bytes per thread 141 | __device__ inline float matmul_warppar(half* x, half* w, int i, int n) { 142 | int lane = threadIdx.x % warpSize; 143 | half2 val = {0, 0}; 144 | for (int j = lane * 2; j < n; j += warpSize * 2) { 145 | half2 ww = *(half2*)&w[i * n + j]; 146 | half2 xx = *(half2*)&x[j]; 147 | val = __hfma2(ww, xx, val); 148 | } 149 | return warpreduce_sum(float(val.x + val.y)); 150 | } 151 | 152 | // warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row 153 | // specialized for fp8 weights and ensures that we maximize transaction sizes by reading 4 bytes per thread 154 | __device__ inline float matmul_warppar(float* x, __nv_fp8_e5m2* w, int i, int n) { 155 | int lane = threadIdx.x % warpSize; 156 | float val = 0.0f; 157 | // use 64-bit loads instead of 32-bit loads to increase memory throughput on H100/A100 158 | // without this we are seeing lower throughput given the limited number of parallel warps in coop kernel 159 | // this is performance-neutral on 4090 but results in issues with x[] load coalescing (that are benign) 160 | for (int j = lane * 8; j < n; j += warpSize * 8) { 161 | ablock<__nv_fp8x4_e5m2, 2> wwp = *(ablock<__nv_fp8x4_e5m2, 2>*)&w[i * n + j]; 162 | #pragma unroll 163 | for (int k = 0; k < 2; ++k) { 164 | float4 ww = fp8x4_e5m2_ff(wwp.v[k]); 165 | float4 xx = *(float4*)&x[j + k * 4]; 166 | val += ww.x * xx.x; 167 | val += ww.y * xx.y; 168 | val += ww.z * xx.z; 169 | val += ww.w * xx.w; 170 | } 171 | } 172 | return warpreduce_sum(val); 173 | } 174 | 175 | // warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row 176 | // specialized for fp8 weights and ensures that we maximize transaction sizes by reading 4 bytes per thread 177 | __device__ inline float matmul_warppar(half* x, __nv_fp8_e5m2* w, int i, int n) { 178 | int lane = threadIdx.x % warpSize; 179 | half2 val = {0, 0}; 180 | // use 64-bit loads instead of 32-bit loads to increase memory throughput on H100/A100 181 | // without this we are seeing lower throughput given the limited number of parallel warps in coop kernel 182 | // this is performance-neutral on 4090 but results in issues with x[] load coalescing (that are benign) 183 | for (int j = lane * 8; j < n; j += warpSize * 8) { 184 | ablock<__nv_fp8x2_e5m2, 4> wwp = *(ablock<__nv_fp8x2_e5m2, 4>*)&w[i * n + j]; 185 | ablock<__half2_raw, 4> xxp = *(ablock<__half2_raw, 4>*)&x[j]; 186 | #pragma unroll 187 | for (int k = 0; k < 4; ++k) { 188 | half2 ww = fp8x2_e5m2_ff(wwp.v[k].__x); 189 | half2 xx = xxp.v[k]; 190 | val = __hfma2(ww, xx, val); 191 | } 192 | } 193 | return warpreduce_sum(float(val.x + val.y)); 194 | } 195 | 196 | // warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row 197 | // specialized for gf4 weights and ensures that we maximize transaction sizes by reading 4 bytes per thread 198 | __device__ inline float matmul_warppar(float* x, uint32_t* w, int i, int n) { 199 | int lane = threadIdx.x % warpSize; 200 | if (n % (warpSize * 16) == 0) { 201 | float val = 0.0f; 202 | for (int j = lane * 8; j < n; j += warpSize * 16) { 203 | uint32_t wg0 = w[i * n / 8 + j / 8]; 204 | uint32_t wg1 = w[i * n / 8 + j / 8 + warpSize]; 205 | 206 | ablock xx0 = *(ablock*)&x[j]; 207 | #pragma unroll 208 | for (int k = 0; k < 8; ++k) { 209 | val += float(gf4_ff(wg0, k)) * xx0.v[k]; 210 | } 211 | 212 | ablock xx1 = *(ablock*)&x[j + warpSize * 8]; 213 | #pragma unroll 214 | for (int k = 0; k < 8; ++k) { 215 | val += float(gf4_ff(wg1, k)) * xx1.v[k]; 216 | } 217 | } 218 | return warpreduce_sum(val); 219 | } else { 220 | float val = 0.0f; 221 | for (int j = lane * 8; j < n; j += warpSize * 8) { 222 | uint32_t wg = w[i * n / 8 + j / 8]; 223 | 224 | ablock xx = *(ablock*)&x[j]; 225 | #pragma unroll 226 | for (int k = 0; k < 8; ++k) { 227 | val += float(gf4_ff(wg, k)) * xx.v[k]; 228 | } 229 | } 230 | return warpreduce_sum(val); 231 | } 232 | } 233 | 234 | // warp-parallel mat*vec; each warp collaboratively computes mat*vec for a single row 235 | // specialized for gf4 weights and ensures that we maximize transaction sizes by reading 4 bytes per thread 236 | __device__ inline float matmul_warppar(half* x, uint32_t* w, int i, int n) { 237 | int lane = threadIdx.x % warpSize; 238 | if (n % (warpSize * 64) == 0) { 239 | half2 val = {0, 0}; 240 | for (int j = lane * 16; j < n; j += warpSize * 64) { 241 | ablock wgp[4] = { 242 | *(ablock*)&w[i * n / 8 + j / 8], 243 | *(ablock*)&w[i * n / 8 + j / 8 + (warpSize * 16) / 8], 244 | *(ablock*)&w[i * n / 8 + j / 8 + (warpSize * 32) / 8], 245 | *(ablock*)&w[i * n / 8 + j / 8 + (warpSize * 48) / 8], 246 | }; 247 | #pragma unroll 248 | for (int u = 0; u < 4; ++u) { 249 | ablock<__half2_raw, 8> xx = *(ablock<__half2_raw, 8>*)&x[j + warpSize * 16 * u]; 250 | #pragma unroll 251 | for (int k = 0; k < 8; k += 2) { 252 | val = __hfma2(gf4x2_ff(wgp[u].v[0], k), xx.v[k / 2], val); 253 | } 254 | #pragma unroll 255 | for (int k = 0; k < 8; k += 2) { 256 | val = __hfma2(gf4x2_ff(wgp[u].v[1], k), xx.v[k / 2 + 4], val); 257 | } 258 | } 259 | } 260 | return warpreduce_sum(float(val.x + val.y)); 261 | } else { 262 | half2 val = {0, 0}; 263 | for (int j = lane * 16; j < n; j += warpSize * 16) { 264 | ablock wgp = *(ablock*)&w[i * n / 8 + j / 8]; 265 | 266 | ablock<__half2_raw, 8> xx = *(ablock<__half2_raw, 8>*)&x[j]; 267 | #pragma unroll 268 | for (int k = 0; k < 8; k += 2) { 269 | val = __hfma2(gf4x2_ff(wgp.v[0], k), xx.v[k / 2], val); 270 | } 271 | #pragma unroll 272 | for (int k = 0; k < 8; k += 2) { 273 | val = __hfma2(gf4x2_ff(wgp.v[1], k), xx.v[k / 2 + 4], val); 274 | } 275 | } 276 | return warpreduce_sum(float(val.x + val.y)); 277 | } 278 | } 279 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 😌 calm 2 | 3 | This is an implementation of language model inference, aiming to get maximum single-GPU single-batch hardware utilization for LLM architectures with a minimal implementation and no dependencies[^1]. 4 | 5 | The goal of this project is experimentation and prototyping; it does not aim to be production ready or stable. 6 | 7 | Parts of this code are based on Andrej Karpathy's [llama2.c](https://github.com/karpathy/llama2.c). 8 | 9 | ## Running 10 | 11 | To build and run `calm`, you need to download and convert a model, build the code using `make`[^2] and run it: 12 | 13 | ```sh 14 | git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 15 | python tools/convert.py mistral-7b-instruct.calm Mistral-7B-Instruct-v0.2/ 16 | make && ./build/run mistral-7b-instruct.calm -i "Q: What is the meaning of life?" -t 0 17 | ``` 18 | 19 | You can also run the model in chat mode (for models like Mistral/Mixtral you might want to increase context size via `-c` from the default 4096): 20 | 21 | ```sh 22 | make && ./build/run mistral-7b-instruct.calm -y "You are a helpful AI assistant." 23 | ``` 24 | 25 | Before running Python you may want to install the dependencies via `pip install -r tools/requirements.txt`. When using git to download models, git-lfs is required and the download size may be larger than necessary; you can use `tools/download.py` instead (assumes models use Safetensors by default): 26 | 27 | ```sh 28 | python tools/download.py Mistral-7B-Instruct-v0.2/ mistralai/Mistral-7B-Instruct-v0.2 29 | ``` 30 | 31 | ## Supported models 32 | 33 | calm supports a subset of decoder-only transformer architectures: 34 | 35 | - Llama-like baseline (pre/post normalization, gated FFN, sequential attention mixing and FFN, RoPE) 36 | - RoPE enhancements (partial rotary dimension, independent head dimension) 37 | - SiLU or GELU FFN gate activation 38 | - RMSNorm or LayerNorm* normalization (no bias support) 39 | - Optional minor variations (QKV bias, QKV clipping, tied embeddings) 40 | - Optional mixture of experts (with top-k expert selection) 41 | 42 | It has been tested on following models: 43 | 44 | | Architecture | Models | 45 | |-------------------|--------| 46 | | Llama | [Llama2 7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), [Llama2 13B](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf), [Llama3 8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 47 | | Llama-like | [TinyLlama 1.1B](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0), [Cosmo 1B](https://huggingface.co/HuggingFaceTB/cosmo-1b), [LLaMA Pro 8B](https://huggingface.co/TencentARC/LLaMA-Pro-8B-Instruct), [H2O Danube 1.8B](https://huggingface.co/h2oai/h2o-danube-1.8b-chat), [DeepSeekMath 7B](https://huggingface.co/deepseek-ai/deepseek-math-7b-instruct), [LargeWorldModel 7B 1M](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M), [Xverse 7B](https://huggingface.co/xverse/XVERSE-7B-Chat), [LLM360 K2](https://huggingface.co/LLM360/K2-Chat) | 48 | | Yi | [Yi 1.5 6B](https://huggingface.co/01-ai/Yi-1.5-6B-Chat), [Yi 1.5 9B](https://huggingface.co/01-ai/Yi-1.5-9B-Chat), [Yi 1.5 34B](https://huggingface.co/01-ai/Yi-1.5-34B-Chat) | 49 | | Mistral | [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3), [Mistral Nemo 12B](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407), [Codestral 22B](https://huggingface.co/mistralai/Codestral-22B-v0.1), [Mistral Pro 8B](https://huggingface.co/TencentARC/Mistral_Pro_8B_v0.1), [SOLAR 10.7B](https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0), [GritLM 7B](https://huggingface.co/GritLM/GritLM-7B), [Starling 7B](https://huggingface.co/Nexusflow/Starling-LM-7B-beta) | 50 | | Qwen2 | [Qwen1.5 0.5B](https://huggingface.co/Qwen/Qwen1.5-0.5B), [Qwen1.5 1.8B](https://huggingface.co/Qwen/Qwen1.5-1.8B), [Qwen1.5 4B](https://huggingface.co/Qwen/Qwen1.5-4B), [Qwen1.5 7B](https://huggingface.co/Qwen/Qwen1.5-7B), [Qwen1.5 14B](https://huggingface.co/Qwen/Qwen1.5-14B), [Qwen2 0.5B](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct), [Qwen2 1.5B](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct), [Qwen2 7B](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | 51 | | Mixtral | [Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), [Mixtral 8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), [GritLM 8x7B](https://huggingface.co/GritLM/GritLM-8x7B) | 52 | | OLMo | [OLMo 1B](https://huggingface.co/allenai/OLMo-1B), [OLMo 7B](https://huggingface.co/allenai/OLMo-7B), [OLMo 1.7 7B](https://huggingface.co/allenai/OLMo-1.7-7B) | 53 | | Gemma | [Gemma 2B](https://huggingface.co/google/gemma-2b-it), [Gemma 7B](https://huggingface.co/google/gemma-7b-it) (*note: 7B version has issues with fp8 quantization*) | 54 | | MiniCPM | [MiniCPM 2B](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16), [MiniCPM 2B 128K](https://huggingface.co/openbmb/MiniCPM-2B-128k), [MiniCPM MoE 8x2B](https://huggingface.co/openbmb/MiniCPM-MoE-8x2B) | 55 | | Cohere | [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01), [Aya 23 8B](https://huggingface.co/CohereForAI/aya-23-8B), [Aya 23 35B](https://huggingface.co/CohereForAI/aya-23-35B) | 56 | | InternLM | [InternLM2-1.8B](https://huggingface.co/internlm/internlm2-1_8b), [InternLM2-7B](https://huggingface.co/internlm/internlm2-7b), [InternLM2-20B](https://huggingface.co/internlm/internlm2-20b) | 57 | | DBRX | [DBRX 132B](https://huggingface.co/databricks/dbrx-instruct) | 58 | | Phi3 | [Phi3 Mini 3.8B](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/), [Phi3 Medium 14B](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) | 59 | 60 | ## Supported formats 61 | 62 | Model weights support `fp16`, `fp8` and `gf4` formats; the weight type is specified at conversion time via `--dtype` argument to `convert.py`, and defaults to `fp8`. 63 | 64 | `fp16` corresponds to 16-bit floating point (e5m10). Note that some models store weights in bf16 which will be automatically converted. 65 | 66 | `fp8` corresponds to 8-bit floating point (e5m2). Using `fp8` carries a ~0.5% perplexity penalty at almost double the inference speed and half the model size. e4m3 variant of `fp8` would result in a much smaller perplexity penalty (~0.1%) with basic tensor scaling, but it's currently not used because of performance issues wrt floating-point conversion. 67 | 68 | `gf4` corresponds to 4-bit grouped floating point (8 values are stored in 32 bits using 3 bit quantized scale per value and one fp8 group scale). Using `gf4` currently carries a perplexity penalty but increases inference speed by ~75% and halves the model size compared to `fp8`. Unlike llama.cpp's K-quants, `gf4` quantization is pure and uniform - all layers are quantized to exactly 4 bits per weight. 69 | 70 | KV cache is using `fp16` by default; when using longer contexts (> 4096), CUDA implementation automatically switches to `fp8` to improve memory/performance. This comes at a small perplexity cost. 71 | 72 | ## Model files 73 | 74 | calm uses [🤗 Safetensors](https://huggingface.co/docs/safetensors/index) to store model files. Note that the models require conversion (see below), because calm stores model hyperparameters in .safetensors metadata and may expect a particular set of tensor names or weight order within tensors that is not always compatible with the source. Tokenizer data is stored as tensors inside the model file as well. 75 | 76 | ## Performance 77 | 78 | Auto-regressive prediction for a single sequence needs to read the entire model and the entire KV cache (until current token) for every token. As such, given an optimal implementation we'd expect the process to be bandwidth bound. Note that the cost of token generation at the beginning of the sequence should be smaller than the cost at the end of the sequence due to the need to read data from KV cache. 79 | 80 | Currently prompts are processed serially, one token at a time; in the future, prompt processing will need to be parallelized to avoid the bandwidth bottleneck. 81 | 82 | With smaller weights on small models, getting closer to bandwidth limit becomes more difficult. Future optimizations may increase the gap here for small models, although smaller weights are most valuable to be able to infer larger models. 83 | 84 | ### NVidia 85 | 86 | When using NVidia GeForce RTX 4090, `calm` gets the following performance on a few models; each model is measured with `fp16`, `fp8` and `gf4` weights at the beginning of the context window (first 32 tokens) and at the end (last 32 tokens with an offset 2000 for 2048 contexts, 4000 for 4096 contexts and 16000 for 16384 contexts): 87 | 88 | | Model (context) | Performance (first 32) | Performance (last 32) | 89 | | ----------- | ----------- | ----------- | 90 | | Llama3 8B (4096), fp16 | 61 tok/s (923 GB/s) | 59 tok/s (919 GB/s) | 91 | | Llama3 8B (4096), fp8 | 120 tok/s (903 GB/s) | 110 tok/s (889 GB/s) | 92 | | Llama3 8B (4096), gf4 | 225 tok/s (846 GB/s) | 194 tok/s (830 GB/s) | 93 | | Llama2 7B (4096), fp16 | 69 tok/s (919 GB/s) | 60 tok/s (921 GB/s) | 94 | | Llama2 7B (4096), fp8 | 135 tok/s (893 GB/s) | 103 tok/s (899 GB/s) | 95 | | Llama2 7B (4096), gf4 | 246 tok/s (815 GB/s) | 158 tok/s (857 GB/s) | 96 | | Llama2 13B (4096), fp8 | 70 tok/s (910 GB/s) | 56 tok/s (907 GB/s) | 97 | | Llama2 13B (4096), gf4 | 131 tok/s (848 GB/s) | 88 tok/s (863 GB/s) | 98 | | Mistral 7B (4096), fp16 | 65 tok/s (925 GB/s) | 62 tok/s (916 GB/s) | 99 | | Mistral 7B (4096), fp8 | 127 tok/s (902 GB/s) | 116 tok/s (888 GB/s) | 100 | | Mistral 7B (4096), gf4 | 237 tok/s (843 GB/s) | 203 tok/s (832 GB/s) | 101 | | Mixtral 8x7B (4096), gf4 | 137 tok/s (875 GB/s) | 125 tok/s (862 GB/s) | 102 | | Mixtral 8x7B (16384), gf4 | 137 tok/s (879 GB/s) | 105 tok/s (781 GB/s) | 103 | | Yi 34B (4096), gf4 | 52 tok/s (884 GB/s) | 47 tok/s (851 GB/s) | 104 | 105 | RTX 4090 has a peak bandwidth of ~1008 GB/s, however it's unclear if a peak higher than ~950 GB/s is attainable in practice[^3]. 106 | 107 | `calm` can run on A100/H100 accelerators (but is mostly tuned for H100 `fp8` weights). When using Mixtral 8x7B (fp8) on 1xH100 SXM, it runs at ~200 tok/s (2550 GB/s) for 256-token outputs. 108 | 109 | ### Apple 110 | 111 | When using Apple Silicon (Metal), `calm` gets the following performance; each model is measured with `fp16`, `fp8` and `gf4` weights at the beginning of the context window (first 32 tokens) and at the end (last 32 tokens with an offset 2000 for 2048 contexts, 4000 for 4096 contexts and 16000 for 16384 contexts): 112 | 113 | | Chip | Model (context) | Performance (first 32) | Performance (last 32) | 114 | | ----- | ----------- | ----------- | ----------- | 115 | | M4 (120 GB/s) | Llama3 8B (4096), fp8 | 13 tok/s (101 GB/s) | 13 tok/s (102 GB/s) | 116 | | M4 (120 GB/s) | Llama3 8B (4096), gf4 | 26 tok/s (99 GB/s) | 24 tok/s (99 GB/s) | 117 | | M2 (100 GB/s) | Llama3 8B (4096), fp8 | 12 tok/s (90 GB/s) | 11 tok/s (89 GB/s) | 118 | | M2 (100 GB/s) | Llama3 8B (4096), gf4 | 23 tok/s (89 GB/s) | 20 tok/s (85 GB/s) | 119 | | M2 Pro (200 GB/s) | Llama3 8B (4096), fp8 | 24 tok/s (180 GB/s) | 21 tok/s (172 GB/s) | 120 | | M2 Pro (200 GB/s) | Llama3 8B (4096), gf4 | 45 tok/s (169 GB/s) | 36 tok/s (157 GB/s) | 121 | | M1 Max (400 GB/s) | Llama3 8B (4096), fp8 | 44 tok/s (332 GB/s) | 38 tok/s (306 GB/s) | 122 | | M1 Max (400 GB/s) | Llama3 8B (4096), gf4 | 73 tok/s (274 GB/s) | 58 tok/s (248 GB/s) | 123 | 124 | Note: on higher end chips `calm` currently doesn't reach peak performance; some of this is due to limitations of the chips in other areas, and some is due to the author not having hardware access to the high end models to profile and optimize for. Hardware donations are welcome ;) 125 | 126 | [^1]: CUDA runtime and compiler is used for GPU acceleration, but no CUDA or C libraries are used. Python conversion scripts use safetensors and torch, see `tools/requirements.txt`. 127 | [^2]: Linux is the main supported OS at the moment; calm also works on macOS (on CPU) and has experimental Metal support. 128 | [^3]: Based on testing a specific Gigabyte GeForce RTX 4090 where both individual kernels from this repository and cuBLAS peak at about ~955 GB/s. 129 | -------------------------------------------------------------------------------- /src/infer.m: -------------------------------------------------------------------------------- 1 | #include "model.h" 2 | 3 | #include 4 | 5 | extern unsigned char infer_metallib[]; 6 | extern unsigned int infer_metallib_len; 7 | 8 | static id device; 9 | static id queue; 10 | static id kernels[256]; 11 | static const char* kernel_names[256]; 12 | 13 | static MTLCaptureManager* capture; 14 | 15 | static void dispatch2(id encoder, const char* name, const char* variant, unsigned int threadgroups_x, unsigned int threadgroups_y, unsigned int threadgroup_size, unsigned int threadgroup_smem, void* params, size_t params_size, void** buffers, size_t buffer_count) { 16 | char expected[256]; 17 | strcpy(expected, name); 18 | if (variant) { 19 | strcat(expected, "_"); 20 | strcat(expected, variant); 21 | } 22 | 23 | id state = nil; 24 | for (size_t i = 0; kernels[i]; ++i) { 25 | if (strcmp(kernel_names[i], expected) == 0) { 26 | state = kernels[i]; 27 | break; 28 | } 29 | } 30 | assert(state); 31 | assert(state.maxTotalThreadsPerThreadgroup >= threadgroup_size); 32 | 33 | static const NSUInteger offsets[16] = {}; 34 | 35 | [encoder setComputePipelineState:state]; 36 | [encoder setBytes:params length:params_size atIndex:0]; 37 | [encoder setBuffers:(const id*)buffers offsets:offsets withRange:NSMakeRange(1, buffer_count)]; 38 | [encoder setThreadgroupMemoryLength:threadgroup_smem atIndex:0]; 39 | [encoder dispatchThreadgroups:MTLSizeMake(threadgroups_x, threadgroups_y, 1) threadsPerThreadgroup:MTLSizeMake(threadgroup_size, 1, 1)]; 40 | } 41 | 42 | static void dispatch(id encoder, const char* name, const char* variant, unsigned int threadgroups, unsigned int threadgroup_size, unsigned int threadgroup_smem, void* params, size_t params_size, void** buffers, size_t buffer_count) { 43 | dispatch2(encoder, name, variant, threadgroups, 1, threadgroup_size, threadgroup_smem, params, params_size, buffers, buffer_count); 44 | } 45 | 46 | void init_metal(void) { 47 | NSArray>* devices = MTLCopyAllDevices(); 48 | assert(devices.count > 0); 49 | 50 | device = devices[0]; 51 | queue = [device newCommandQueue]; 52 | 53 | dispatch_data_t lib_data = dispatch_data_create(infer_metallib, infer_metallib_len, dispatch_get_main_queue(), ^{ 54 | }); 55 | 56 | NSError* error = nil; 57 | id library = [device newLibraryWithData:lib_data error:&error]; 58 | assert(library); 59 | 60 | NSArray* functions = library.functionNames; 61 | for (size_t i = 0; i < functions.count; i++) { 62 | id function = [library newFunctionWithName:functions[i]]; 63 | MTLComputePipelineDescriptor* desc = [[MTLComputePipelineDescriptor alloc] init]; 64 | desc.computeFunction = function; 65 | desc.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES; 66 | id state = [device newComputePipelineStateWithDescriptor:desc options:MTLPipelineOptionNone reflection:nil error:&error]; 67 | assert(state); 68 | kernels[i] = state; 69 | kernel_names[i] = [functions[i] UTF8String]; 70 | } 71 | } 72 | 73 | void* upload_metal(void* host, size_t size) { 74 | id buffer = [device newBufferWithBytes:host length:size options:MTLResourceStorageModeShared]; 75 | assert(buffer); 76 | return buffer; 77 | } 78 | 79 | static void* newbuffer(size_t size) { 80 | if (size == 0) return nil; 81 | id buffer = [device newBufferWithLength:size options:MTLResourceStorageModeShared]; 82 | assert(buffer); 83 | return buffer; 84 | } 85 | 86 | void prepare_metal(struct Transformer* transformer) { 87 | struct Config* config = &transformer->config; 88 | struct Weights* weights = &transformer->weights; 89 | struct RunState* state = &transformer->state; 90 | 91 | assert(device); 92 | printf("# Metal: %s, %.1f GiB\n", device.name.UTF8String, (double)device.recommendedMaxWorkingSetSize / (1024 * 1024 * 1024)); 93 | 94 | int dim = config->dim; 95 | int hidden_dim = config->hidden_dim; 96 | int q_dim = config->head_dim * config->n_heads; 97 | int kv_dim = config->head_dim * config->n_kv_heads; 98 | 99 | state->x = (float*)newbuffer(dim * sizeof(float)); 100 | state->xb = (float*)newbuffer(dim * sizeof(float)); 101 | state->hb = (float*)newbuffer(hidden_dim * sizeof(float)); 102 | state->he = (float*)newbuffer(config->n_experts_ac * hidden_dim * sizeof(float)); 103 | state->q = (float*)newbuffer(q_dim * sizeof(float)); 104 | state->att = (float*)newbuffer(config->n_heads * config->seq_len * sizeof(float)); 105 | state->exp = (float*)newbuffer(((config->n_experts_ac ? config->n_experts_ac : 1) * 2) * sizeof(float)); 106 | 107 | assert(state->kvbits == 8 || state->kvbits == 16); 108 | state->key_cache = newbuffer((size_t)config->n_layers * config->seq_len * kv_dim * (state->kvbits / 8)); 109 | state->value_cache = newbuffer((size_t)config->n_layers * config->seq_len * kv_dim * (state->kvbits / 8)); 110 | 111 | // logits are going to be read by the host so we just allocate them in host and write to host directly 112 | state->logits = (float*)newbuffer(config->vocab_size * sizeof(float)); 113 | 114 | float* bqkv = (float*)newbuffer((q_dim + kv_dim * 2) * sizeof(float)); 115 | 116 | for (int l = 0; l < config->n_layers; ++l) { 117 | if (weights->bqkv[l] == NULL) { 118 | weights->bqkv[l] = bqkv; 119 | } 120 | } 121 | 122 | if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { 123 | fprintf(stderr, "# Warning: allocated size %.0f MB exceeds recommended maximum %.0f MB; try running `sudo sysctl iogpu.wired_limit_mb=`\n", 124 | device.currentAllocatedSize / 1024.f / 1024.f, device.recommendedMaxWorkingSetSize / 1024.f / 1024.f); 125 | } 126 | 127 | if (config->n_experts == 0) { 128 | // setup expert buffer to always point to the first (and only) expert 129 | float* moe = [(id)state->exp contents]; 130 | moe[0] = 1.0f; 131 | moe[1] = 0.0f; 132 | } 133 | 134 | if (weights->dbits == 4) { 135 | id commands = [queue commandBufferWithUnretainedReferences]; 136 | id encoder = [commands computeCommandEncoder]; 137 | 138 | dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->token_embedding_table}, 1); 139 | 140 | for (int l = 0; l < config->n_layers; ++l) { 141 | dispatch(encoder, "prepare_gf4", NULL, q_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wq[l]}, 1); 142 | dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wk[l]}, 1); 143 | dispatch(encoder, "prepare_gf4", NULL, kv_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wv[l]}, 1); 144 | dispatch(encoder, "prepare_gf4", NULL, dim * q_dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wo[l]}, 1); 145 | 146 | int n_experts = config->n_experts ? config->n_experts : 1; 147 | 148 | dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->w1[l]}, 1); 149 | dispatch(encoder, "prepare_gf4", NULL, n_experts * dim * hidden_dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->w2[l]}, 1); 150 | dispatch(encoder, "prepare_gf4", NULL, n_experts * hidden_dim * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->w3[l]}, 1); 151 | 152 | if (weights->moegate[l]) { 153 | dispatch(encoder, "prepare_gf4", NULL, config->n_experts * dim / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->moegate[l]}, 1); 154 | } 155 | } 156 | 157 | if (weights->wcls != weights->token_embedding_table) { 158 | dispatch(encoder, "prepare_gf4", NULL, dim * config->vocab_size / 256, 32, 0, (int[]){0}, sizeof(int), (void*[]){weights->wcls}, 1); 159 | } 160 | 161 | [encoder endEncoding]; 162 | [commands commit]; 163 | [commands waitUntilCompleted]; 164 | } 165 | 166 | const char* capenv = getenv("MTL_CAPTURE_ENABLED"); 167 | if (capenv && atoi(capenv)) { 168 | capture = [MTLCaptureManager sharedCaptureManager]; 169 | assert(capture); 170 | 171 | NSString* path = @"calm.gputrace"; 172 | 173 | MTLCaptureDescriptor* desc = [[MTLCaptureDescriptor alloc] init]; 174 | desc.captureObject = queue; 175 | desc.destination = MTLCaptureDestinationGPUTraceDocument; 176 | desc.outputURL = [NSURL fileURLWithPath:path]; 177 | 178 | NSError* error = nil; 179 | [[NSFileManager defaultManager] removeItemAtPath:path error:&error]; 180 | 181 | BOOL started = [capture startCaptureWithDescriptor:desc error:&error]; 182 | assert(started); 183 | 184 | NSLog(@"Capturing first token to %@", desc.outputURL); 185 | } 186 | } 187 | 188 | struct SinkArgs { 189 | int kv_dim; 190 | int head_dim; 191 | int rotary_dim; 192 | 193 | int kv_sink; 194 | int seq_len; 195 | 196 | float theta_log2; 197 | }; 198 | 199 | struct NormArgs { 200 | int size; 201 | float eps; 202 | bool ln; 203 | }; 204 | 205 | struct QkvArgs { 206 | int dim; 207 | int q_dim; 208 | int kv_dim; 209 | int head_dim; 210 | int rotary_dim; 211 | 212 | int pos; 213 | int kv_pos; 214 | int seq_len; 215 | 216 | size_t loff; 217 | 218 | float qkv_clip; 219 | float theta_log2; 220 | }; 221 | 222 | struct AttnArgs { 223 | int seq_len; 224 | int kv_len; 225 | int head_dim; 226 | int kv_mul; 227 | int n_heads; 228 | 229 | size_t loff; 230 | }; 231 | 232 | float* forward_metal(struct Transformer* transformer, int token, int pos, unsigned flags) { 233 | struct Config* p = &transformer->config; 234 | struct Weights* w = &transformer->weights; 235 | struct RunState* s = &transformer->state; 236 | 237 | // a few convenience variables 238 | float* x = s->x; 239 | int dim = p->dim; 240 | int hidden_dim = p->hidden_dim; 241 | int kv_dim = p->head_dim * p->n_kv_heads; 242 | int q_dim = p->head_dim * p->n_heads; 243 | int kv_mul = p->n_heads / p->n_kv_heads; 244 | assert(s->kvbits == 16); // TODO 245 | 246 | const char* dvar = w->dbits == 16 ? "half" : (w->dbits == 8 ? "fp8" : (w->dbits == 4 ? "gf4" : "?")); 247 | const char* kvar = "half"; 248 | const char* kmvar = w->dbits == 16 ? "half_float" : "half_half"; 249 | const char* nvar = w->dbits == 16 ? "float" : "half"; 250 | 251 | char dkvar[32]; 252 | snprintf(dkvar, sizeof(dkvar), "%s_%s", dvar, kvar); 253 | 254 | // following "attention sinks" from StreamingLLM we keep the first few tokens in the KV cache as is 255 | int kv_sink = pos >= p->seq_len ? KV_SINKS : 0; 256 | int kv_pos = kv_sink + (pos - kv_sink) % (p->seq_len - kv_sink); 257 | int kv_len = pos >= p->seq_len ? p->seq_len : pos + 1; 258 | 259 | // ensure all dimensions are warp-aligned 260 | assert(dim % 32 == 0 && kv_dim % 32 == 0 && hidden_dim % 32 == 0); 261 | 262 | const int matmul_par = 4; 263 | 264 | // begin command recording 265 | id commands = [queue commandBufferWithUnretainedReferences]; 266 | id encoder = [commands computeCommandEncoder]; 267 | 268 | // copy the token embedding into x 269 | assert(token < p->vocab_size); 270 | dispatch(encoder, "embed", dvar, dim / 32, 32, 0, (int[]){token * dim}, sizeof(int), (void*[]){x, w->token_embedding_table}, 2); 271 | 272 | // rotate sink tokens forward to keep pace with non-sink tokens 273 | if (kv_sink > 0) { 274 | dispatch(encoder, "rotate_sink", kvar, (kv_sink * kv_dim / 64) * p->n_layers, 32, 0, &(struct SinkArgs){kv_dim, p->head_dim, p->rotary_dim, kv_sink, p->seq_len, log2(p->rope_theta)}, sizeof(struct SinkArgs), (void*[]){s->key_cache}, 1); 275 | } 276 | 277 | // forward all the layers 278 | for (int l = 0; l < p->n_layers; ++l) { 279 | size_t loff = (size_t)l * p->seq_len * kv_dim; // kv cache layer offset for convenience 280 | 281 | // pre-attention rmsnorm 282 | dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs){dim, p->norm_eps, p->norm_ln}, sizeof(struct NormArgs), (void*[]){s->xb, x, w->rms_att_weight[l]}, 3); 283 | 284 | // qkv 285 | dispatch(encoder, "qkv", dkvar, (q_dim + kv_dim * 2) / 2 / matmul_par, 32 * matmul_par, 0, &(struct QkvArgs){dim, q_dim, kv_dim, p->head_dim, p->rotary_dim, pos, kv_pos, p->seq_len, loff, p->qkv_clip, log2(p->rope_theta)}, sizeof(struct QkvArgs), (void*[]){s->xb, s->q, s->key_cache, s->value_cache, w->wq[l], w->wk[l], w->wv[l], w->bqkv[l]}, 8); 286 | 287 | // attn score 288 | int kv_lent = (kv_len + 7) / 8; 289 | 290 | dispatch(encoder, "attn_score", kvar, kv_lent * p->n_heads, 32, 0, &(struct AttnArgs){p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff}, sizeof(struct AttnArgs), (void*[]){s->att, s->q, s->key_cache}, 3); 291 | 292 | // attn softmax 293 | dispatch(encoder, "attn_softmax", NULL, p->n_heads, 1024, 0, &(struct AttnArgs){p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff}, sizeof(struct AttnArgs), (void*[]){s->att}, 1); 294 | 295 | // attn mix 296 | dispatch(encoder, "attn_mix", kmvar, q_dim, 32, 0, &(struct AttnArgs){p->seq_len, kv_len, p->head_dim, kv_mul, p->n_heads, loff}, sizeof(struct AttnArgs), (void*[]){s->q, s->att, s->value_cache}, 3); 297 | 298 | // attn out 299 | dispatch(encoder, "attn_out", dvar, dim / matmul_par, 32 * matmul_par, 0, (int[]){q_dim}, sizeof(int), (void*[]){x, s->q, w->wo[l]}, 3); 300 | 301 | if (!p->norm_par) { 302 | // post-attention rmsnorm 303 | dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs){dim, p->norm_eps, p->norm_ln}, sizeof(struct NormArgs), (void*[]){s->xb, x, w->rms_ffn_weight[l]}, 3); 304 | } 305 | 306 | // moe gate 307 | if (p->n_experts) { 308 | dispatch(encoder, "moe_gate", dvar, 1, p->n_experts * 32, 0, (int[]){dim, p->n_experts, p->n_experts_ac}, sizeof(int) * 3, (void*[]){s->exp, s->xb, w->moegate[l]}, 3); 309 | } 310 | 311 | // ffn 312 | float* hb = p->n_experts ? s->he : s->hb; 313 | int n_experts_ac = p->n_experts_ac ? p->n_experts_ac : 1; 314 | 315 | dispatch2(encoder, p->act_gelu ? "ffn13_gelu" : "ffn13_silu", dvar, hidden_dim / matmul_par, n_experts_ac, 32 * matmul_par, 0, (int[]){dim, hidden_dim}, sizeof(int) * 2, (void*[]){hb, s->xb, s->exp, w->w1[l], w->w3[l]}, 5); 316 | dispatch2(encoder, "ffn2", dvar, dim / matmul_par, n_experts_ac, 32 * matmul_par, 0, (int[]){hidden_dim, dim}, sizeof(int) * 2, (void*[]){x, hb, s->exp, w->w2[l]}, 4); 317 | } 318 | 319 | // classifier into logits 320 | if ((flags & FF_UPDATE_KV_ONLY) == 0) { 321 | dispatch(encoder, "rmsnorm", nvar, 1, 1024, 0, &(struct NormArgs){dim, p->norm_eps, p->norm_ln}, sizeof(struct NormArgs), (void*[]){s->xb, x, w->rms_final_weight}, 3); 322 | dispatch(encoder, "output", dvar, p->vocab_size / matmul_par, 32 * matmul_par, 0, (int[]){dim}, sizeof(int), (void*[]){s->logits, s->xb, w->wcls}, 3); 323 | } 324 | 325 | // submit commands and wait 326 | [encoder endEncoding]; 327 | [commands commit]; 328 | [commands waitUntilCompleted]; 329 | 330 | if (commands.status != MTLCommandBufferStatusCompleted) { 331 | NSError* error = commands.error; 332 | fprintf(stderr, "Metal error %ld during command execution: %s\n", error.code, error.localizedDescription.UTF8String); 333 | abort(); 334 | } 335 | 336 | if (capture) { 337 | [capture stopCapture]; 338 | capture = nil; 339 | } 340 | 341 | if (flags & FF_UPDATE_KV_ONLY) { 342 | // only update kv cache and don't output logits 343 | return NULL; 344 | } 345 | 346 | return [(id)s->logits contents]; 347 | } -------------------------------------------------------------------------------- /src/infer.c: -------------------------------------------------------------------------------- 1 | #include "model.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #ifdef _OPENMP 11 | #include 12 | #endif 13 | 14 | #if defined(__AVX2__) && defined(__F16C__) 15 | #include 16 | #endif 17 | 18 | // we only support CPU inference when the compiler supports _Float16 type natively 19 | #if defined(__FLT16_MANT_DIG__) 20 | typedef _Float16 half; 21 | #else 22 | typedef short half; 23 | #endif 24 | 25 | // we only support fp16 kv cache by default; this can be changed to float with a recompile 26 | typedef half kvtype_t; 27 | 28 | inline half fp82half(unsigned char v) { 29 | union { 30 | unsigned short u; 31 | half f; 32 | } u; 33 | u.u = v << 8; 34 | return u.f; 35 | } 36 | 37 | inline float gf4_ff(uint32_t v, int k) { 38 | float s = fp82half(v & 0xff) / -4.f; // we expect compiler to reuse this across multiple calls 39 | return ((int)((v >> (8 + k * 3)) & 7) - 4) * s; 40 | } 41 | 42 | typedef float (*dotprod_t)(void* w, int n, int i, float* x); 43 | 44 | static float dotprod_fp16(void* w, int n, int i, float* x) { 45 | half* r = (half*)w + i * n; 46 | #if defined(__AVX2__) && defined(__F16C__) 47 | assert(n % 16 == 0); 48 | __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); 49 | for (int j = 0; j < n; j += 16) { 50 | __m256i rw = _mm256_loadu_si256((__m256i*)&r[j]); 51 | __m128i rlo = _mm256_castsi256_si128(rw); 52 | __m128i rhi = _mm256_extractf128_si256(rw, 1); 53 | __m256 x0 = _mm256_loadu_ps(&x[j]); 54 | __m256 x1 = _mm256_loadu_ps(&x[j + 8]); 55 | acc0 = _mm256_add_ps(_mm256_mul_ps(x0, _mm256_cvtph_ps(rlo)), acc0); 56 | acc1 = _mm256_add_ps(_mm256_mul_ps(x1, _mm256_cvtph_ps(rhi)), acc1); 57 | } 58 | __m256 acc8 = _mm256_add_ps(acc0, acc1); 59 | __m128 acc4 = _mm_add_ps(_mm256_castps256_ps128(acc8), _mm256_extractf128_ps(acc8, 1)); 60 | __m128 accf = _mm_dp_ps(acc4, _mm_set1_ps(1.0f), 0xf1); 61 | return _mm_cvtss_f32(accf); 62 | #else 63 | float val = 0.0f; 64 | #pragma omp simd reduction(+ : val) simdlen(32) 65 | for (int j = 0; j < n; j++) { 66 | val += r[j] * x[j]; 67 | } 68 | return val; 69 | #endif 70 | } 71 | 72 | static float dotprod_fp8(void* w, int n, int i, float* x) { 73 | char* r = (char*)w + i * n; 74 | #if defined(__AVX2__) && defined(__F16C__) 75 | assert(n % 16 == 0); 76 | __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); 77 | for (int j = 0; j < n; j += 16) { 78 | __m128i rw = _mm_loadu_si128((__m128i*)&r[j]); 79 | __m128i rlo = _mm_unpacklo_epi8(_mm_setzero_si128(), rw); 80 | __m128i rhi = _mm_unpackhi_epi8(_mm_setzero_si128(), rw); 81 | __m256 x0 = _mm256_loadu_ps(&x[j]); 82 | __m256 x1 = _mm256_loadu_ps(&x[j + 8]); 83 | acc0 = _mm256_add_ps(_mm256_mul_ps(x0, _mm256_cvtph_ps(rlo)), acc0); 84 | acc1 = _mm256_add_ps(_mm256_mul_ps(x1, _mm256_cvtph_ps(rhi)), acc1); 85 | } 86 | __m256 acc8 = _mm256_add_ps(acc0, acc1); 87 | __m128 acc4 = _mm_add_ps(_mm256_castps256_ps128(acc8), _mm256_extractf128_ps(acc8, 1)); 88 | __m128 accf = _mm_dp_ps(acc4, _mm_set1_ps(1.0f), 0xf1); 89 | return _mm_cvtss_f32(accf); 90 | #else 91 | float val = 0.0f; 92 | #pragma omp simd reduction(+ : val) simdlen(32) 93 | for (int j = 0; j < n; j++) { 94 | val += fp82half(r[j]) * x[j]; 95 | } 96 | return val; 97 | #endif 98 | } 99 | 100 | static float dotprod_gf4(void* w, int n, int i, float* x) { 101 | uint32_t* r = (uint32_t*)w + i * n / 8; 102 | #if defined(__AVX2__) && defined(__F16C__) 103 | assert(n % 32 == 0); 104 | __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); 105 | for (int j = 0; j < n; j += 32) { 106 | __m128i wg = _mm_loadu_si128((__m128i*)&r[j / 8]); 107 | const __m128i wgfm = _mm_setr_epi8(-1, 0, -1, 4, -1, 8, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1); 108 | __m128 wgf = _mm_cvtph_ps(_mm_shuffle_epi8(wg, wgfm)); // note: scale 1/-4.f is baked into wgtab below 109 | __m256 x0 = _mm256_loadu_ps(&x[j]); 110 | __m256 x1 = _mm256_loadu_ps(&x[j + 8]); 111 | __m256 x2 = _mm256_loadu_ps(&x[j + 16]); 112 | __m256 x3 = _mm256_loadu_ps(&x[j + 24]); 113 | __m256i wgp = _mm256_broadcastsi128_si256(wg); 114 | __m256 wgfp = _mm256_castsi256_ps(_mm256_broadcastsi128_si256(_mm_castps_si128(wgf))); 115 | const __m256i wgbits = _mm256_setr_epi32(8, 11, 14, 17, 20, 23, 26, 29); 116 | const __m256 wgtab = _mm256_setr_ps(-4 / -4.f, -3 / -4.f, -2 / -4.f, -1 / -4.f, 0 / -4.f, 1 / -4.f, 2 / -4.f, 3 / -4.f); 117 | __m256 w0 = _mm256_permutevar8x32_ps(wgtab, _mm256_srlv_epi32(_mm256_shuffle_epi32(wgp, 0x00), wgbits)); 118 | __m256 w1 = _mm256_permutevar8x32_ps(wgtab, _mm256_srlv_epi32(_mm256_shuffle_epi32(wgp, 0x55), wgbits)); 119 | __m256 w2 = _mm256_permutevar8x32_ps(wgtab, _mm256_srlv_epi32(_mm256_shuffle_epi32(wgp, 0xaa), wgbits)); 120 | __m256 w3 = _mm256_permutevar8x32_ps(wgtab, _mm256_srlv_epi32(_mm256_shuffle_epi32(wgp, 0xff), wgbits)); 121 | acc0 = _mm256_add_ps(_mm256_mul_ps(w0, _mm256_mul_ps(x0, _mm256_shuffle_ps(wgfp, wgfp, 0x00))), acc0); 122 | acc1 = _mm256_add_ps(_mm256_mul_ps(w1, _mm256_mul_ps(x1, _mm256_shuffle_ps(wgfp, wgfp, 0x55))), acc1); 123 | acc0 = _mm256_add_ps(_mm256_mul_ps(w2, _mm256_mul_ps(x2, _mm256_shuffle_ps(wgfp, wgfp, 0xaa))), acc0); 124 | acc1 = _mm256_add_ps(_mm256_mul_ps(w3, _mm256_mul_ps(x3, _mm256_shuffle_ps(wgfp, wgfp, 0xff))), acc1); 125 | } 126 | __m256 acc8 = _mm256_add_ps(acc0, acc1); 127 | __m128 acc4 = _mm_add_ps(_mm256_castps256_ps128(acc8), _mm256_extractf128_ps(acc8, 1)); 128 | __m128 accf = _mm_dp_ps(acc4, _mm_set1_ps(1.0f), 0xf1); 129 | return _mm_cvtss_f32(accf); 130 | #else 131 | float val = 0.0f; 132 | for (int j = 0; j < n; j += 8) { 133 | uint32_t wg = r[j / 8]; 134 | for (int k = 0; k < 8; ++k) { 135 | val += gf4_ff(wg, k) * x[j + k]; 136 | } 137 | } 138 | return val; 139 | #endif 140 | } 141 | 142 | void prepare(struct Transformer* transformer) { 143 | struct Config* p = &transformer->config; 144 | struct RunState* s = &transformer->state; 145 | 146 | int q_dim = p->head_dim * p->n_heads; 147 | int kv_dim = p->head_dim * p->n_kv_heads; 148 | 149 | // we calloc instead of malloc to keep valgrind happy 150 | s->x = calloc(p->dim, sizeof(float)); 151 | s->xb = calloc(p->dim, sizeof(float)); 152 | s->xb2 = calloc(p->dim, sizeof(float)); 153 | s->hb = calloc(p->hidden_dim, sizeof(float)); 154 | s->hb2 = calloc(p->hidden_dim, sizeof(float)); 155 | s->q = calloc(q_dim, sizeof(float)); 156 | s->k = calloc(kv_dim, sizeof(float)); 157 | s->v = calloc(kv_dim, sizeof(float)); 158 | s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); 159 | s->exp = calloc(p->n_experts + (p->n_experts_ac ? p->n_experts_ac : 1) * 2, sizeof(float)); 160 | s->logits = calloc(p->vocab_size, sizeof(float)); 161 | assert(s->kvbits == sizeof(kvtype_t) * 8); 162 | s->key_cache = calloc((size_t)p->n_layers * p->seq_len * kv_dim, sizeof(kvtype_t)); 163 | s->value_cache = calloc((size_t)p->n_layers * p->seq_len * kv_dim, sizeof(kvtype_t)); 164 | 165 | // ensure all mallocs went fine 166 | if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->key_cache || !s->value_cache || !s->att || !s->logits) { 167 | fprintf(stderr, "malloc failed!\n"); 168 | abort(); 169 | } 170 | 171 | #if defined(_OPENMP) && defined(__linux__) 172 | // avoid SMT overhead by default 173 | if (getenv("OMP_NUM_THREADS") == NULL) { 174 | omp_set_num_threads(omp_get_num_procs() / 2); 175 | } 176 | #endif 177 | 178 | #if !defined(__FLT16_MANT_DIG__) 179 | assert(!"_Float16 compiler support is required for CPU backend\n"); 180 | #endif 181 | } 182 | 183 | static void rmsnorm(float* o, float* x, float* weight, int size, float eps, bool ln) { 184 | // calculate mean 185 | float mean = 0.0f; 186 | 187 | if (ln) { 188 | for (int j = 0; j < size; j++) { 189 | mean += x[j]; 190 | } 191 | mean /= size; 192 | } 193 | 194 | // calculate sum of squared deltas 195 | float ss = 0.0f; 196 | for (int j = 0; j < size; j++) { 197 | ss += (x[j] - mean) * (x[j] - mean); 198 | } 199 | 200 | float var = ss / size; 201 | 202 | // normalize and scale 203 | float scale = 1.0f / sqrtf(var + eps); 204 | for (int j = 0; j < size; j++) { 205 | o[j] = (x[j] - mean) * scale * weight[j]; 206 | } 207 | } 208 | 209 | static void matmul(float* xout, float* x, void* w, float* b, int n, int d, dotprod_t dotprod) { 210 | // W (d,n) @ x (n,) -> xout (d,) 211 | // by far the most amount of time is spent inside this little function 212 | int i; 213 | #pragma omp parallel for private(i) 214 | for (i = 0; i < d; i++) { 215 | float val = dotprod(w, n, i, x); 216 | if (b) { 217 | val += b[i]; 218 | } 219 | xout[i] = val; 220 | } 221 | } 222 | 223 | static void rope(float* vec, int d, int head_dim, int pos, float theta, int rotary_dim) { 224 | for (int i = 0; i < d; i += 2) { 225 | int j_head = i % head_dim; 226 | float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim); 227 | float val = pos * freq; 228 | float fcr = cosf(val); 229 | float fci = sinf(val); 230 | 231 | float v0 = vec[i]; 232 | float v1 = vec[i + 1]; 233 | vec[i] = v0 * fcr - v1 * fci; 234 | vec[i + 1] = v0 * fci + v1 * fcr; 235 | } 236 | } 237 | 238 | static void attn(float* xout, float* atth, float* qh, kvtype_t* kh, kvtype_t* vh, int head_dim, int kv_dim, int kv_len) { 239 | float score_max = -FLT_MAX; 240 | 241 | // calculate attention scores as dot products of q and k; also track score max for this head 242 | for (int t = 0; t < kv_len; ++t) { 243 | float score = 0.0f; 244 | for (int j = 0; j < head_dim; ++j) { 245 | score += qh[j] * kh[t * kv_dim + j]; 246 | } 247 | score /= sqrtf(head_dim); 248 | score_max = (score_max < score) ? score : score_max; 249 | atth[t] = score; 250 | } 251 | 252 | // softmax the scores to get attention weights over [0..kv_len) 253 | float score_sum = 0.f; 254 | for (int t = 0; t < kv_len; ++t) { 255 | atth[t] = expf(atth[t] - score_max); 256 | score_sum += atth[t]; 257 | } 258 | 259 | // mix values with attention weights 260 | for (int j = 0; j < head_dim; ++j) { 261 | float res = 0.f; 262 | for (int t = 0; t < kv_len; ++t) { 263 | res += (atth[t] / score_sum) * vh[t * kv_dim + j]; 264 | } 265 | xout[j] = res; 266 | } 267 | } 268 | 269 | inline float gelu(float x) { 270 | return 0.5f * x * (1.0f + tanhf(0.797885f * (x + 0.044715f * x * x * x))); 271 | } 272 | 273 | inline float silu(float x) { 274 | return x / (1.0f + expf(-x)); 275 | } 276 | 277 | static void moe_gate(float* moe_weights, int* moe_experts, float* x, int d, int active) { 278 | // softmax across experts 279 | float max_val = -FLT_MAX; 280 | for (int j = 0; j < d; ++j) { 281 | max_val = (max_val < x[j]) ? x[j] : max_val; 282 | } 283 | 284 | // top k 285 | uint64_t mask = 0; 286 | float wsum = 0.0f; 287 | 288 | for (int k = 0; k < active; ++k) { 289 | int best = -1; 290 | for (int j = 0; j < d; ++j) { 291 | if ((mask & (1ull << j)) == 0 && (best == -1 || x[j] > x[best])) { 292 | best = j; 293 | } 294 | } 295 | 296 | moe_experts[k] = best; 297 | wsum += expf(x[moe_experts[k]] - max_val); 298 | mask |= 1ull << best; 299 | } 300 | 301 | // top k weights, normalized 302 | for (int k = 0; k < active; ++k) { 303 | moe_weights[k] = expf(x[moe_experts[k]] - max_val) / wsum; 304 | } 305 | } 306 | 307 | inline float clip(float x, float v) { 308 | return x < -v ? -v : (x > v ? v : x); 309 | } 310 | 311 | float* forward(struct Transformer* transformer, int token, int pos, unsigned flags) { 312 | if (transformer->weights.dbits != 4 && transformer->weights.dbits != 8 && transformer->weights.dbits != 16) { 313 | assert(!"Unsupported dbits: must be 8 or 16 for CPU"); 314 | } 315 | 316 | dotprod_t dotprod = transformer->weights.dbits == 4 ? dotprod_gf4 : (transformer->weights.dbits == 8 ? dotprod_fp8 : dotprod_fp16); 317 | 318 | // a few convenience variables 319 | struct Config* p = &transformer->config; 320 | struct Weights* w = &transformer->weights; 321 | struct RunState* s = &transformer->state; 322 | float* x = s->x; 323 | int dim = p->dim; 324 | int hidden_dim = p->hidden_dim; 325 | int q_dim = p->head_dim * p->n_heads; 326 | int kv_dim = p->head_dim * p->n_kv_heads; 327 | int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery 328 | 329 | // following "attention sinks" from StreamingLLM we keep the first few tokens in the KV cache as is 330 | int kv_sink = pos >= p->seq_len ? KV_SINKS : 0; 331 | int kv_pos = kv_sink + (pos - kv_sink) % (p->seq_len - kv_sink); 332 | int kv_len = pos >= p->seq_len ? p->seq_len : pos + 1; 333 | 334 | // copy the token embedding into x 335 | char* content_row = (char*)w->token_embedding_table + token * dim * (size_t)w->dbits / 8; 336 | if (w->dbits == 4) { 337 | for (int i = 0; i < dim; i += 8) { 338 | uint32_t wg = ((uint32_t*)content_row)[i / 8]; 339 | for (int k = 0; k < 8; ++k) { 340 | x[i + k] = gf4_ff(wg, k); 341 | } 342 | } 343 | } else { 344 | for (int i = 0; i < dim; ++i) { 345 | x[i] = w->dbits == 8 ? fp82half(content_row[i]) : ((half*)content_row)[i]; 346 | } 347 | } 348 | 349 | // forward all the layers 350 | for (int l = 0; l < p->n_layers; l++) { 351 | // attention rmsnorm 352 | rmsnorm(s->xb, x, w->rms_att_weight[l], dim, p->norm_eps, p->norm_ln); 353 | 354 | // key and value point to the kv cache 355 | size_t loff = (size_t)l * p->seq_len * kv_dim; // kv cache layer offset for convenience 356 | kvtype_t* kb = (kvtype_t*)s->key_cache + loff; 357 | kvtype_t* vb = (kvtype_t*)s->value_cache + loff; 358 | 359 | // qkv matmuls for this position 360 | matmul(s->q, s->xb, w->wq[l], w->bqkv[l], dim, q_dim, dotprod); 361 | matmul(s->k, s->xb, w->wk[l], w->bqkv[l] ? w->bqkv[l] + q_dim : NULL, dim, kv_dim, dotprod); 362 | matmul(s->v, s->xb, w->wv[l], w->bqkv[l] ? w->bqkv[l] + q_dim + kv_dim : NULL, dim, kv_dim, dotprod); 363 | 364 | // some models require clipping qkv values 365 | for (int i = 0; i < q_dim; i++) { 366 | s->q[i] = clip(s->q[i], p->qkv_clip); 367 | } 368 | for (int i = 0; i < kv_dim; i++) { 369 | s->k[i] = clip(s->k[i], p->qkv_clip); 370 | s->v[i] = clip(s->v[i], p->qkv_clip); 371 | } 372 | 373 | // RoPE relative positional encoding: complex-valued rotate q and k in each head 374 | rope(s->q, q_dim, p->head_dim, pos, p->rope_theta, p->rotary_dim); 375 | rope(s->k, kv_dim, p->head_dim, pos, p->rope_theta, p->rotary_dim); 376 | 377 | // update kv cache 378 | for (int i = 0; i < kv_dim; i++) { 379 | kb[kv_pos * kv_dim + i] = s->k[i]; 380 | vb[kv_pos * kv_dim + i] = s->v[i]; 381 | } 382 | 383 | // rotate sink tokens forward to keep pace with non-sink tokens 384 | for (int r = 0; r < kv_sink; r++) { 385 | for (int i = 0; i < kv_dim; i++) { 386 | s->k[i] = kb[r * kv_dim + i]; 387 | } 388 | 389 | rope(s->k, kv_dim, p->head_dim, 1, p->rope_theta, p->rotary_dim); 390 | 391 | for (int i = 0; i < kv_dim; i++) { 392 | kb[r * kv_dim + i] = s->k[i]; 393 | } 394 | } 395 | 396 | // multihead attention. iterate over all heads 397 | int h; 398 | #pragma omp parallel for private(h) 399 | for (h = 0; h < p->n_heads; h++) { 400 | float* qh = s->q + h * p->head_dim; 401 | float* atth = s->att + h * p->seq_len; 402 | kvtype_t* kh = kb + (h / kv_mul) * p->head_dim; 403 | kvtype_t* vh = vb + (h / kv_mul) * p->head_dim; 404 | 405 | attn(s->xb2 + h * p->head_dim, atth, qh, kh, vh, p->head_dim, kv_dim, kv_len); 406 | } 407 | 408 | // final matmul to get the output of the attention 409 | // TODO: we're using hb as a temporary storage, hacky 410 | matmul(s->hb, s->xb2, w->wo[l], NULL, q_dim, dim, dotprod); 411 | 412 | // residual connection back into x 413 | for (int i = 0; i < dim; i++) { 414 | x[i] += s->hb[i]; 415 | } 416 | 417 | if (!p->norm_par) { 418 | // ffn rmsnorm 419 | rmsnorm(s->xb, x, w->rms_ffn_weight[l], dim, p->norm_eps, p->norm_ln); 420 | } 421 | 422 | float* moe_weights = s->exp + p->n_experts; 423 | int* moe_experts = (int*)moe_weights + (p->n_experts_ac ? p->n_experts_ac : 1); 424 | 425 | if (p->n_experts) { 426 | // moe gate 427 | matmul(s->exp, s->xb, w->moegate[l], NULL, dim, p->n_experts, dotprod); 428 | moe_gate(moe_weights, moe_experts, s->exp, p->n_experts, p->n_experts_ac); 429 | } else { 430 | moe_weights[0] = 1.0f; 431 | moe_experts[0] = 0; 432 | } 433 | 434 | // mix self.w2(F.silu(self.w1(x)) * self.w3(x)) 435 | for (int e = 0; e < (p->n_experts_ac ? p->n_experts_ac : 1); ++e) { 436 | size_t esize = dim * hidden_dim * (size_t)w->dbits / 8; 437 | matmul(s->hb, s->xb, (char*)w->w1[l] + moe_experts[e] * esize, NULL, dim, hidden_dim, dotprod); 438 | matmul(s->hb2, s->xb, (char*)w->w3[l] + moe_experts[e] * esize, NULL, dim, hidden_dim, dotprod); 439 | 440 | if (p->act_gelu) { 441 | // GEGLU non-linearity 442 | for (int i = 0; i < hidden_dim; i++) { 443 | s->hb[i] = gelu(s->hb[i]) * s->hb2[i]; 444 | } 445 | } else { 446 | // SwiGLU non-linearity 447 | for (int i = 0; i < hidden_dim; i++) { 448 | s->hb[i] = silu(s->hb[i]) * s->hb2[i]; 449 | } 450 | } 451 | 452 | matmul(s->xb2, s->hb, (char*)w->w2[l] + moe_experts[e] * esize, NULL, hidden_dim, dim, dotprod); 453 | 454 | for (int i = 0; i < dim; i++) { 455 | x[i] += s->xb2[i] * moe_weights[e]; 456 | } 457 | } 458 | } 459 | 460 | if (flags & FF_UPDATE_KV_ONLY) { 461 | // only update kv cache and don't output logits 462 | return NULL; 463 | } 464 | 465 | // final rmsnorm 466 | rmsnorm(x, x, w->rms_final_weight, dim, p->norm_eps, p->norm_ln); 467 | 468 | // classifier into logits 469 | matmul(s->logits, x, w->wcls, NULL, p->dim, p->vocab_size, dotprod); 470 | 471 | return s->logits; 472 | } 473 | -------------------------------------------------------------------------------- /src/infer.metal: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | using namespace metal; 5 | 6 | template struct ActT { typedef float type; }; 7 | template <> struct ActT { typedef half type; }; 8 | template <> struct ActT { typedef half type; }; 9 | template using Act = typename ActT::type; 10 | 11 | static constant int warpSize = 32; 12 | 13 | inline float blockreduce_sum(threadgroup float* vs, float val, uint id) { 14 | val = simd_sum(val); 15 | 16 | vs[id / warpSize] = val; 17 | threadgroup_barrier(mem_flags::mem_threadgroup); 18 | 19 | return simd_sum(vs[id % warpSize]); 20 | } 21 | 22 | inline float blockreduce_max(threadgroup float* vs, float val, uint id) { 23 | val = simd_max(val); 24 | 25 | vs[id / warpSize] = val; 26 | threadgroup_barrier(mem_flags::mem_threadgroup); 27 | 28 | return simd_max(vs[id % warpSize]); 29 | } 30 | 31 | inline half gf4_ff(uint32_t v, int k) { 32 | half s = as_type(uint16_t(v << 8)) * half(-0.25f); // we expect compiler to reuse this across multiple calls 33 | return half((int((v >> (8 + k * 3)) & 7) ^ 4) - 4) * s; 34 | } 35 | 36 | inline float matmul_warppar(device float* x, device half* w, int i, int n, uint id) { 37 | int lane = id % warpSize; 38 | float val = 0.0f; 39 | for (int j = lane * 2; j < n; j += warpSize * 2) { 40 | float2 ww = float2(*(device half2*)&w[i * n + j]); 41 | float2 xx = *(device float2*)&x[j]; 42 | val += ww.x * xx.x; 43 | val += ww.y * xx.y; 44 | } 45 | return simd_sum(val); 46 | } 47 | 48 | template 49 | inline float matmul_warppar(device AT* x, device uint8_t* w, int i, int n, uint id) { 50 | typedef __attribute__((__ext_vector_type__(4))) AT AT4; 51 | 52 | int lane = id % warpSize; 53 | float val = 0.0f; 54 | for (int j = lane * 8; j < n; j += warpSize * 8) { 55 | uint2 wwp = *(device uint2*)&w[i * n + j]; 56 | AT4 xxp[2] = {*(device AT4*)&x[j], *(device AT4*)&x[j + 4]}; 57 | for (int k = 0; k < 2; ++k) { 58 | half2 wwe = as_type(wwp[k] & 0xff00ff00); 59 | half2 wwo = as_type((wwp[k] << 8) & 0xff00ff00); 60 | val += wwo.x * xxp[k].x; 61 | val += wwe.x * xxp[k].y; 62 | val += wwo.y * xxp[k].z; 63 | val += wwe.y * xxp[k].w; 64 | } 65 | } 66 | return simd_sum(val); 67 | } 68 | 69 | template 70 | inline float matmul_warppar(device AT* x, device uint32_t* w, int i, int n, uint id) { 71 | typedef __attribute__((__ext_vector_type__(4))) AT AT4; 72 | 73 | int lane = id % warpSize; 74 | float val = 0.0f; 75 | for (int j = lane * 8; j < n; j += warpSize * 8) { 76 | uint32_t wg = w[i * n / 8 + j / 8]; 77 | AT4 xx0 = *(device AT4*)&x[j]; 78 | AT4 xx1 = *(device AT4*)&x[j + 4]; 79 | 80 | int wgi = ((wg & 0xfff00000) | ((wg >> 4) & 0xfff0)); 81 | 82 | float us = as_type(uint16_t(wg << 8)); 83 | float s = us * -0.25f * exp2(-13.f); 84 | 85 | float acc = 0; 86 | for (int k = 0; k < 4; ++k) { 87 | int wgk = wgi << (9 - k * 3); 88 | if (k != 0) wgk &= 0xe000e000; 89 | short2 wgkp = as_type(wgk); 90 | acc += float(wgkp.x) * xx0[k]; 91 | acc += float(wgkp.y) * xx1[k]; 92 | } 93 | val += acc * s; 94 | } 95 | return simd_sum(val); 96 | } 97 | 98 | kernel void prepare_gf4(constant int& n [[buffer(0)]], device uint32_t* data [[buffer(1)]], uint id [[thread_position_in_grid]]) { 99 | uint32_t wg = data[id]; 100 | wg ^= 0x92492400; 101 | data[id] = wg; 102 | } 103 | 104 | inline float gelu(float x) { 105 | return 0.5f * x * (1.0f + precise::tanh(0.797885f * (x + 0.044715f * x * x * x))); 106 | } 107 | 108 | inline float silu(float x) { 109 | return x / (1.0f + exp(-x)); 110 | } 111 | 112 | inline float embed(device half* w, int i) { 113 | return w[i]; 114 | } 115 | 116 | inline float embed(device uint8_t* w, int i) { 117 | return as_type(uint16_t(w[i] << 8)); 118 | } 119 | 120 | inline float embed(device uint32_t* w, int i) { 121 | return gf4_ff(w[i / 8], i % 8); 122 | } 123 | 124 | template 125 | kernel void kernel_embed(constant int& token_offset [[buffer(0)]], device float* o [[buffer(1)]], device T* weight [[buffer(2)]], uint id [[thread_position_in_grid]]) { 126 | o[id] = embed(weight, id + token_offset); 127 | } 128 | 129 | template [[host_name("embed_half")]] kernel void kernel_embed(constant int&, device float*, device half*, uint); 130 | template [[host_name("embed_fp8")]] kernel void kernel_embed(constant int&, device float*, device uint8_t*, uint); 131 | template [[host_name("embed_gf4")]] kernel void kernel_embed(constant int&, device float*, device uint32_t*, uint); 132 | 133 | struct SinkArgs { 134 | int kv_dim; 135 | int head_dim; 136 | int rotary_dim; 137 | 138 | int kv_sink; 139 | int seq_len; 140 | 141 | float theta_log2; 142 | }; 143 | 144 | template 145 | kernel void kernel_rotate_sink(constant SinkArgs& args [[buffer(0)]], device KVT* keyc [[buffer(1)]], uint id [[thread_position_in_grid]]) { 146 | int i = (id * 2) % (args.kv_sink * args.kv_dim); 147 | int l = id / (args.kv_sink * args.kv_dim / 2); 148 | 149 | int j_head = i % args.head_dim; 150 | float freq = j_head >= args.rotary_dim ? 0.f : exp2(-args.theta_log2 * (float)j_head / (float)args.rotary_dim); 151 | 152 | // rotate sink tokens forward to keep pace with non-sink tokens 153 | float fcr; 154 | float fci = sincos(freq, fcr); 155 | 156 | size_t loff = (size_t)l * args.seq_len * args.kv_dim; 157 | device KVT* kb = keyc + loff; 158 | 159 | // note: k layout is transposed / tiled to improve attn_score performance 160 | int t = i / args.kv_dim; 161 | int k = i % args.kv_dim; 162 | int o = t * 16 + args.seq_len * (k / 16) * 16 + (k % 16); 163 | 164 | float v0 = float(kb[o + 0]); 165 | float v1 = float(kb[o + 1]); 166 | 167 | float r0 = v0 * fcr - v1 * fci; 168 | float r1 = v0 * fci + v1 * fcr; 169 | 170 | kb[o + 0] = KVT(r0); 171 | kb[o + 1] = KVT(r1); 172 | } 173 | 174 | template [[host_name("rotate_sink_half")]] kernel void kernel_rotate_sink(constant SinkArgs&, device half*, uint); 175 | 176 | struct NormArgs { 177 | int size; 178 | float eps; 179 | bool ln; 180 | }; 181 | 182 | template 183 | kernel void kernel_rmsnorm(constant NormArgs& args [[buffer(0)]], device T* o [[buffer(1)]], device float* x [[buffer(2)]], device float* weight [[buffer(3)]], uint id [[thread_position_in_grid]]) { 184 | int i = id; 185 | const int blockSize = 1024; 186 | int size = args.size; 187 | 188 | threadgroup float vs[32]; 189 | 190 | float mean = 0.0f; 191 | if (args.ln) { 192 | // calculate sum (per thread) 193 | float sum = 0.0f; 194 | for (int j = i; j < size; j += blockSize) { 195 | sum += x[j]; 196 | } 197 | 198 | // sum across threads in block 199 | mean = blockreduce_sum(vs, sum, i) / size; 200 | } 201 | 202 | // calculate sum of squares (per thread) 203 | float ss = 0.0f; 204 | for (int j = i; j < size; j += blockSize) { 205 | float v = x[j] - mean; 206 | ss += v * v; 207 | } 208 | 209 | // sum across threads in block 210 | ss = blockreduce_sum(vs, ss, i); 211 | 212 | float scale = rsqrt(ss / size + args.eps); 213 | 214 | for (int j = i; j < size; j += blockSize) { 215 | o[j] = (x[j] - mean) * weight[j] * scale; 216 | } 217 | } 218 | 219 | template [[host_name("rmsnorm_float")]] kernel void kernel_rmsnorm(constant NormArgs&, device float*, device float*, device float*, uint); 220 | template [[host_name("rmsnorm_half")]] kernel void kernel_rmsnorm(constant NormArgs&, device half*, device float*, device float*, uint); 221 | 222 | struct QkvArgs { 223 | int dim; 224 | int q_dim; 225 | int kv_dim; 226 | int head_dim; 227 | int rotary_dim; 228 | 229 | int pos; 230 | int kv_pos; 231 | int seq_len; 232 | 233 | size_t loff; 234 | 235 | float qkv_clip; 236 | float theta_log2; 237 | }; 238 | 239 | template 240 | kernel void kernel_qkv(constant QkvArgs& args [[buffer(0)]], device Act* x [[buffer(1)]], device float* qout [[buffer(2)]], device KVT* keyc [[buffer(3)]], device KVT* valc [[buffer(4)]], device T* wq [[buffer(5)]], device T* wk [[buffer(6)]], device T* wv [[buffer(7)]], device float* bqkv [[buffer(8)]], uint id [[thread_position_in_grid]]) { 241 | int dim = args.dim; 242 | int q_dim = args.q_dim; 243 | int kv_dim = args.kv_dim; 244 | 245 | int j = (id / warpSize) * 2; 246 | device T* w = j < q_dim ? wq : (j < q_dim + kv_dim ? wk : wv); 247 | int k = j < q_dim ? j : (j < q_dim + kv_dim ? j - q_dim : j - q_dim - kv_dim); 248 | 249 | float v0 = matmul_warppar(x, w, k + 0, dim, id); 250 | float v1 = matmul_warppar(x, w, k + 1, dim, id); 251 | 252 | v0 += bqkv[j + 0]; 253 | v1 += bqkv[j + 1]; 254 | 255 | v0 = min(max(v0, -args.qkv_clip), args.qkv_clip); 256 | v1 = min(max(v1, -args.qkv_clip), args.qkv_clip); 257 | 258 | if (id % warpSize == 0) { 259 | int j_head = j % args.head_dim; 260 | float freq = j_head >= args.rotary_dim ? 0.f : exp2(-args.theta_log2 * (float)j_head / (float)args.rotary_dim); 261 | float fcr; 262 | float fci = sincos(args.pos * freq, fcr); 263 | 264 | if (j < q_dim) { 265 | qout[k + 0] = v0 * fcr - v1 * fci; 266 | qout[k + 1] = v0 * fci + v1 * fcr; 267 | } else if (j < q_dim + kv_dim) { 268 | // note: k layout is transposed / tiled to improve attn_score performance 269 | int off = args.kv_pos * 16 + args.seq_len * (k / 16) * 16 + (k % 16); 270 | keyc[args.loff + off + 0] = KVT(v0 * fcr - v1 * fci); 271 | keyc[args.loff + off + 1] = KVT(v0 * fci + v1 * fcr); 272 | } else { 273 | // note: v layout is transposed (we store all positions for a given head contiguously) to improve attn_mix performance 274 | valc[args.loff + args.kv_pos + args.seq_len * (k + 0)] = KVT(v0); 275 | valc[args.loff + args.kv_pos + args.seq_len * (k + 1)] = KVT(v1); 276 | } 277 | } 278 | } 279 | 280 | template [[host_name("qkv_half_half")]] kernel void kernel_qkv(constant QkvArgs&, device float*, device float*, device half*, device half*, device half*, device half*, device half*, device float*, uint); 281 | template [[host_name("qkv_fp8_half")]] kernel void kernel_qkv(constant QkvArgs&, device half*, device float*, device half*, device half*, device uint8_t*, device uint8_t*, device uint8_t*, device float*, uint); 282 | template [[host_name("qkv_gf4_half")]] kernel void kernel_qkv(constant QkvArgs&, device half*, device float*, device half*, device half*, device uint32_t*, device uint32_t*, device uint32_t*, device float*, uint); 283 | 284 | inline float4 attn_load4(device half* p) { 285 | return float4(*(device half4*)p); 286 | } 287 | 288 | template 289 | inline float attn_score(device KVT* kht, device float* qh, int head_dim, int seq_len, int t, int off) { 290 | float score = 0.0f; 291 | for (int j = 0; j < head_dim; j += 16) { 292 | float4 kk = attn_load4(&kht[j * seq_len + t * 16 + off]); 293 | float4 qq = *(device float4*)&qh[j + off]; 294 | score += kk.x * qq.x; 295 | score += kk.y * qq.y; 296 | score += kk.z * qq.z; 297 | score += kk.w * qq.w; 298 | } 299 | 300 | return score; 301 | } 302 | 303 | template 304 | inline float attn_warpdot(device KVT* val, device float* atth, int kv_len, uint id) { 305 | int kv_len4 = kv_len & ~3; 306 | int lane = id % warpSize; 307 | 308 | float res = 0.0f; 309 | float sum = 0.0f; 310 | for (int t = lane * 4; t < kv_len4; t += warpSize * 4) { 311 | float4 vv = attn_load4(&val[t]); 312 | float4 aa = *(device float4*)&atth[t]; 313 | res += vv.x * aa.x; 314 | res += vv.y * aa.y; 315 | res += vv.z * aa.z; 316 | res += vv.w * aa.w; 317 | sum += aa.x + aa.y + aa.z + aa.w; 318 | } 319 | 320 | if (kv_len4 + lane < kv_len) { 321 | float a = atth[kv_len4 + lane]; 322 | res += a * float(val[kv_len4 + lane]); 323 | sum += a; 324 | } 325 | 326 | res = simd_sum(res); 327 | sum = simd_sum(sum); 328 | 329 | return res / sum; 330 | } 331 | 332 | struct AttnArgs { 333 | int seq_len; 334 | int kv_len; 335 | int head_dim; 336 | int kv_mul; 337 | int n_heads; 338 | 339 | size_t loff; 340 | }; 341 | 342 | template 343 | kernel void kernel_attn_score(constant AttnArgs& args [[buffer(0)]], device float* att [[buffer(1)]], device float* q [[buffer(2)]], device KVT* keyc [[buffer(3)]], uint id [[thread_position_in_grid]]) { 344 | int j = id / warpSize; 345 | 346 | int h = j % args.n_heads; 347 | int kvh = h / args.kv_mul; 348 | int t = (j / args.n_heads) * 8 + (id % warpSize) / 4; 349 | 350 | if (t < args.kv_len) { 351 | device float* qh = q + h * args.head_dim; 352 | device KVT* kh = keyc + args.loff + kvh * args.head_dim * args.seq_len; 353 | device float* atth = att + h * args.seq_len; 354 | 355 | float score = attn_score(kh, qh, args.head_dim, args.seq_len, t, 4 * (id % 4)); 356 | 357 | // reduce score across threads in warp; every 4 threads are processing the same output score 358 | score += simd_shuffle_xor(score, 2); 359 | score += simd_shuffle_xor(score, 1); 360 | score /= sqrt(float(args.head_dim)); 361 | 362 | atth[t] = score; 363 | } 364 | } 365 | 366 | template [[host_name("attn_score_half")]] kernel void kernel_attn_score(constant AttnArgs&, device float*, device float*, device half*, uint); 367 | 368 | [[host_name("attn_softmax")]] kernel void kernel_attn_softmax(constant AttnArgs& args [[buffer(0)]], device float* att [[buffer(1)]], uint id [[thread_position_in_grid]]) { 369 | const int blockSize = 1024; 370 | int h = id / blockSize; 371 | device float* atth = att + h * args.seq_len; 372 | 373 | int i = id % blockSize; 374 | int size = args.kv_len; 375 | device float* x = atth; 376 | 377 | threadgroup float vs[32]; 378 | 379 | // find max value per thread (for numerical stability) 380 | float max_val = -FLT_MAX; 381 | for (int j = i; j < size; j += blockSize) { 382 | max_val = max(max_val, x[j]); 383 | } 384 | 385 | // max across threads in block 386 | max_val = blockreduce_max(vs, max_val, i); 387 | 388 | // exp per thread 389 | for (int j = i; j < size; j += blockSize) { 390 | x[j] = exp(x[j] - max_val); 391 | } 392 | } 393 | 394 | template 395 | kernel void kernel_attn_mix(constant AttnArgs& args [[buffer(0)]], device AT* qout [[buffer(1)]], device float* att [[buffer(2)]], device KVT* valc [[buffer(3)]], uint id [[thread_position_in_grid]]) { 396 | int j = id / warpSize; 397 | 398 | int h = j / args.head_dim; 399 | int kvh = h / args.kv_mul; 400 | int j_head = j % args.head_dim; 401 | 402 | device float* atth = att + h * args.seq_len; 403 | device KVT* vh = valc + args.loff + kvh * args.head_dim * args.seq_len; 404 | device KVT* val = vh + j_head * args.seq_len; 405 | 406 | float res = attn_warpdot(val, atth, args.kv_len, id); 407 | 408 | if (id % warpSize == 0) { 409 | qout[j] = res; 410 | } 411 | } 412 | 413 | template [[host_name("attn_mix_half_float")]] kernel void kernel_attn_mix(constant AttnArgs&, device float*, device float*, device half*, uint); 414 | template [[host_name("attn_mix_half_half")]] kernel void kernel_attn_mix(constant AttnArgs&, device half*, device float*, device half*, uint); 415 | 416 | template 417 | kernel void kernel_attn_out(constant int& n [[buffer(0)]], device float* xout [[buffer(1)]], device Act* x [[buffer(2)]], device T* w [[buffer(3)]], uint id [[thread_position_in_grid]]) { 418 | int j = id / warpSize; 419 | float val = matmul_warppar(x, w, j, n, id); 420 | 421 | if (id % warpSize == 0) { 422 | xout[j] += val; 423 | } 424 | } 425 | 426 | template [[host_name("attn_out_half")]] kernel void kernel_attn_out(constant int&, device float*, device float*, device half*, uint); 427 | template [[host_name("attn_out_fp8")]] kernel void kernel_attn_out(constant int&, device float*, device half*, device uint8_t*, uint); 428 | template [[host_name("attn_out_gf4")]] kernel void kernel_attn_out(constant int&, device float*, device half*, device uint32_t*, uint); 429 | 430 | inline void moe_gate_warp(device float* moebuf, threadgroup float* weights, int experts, int active, uint id) { 431 | int i = id; 432 | 433 | // (unscaled) softmax across experts 434 | float w = (i < experts) ? weights[i] : -FLT_MAX; 435 | float max_val = simd_max(w); 436 | w = exp(w - max_val); 437 | 438 | // weight in top 24 bits, index in bottom 8 439 | int wi = (as_type(w) & 0xffffff00) | i; 440 | 441 | // top k within warp 442 | float sumw = 0.f; 443 | int acti = -1; 444 | 445 | for (int k = 0; k < active; ++k) { 446 | int maxi = simd_max(wi); 447 | 448 | sumw += as_type(maxi); 449 | 450 | // keeps top weight in thread k, clears weight for thread with max thread to avoid re-selection 451 | acti = (i == k) ? maxi : acti; 452 | wi = (wi == maxi) ? 0 : wi; 453 | } 454 | 455 | // write normalized weights 456 | if (i < active) { 457 | assert(acti >= 0); 458 | 459 | moebuf[i * 2 + 0] = as_type(acti) / sumw; 460 | *(device int*)&moebuf[i * 2 + 1] = acti & 0xff; 461 | } 462 | } 463 | 464 | template 465 | kernel void kernel_moe_gate(constant int* args [[buffer(0)]], device float* moebuf [[buffer(1)]], device Act* x [[buffer(2)]], device T* w [[buffer(3)]], uint id [[thread_position_in_grid]]) { 466 | int n = args[0]; 467 | int experts = args[1]; 468 | int active = args[2]; 469 | 470 | int j = id / warpSize; 471 | float v = matmul_warppar(x, w, j, n, id); 472 | 473 | threadgroup float ws[32]; 474 | ws[j] = v; 475 | threadgroup_barrier(mem_flags::mem_threadgroup); 476 | 477 | if (id < warpSize) { 478 | moe_gate_warp(moebuf, ws, experts, active, id); 479 | } 480 | } 481 | 482 | template [[host_name("moe_gate_half")]] kernel void kernel_moe_gate(constant int*, device float*, device float*, device half*, uint); 483 | template [[host_name("moe_gate_fp8")]] kernel void kernel_moe_gate(constant int*, device float*, device half*, device uint8_t*, uint); 484 | template [[host_name("moe_gate_gf4")]] kernel void kernel_moe_gate(constant int*, device float*, device half*, device uint32_t*, uint); 485 | 486 | template 487 | kernel void kernel_ffn13(constant int* args [[buffer(0)]], device Act* xout [[buffer(1)]], device Act* x [[buffer(2)]], device float* moebuf [[buffer(3)]], device T* w1 [[buffer(4)]], device T* w3 [[buffer(5)]], uint2 id [[thread_position_in_grid]]) { 488 | int n = args[0]; 489 | int d = args[1]; 490 | 491 | int j = id.x / warpSize; 492 | int e = id.y; 493 | int je = j + ((device int*)moebuf)[e * 2 + 1] * d; 494 | 495 | float v1 = matmul_warppar(x, w1, je, n, id.x); 496 | float v3 = matmul_warppar(x, w3, je, n, id.x); 497 | 498 | if (id.x % warpSize == 0) { 499 | xout[j + e * d] = (act_gelu ? gelu(v1) : silu(v1)) * v3; 500 | } 501 | } 502 | 503 | template [[host_name("ffn13_silu_half")]] kernel void kernel_ffn13(constant int*, device float*, device float*, device float*, device half*, device half*, uint2); 504 | template [[host_name("ffn13_silu_fp8")]] kernel void kernel_ffn13(constant int*, device half*, device half*, device float*, device uint8_t*, device uint8_t*, uint2); 505 | template [[host_name("ffn13_silu_gf4")]] kernel void kernel_ffn13(constant int*, device half*, device half*, device float*, device uint32_t*, device uint32_t*, uint2); 506 | 507 | template [[host_name("ffn13_gelu_half")]] kernel void kernel_ffn13(constant int*, device float*, device float*, device float*, device half*, device half*, uint2); 508 | template [[host_name("ffn13_gelu_fp8")]] kernel void kernel_ffn13(constant int*, device half*, device half*, device float*, device uint8_t*, device uint8_t*, uint2); 509 | template [[host_name("ffn13_gelu_gf4")]] kernel void kernel_ffn13(constant int*, device half*, device half*, device float*, device uint32_t*, device uint32_t*, uint2); 510 | 511 | template 512 | kernel void kernel_ffn2(constant int* args [[buffer(0)]], device float* xout [[buffer(1)]], device Act* x [[buffer(2)]], device float* moebuf [[buffer(3)]], device T* w2 [[buffer(4)]], uint2 id [[thread_position_in_grid]]) { 513 | int n = args[0]; 514 | int d = args[1]; 515 | 516 | int j = id.x / warpSize; 517 | int e = id.y; 518 | int je = j + ((device int*)moebuf)[e * 2 + 1] * d; 519 | 520 | float val = matmul_warppar(x + e * n, w2, je, n, id.x); 521 | 522 | if (id.x % warpSize == 0) { 523 | atomic_fetch_add_explicit((volatile device atomic_float*)&xout[j], val * moebuf[e * 2 + 0], memory_order_relaxed); 524 | } 525 | } 526 | 527 | template [[host_name("ffn2_half")]] kernel void kernel_ffn2(constant int*, device float*, device float*, device float*, device half*, uint2); 528 | template [[host_name("ffn2_fp8")]] kernel void kernel_ffn2(constant int*, device float*, device half*, device float*, device uint8_t*, uint2); 529 | template [[host_name("ffn2_gf4")]] kernel void kernel_ffn2(constant int*, device float*, device half*, device float*, device uint32_t*, uint2); 530 | 531 | template 532 | kernel void kernel_output(constant int& n [[buffer(0)]], device float* xout [[buffer(1)]], device Act* x [[buffer(2)]], device T* w [[buffer(3)]], uint id [[thread_position_in_grid]]) { 533 | int j = id / warpSize; 534 | float val = matmul_warppar(x, w, j, n, id); 535 | 536 | if (id % warpSize == 0) { 537 | xout[j] = val; 538 | } 539 | } 540 | 541 | template [[host_name("output_half")]] kernel void kernel_output(constant int&, device float*, device float*, device half*, uint); 542 | template [[host_name("output_fp8")]] kernel void kernel_output(constant int&, device float*, device half*, device uint8_t*, uint); 543 | template [[host_name("output_gf4")]] kernel void kernel_output(constant int&, device float*, device half*, device uint32_t*, uint); 544 | -------------------------------------------------------------------------------- /src/run.c: -------------------------------------------------------------------------------- 1 | // Inference for Llama-2 Transformer model in pure C 2 | // Based on llama2.c by Andrej Karpathy 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "model.h" 15 | #include "sampler.h" 16 | #include "tensors.h" 17 | #include "tokenizer.h" 18 | 19 | void prepare(struct Transformer* transformer); 20 | float* forward(struct Transformer* transformer, int token, int pos, unsigned flags); 21 | 22 | void* upload_cuda(void* host, size_t size); 23 | void prepare_cuda(struct Transformer* transformer); 24 | float* forward_cuda(struct Transformer* transformer, int token, int pos, unsigned flags); 25 | void perf_cuda(void); 26 | 27 | void init_metal(void); 28 | void* upload_metal(void* host, size_t size); 29 | void prepare_metal(struct Transformer* transformer); 30 | float* forward_metal(struct Transformer* transformer, int token, int pos, unsigned flags); 31 | 32 | void get_config(struct Config* config, struct Tensors* tensors, int context) { 33 | config->dim = atoi(tensors_metadata(tensors, "dim")); 34 | config->hidden_dim = atoi(tensors_metadata(tensors, "hidden_dim")); 35 | config->n_layers = atoi(tensors_metadata(tensors, "n_layers")); 36 | config->n_heads = atoi(tensors_metadata(tensors, "n_heads")); 37 | config->n_kv_heads = atoi(tensors_metadata(tensors, "n_kv_heads")); 38 | config->vocab_size = atoi(tensors_metadata(tensors, "vocab_size")); 39 | config->head_dim = atoi(tensors_metadata(tensors, "head_dim")); 40 | 41 | // for now limit seq_len to 4096 to avoid KV cache OOM for models like Mistral since window size isn't correctly specified 42 | const char* max_seq_len = tensors_metadata_find(tensors, "max_seq_len"); 43 | config->seq_len = max_seq_len && atoi(max_seq_len) < 4096 ? atoi(max_seq_len) : 4096; 44 | 45 | if (context) { 46 | config->seq_len = context; 47 | } 48 | 49 | config->rope_theta = atof(tensors_metadata(tensors, "rope_theta")); 50 | config->rotary_dim = atoi(tensors_metadata(tensors, "rotary_dim")); 51 | 52 | if (tensors_metadata_find(tensors, "n_experts")) { 53 | config->n_experts = atoi(tensors_metadata(tensors, "n_experts")); 54 | config->n_experts_ac = atoi(tensors_metadata(tensors, "n_experts_active")); 55 | } 56 | 57 | const char* norm_eps = tensors_metadata_find(tensors, "norm_eps"); 58 | config->norm_eps = norm_eps ? atof(norm_eps) : 1e-5; 59 | 60 | const char* act_type = tensors_metadata_find(tensors, "act_type"); 61 | config->act_gelu = act_type && strcmp(act_type, "gelu") == 0; 62 | 63 | const char* norm_type = tensors_metadata_find(tensors, "norm_type"); 64 | config->norm_ln = norm_type && strncmp(norm_type, "layernorm", 9) == 0; // note: we currently don't support layernorm bias 65 | config->norm_par = norm_type && strcmp(norm_type, "layernorm_par") == 0; // note: we currently don't support layernorm bias 66 | 67 | const char* qkv_clip = tensors_metadata_find(tensors, "qkv_clip"); 68 | config->qkv_clip = qkv_clip ? atof(qkv_clip) : FLT_MAX; 69 | } 70 | 71 | void get_weights(struct Config* config, struct Weights* weights, struct Tensors* tensors) { 72 | const char* dtype = tensors_metadata(tensors, "dtype"); 73 | 74 | enum DType wtype = strcmp(dtype, "gf4") == 0 ? dt_i32 : (strcmp(dtype, "fp8") == 0 ? dt_f8e5m2 : dt_f16); 75 | int gsize = strcmp(dtype, "gf4") == 0 ? 8 : 1; 76 | 77 | weights->dbits = strcmp(dtype, "gf4") == 0 ? 4 : (strcmp(dtype, "fp8") == 0 ? 8 : 16); 78 | 79 | weights->token_embedding_table = tensors_get(tensors, "model.embed.weight", 0, wtype, (int[]){config->vocab_size, config->dim / gsize, 0, 0}); 80 | 81 | for (int l = 0; l < config->n_layers; ++l) { 82 | weights->rms_att_weight[l] = (float*)tensors_get(tensors, "model.layers.%d.attn.norm.weight", l, dt_f32, (int[]){config->dim, 0, 0, 0}); 83 | 84 | if (!config->norm_par) { 85 | weights->rms_ffn_weight[l] = (float*)tensors_get(tensors, "model.layers.%d.mlp.norm.weight", l, dt_f32, (int[]){config->dim, 0, 0, 0}); 86 | } 87 | 88 | weights->wq[l] = tensors_get(tensors, "model.layers.%d.attn.wq.weight", l, wtype, (int[]){config->n_heads * config->head_dim, config->dim / gsize, 0, 0}); 89 | weights->wk[l] = tensors_get(tensors, "model.layers.%d.attn.wk.weight", l, wtype, (int[]){config->n_kv_heads * config->head_dim, config->dim / gsize, 0, 0}); 90 | weights->wv[l] = tensors_get(tensors, "model.layers.%d.attn.wv.weight", l, wtype, (int[]){config->n_kv_heads * config->head_dim, config->dim / gsize, 0, 0}); 91 | weights->wo[l] = tensors_get(tensors, "model.layers.%d.attn.wo.weight", l, wtype, (int[]){config->dim, config->n_heads * config->head_dim / gsize, 0, 0}); 92 | 93 | if (tensors_find(tensors, "model.layers.%d.attn.wqkv.bias", l)) { 94 | weights->bqkv[l] = (float*)tensors_get(tensors, "model.layers.%d.attn.wqkv.bias", l, dt_f32, (int[]){(config->n_heads + config->n_kv_heads * 2) * config->head_dim, 0, 0, 0}); 95 | } 96 | 97 | if (config->n_experts) { 98 | weights->moegate[l] = tensors_get(tensors, "model.layers.%d.moegate.weight", l, wtype, (int[]){config->n_experts, config->dim / gsize, 0, 0}); 99 | 100 | weights->w1[l] = tensors_get(tensors, "model.layers.%d.mlp.w1.weight", l, wtype, (int[]){config->n_experts, config->hidden_dim, config->dim / gsize, 0}); 101 | weights->w2[l] = tensors_get(tensors, "model.layers.%d.mlp.w2.weight", l, wtype, (int[]){config->n_experts, config->dim, config->hidden_dim / gsize, 0}); 102 | weights->w3[l] = tensors_get(tensors, "model.layers.%d.mlp.w3.weight", l, wtype, (int[]){config->n_experts, config->hidden_dim, config->dim / gsize, 0}); 103 | } else { 104 | weights->w1[l] = tensors_get(tensors, "model.layers.%d.mlp.w1.weight", l, wtype, (int[]){config->hidden_dim, config->dim / gsize, 0, 0}); 105 | weights->w2[l] = tensors_get(tensors, "model.layers.%d.mlp.w2.weight", l, wtype, (int[]){config->dim, config->hidden_dim / gsize, 0, 0}); 106 | weights->w3[l] = tensors_get(tensors, "model.layers.%d.mlp.w3.weight", l, wtype, (int[]){config->hidden_dim, config->dim / gsize, 0, 0}); 107 | } 108 | } 109 | 110 | weights->rms_final_weight = (float*)tensors_get(tensors, "model.norm.weight", 0, dt_f32, (int[]){config->dim, 0, 0, 0}); 111 | 112 | if (tensors_find(tensors, "model.output.weight", 0) == NULL) { 113 | weights->wcls = weights->token_embedding_table; // tied weights 114 | } else { 115 | weights->wcls = tensors_get(tensors, "model.output.weight", 0, wtype, (int[]){config->vocab_size, config->dim / gsize, 0, 0}); 116 | } 117 | } 118 | 119 | void build_tokenizer(struct Tokenizer* t, struct Tensors* tensors, int vocab_size) { 120 | struct Tensor* tensor = tensors_find(tensors, "tokenizer.tokens", 0); 121 | 122 | char* tokens = (char*)tensors_get(tensors, "tokenizer.tokens", 0, dt_u8, (int[]){tensor->shape[0], 0, 0, 0}); 123 | float* scores = (float*)tensors_get(tensors, "tokenizer.scores", 0, dt_f32, (int[]){vocab_size, 0, 0, 0}); 124 | 125 | int bos_id = atoi(tensors_metadata(tensors, "bos_token_id")); 126 | int eos_id = atoi(tensors_metadata(tensors, "eos_token_id")); 127 | 128 | tokenizer_init(t, tokens, scores, bos_id, eos_id, vocab_size, tensor->shape[0]); 129 | } 130 | 131 | size_t count_bytes(struct Tensors* tensors, const char* prefix, const char* filter, size_t* out_params) { 132 | size_t bytes = 0, params = 0; 133 | for (int i = 0; i < tensors->n_tensors; ++i) { 134 | struct Tensor* tensor = &tensors->tensors[i]; 135 | if (strncmp(tensor->name, prefix, strlen(prefix)) != 0) { 136 | continue; 137 | } 138 | if (filter && strstr(tensor->name, filter) == NULL) { 139 | continue; 140 | } 141 | int elts = tensor->dtype == dt_i32 ? 8 : 1; // gsize hack for gf4 142 | for (int j = 0; j < 4 && tensor->shape[j] != 0; ++j) { 143 | elts *= tensor->shape[j]; 144 | } 145 | params += elts; 146 | bytes += tensor->size; 147 | } 148 | if (out_params) { 149 | *out_params = params; 150 | } 151 | return bytes; 152 | } 153 | 154 | long time_in_ms() { 155 | // return time in milliseconds, for benchmarking the model speed 156 | struct timespec time; 157 | clock_gettime(CLOCK_REALTIME, &time); 158 | return time.tv_sec * 1000 + time.tv_nsec / 1000000; 159 | } 160 | 161 | size_t kvcache_bandwidth(struct Config* config, int kvbits, int pos) { 162 | int kv_dim = config->head_dim * config->n_kv_heads; 163 | int kv_len = pos >= config->seq_len ? config->seq_len : pos + 1; 164 | return 2 * (size_t)(kvbits / 8) * config->n_layers * kv_dim * kv_len; 165 | } 166 | 167 | void generate(struct Transformer* transformer, struct Tokenizer* tokenizer, struct Sampler* sampler, char* prompt, int steps, int pos_offset) { 168 | char* empty_prompt = ""; 169 | if (prompt == NULL) { 170 | prompt = empty_prompt; 171 | } 172 | 173 | // encode the (string) prompt into tokens sequence 174 | int* prompt_tokens = (int*)malloc(tokenizer_bound(strlen(prompt)) * sizeof(int)); 175 | int num_prompt_tokens = tokenizer_encode(tokenizer, prompt, TF_ENCODE_BOS, prompt_tokens); 176 | if (num_prompt_tokens < 1) { 177 | fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); 178 | exit(EXIT_FAILURE); 179 | } 180 | 181 | char* tokens_env = getenv("CALM_TOKENS"); 182 | if (tokens_env && atoi(tokens_env)) { 183 | for (int i = 0; i < num_prompt_tokens; i++) { 184 | printf("[%s:%d]", tokenizer_decode(tokenizer, prompt_tokens[i], prompt_tokens[i]), prompt_tokens[i]); 185 | } 186 | printf("\n"); 187 | } 188 | 189 | // start the main loop 190 | size_t read_bytes = 0; 191 | long start = time_in_ms(); 192 | 193 | int next; // will store the next token in the sequence 194 | int token = prompt_tokens[0]; // kick off with the first token in the prompt 195 | int pos = 0; // position in the sequence 196 | 197 | // print first prompt token since it won't be decoded 198 | if (token != tokenizer->bos_id) { 199 | char* piece = tokenizer_decode(tokenizer, tokenizer->bos_id, token); 200 | printf("%s", piece); 201 | fflush(stdout); 202 | } 203 | 204 | float* logits_last = NULL; 205 | 206 | while (pos < steps || steps < 0) { 207 | // forward the transformer to get logits for the next token 208 | unsigned flags = pos < num_prompt_tokens - 1 ? FF_UPDATE_KV_ONLY : 0; 209 | float* logits = transformer->forward(transformer, token, pos + pos_offset, flags); 210 | 211 | read_bytes += transformer->n_bandwidth; 212 | read_bytes += kvcache_bandwidth(&transformer->config, transformer->state.kvbits, pos + pos_offset); 213 | logits_last = logits; 214 | 215 | // advance the state machine 216 | if (pos < num_prompt_tokens - 1) { 217 | // if we are still processing the input prompt, force the next prompt token 218 | next = prompt_tokens[pos + 1]; 219 | } else { 220 | // otherwise sample the next token from the logits 221 | next = sample(sampler, logits); 222 | assert(next >= 0); 223 | 224 | // data-dependent terminating condition: the BOS token delimits sequences, EOS token ends the sequence, EOT token ends the turn 225 | if (next == tokenizer->bos_id || next == tokenizer->eos_id || next == tokenizer->eot_id) { 226 | break; 227 | } 228 | } 229 | pos++; 230 | 231 | // print the token as string, decode it with the Tokenizer object 232 | char* piece = tokenizer_decode(tokenizer, token, next); 233 | printf("%s", piece); 234 | fflush(stdout); 235 | token = next; 236 | } 237 | printf("\n"); 238 | 239 | long end = time_in_ms(); 240 | 241 | // fold last token's logits into a hash for validation 242 | unsigned logits_hash = 0; 243 | if (logits_last) { 244 | for (int k = 0; k < transformer->config.vocab_size; ++k) { 245 | logits_hash = logits_hash * 5 + *(unsigned*)(&logits_last[k]); 246 | } 247 | } 248 | 249 | fprintf(stderr, "# %d tokens: throughput: %.2f tok/s; latency: %.2f ms/tok; bandwidth: %.2f GB/s; total %.3f sec; #%08x\n", 250 | pos, 251 | pos / (double)(end - start) * 1000, (double)(end - start) / pos, 252 | ((double)read_bytes / 1e9) / ((double)(end - start) / 1000), 253 | (double)(end - start) / 1000, logits_hash); 254 | 255 | free(prompt_tokens); 256 | } 257 | 258 | void study(struct Transformer* transformer, struct Tokenizer* tokenizer, const char* path, int steps) { 259 | int max_input_size = 64 * 1024; 260 | int max_tokens = tokenizer_bound(max_input_size); 261 | 262 | FILE* file = fopen(path, "r"); 263 | if (!file) { 264 | fprintf(stderr, "failed to open %s\n", path); 265 | exit(EXIT_FAILURE); 266 | } 267 | 268 | char* input = (char*)malloc(max_input_size + 1); 269 | size_t input_size = fread(input, 1, max_input_size, file); 270 | fclose(file); 271 | 272 | input[input_size] = '\0'; 273 | 274 | long start = time_in_ms(); 275 | 276 | int* tokens = (int*)malloc(max_tokens * sizeof(int)); 277 | int n_tokens = tokenizer_encode(tokenizer, input, TF_ENCODE_BOS, tokens); 278 | 279 | long mid = time_in_ms(); 280 | 281 | free(input); 282 | 283 | printf("# %s: %d tokens (%.3f sec), chunked with size %d\n", 284 | path, n_tokens, (double)(mid - start) / 1000, steps); 285 | 286 | int vocab_size = transformer->config.vocab_size; 287 | 288 | double sum = 0, ss = 0, den = 0; 289 | double ppl = 0, pplerr = 0; 290 | 291 | for (int i = 0; i + 1 < n_tokens; i++) { 292 | if (i != 0 && i % 1000 == 0) { 293 | printf("# progress (%d/%d): %.3f ± %.3f\n", i, n_tokens, ppl, pplerr); 294 | } 295 | 296 | int pos = steps <= 0 ? i : i % steps; 297 | float* logits = transformer->forward(transformer, tokens[i], pos, 0); 298 | double logprob = log(sample_prob(tokens[i + 1], logits, vocab_size)); 299 | 300 | // update stats for mean/std 301 | sum += logprob; 302 | ss += logprob * logprob; 303 | den += 1; 304 | 305 | // update ppl and ppl error using standard error of the mean 306 | ppl = exp(-sum / den); 307 | pplerr = ppl * sqrt((ss - sum * sum / den) / den / den); 308 | } 309 | 310 | long end = time_in_ms(); 311 | 312 | free(tokens); 313 | 314 | printf("# perplexity: %.3f ± %.3f (%.2f sec, %.2f tok/s)\n", 315 | ppl, pplerr, (double)(end - mid) / 1000, (double)(n_tokens - 1) / (double)(end - mid) * 1000); 316 | } 317 | 318 | static const char* chatframe(struct Tokenizer* tokenizer, bool has_system) { 319 | if (tokenizer_find(tokenizer, "<|eot_id|>") >= 0) { 320 | // llama3 321 | return has_system ? "<|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 322 | : "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; 323 | } else if (tokenizer_find(tokenizer, "<|im_start|>") >= 0) { 324 | // chatml 325 | return has_system ? "<|im_start|>system\n%s<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n" 326 | : "\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n"; 327 | } else if (tokenizer_find(tokenizer, "") >= 0) { 328 | // gemma 329 | return has_system ? "user\nSYSTEM: %s\n\n%s\nmodel\n" 330 | : "\nuser\n%s\nmodel\n"; 331 | } else if (tokenizer_find(tokenizer, "<|START_OF_TURN_TOKEN|>") >= 0) { 332 | // cohere 333 | return has_system ? "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>%s<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>%s<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" 334 | : "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>%s<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; 335 | } else if (tokenizer_find(tokenizer, "<|assistant|>") >= 0) { 336 | // phi3 337 | return has_system ? "<|system|>\n%s<|end|>\n<|user|>\n%s<|end|>\n<|assistant|>\n" 338 | : "\n<|user|>\n%s<|end|>\n<|assistant|>\n"; 339 | } else if (tokenizer_find(tokenizer, "<|beginofsystem|>") >= 0) { 340 | // k2 341 | return has_system ? "<|beginofsystem|>%s<|endofsystemprompt|><|beginofuser|>%s<|beginofsystem|>" 342 | : "<|beginofuser|>%s<|beginofsystem|>"; 343 | } else { 344 | // llama 345 | return has_system ? "[INST] <>\n%s\n<>\n\n%s [/INST]" : "[INST] %s [/INST]"; 346 | } 347 | } 348 | 349 | void chat(struct Transformer* transformer, struct Tokenizer* tokenizer, struct Sampler* sampler, char* cli_prompt, char* system_prompt) { 350 | char user_prompt[512]; 351 | char rendered_prompt[sizeof(user_prompt) * 2]; 352 | int prompt_tokens[sizeof(rendered_prompt) + 4]; 353 | int num_prompt_tokens = 0; 354 | 355 | int user_idx = 0; 356 | int user_turn = 1; // user starts 357 | int next = 0; // will store the next token in the sequence 358 | int token = 0; // stores the current token to feed into the transformer 359 | int pos = 0; // position in the sequence 360 | for (;;) { 361 | // when it is the user's turn to contribute tokens to the dialog... 362 | if (user_turn) { 363 | // get the user prompt 364 | if (pos == 0 && cli_prompt != NULL) { 365 | // user prompt for position 0 was passed in, use it 366 | snprintf(user_prompt, sizeof(user_prompt), "%s\n", cli_prompt); 367 | } else { 368 | // otherwise get user prompt from stdin 369 | double seq_pct = (double)pos / (double)transformer->config.seq_len; 370 | char progress[64] = {}; 371 | for (int i = 0; i < 10; ++i) { 372 | strcat(progress, seq_pct < i * 0.1 ? "░" : (seq_pct < i * 0.1 + 0.05 ? "▒" : "█")); 373 | } 374 | printf("%s \033[32mUser:\033[37m ", progress); 375 | fflush(stdout); 376 | char* x = fgets(user_prompt, sizeof(user_prompt), stdin); 377 | (void)x; 378 | } 379 | // render user/system prompts into the chat schema 380 | if (pos == 0 && system_prompt[0] != '\0') { 381 | snprintf(rendered_prompt, sizeof(rendered_prompt), chatframe(tokenizer, true), system_prompt, user_prompt); 382 | } else { 383 | snprintf(rendered_prompt, sizeof(rendered_prompt), chatframe(tokenizer, false), user_prompt); 384 | } 385 | 386 | // encode the rendered prompt into tokens 387 | num_prompt_tokens = tokenizer_encode(tokenizer, rendered_prompt, pos == 0 ? TF_ENCODE_BOS : 0, prompt_tokens); 388 | user_idx = 0; // reset the user index 389 | user_turn = 0; 390 | printf("\n\033[33mAssistant:\033[00m "); 391 | } 392 | 393 | // if we are still processing the input prompt, force the next prompt token, otherwise use the next token sampled from previous turn 394 | if (user_idx < num_prompt_tokens) { 395 | token = prompt_tokens[user_idx++]; 396 | } else { 397 | token = next; 398 | } 399 | 400 | // forward the transformer to get logits for the next token 401 | unsigned flags = user_idx < num_prompt_tokens ? FF_UPDATE_KV_ONLY : 0; 402 | float* logits = transformer->forward(transformer, token, pos, flags); 403 | pos++; 404 | 405 | if (user_idx >= num_prompt_tokens) { 406 | next = sample(sampler, logits); 407 | 408 | // EOS token ends the Assistant turn 409 | if (next == tokenizer->eos_id || next == tokenizer->eot_id) { 410 | printf("\n\n"); 411 | user_turn = 1; 412 | } else { 413 | char* piece = tokenizer_decode(tokenizer, token, next); 414 | printf("%s", piece); 415 | fflush(stdout); 416 | } 417 | } 418 | } 419 | } 420 | 421 | void error_usage() { 422 | fprintf(stderr, "Usage: run [options]\n"); 423 | fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n"); 424 | fprintf(stderr, "Options:\n"); 425 | fprintf(stderr, " -t temperature in [0,inf], default 1.0\n"); 426 | fprintf(stderr, " -p p value in min-p (cutoff) sampling in [0,1] default 0.1\n"); 427 | fprintf(stderr, " -s random seed, default time(NULL)\n"); 428 | fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len, -1 = infinite\n"); 429 | fprintf(stderr, " -r number of sequences to decode, default 1\n"); 430 | fprintf(stderr, " -c context length, default to model-specific maximum\n"); 431 | fprintf(stderr, " -i input prompt (- to read from stdin)\n"); 432 | fprintf(stderr, " -x compute perplexity for text file\n"); 433 | fprintf(stderr, " -y chat mode with a system prompt\n"); 434 | exit(EXIT_FAILURE); 435 | } 436 | 437 | int main(int argc, char* argv[]) { 438 | 439 | // default parameters 440 | char* checkpoint_path = NULL; // e.g. out/model.bin 441 | float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher 442 | float minp = 0.1f; // min-p sampling. 0.0 = off 443 | int steps = 256; // number of steps to run for 444 | int sequences = 1; // number of sequences to decode 445 | char* prompt = NULL; // prompt string 446 | char* perplexity = NULL; // text file for perplexity 447 | char* system_prompt = NULL; // chat system prompt 448 | unsigned long long rng_seed = 0; // seed rng with time by default 449 | int context = 0; // context length 450 | 451 | // poor man's C argparse so we can override the defaults above from the command line 452 | if (argc >= 2) { 453 | checkpoint_path = argv[1]; 454 | } else { 455 | error_usage(); 456 | } 457 | for (int i = 2; i < argc; i += 2) { 458 | // do some basic validation 459 | if (i + 1 >= argc) { 460 | error_usage(); 461 | } // must have arg after flag 462 | if (argv[i][0] != '-') { 463 | error_usage(); 464 | } // must start with dash 465 | if (strlen(argv[i]) != 2) { 466 | error_usage(); 467 | } // must be -x (one dash, one letter) 468 | // read in the args 469 | if (argv[i][1] == 't') { 470 | temperature = atof(argv[i + 1]); 471 | } else if (argv[i][1] == 'p') { 472 | minp = atof(argv[i + 1]); 473 | } else if (argv[i][1] == 's') { 474 | rng_seed = atoi(argv[i + 1]); 475 | } else if (argv[i][1] == 'n') { 476 | steps = atoi(argv[i + 1]); 477 | } else if (argv[i][1] == 'r') { 478 | sequences = atoi(argv[i + 1]); 479 | } else if (argv[i][1] == 'i') { 480 | prompt = argv[i + 1]; 481 | } else if (argv[i][1] == 'x') { 482 | perplexity = argv[i + 1]; 483 | } else if (argv[i][1] == 'c') { 484 | context = atoi(argv[i + 1]); 485 | } else if (argv[i][1] == 'y') { 486 | system_prompt = argv[i + 1]; 487 | } else { 488 | error_usage(); 489 | } 490 | } 491 | 492 | // parameter validation/overrides 493 | if (rng_seed <= 0) 494 | rng_seed = (unsigned int)time(NULL); 495 | 496 | if (prompt && strcmp(prompt, "-") == 0) { 497 | static char input[65536]; 498 | size_t input_size = fread(input, 1, sizeof(input) - 1, stdin); 499 | input[input_size] = '\0'; 500 | prompt = input; 501 | } 502 | 503 | #ifdef __linux__ 504 | char* cpu = getenv("CALM_CPU"); 505 | bool cuda = !cpu || atoi(cpu) == 0; 506 | #endif 507 | 508 | #ifdef __APPLE__ 509 | char* cpu = getenv("CALM_CPU"); 510 | bool metal = !cpu || atoi(cpu) == 0; 511 | #endif 512 | 513 | // read .safetensors model 514 | struct Tensors tensors = {}; 515 | if (tensors_open(&tensors, checkpoint_path) != 0) { 516 | fprintf(stderr, "failed to open %s\n", checkpoint_path); 517 | exit(EXIT_FAILURE); 518 | } 519 | 520 | // build transformer using tensors from the input model file 521 | struct Transformer transformer = {}; 522 | get_config(&transformer.config, &tensors, context); 523 | transformer.n_bytes = count_bytes(&tensors, "model.", NULL, &transformer.n_params); 524 | transformer.n_bandwidth = transformer.n_bytes - count_bytes(&tensors, "model.embed.", NULL, NULL); 525 | if (tensors_find(&tensors, "model.output.weight", 0) == NULL) { 526 | transformer.n_bandwidth += tensors_find(&tensors, "model.embed.weight", 0)->size; 527 | } 528 | if (transformer.config.n_experts) { 529 | size_t mlp = count_bytes(&tensors, "model.layers.", ".mlp.w", NULL); 530 | transformer.n_bandwidth -= mlp; 531 | transformer.n_bandwidth += mlp / transformer.config.n_experts * transformer.config.n_experts_ac; 532 | } 533 | 534 | transformer.state.kvbits = 16; 535 | 536 | #ifdef __linux__ 537 | if (cuda && transformer.config.seq_len > 4096) { 538 | transformer.state.kvbits = 8; // for now use fp8 for larger contexts automatically without explicit control 539 | } 540 | #endif 541 | 542 | printf("# %s: %.1fB params (%.1f GiB @ %.2f bpw), %d context (kvcache %.1f GiB @ fp%d)\n", 543 | checkpoint_path, 544 | (double)transformer.n_params / 1e9, (double)transformer.n_bytes / 1024 / 1024 / 1024, 545 | (double)transformer.n_bytes * 8 / (double)transformer.n_params, 546 | transformer.config.seq_len, 547 | (double)kvcache_bandwidth(&transformer.config, transformer.state.kvbits, transformer.config.seq_len - 1) / 1024 / 1024 / 1024, 548 | transformer.state.kvbits); 549 | 550 | #ifdef __linux__ 551 | // upload tensors to the GPU 552 | if (cuda) { 553 | int i; 554 | for (i = 0; i < tensors.n_tensors; ++i) { 555 | struct Tensor* tensor = &tensors.tensors[i]; 556 | if (strncmp(tensor->name, "model.", 6) == 0) { 557 | tensor->data = upload_cuda(tensor->data, tensor->size); 558 | } 559 | } 560 | } 561 | #endif 562 | 563 | #ifdef __APPLE__ 564 | // upload tensors to the GPU 565 | if (metal) { 566 | init_metal(); 567 | for (int i = 0; i < tensors.n_tensors; ++i) { 568 | struct Tensor* tensor = &tensors.tensors[i]; 569 | if (strncmp(tensor->name, "model.", 6) == 0) { 570 | tensor->data = upload_metal(tensor->data, tensor->size); 571 | } 572 | } 573 | } 574 | #endif 575 | 576 | get_weights(&transformer.config, &transformer.weights, &tensors); 577 | 578 | #ifdef __linux__ 579 | if (cuda) { 580 | prepare_cuda(&transformer); 581 | transformer.forward = forward_cuda; 582 | } 583 | #endif 584 | 585 | #ifdef __APPLE__ 586 | if (metal) { 587 | prepare_metal(&transformer); 588 | transformer.forward = forward_metal; 589 | } 590 | #endif 591 | 592 | // CPU fallback 593 | if (!transformer.forward) { 594 | prepare(&transformer); 595 | transformer.forward = forward; 596 | } 597 | 598 | // build the Tokenizer via the tokenizer .bin file 599 | struct Tokenizer tokenizer; 600 | build_tokenizer(&tokenizer, &tensors, transformer.config.vocab_size); 601 | 602 | // build the Sampler 603 | struct Sampler sampler = {transformer.config.vocab_size, rng_seed, temperature, minp}; 604 | 605 | // hack for profiling: offset pos to make sure we need to use most of kv cache 606 | char* pos_offset_env = getenv("CALM_POSO"); 607 | int pos_offset = pos_offset_env ? atoi(pos_offset_env) : 0; 608 | 609 | // do one inference as warmup 610 | // when using cpu, this makes sure tensors are loaded into memory (via mmap) 611 | // when using cuda, this makes sure all kernels are compiled and instantiated 612 | transformer.forward(&transformer, 0, pos_offset, 0); 613 | 614 | // -n 0 means use the full context length 615 | if (steps == 0) 616 | steps = transformer.config.seq_len; 617 | 618 | // run! 619 | if (perplexity) { 620 | study(&transformer, &tokenizer, perplexity, steps); 621 | } else if (system_prompt) { 622 | chat(&transformer, &tokenizer, &sampler, prompt, system_prompt); 623 | } else { 624 | for (int s = 0; s < sequences; ++s) { 625 | generate(&transformer, &tokenizer, &sampler, prompt, steps, pos_offset); 626 | } 627 | } 628 | 629 | #ifdef __linux__ 630 | if (cuda && getenv("CUDA_INJECTION64_PATH")) { 631 | perf_cuda(); 632 | } 633 | #endif 634 | 635 | // memory and file handles cleanup 636 | // TODO: free transformer.state and transformer.weights for CUDA 637 | tokenizer_free(&tokenizer); 638 | tensors_close(&tensors); 639 | return 0; 640 | } 641 | -------------------------------------------------------------------------------- /src/infer.cu: -------------------------------------------------------------------------------- 1 | #include "model.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "helpers.cuh" 11 | 12 | #define CUDA_CHECK(x) \ 13 | do { \ 14 | cudaError_t err = x; \ 15 | if (err != cudaSuccess) { \ 16 | fprintf(stderr, "CUDA error in %s at %s:%d: %s (%s=%d)\n", __FUNCTION__, __FILE__, __LINE__, \ 17 | cudaGetErrorString(err), cudaGetErrorName(err), err); \ 18 | abort(); \ 19 | } \ 20 | } while (0) 21 | 22 | #define PROF_TOKEN(bytes) ((0xCDAFull << 48) | (bytes)) 23 | 24 | template 25 | struct CoopLayer { 26 | float* rms_att_weight; 27 | T* wq; 28 | T* wk; 29 | T* wv; 30 | T* wo; 31 | float* bqkv; 32 | 33 | float* rms_ffn_weight; 34 | T* moegate; 35 | T* w1; 36 | T* w2; 37 | T* w3; 38 | }; 39 | 40 | static cudaStream_t stream; 41 | 42 | static int coopsms; 43 | 44 | static __constant__ CoopLayer cooplayers[MAX_LAYERS]; 45 | 46 | static uint64_t* coopperf; 47 | static uint64_t coopperfbw[16]; 48 | static int coopruns; 49 | 50 | static void* cuda_devicecopy(void* host, size_t size) { 51 | void* device = NULL; 52 | CUDA_CHECK(cudaMalloc(&device, size)); 53 | CUDA_CHECK(cudaMemcpyAsync(device, host, size, cudaMemcpyHostToDevice)); 54 | return device; 55 | } 56 | 57 | static void* cuda_devicealloc(size_t size) { 58 | void* ptr = NULL; 59 | CUDA_CHECK(cudaMalloc(&ptr, size)); 60 | return ptr; 61 | } 62 | 63 | static void* cuda_hostalloc(size_t size) { 64 | void* ptr = NULL; 65 | CUDA_CHECK(cudaHostAlloc(&ptr, size, 0)); 66 | return ptr; 67 | } 68 | 69 | extern "C" void* upload_cuda(void* host, size_t size) { 70 | return cuda_devicecopy(host, size); 71 | } 72 | 73 | extern "C" void prepare_cuda(struct Transformer* transformer) { 74 | struct Config* config = &transformer->config; 75 | struct Weights* weights = &transformer->weights; 76 | struct RunState* state = &transformer->state; 77 | 78 | cudaDeviceProp devprop = {}; 79 | CUDA_CHECK(cudaGetDeviceProperties(&devprop, 0)); 80 | assert(devprop.cooperativeLaunch); 81 | 82 | printf("# CUDA: %s, compute %d.%d, %d SMs, %.1f GiB, peak bandwidth %.0f GB/s (ECC %d)\n", 83 | devprop.name, devprop.major, devprop.minor, devprop.multiProcessorCount, 84 | (double)devprop.totalGlobalMem / (1024 * 1024 * 1024), 85 | (double)devprop.memoryClockRate * (devprop.memoryBusWidth / 8) * 2 / 1e6, devprop.ECCEnabled); 86 | 87 | coopsms = devprop.multiProcessorCount; 88 | 89 | if (getenv("CUDA_INJECTION64_PATH")) { 90 | coopperf = (uint64_t*)cuda_devicealloc(sizeof(uint64_t) * 16); 91 | CUDA_CHECK(cudaMemset(coopperf, 0, sizeof(uint64_t) * 16)); 92 | } 93 | 94 | CUDA_CHECK(cudaStreamCreate(&stream)); 95 | 96 | int dim = config->dim; 97 | int hidden_dim = config->hidden_dim; 98 | int q_dim = config->head_dim * config->n_heads; 99 | int kv_dim = config->head_dim * config->n_kv_heads; 100 | 101 | state->x = (float*)cuda_devicealloc(dim * sizeof(float)); 102 | state->hb = (float*)cuda_devicealloc(hidden_dim * sizeof(float)); 103 | state->he = (float*)cuda_devicealloc(config->n_experts_ac * hidden_dim * sizeof(float)); 104 | state->q = (float*)cuda_devicealloc(q_dim * sizeof(float)); 105 | state->att = (float*)cuda_devicealloc(config->n_heads * config->seq_len * 2 * sizeof(float)); 106 | 107 | assert(state->kvbits == 8 || state->kvbits == 16); 108 | state->key_cache = cuda_devicealloc((size_t)config->n_layers * config->seq_len * kv_dim * (state->kvbits / 8)); 109 | state->value_cache = cuda_devicealloc((size_t)config->n_layers * config->seq_len * kv_dim * (state->kvbits / 8)); 110 | 111 | // logits are going to be read by the host so we just allocate them in host and write to host directly 112 | state->logits = (float*)cuda_hostalloc(config->vocab_size * sizeof(float)); 113 | 114 | CoopLayer layers[MAX_LAYERS]; 115 | for (int l = 0; l < config->n_layers; ++l) { 116 | layers[l].rms_att_weight = weights->rms_att_weight[l]; 117 | layers[l].wq = weights->wq[l]; 118 | layers[l].wk = weights->wk[l]; 119 | layers[l].wv = weights->wv[l]; 120 | layers[l].wo = weights->wo[l]; 121 | layers[l].bqkv = weights->bqkv[l]; 122 | 123 | layers[l].rms_ffn_weight = weights->rms_ffn_weight[l]; 124 | layers[l].moegate = weights->moegate[l]; 125 | layers[l].w1 = weights->w1[l]; 126 | layers[l].w2 = weights->w2[l]; 127 | layers[l].w3 = weights->w3[l]; 128 | } 129 | 130 | CUDA_CHECK(cudaMemcpyToSymbol(cooplayers, layers, sizeof(layers))); 131 | } 132 | 133 | template 134 | __device__ inline float embed(T* weight, int idx) { 135 | return float(weight[idx]); 136 | } 137 | 138 | __device__ inline float embed(uint32_t* weight, int idx) { 139 | return gf4_ff(weight[idx / 8], idx % 8); 140 | } 141 | 142 | template 143 | __global__ static void kernel_embed(float* o, T* weight, int token, int n) { 144 | int i = blockIdx.x * blockDim.x + threadIdx.x; 145 | assert(i < n); 146 | 147 | o[i] = embed(weight, token * n + i); 148 | } 149 | 150 | template 151 | __global__ static void kernel_rotate_sink(uint64_t, int kvd, KVT* key_cache, int head_dim, int kv_sink, float theta_log2, int seq_len, int rotary_dim) { 152 | int i = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 153 | assert(i < kv_sink * kvd); 154 | 155 | int l = blockIdx.y; 156 | 157 | int j_head = i % head_dim; 158 | float freq = j_head >= rotary_dim ? 0.f : exp2f(-theta_log2 * (float)j_head / (float)rotary_dim); 159 | 160 | // rotate sink tokens forward to keep pace with non-sink tokens 161 | float fcr, fci; 162 | sincosf(freq, &fci, &fcr); 163 | 164 | size_t loff = (size_t)l * seq_len * kvd; 165 | KVT* kb = key_cache + loff; 166 | 167 | // note: k layout is transposed / tiled to improve attn_score performance 168 | int t = i / kvd; 169 | int k = i % kvd; 170 | int o = t * 16 + seq_len * (k / 16) * 16 + (k % 16); 171 | 172 | float v0 = float(kb[o + 0]); 173 | float v1 = float(kb[o + 1]); 174 | 175 | float r0 = v0 * fcr - v1 * fci; 176 | float r1 = v0 * fci + v1 * fcr; 177 | 178 | kb[o + 0] = KVT(r0); 179 | kb[o + 1] = KVT(r1); 180 | } 181 | 182 | __device__ inline float gelu(float x) { 183 | return 0.5f * x * (1.0f + tanhf(0.797885f * (x + 0.044715f * x * x * x))); 184 | } 185 | 186 | __device__ inline float silu(float x) { 187 | return x / (1.0f + expf(-x)); 188 | } 189 | 190 | __device__ static void moe_gate_warp(float* moe_weights, int* moe_experts, float* weights, int experts, int active) { 191 | int i = threadIdx.x; 192 | 193 | // (unscaled) softmax across experts 194 | float w = (i < experts) ? weights[i] : -FLT_MAX; 195 | float max_val = warpreduce_max(w); 196 | w = expf(w - max_val); 197 | 198 | // weight in top 24 bits, index in bottom 8 199 | int wi = (__float_as_int(w) & 0xffffff00) | i; 200 | 201 | // top k within warp 202 | float sumw = 0.f; 203 | int acti = -1; 204 | 205 | for (int k = 0; k < active; ++k) { 206 | int maxi = warpreduce_maxi(wi); 207 | 208 | sumw += __int_as_float(maxi); 209 | 210 | // keeps top weight in thread k, clears weight for thread with max thread to avoid re-selection 211 | acti = (i == k) ? maxi : acti; 212 | wi = (wi == maxi) ? 0 : wi; 213 | } 214 | 215 | // write normalized weights 216 | if (i < active) { 217 | assert(acti >= 0); 218 | 219 | moe_experts[i] = acti & 0xff; 220 | moe_weights[i] = __int_as_float(acti) / sumw; 221 | } 222 | } 223 | 224 | __device__ inline float4 attn_load4(half* p) { 225 | ablock<__half2_raw, 2> h = *(ablock<__half2_raw, 2>*)p; 226 | float2 h0 = __half22float2(h.v[0]), h1 = __half22float2(h.v[1]); 227 | return {h0.x, h0.y, h1.x, h1.y}; 228 | } 229 | 230 | __device__ inline float4 attn_load4(__nv_fp8_e5m2* p) { 231 | return fp8x4_e5m2_ff(*(__nv_fp8x4_e5m2*)p); 232 | } 233 | 234 | template 235 | __device__ inline float attn_score(KVT* kht, float* qh, int head_dim, int seq_len, int t, int off) { 236 | float score = 0.0f; 237 | for (int j = 0; j < head_dim; j += 16) { 238 | float4 kk = attn_load4(&kht[j * seq_len + t * 16 + off]); 239 | float4 qq = *(float4*)&qh[j + off]; 240 | score += kk.x * qq.x; 241 | score += kk.y * qq.y; 242 | score += kk.z * qq.z; 243 | score += kk.w * qq.w; 244 | } 245 | 246 | return score; 247 | } 248 | 249 | template 250 | __device__ inline float attn_warpdot(KVT* val, float* atth, int kv_len) { 251 | int kv_len4 = kv_len & ~3; 252 | int lane = threadIdx.x % warpSize; 253 | 254 | float res = 0.0f; 255 | float sum = 0.0f; 256 | for (int t = lane * 4; t < kv_len4; t += warpSize * 4) { 257 | float4 vv = attn_load4(&val[t]); 258 | float4 aa = *(float4*)&atth[t]; 259 | res += vv.x * aa.x; 260 | res += vv.y * aa.y; 261 | res += vv.z * aa.z; 262 | res += vv.w * aa.w; 263 | sum += aa.x + aa.y + aa.z + aa.w; 264 | } 265 | 266 | if (kv_len4 + lane < kv_len) { 267 | float a = atth[kv_len4 + lane]; 268 | res += a * float(val[kv_len4 + lane]); 269 | sum += a; 270 | } 271 | 272 | res = warpreduce_sum(res); 273 | sum = warpreduce_sum(sum); 274 | 275 | return res / sum; 276 | } 277 | 278 | __device__ static void softmax(float* xout, float* x, int size) { 279 | int i = threadIdx.x; 280 | 281 | // find max value per thread (for numerical stability) 282 | float max_val = -FLT_MAX; 283 | for (int j = i; j < size; j += blockDim.x) { 284 | max_val = max(max_val, x[j]); 285 | } 286 | 287 | // max across threads in block 288 | max_val = blockreduce_max(max_val); 289 | 290 | // exp per thread 291 | for (int j = i; j < size; j += blockDim.x) { 292 | xout[j] = expf(x[j] - max_val); 293 | } 294 | } 295 | 296 | template 297 | __device__ static float rmsnorm(T* o, float* x, float* weight, int size, float eps, bool ln) { 298 | int i = threadIdx.x; 299 | int blockSize = blockDim.x; 300 | 301 | float mean = 0.0f; 302 | if (ln) { 303 | // calculate sum (per thread) 304 | float sum = 0.0f; 305 | for (int j = i; j < size; j += blockSize) { 306 | sum += x[j]; 307 | } 308 | 309 | // sum across threads in block 310 | mean = blockreduce_sum(sum) / size; 311 | } 312 | 313 | // calculate sum of squares (per thread) 314 | float ss = 0.0f; 315 | for (int j = i * 2; j < size; j += blockSize * 2) { 316 | float2 xx = *(float2*)&x[j]; 317 | float2 ww = *(float2*)&weight[j]; 318 | float v0 = xx.x - mean; 319 | float v1 = xx.y - mean; 320 | ss += v0 * v0; 321 | ss += v1 * v1; 322 | *(ablock*)&o[j] = { v0 * ww.x, v1 * ww.y }; 323 | } 324 | 325 | // sum across threads in block 326 | ss = blockreduce_sum(ss); 327 | 328 | // caller is responsible for normalization 329 | return rsqrtf(ss / size + eps); 330 | } 331 | 332 | __device__ static void syncgrid() { 333 | volatile unsigned int* barrier = &cooperative_groups::details::get_grid_workspace()->barrier; 334 | 335 | if (threadIdx.x == 0) { 336 | unsigned int nb = 1; 337 | if (blockIdx.x == 0) { 338 | nb = 0x80000000 - (gridDim.x - 1); 339 | } 340 | 341 | unsigned int old_arrive; 342 | asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(old_arrive) : _CG_ASM_PTR_CONSTRAINT(barrier), "r"(nb) : "memory"); 343 | 344 | unsigned int current_arrive; 345 | do { 346 | asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT(barrier) : "memory"); 347 | } while (((old_arrive ^ current_arrive) & 0x80000000) == 0); 348 | } 349 | 350 | __syncthreads(); 351 | } 352 | 353 | template 354 | struct CoopArgs { 355 | uint64_t bw; 356 | uint64_t* perfstats; 357 | 358 | float* x; 359 | float* hb; 360 | float* q; 361 | float* att; 362 | 363 | KVT* key_cache; 364 | KVT* val_cache; 365 | 366 | int n_layers; 367 | 368 | int dim; 369 | int hidden_dim; 370 | int head_dim; 371 | int n_heads; 372 | int n_kv_heads; 373 | int n_experts; 374 | int n_experts_ac; 375 | int seq_len; 376 | int rotary_dim; 377 | 378 | bool norm_ln; 379 | bool act_gelu; 380 | 381 | int kv_len; 382 | int kv_pos; 383 | int pos; 384 | 385 | float norm_eps; 386 | float theta_log2; 387 | float qkv_clip; 388 | }; 389 | 390 | __device__ static void coopstage(uint64_t* stats, int stage) { 391 | __shared__ uint64_t lastt; 392 | 393 | if (stats && blockIdx.x == 0 && threadIdx.x == 0) { 394 | uint64_t t; 395 | asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t)); 396 | 397 | if (stage >= 0) { 398 | stats[stage] += t - lastt; 399 | } 400 | lastt = t; 401 | } 402 | } 403 | 404 | template 405 | __global__ __launch_bounds__(1024, 1) static void kernel_forward(const __grid_constant__ CoopArgs args) { 406 | extern __shared__ char smem[]; 407 | __shared__ float rmsscale; 408 | 409 | __shared__ float moe_weights[32]; 410 | __shared__ int moe_experts[32]; 411 | 412 | AT* xs = (AT*)smem; 413 | 414 | int dim = args.dim; 415 | int hidden_dim = args.hidden_dim; 416 | int head_dim = args.head_dim; 417 | 418 | int kv_mul = args.n_heads / args.n_kv_heads; 419 | int q_dim = args.head_dim * args.n_heads; 420 | int kv_dim = args.head_dim * args.n_kv_heads; 421 | 422 | const int IK = 4; // K consecutive warps per block, groups of K are interleaved across SMs for better work distribution 423 | int io = blockIdx.x * IK + (threadIdx.x / warpSize % IK) + gridDim.x * IK * (threadIdx.x / warpSize / IK); 424 | int ib = (gridDim.x * blockDim.x) / warpSize; 425 | 426 | // dummy moe weights for non-moe models; will be overwritten by moe gate 427 | moe_weights[0] = 1.f; 428 | moe_experts[0] = 0; 429 | 430 | coopstage(args.perfstats, -1); // init timing 431 | 432 | static __device__ int badsoftmax = 0; 433 | 434 | for (int l = 0; l < args.n_layers; ++l) { 435 | const CoopLayer* L = (const CoopLayer*)&cooplayers[l]; 436 | 437 | if (blockIdx.x == 0 && threadIdx.x < warpSize) { 438 | badsoftmax = 0; 439 | } 440 | 441 | // pre-attention rmsnorm (into shared memory) 442 | rmsscale = rmsnorm(xs, args.x, L->rms_att_weight, dim, args.norm_eps, args.norm_ln); 443 | 444 | size_t loff = (size_t)l * args.seq_len * kv_dim; // kv cache layer offset for convenience 445 | KVT* keyb = args.key_cache + loff; 446 | KVT* valb = args.val_cache + loff; 447 | 448 | // qkv matmul + RoPE encoding + update KV cache 449 | for (int j = io * 2; j < q_dim + kv_dim * 2; j += ib * 2) { 450 | T* w = j < q_dim ? L->wq : (j < q_dim + kv_dim ? L->wk : L->wv); 451 | int k = j < q_dim ? j : (j < q_dim + kv_dim ? j - q_dim : j - q_dim - kv_dim); 452 | 453 | float v0 = matmul_warppar(xs, w, k + 0, dim) * rmsscale; 454 | float v1 = matmul_warppar(xs, w, k + 1, dim) * rmsscale; 455 | 456 | if (L->bqkv) { 457 | v0 += L->bqkv[j + 0]; 458 | v1 += L->bqkv[j + 1]; 459 | } 460 | 461 | v0 = min(max(v0, -args.qkv_clip), args.qkv_clip); 462 | v1 = min(max(v1, -args.qkv_clip), args.qkv_clip); 463 | 464 | if (threadIdx.x % warpSize == 0) { 465 | int j_head = j % head_dim; 466 | float freq = j_head >= args.rotary_dim ? 0.f : exp2f(-args.theta_log2 * (float)j_head / (float)args.rotary_dim); 467 | float fcr, fci; 468 | sincosf(args.pos * freq, &fci, &fcr); 469 | 470 | if (j < q_dim) { 471 | args.q[k + 0] = v0 * fcr - v1 * fci; 472 | args.q[k + 1] = v0 * fci + v1 * fcr; 473 | } else if (j < q_dim + kv_dim) { 474 | // note: k layout is transposed / tiled to improve attn_score performance 475 | int off = args.kv_pos * 16 + args.seq_len * (k / 16) * 16 + (k % 16); 476 | keyb[off + 0] = KVT(v0 * fcr - v1 * fci); 477 | keyb[off + 1] = KVT(v0 * fci + v1 * fcr); 478 | } else { 479 | // note: v layout is transposed (we store all positions for a given head contiguously) to improve attn_mix performance 480 | valb[args.kv_pos + args.seq_len * (k + 0)] = KVT(v0); 481 | valb[args.kv_pos + args.seq_len * (k + 1)] = KVT(v1); 482 | } 483 | } 484 | } 485 | 486 | __syncthreads(); // TODO: unclear why this is needed for determinism 487 | syncgrid(); 488 | coopstage(args.perfstats, 0); 489 | 490 | // attention score 491 | int kv_lent = (args.kv_len + 7) / 8; 492 | 493 | for (int j = io; j < kv_lent * args.n_heads; j += ib) { 494 | int h = j % args.n_heads; 495 | int kvh = h / kv_mul; 496 | int t = (j / args.n_heads) * 8 + (threadIdx.x % warpSize) / 4; 497 | 498 | unsigned active = __ballot_sync(0xffffffff, t < args.kv_len); 499 | 500 | if (t < args.kv_len) { 501 | float* qh = args.q + h * head_dim; 502 | KVT* kh = keyb + kvh * head_dim * args.seq_len; 503 | float* atth = args.att + h * args.seq_len * 2; 504 | 505 | float score = attn_score(kh, qh, head_dim, args.seq_len, t, 4 * (threadIdx.x % 4)); 506 | 507 | // reduce score across threads in warp; every 4 threads are processing the same output score 508 | score += __shfl_xor_sync(active, score, 2); 509 | score += __shfl_xor_sync(active, score, 1); 510 | score /= sqrtf(head_dim); 511 | 512 | atth[t] = expf(score); 513 | atth[t + args.seq_len] = score; 514 | 515 | // to reduce latency we prefer computing softmax without the numeric stabilization, which is safe if all inputs are small 516 | if (fabsf(score) > 40) { 517 | badsoftmax = 1; 518 | } 519 | } 520 | } 521 | 522 | syncgrid(); 523 | coopstage(args.perfstats, 1); 524 | 525 | if (badsoftmax) { 526 | // attention softmax 527 | if (blockIdx.x < args.n_heads) { 528 | int h = blockIdx.x; 529 | float* atth = args.att + h * args.seq_len * 2; 530 | 531 | softmax(atth, atth + args.seq_len, args.kv_len); 532 | } 533 | 534 | syncgrid(); 535 | coopstage(args.perfstats, 2); 536 | } 537 | 538 | // attention mix 539 | for (int j = io; j < q_dim; j += ib) { 540 | int h = j / head_dim; 541 | int kvh = h / kv_mul; 542 | int j_head = j % head_dim; 543 | 544 | float* atth = args.att + h * args.seq_len * 2; 545 | KVT* vh = valb + kvh * head_dim * args.seq_len; 546 | KVT* val = vh + j_head * args.seq_len; 547 | 548 | float res = attn_warpdot(val, atth, args.kv_len); 549 | 550 | if (threadIdx.x % warpSize == 0) { 551 | args.q[j] = res; 552 | } 553 | } 554 | 555 | syncgrid(); 556 | coopstage(args.perfstats, 3); 557 | 558 | // attention output 559 | for (int j = io; j < dim; j += ib) { 560 | float val = matmul_warppar(args.q, L->wo, j, q_dim); 561 | 562 | if (threadIdx.x % warpSize == 0) { 563 | args.x[j] += val; 564 | } 565 | } 566 | 567 | __syncthreads(); // TODO: unclear why this is needed for determinism 568 | syncgrid(); 569 | coopstage(args.perfstats, 4); 570 | 571 | // post-attention rmsnorm (into shared memory) 572 | if (L->rms_ffn_weight) { 573 | rmsscale = rmsnorm(xs, args.x, L->rms_ffn_weight, dim, args.norm_eps, args.norm_ln); 574 | } 575 | 576 | // moegate 577 | if (args.n_experts) { 578 | __shared__ float exp[32]; 579 | int j = threadIdx.x / warpSize; 580 | 581 | if (j < args.n_experts) { 582 | float val = matmul_warppar(xs, L->moegate, j, dim) * rmsscale; 583 | 584 | exp[j] = val; 585 | } 586 | 587 | __syncthreads(); 588 | 589 | if (threadIdx.x < warpSize) { 590 | moe_gate_warp(moe_weights, moe_experts, exp, args.n_experts, args.n_experts_ac); 591 | } 592 | 593 | __syncthreads(); 594 | } 595 | 596 | // F.silu(self.w1(x)) * self.w3(x) 597 | for (int j = io; j < hidden_dim * args.n_experts_ac; j += ib) { 598 | int je = (j % hidden_dim) + moe_experts[j / hidden_dim] * hidden_dim; 599 | float v1 = matmul_warppar(xs, L->w1, je, dim) * rmsscale; 600 | float v3 = matmul_warppar(xs, L->w3, je, dim) * rmsscale; 601 | 602 | float val = (args.act_gelu ? gelu(v1) : silu(v1)) * v3; 603 | 604 | if (threadIdx.x % warpSize == 0) { 605 | args.hb[j] = val; 606 | } 607 | } 608 | 609 | syncgrid(); 610 | coopstage(args.perfstats, 5); 611 | 612 | // self.w2(...) + pre-rmsnorm residual 613 | for (int j = io; j < dim * args.n_experts_ac; j += ib) { 614 | int je = (j % dim) + moe_experts[j / dim] * dim; 615 | float val = matmul_warppar(args.hb + (j / dim) * hidden_dim, L->w2, je, hidden_dim); 616 | 617 | if (threadIdx.x % warpSize == 0) { 618 | atomicAdd(&args.x[j % dim], val * moe_weights[j / dim]); 619 | } 620 | } 621 | 622 | __syncthreads(); // TODO: unclear why this is needed for determinism 623 | syncgrid(); 624 | coopstage(args.perfstats, 6); 625 | } 626 | } 627 | 628 | template 629 | __global__ static void kernel_output(uint64_t, float* xout, float* x, T* w, float* rms_weight, int n, int d, float norm_eps, bool norm_ln) { 630 | extern __shared__ char smem[]; 631 | 632 | AT* xs = (AT*)smem; 633 | 634 | float rmsscale = rmsnorm(xs, x, rms_weight, n, norm_eps, norm_ln); 635 | 636 | int io = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize; 637 | int ib = (gridDim.x * blockDim.x) / warpSize; 638 | 639 | for (int j = io; j < d; j += ib) { 640 | float val = matmul_warppar(xs, w, j, n) * rmsscale; 641 | 642 | // instead of writing one value per block, we transpose the values and write all results from first warp 643 | val = blocktranspose(val, 0.f); 644 | 645 | if (threadIdx.x < blockDim.x / warpSize) { 646 | xout[j + threadIdx.x] = val; 647 | } 648 | } 649 | } 650 | 651 | template 652 | static float* forward(struct Transformer* transformer, int token, int pos, unsigned flags) { 653 | struct Config* p = &transformer->config; 654 | struct Weights* w = &transformer->weights; 655 | struct RunState* s = &transformer->state; 656 | 657 | // a few convenience variables 658 | float* x = s->x; 659 | int dim = p->dim; 660 | int hidden_dim = p->hidden_dim; 661 | int kv_dim = p->head_dim * p->n_kv_heads; 662 | size_t dbits = w->dbits; // size_t prevents integer overflow in multiplications below 663 | 664 | // following "attention sinks" from StreamingLLM we keep the first few tokens in the KV cache as is 665 | int kv_sink = pos >= p->seq_len ? KV_SINKS : 0; 666 | int kv_pos = kv_sink + (pos - kv_sink) % (p->seq_len - kv_sink); 667 | int kv_len = pos >= p->seq_len ? p->seq_len : pos + 1; 668 | 669 | // ensure all dimensions are warp-aligned 670 | assert(dim % 32 == 0 && kv_dim % 32 == 0 && hidden_dim % 32 == 0); 671 | 672 | // copy the token embedding into x 673 | assert(token < p->vocab_size); 674 | kernel_embed<<>>(x, (T*)w->token_embedding_table, token, dim); 675 | 676 | // rotate sink tokens forward to keep pace with non-sink tokens 677 | if (kv_sink > 0) { 678 | kernel_rotate_sink<<n_layers), 32, 0, stream>>>( 679 | PROF_TOKEN(kv_sink * kv_dim * sizeof(KVT)), kv_dim, (KVT*)s->key_cache, p->head_dim, kv_sink, log2(p->rope_theta), p->seq_len, p->rotary_dim); 680 | } 681 | 682 | // forward all the layers 683 | size_t kvbw = p->n_kv_heads * p->head_dim * kv_len * sizeof(KVT) + p->n_heads * kv_len * sizeof(float); 684 | 685 | uint64_t bw = 0; 686 | bw += p->head_dim * (p->n_heads + p->n_kv_heads * 2) * dim * dbits / 8; // QKV 687 | bw += kvbw * 2; // attn scoring and mixing 688 | bw += p->head_dim * p->n_heads * dim * dbits / 8; // attn output 689 | bw += 3 * (hidden_dim * dim * dbits / 8) * max(p->n_experts_ac, 1); // MLP 690 | bw *= p->n_layers; 691 | 692 | coopruns++; 693 | coopperfbw[0] += (size_t)p->n_layers * (p->head_dim * (p->n_heads + p->n_kv_heads * 2) * dim * dbits / 8); // QKV 694 | coopperfbw[1] += (size_t)p->n_layers * kvbw; // attn scoring 695 | coopperfbw[2] += 0; // attn softmax 696 | coopperfbw[3] += (size_t)p->n_layers * kvbw; // attn mixing 697 | coopperfbw[4] += (size_t)p->n_layers * (p->head_dim * p->n_heads * dim * dbits / 8); // attn output 698 | coopperfbw[5] += (size_t)p->n_layers * (2 * (hidden_dim * dim * dbits / 8) * max(p->n_experts_ac, 1)); // MLP 699 | coopperfbw[6] += (size_t)p->n_layers * (1 * (hidden_dim * dim * dbits / 8) * max(p->n_experts_ac, 1)); // MLP 700 | 701 | CoopArgs args = { 702 | PROF_TOKEN(bw), 703 | coopperf, 704 | // token state 705 | x, p->n_experts ? s->he : s->hb, s->q, s->att, 706 | // key/value cache; note that layers are passed via cooplayers[] 707 | (KVT*)s->key_cache, (KVT*)s->value_cache, 708 | // model dimensions 709 | p->n_layers, 710 | dim, hidden_dim, p->head_dim, 711 | p->n_heads, p->n_kv_heads, p->n_experts, max(p->n_experts_ac, 1), 712 | p->seq_len, p->rotary_dim, 713 | // model configuration 714 | p->norm_ln, p->act_gelu, 715 | // token position (and derived data) 716 | kv_len, kv_pos, pos, 717 | // model parameters 718 | p->norm_eps, log2(p->rope_theta), p->qkv_clip, 719 | }; 720 | void* argsp = &args; 721 | 722 | CUDA_CHECK(cudaLaunchCooperativeKernel((void*)kernel_forward, coopsms, 1024, &argsp, dim * sizeof(AT), stream)); 723 | 724 | if (flags & FF_UPDATE_KV_ONLY) { 725 | // only update kv cache and don't output logits 726 | return NULL; 727 | } 728 | 729 | int output_blk = 32 * 32; 730 | int output_par = 1; 731 | CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&output_par, kernel_output, output_blk, dim * sizeof(AT))); 732 | 733 | // classifier into logits 734 | kernel_output<<>>( 735 | PROF_TOKEN(p->vocab_size * dim * dbits / 8), s->logits, x, (T*)w->wcls, w->rms_final_weight, dim, p->vocab_size, p->norm_eps, p->norm_ln); 736 | 737 | CUDA_CHECK(cudaStreamSynchronize(stream)); 738 | CUDA_CHECK(cudaGetLastError()); // check for kernel launch errors; they might fail with OOM due to lazy kernel compilation 739 | 740 | return s->logits; 741 | } 742 | 743 | extern "C" float* forward_cuda(struct Transformer* transformer, int token, int pos, unsigned flags) { 744 | #define CASE(dbits_, dtype, kvbits_, kvtype, atype) \ 745 | if (transformer->weights.dbits == dbits_ && transformer->state.kvbits == kvbits_) \ 746 | return forward(transformer, token, pos, flags) 747 | 748 | CASE(4, uint32_t, 8, __nv_fp8_e5m2, float); 749 | CASE(4, uint32_t, 16, __half, float); 750 | CASE(8, __nv_fp8_e5m2, 8, __nv_fp8_e5m2, float); 751 | CASE(8, __nv_fp8_e5m2, 16, __half, float); 752 | CASE(16, __half, 8, __nv_fp8_e5m2, float); 753 | CASE(16, __half, 16, __half, float); 754 | 755 | assert(!"Unsupported dbits/kvbits combination for CUDA: dbits must be 4, 8 or 16, kvbits must be 8 or 16"); 756 | return NULL; 757 | 758 | #undef CASE 759 | } 760 | 761 | extern "C" void perf_cuda() { 762 | if (coopperf == NULL || coopruns == 0) 763 | return; 764 | 765 | uint64_t hostperf[16] = {}; 766 | CUDA_CHECK(cudaMemcpy(hostperf, coopperf, sizeof(hostperf), cudaMemcpyDeviceToHost)); 767 | 768 | static const char* stagenames[16] = { 769 | "matmul_qkv", 770 | "attn_score", 771 | "attn_softmax", 772 | "attn_mix", 773 | "matmul_attn", 774 | "matmul_ffn_up", 775 | "matmul_ffn_down", 776 | }; 777 | 778 | double freq = 1e9; 779 | 780 | uint64_t total = 0; 781 | for (int stage = 0; stage < 16; ++stage) { 782 | total += hostperf[stage]; 783 | } 784 | 785 | printf("\nkernel_forward breakdown (over %d runs, avg %.1f usec/run):\n", 786 | coopruns, (double)total / (double)coopruns / freq * 1e6); 787 | 788 | for (int stage = 0; stage < 16; ++stage) { 789 | if (hostperf[stage] == 0) 790 | continue; 791 | 792 | uint64_t t = hostperf[stage]; 793 | uint64_t tbw = coopperfbw[stage]; 794 | 795 | printf("\t[%d] %16s: %4.1f%%; %8.1f usec/run, %6.1f GB/s\n", 796 | stage, stagenames[stage], 797 | (double)t / (double)total * 100, 798 | (double)(t / coopruns) / freq * 1e6, 799 | ((double)tbw / 1e9) / ((double)t / freq)); 800 | } 801 | } 802 | -------------------------------------------------------------------------------- /tools/convert.py: -------------------------------------------------------------------------------- 1 | # Produce a safetensors model file out of multiple inputs 2 | # python convert.py model.safetensors --config config.json --models file1.bin file2.bin ... 3 | 4 | import argparse 5 | import base64 6 | import json 7 | import os.path 8 | import safetensors 9 | import safetensors.torch 10 | import torch 11 | # optionally imports sentencepiece below when converting models without HF tokenizer.json 12 | 13 | argp = argparse.ArgumentParser() 14 | argp.add_argument("output", type=str) 15 | argp.add_argument("input", type=str, nargs="?") 16 | argp.add_argument("--config", type=str) 17 | argp.add_argument("--tokenizer", type=str) 18 | argp.add_argument("--models", type=str, nargs="+") 19 | argp.add_argument("--dtype", type=str, default="fp8", choices=["fp16", "fp8", "gf4"]) 20 | args = argp.parse_args() 21 | 22 | if args.input is not None: 23 | # assume input is a directory with HuggingFace layout 24 | if args.config is None: 25 | args.config = os.path.join(args.input, "config.json") 26 | if not os.path.exists(args.config): 27 | argp.error("no config.json found in {}".format(args.input)) 28 | if args.tokenizer is None: 29 | args.tokenizer = os.path.join(args.input, "tokenizer.json") 30 | if not os.path.exists(args.tokenizer): 31 | args.tokenizer = os.path.join(args.input, "tokenizer.model") 32 | if not os.path.exists(args.tokenizer): 33 | argp.error("no tokenizer.json or tokenizer.model found in {}".format(args.input)) 34 | if args.models is None: 35 | files = os.listdir(args.input) 36 | args.models = [os.path.join(args.input, fn) for fn in files if os.path.splitext(fn)[1] == ".safetensors"] 37 | if len(args.models) == 0: 38 | args.models = [os.path.join(args.input, fn) for fn in files if os.path.splitext(fn)[1] == ".bin"] 39 | if len(args.models) == 0: 40 | argp.error("no .safetensors or .bin files found in {}".format(args.input)) 41 | elif args.config is None or args.models is None: 42 | argp.error("arguments --config, --tokenizer and --models are required unless argument input is specified") 43 | 44 | with open(args.config, "r") as f: 45 | config = json.load(f) 46 | 47 | metadata = {} 48 | tensors = {} 49 | 50 | arch = config["architectures"][0] 51 | arch_remap = {"LlamaForCausalLM": "llama", "MistralForCausalLM": "mistral", "MixtralForCausalLM": "mixtral", "Qwen2ForCausalLM": "qwen2", "OLMoForCausalLM": "olmo", "GemmaForCausalLM": "gemma", "MiniCPMForCausalLM": "minicpm", "CohereForCausalLM": "cohere", "InternLM2ForCausalLM": "internlm2", "DbrxForCausalLM": "dbrx", "XverseForCausalLM": "xverse", "Phi3ForCausalLM": "phi3", "OlmoeForCausalLM": "olmoe"} 52 | assert arch in arch_remap, "Unsupported architecture: {}; must be one of: {}".format(arch, list(arch_remap.keys())) 53 | arch = arch_remap[arch] 54 | 55 | metadata["arch"] = arch 56 | metadata["dtype"] = args.dtype 57 | 58 | if arch in ["llama", "mistral", "mixtral", "qwen2", "gemma", "minicpm", "cohere", "internlm2", "xverse", "phi3", "olmoe"]: 59 | metadata["dim"] = config["hidden_size"] 60 | metadata["hidden_dim"] = config["intermediate_size"] 61 | metadata["head_dim"] = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]) 62 | metadata["n_layers"] = config["num_hidden_layers"] 63 | metadata["n_heads"] = config["num_attention_heads"] 64 | metadata["n_kv_heads"] = config.get("num_key_value_heads", config["num_attention_heads"]) 65 | metadata["vocab_size"] = config["vocab_size"] 66 | metadata["max_seq_len"] = 2048 if arch == "phi3" else config["max_position_embeddings"] 67 | metadata["bos_token_id"] = -1 if arch in ["qwen2", "olmoe"] else config["bos_token_id"] 68 | metadata["eos_token_id"] = config["eos_token_id"] 69 | metadata["rope_theta"] = config.get("rope_theta", 10000.0) 70 | metadata["rotary_dim"] = int(metadata["head_dim"] * config.get("partial_rotary_factor", 1)) 71 | metadata["norm_eps"] = config["layer_norm_eps"] if arch == "cohere" else config["rms_norm_eps"] 72 | metadata["norm_type"] = "layernorm_par" if arch == "cohere" else "rmsnorm" 73 | 74 | assert config["hidden_act"] in ["gelu", "silu"] 75 | metadata["act_type"] = config["hidden_act"] 76 | 77 | # moe 78 | if arch in ["mixtral"]: 79 | metadata["n_experts"] = config["num_local_experts"] 80 | metadata["n_experts_active"] = config["num_experts_per_tok"] 81 | elif arch in ["minicpm"] and "num_experts" in config: 82 | metadata["n_experts"] = config["num_experts"] 83 | metadata["n_experts_active"] = config["num_experts_per_tok"] 84 | elif arch in ["olmoe"]: 85 | metadata["n_experts"] = config["num_experts"] 86 | metadata["n_experts_active"] = config["num_experts_per_tok"] 87 | elif arch == "olmo": 88 | metadata["dim"] = config["d_model"] 89 | metadata["hidden_dim"] = (config["mlp_hidden_size"] or config["d_model"] * config["mlp_ratio"]) // 2 90 | metadata["n_layers"] = config["n_layers"] 91 | metadata["n_heads"] = config["n_heads"] 92 | metadata["n_kv_heads"] = config["n_heads"] 93 | metadata["vocab_size"] = config["embedding_size"] 94 | metadata["max_seq_len"] = config["max_sequence_length"] 95 | metadata["bos_token_id"] = -1 96 | metadata["eos_token_id"] = config["eos_token_id"] 97 | metadata["rope_theta"] = 10000.0 98 | metadata["rotary_dim"] = config["d_model"] // config["n_heads"] 99 | metadata["norm_eps"] = 1e-5 100 | metadata["norm_type"] = "layernorm" 101 | 102 | assert config["activation_type"] == "swiglu" 103 | metadata["act_type"] = "silu" 104 | 105 | if config.get("clip_qkv", None): 106 | metadata["qkv_clip"] = config["clip_qkv"] 107 | elif arch == "dbrx": 108 | metadata["dim"] = config["d_model"] 109 | metadata["hidden_dim"] = config["ffn_config"]["ffn_hidden_size"] 110 | metadata["head_dim"] = config["d_model"] // config["n_heads"] 111 | metadata["n_layers"] = config["n_layers"] 112 | metadata["n_heads"] = config["n_heads"] 113 | metadata["n_kv_heads"] = config["attn_config"]["kv_n_heads"] 114 | metadata["vocab_size"] = config["vocab_size"] 115 | metadata["max_seq_len"] = config["max_seq_len"] 116 | metadata["bos_token_id"] = -1 117 | metadata["eos_token_id"] = 100257 118 | metadata["rope_theta"] = config["attn_config"]["rope_theta"] 119 | metadata["rotary_dim"] = config["d_model"] // config["n_heads"] 120 | metadata["norm_eps"] = 1e-5 121 | metadata["norm_type"] = "layernorm" 122 | metadata["act_type"] = "silu" 123 | metadata["n_experts"] = config["ffn_config"]["moe_num_experts"] 124 | metadata["n_experts_active"] = config["ffn_config"]["moe_top_k"] 125 | metadata["qkv_clip"] = config["attn_config"]["clip_qkv"] 126 | 127 | # this is a horrible gpt-2 unicode byte encoder hack from https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 128 | # this has poisoned all HF tokenizer configs that use ByteLevel decoder/preprocessor 129 | # as a result we get crazy UTF-8-as-bytes-as-UTF8 in the tokenizer data that we need to convert back 130 | def gpt2_bytes_to_unicode(): 131 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 132 | cs = bs[:] 133 | n = 0 134 | for b in range(2**8): 135 | if b not in bs: 136 | bs.append(b) 137 | cs.append(2**8+n) 138 | n += 1 139 | cs = [chr(n) for n in cs] 140 | return dict(zip(bs, cs)) 141 | 142 | # load tokenizer model 143 | tokens = [""] * metadata["vocab_size"] 144 | scores = [0] * metadata["vocab_size"] 145 | tokens_gpt2 = False 146 | 147 | ext = os.path.splitext(args.tokenizer)[1] 148 | if ext == ".json": 149 | with open(args.tokenizer, "r") as f: 150 | tokenizer = json.load(f) 151 | 152 | vocab = tokenizer["model"]["vocab"] 153 | assert len(vocab) <= config["vocab_size"] 154 | 155 | tokens_gpt2 = not tokenizer["model"].get("byte_fallback", False) 156 | 157 | for t, i in vocab.items(): 158 | tokens[i] = t 159 | 160 | for added in tokenizer["added_tokens"]: 161 | tokens[added["id"]] = added["content"] 162 | 163 | # compute score as negative merge index so that earlier merges get selected first 164 | for i, m in enumerate(tokenizer["model"]["merges"]): 165 | t1, t2 = (m[0], m[1]) if isinstance(m, list) else m.split(" ", 2) 166 | ti = vocab[t1 + t2] 167 | if scores[ti] == 0: 168 | scores[ti] = -(1 + i) 169 | elif ext == ".model": 170 | import sentencepiece 171 | sp_model = sentencepiece.SentencePieceProcessor(model_file=args.tokenizer) 172 | assert sp_model.vocab_size() <= config["vocab_size"] 173 | assert sp_model.bos_id() == config["bos_token_id"] 174 | assert sp_model.eos_id() == config["eos_token_id"] 175 | 176 | for i in range(sp_model.vocab_size()): 177 | tokens[i] = sp_model.id_to_piece(i) 178 | scores[i] = sp_model.get_score(i) 179 | elif ext == ".tiktoken": 180 | with open(args.tokenizer, "r") as f: 181 | vocab = f.readlines() 182 | assert len(vocab) <= config["vocab_size"] 183 | 184 | for i, l in enumerate(vocab): 185 | t, r = l.rstrip().split(" ") 186 | t = base64.b64decode(t) 187 | tokens[i] = t.decode("utf-8", errors="replace").replace("\0", "\7") 188 | scores[i] = -int(r) 189 | else: 190 | raise Exception("Unknown tokenizer file extension: {}; expected .json or .model/.tiktoken".format(ext)) 191 | 192 | # postprocess tokens 193 | gpt2_decode = {v: k for k, v in gpt2_bytes_to_unicode().items()} 194 | 195 | for i, t in enumerate(tokens): 196 | if tokens_gpt2: 197 | b = bytes([gpt2_decode.get(c, 0) for c in t]) 198 | else: 199 | t = t.replace('\u2581', ' ') # sentencepiece uses this character as whitespace 200 | b = t.encode('utf-8') 201 | 202 | b = b.replace(b"\0", b"\7") # replace null bytes with bell characters 203 | assert b.count(0) == 0 # no null bytes allowed 204 | 205 | tokens[i] = b 206 | 207 | # load model files 208 | weights = {} 209 | for fn in args.models: 210 | ext = os.path.splitext(fn)[1] 211 | if ext == ".safetensors": 212 | with safetensors.safe_open(fn, framework="pt") as f: 213 | for k in f.keys(): 214 | assert(k not in weights) 215 | weights[k] = f.get_tensor(k) 216 | elif ext == ".bin": 217 | pth = torch.load(fn, map_location="cpu", weights_only=True) 218 | for k in pth.keys(): 219 | assert(k not in weights) 220 | weights[k] = pth[k] 221 | else: 222 | raise Exception("Unknown model file extension: {}; expected .safetensors or .bin".format(ext)) 223 | 224 | # huggingface permutes WQ and WK, this function reverses it 225 | # see https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122 226 | def permute_reverse(w, heads, rotary_dim): 227 | head_dim = w.shape[0] // heads 228 | assert rotary_dim <= head_dim 229 | w = torch.unflatten(w, 0, (-1, head_dim)) 230 | # wr is the rotary part, wk is the part kept unrotated 231 | wr = w[:, :rotary_dim] 232 | wk = w[:, rotary_dim:] 233 | # switch wr from outputting two rotary_dim/2 chunks to outputting values interleaved 234 | wr = torch.unflatten(wr, 1, (2, -1)) 235 | wr = wr.transpose(1, 2) 236 | wr = wr.flatten(1, 2) 237 | # assemble the heads back 238 | w = torch.cat([wr, wk], dim=1) 239 | return torch.flatten(w, 0, 1) 240 | 241 | # fp8 support requires torch 2.1, but we support other dtypes on earlier versions 242 | dtype = {"fp16": torch.float16, "fp8": getattr(torch, "float8_e5m2", None), "gf4": torch.uint8}[args.dtype] 243 | assert dtype 244 | 245 | # gf4 quantization: 8 values get quantized to 32 bits, 3-bit normalized int per value + shared fp8 scale factor 246 | # int range is asymmetric; we use this fact to encode the max value as -4 to expand the range a little bit 247 | def gf4(t): 248 | if torch.cuda.is_available(): 249 | t.max() # work around cuda load from mmap using small block size for reading... 250 | t = t.cuda() 251 | # groups of 8 values 252 | gt = t.unflatten(-1, (-1, 8)) 253 | # max (abs) of each group 254 | _, gmaxi = gt.abs().max(-1) 255 | gmax = gt.gather(-1, gmaxi.unsqueeze(-1)) 256 | # round gmax to fp8 to make sure we're quantizing to the right range 257 | gmax = gmax.to(torch.float8_e5m2).to(gmax.dtype) 258 | # normalize gt; note that gmax may be zero 259 | gt /= gmax 260 | torch.nan_to_num(gt, nan=0.0, posinf=0.0, neginf=0.0, out=gt) 261 | # normalize each group by -max ([-1, 1]) and quantize to [0, 8) 262 | # note that 8 needs to be clamped to 7 since positive half of the range is shorter 263 | gtq = (gt.to(torch.float16) * -4 + 4).clamp(0, 7).round().to(torch.int32) 264 | # assemble the results 265 | gtq <<= torch.tensor([8 + i * 3 for i in range(8)], dtype=torch.int32, device=gtq.device) 266 | gtr = gtq.sum(-1, dtype=torch.int32) 267 | gtr += gmax.squeeze(-1).to(torch.float8_e5m2).view(torch.uint8) 268 | return gtr.cpu() 269 | 270 | # preprocess weights 271 | if arch == "minicpm": 272 | # apply various scaling factors that other models don't have to tensors 273 | embed_scale = config["scale_emb"] 274 | resid_scale = config["scale_depth"] / (config["num_hidden_layers"] ** 0.5) 275 | final_scale = config["dim_model_base"] / config["hidden_size"] 276 | 277 | weights["model.norm.weight"] *= final_scale / (1.0 if config.get("tie_word_embeddings", None) == False else embed_scale) 278 | weights["model.embed_tokens.weight"] *= embed_scale 279 | 280 | for l in range(config["num_hidden_layers"]): 281 | weights[f"model.layers.{l}.self_attn.o_proj.weight"] *= resid_scale 282 | 283 | if "num_experts" in config: 284 | for e in range(config["num_experts"]): 285 | weights[f"model.layers.{l}.mlp.experts.{e}.w2.weight"] *= resid_scale 286 | else: 287 | weights[f"model.layers.{l}.mlp.down_proj.weight"] *= resid_scale 288 | elif arch == "gemma": 289 | # gemma's norm weights are stored relative to 1.0 290 | weights["model.norm.weight"] = weights["model.norm.weight"].float() + 1 291 | 292 | for l in range(config["num_hidden_layers"]): 293 | weights[f"model.layers.{l}.input_layernorm.weight"] = weights[f"model.layers.{l}.input_layernorm.weight"].float() + 1 294 | weights[f"model.layers.{l}.post_attention_layernorm.weight"] = weights[f"model.layers.{l}.post_attention_layernorm.weight"].float() + 1 295 | 296 | # apply embedding scale (and counter it since output weights are tied) 297 | # this improves precision for fp8 298 | embed_scale = config["hidden_size"] ** 0.5 299 | 300 | weights["model.norm.weight"] *= 1 / embed_scale 301 | weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"].float() * embed_scale 302 | elif arch == "cohere": 303 | weights["model.norm.weight"] *= config["logit_scale"] 304 | 305 | # convert weights 306 | progress = 0 307 | def conv(t): 308 | global progress 309 | progress += 1 310 | print(f"\rConverting tensor {progress}: {t.shape}", end="", flush=True) 311 | return gf4(t) if dtype == torch.uint8 else t.to(dtype) 312 | 313 | if arch in ["llama", "mistral", "mixtral", "qwen2", "gemma", "minicpm", "cohere", "xverse", "olmoe"]: 314 | if arch == "olmoe": 315 | print("Warning: Olmoe uses QK norm which we do not support") 316 | 317 | tensors["model.embed.weight"] = conv(weights["model.embed_tokens.weight"]) 318 | 319 | for l in range(config["num_hidden_layers"]): 320 | tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"model.layers.{l}.input_layernorm.weight"].float() 321 | 322 | rotary_dim = metadata["rotary_dim"] 323 | n_heads = config["num_attention_heads"] 324 | n_kv_heads = config.get("num_key_value_heads", n_heads) 325 | 326 | if arch == "cohere": 327 | tensors[f"model.layers.{l}.attn.wq.weight"] = conv(weights[f"model.layers.{l}.self_attn.q_proj.weight"]) 328 | tensors[f"model.layers.{l}.attn.wk.weight"] = conv(weights[f"model.layers.{l}.self_attn.k_proj.weight"]) 329 | else: 330 | tensors[f"model.layers.{l}.attn.wq.weight"] = conv(permute_reverse(weights[f"model.layers.{l}.self_attn.q_proj.weight"], n_heads, rotary_dim)) 331 | tensors[f"model.layers.{l}.attn.wk.weight"] = conv(permute_reverse(weights[f"model.layers.{l}.self_attn.k_proj.weight"], n_kv_heads, rotary_dim)) 332 | 333 | tensors[f"model.layers.{l}.attn.wv.weight"] = conv(weights[f"model.layers.{l}.self_attn.v_proj.weight"]) 334 | tensors[f"model.layers.{l}.attn.wo.weight"] = conv(weights[f"model.layers.{l}.self_attn.o_proj.weight"]) 335 | 336 | if arch in ["qwen2"]: 337 | tensors[f"model.layers.{l}.attn.wqkv.bias"] = torch.cat([ 338 | permute_reverse(weights[f"model.layers.{l}.self_attn.q_proj.bias"], n_heads, rotary_dim).float(), 339 | permute_reverse(weights[f"model.layers.{l}.self_attn.k_proj.bias"], n_kv_heads, rotary_dim).float(), 340 | weights[f"model.layers.{l}.self_attn.v_proj.bias"].float() 341 | ]) 342 | 343 | if arch != "cohere": 344 | tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"model.layers.{l}.post_attention_layernorm.weight"].float() 345 | 346 | if arch in ["mixtral"]: 347 | tensors[f"model.layers.{l}.moegate.weight"] = conv(weights[f"model.layers.{l}.block_sparse_moe.gate.weight"]) 348 | 349 | tensors[f"model.layers.{l}.mlp.w1.weight"] = torch.stack([conv(weights[f"model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight"]) for e in range(config["num_local_experts"])]) 350 | tensors[f"model.layers.{l}.mlp.w2.weight"] = torch.stack([conv(weights[f"model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight"]) for e in range(config["num_local_experts"])]) 351 | tensors[f"model.layers.{l}.mlp.w3.weight"] = torch.stack([conv(weights[f"model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight"]) for e in range(config["num_local_experts"])]) 352 | elif arch in ["minicpm"] and "num_experts" in config: 353 | tensors[f"model.layers.{l}.moegate.weight"] = conv(weights[f"model.layers.{l}.mlp.gate.weight"]) 354 | 355 | tensors[f"model.layers.{l}.mlp.w1.weight"] = torch.stack([conv(weights[f"model.layers.{l}.mlp.experts.{e}.w1.weight"]) for e in range(config["num_experts"])]) 356 | tensors[f"model.layers.{l}.mlp.w2.weight"] = torch.stack([conv(weights[f"model.layers.{l}.mlp.experts.{e}.w2.weight"]) for e in range(config["num_experts"])]) 357 | tensors[f"model.layers.{l}.mlp.w3.weight"] = torch.stack([conv(weights[f"model.layers.{l}.mlp.experts.{e}.w3.weight"]) for e in range(config["num_experts"])]) 358 | elif arch in ["olmoe"]: 359 | tensors[f"model.layers.{l}.moegate.weight"] = conv(weights[f"model.layers.{l}.mlp.gate.weight"]) 360 | 361 | tensors[f"model.layers.{l}.mlp.w1.weight"] = torch.stack([conv(weights[f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight"]) for e in range(config["num_experts"])]) 362 | tensors[f"model.layers.{l}.mlp.w2.weight"] = torch.stack([conv(weights[f"model.layers.{l}.mlp.experts.{e}.down_proj.weight"]) for e in range(config["num_experts"])]) 363 | tensors[f"model.layers.{l}.mlp.w3.weight"] = torch.stack([conv(weights[f"model.layers.{l}.mlp.experts.{e}.up_proj.weight"]) for e in range(config["num_experts"])]) 364 | else: 365 | tensors[f"model.layers.{l}.mlp.w1.weight"] = conv(weights[f"model.layers.{l}.mlp.gate_proj.weight"]) 366 | tensors[f"model.layers.{l}.mlp.w2.weight"] = conv(weights[f"model.layers.{l}.mlp.down_proj.weight"]) 367 | tensors[f"model.layers.{l}.mlp.w3.weight"] = conv(weights[f"model.layers.{l}.mlp.up_proj.weight"]) 368 | 369 | tensors["model.norm.weight"] = weights["model.norm.weight"].float() 370 | if config.get("tie_word_embeddings", None) != True: 371 | tensors["model.output.weight"] = conv(weights["lm_head.weight"]) 372 | elif arch == "internlm2": 373 | tensors["model.embed.weight"] = conv(weights["model.tok_embeddings.weight"]) 374 | 375 | for l in range(config["num_hidden_layers"]): 376 | tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"model.layers.{l}.attention_norm.weight"].float() 377 | 378 | head_dim = metadata["head_dim"] 379 | n_heads = config["num_attention_heads"] 380 | n_kv_heads = config.get("num_key_value_heads", n_heads) 381 | kv_mul = n_heads // n_kv_heads 382 | 383 | wqkv = weights[f"model.layers.{l}.attention.wqkv.weight"] 384 | wqkv = wqkv.unflatten(0, (n_kv_heads, kv_mul + 2, head_dim)) 385 | 386 | tensors[f"model.layers.{l}.attn.wq.weight"] = conv(permute_reverse(wqkv[:, :kv_mul].flatten(0, 2), n_heads, head_dim)) 387 | tensors[f"model.layers.{l}.attn.wk.weight"] = conv(permute_reverse(wqkv[:, kv_mul].flatten(0, 1), n_kv_heads, head_dim)) 388 | 389 | tensors[f"model.layers.{l}.attn.wv.weight"] = conv(wqkv[:, kv_mul+1].flatten(0, 1)) 390 | tensors[f"model.layers.{l}.attn.wo.weight"] = conv(weights[f"model.layers.{l}.attention.wo.weight"]) 391 | 392 | tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"model.layers.{l}.ffn_norm.weight"].float() 393 | 394 | tensors[f"model.layers.{l}.mlp.w1.weight"] = conv(weights[f"model.layers.{l}.feed_forward.w1.weight"]) 395 | tensors[f"model.layers.{l}.mlp.w2.weight"] = conv(weights[f"model.layers.{l}.feed_forward.w2.weight"]) 396 | tensors[f"model.layers.{l}.mlp.w3.weight"] = conv(weights[f"model.layers.{l}.feed_forward.w3.weight"]) 397 | 398 | tensors["model.norm.weight"] = weights["model.norm.weight"].float() 399 | tensors["model.output.weight"] = conv(weights["output.weight"]) 400 | elif arch == "olmo": 401 | tensors["model.embed.weight"] = conv(weights["model.transformer.wte.weight"]) 402 | 403 | for l in range(config["n_layers"]): 404 | tensors[f"model.layers.{l}.attn.norm.weight"] = torch.ones(config["d_model"], dtype=torch.float32) 405 | 406 | dim = config["d_model"] 407 | head_dim = dim // config["n_heads"] 408 | hidden_dim = (config["mlp_hidden_size"] or config["d_model"] * config["mlp_ratio"]) // 2 409 | 410 | attn_proj = weights[f"model.transformer.blocks.{l}.att_proj.weight"] 411 | assert attn_proj.shape == (dim * 3, dim) 412 | 413 | tensors[f"model.layers.{l}.attn.wq.weight"] = conv(permute_reverse(attn_proj[:dim], config["n_heads"], head_dim)) 414 | tensors[f"model.layers.{l}.attn.wk.weight"] = conv(permute_reverse(attn_proj[dim:dim*2], config["n_heads"], head_dim)) 415 | tensors[f"model.layers.{l}.attn.wv.weight"] = conv(attn_proj[dim*2:]) 416 | tensors[f"model.layers.{l}.attn.wo.weight"] = conv(weights[f"model.transformer.blocks.{l}.attn_out.weight"]) 417 | 418 | tensors[f"model.layers.{l}.mlp.norm.weight"] = torch.ones(config["d_model"], dtype=torch.float32) 419 | 420 | mlp_proj = weights[f"model.transformer.blocks.{l}.ff_proj.weight"] 421 | assert mlp_proj.shape == (hidden_dim * 2, dim) 422 | 423 | tensors[f"model.layers.{l}.mlp.w1.weight"] = conv(mlp_proj[hidden_dim:]) 424 | tensors[f"model.layers.{l}.mlp.w2.weight"] = conv(weights[f"model.transformer.blocks.{l}.ff_out.weight"]) 425 | tensors[f"model.layers.{l}.mlp.w3.weight"] = conv(mlp_proj[:hidden_dim]) 426 | 427 | tensors["model.norm.weight"] = torch.ones(config["d_model"], dtype=torch.float32) 428 | if not config["weight_tying"]: 429 | tensors["model.output.weight"] = conv(weights["model.transformer.ff_out.weight"]) 430 | elif arch == "dbrx": 431 | tensors["model.embed.weight"] = conv(weights["transformer.wte.weight"]) 432 | 433 | for l in range(config["n_layers"]): 434 | tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"transformer.blocks.{l}.norm_attn_norm.norm_1.weight"].float() 435 | 436 | head_dim = config["d_model"] // config["n_heads"] 437 | n_heads = config["n_heads"] 438 | n_kv_heads = config["attn_config"]["kv_n_heads"] 439 | 440 | dim = config["d_model"] 441 | hidden_dim = config["ffn_config"]["ffn_hidden_size"] 442 | n_experts = config["ffn_config"]["moe_num_experts"] 443 | 444 | wqkv = weights[f"transformer.blocks.{l}.norm_attn_norm.attn.Wqkv.weight"] 445 | 446 | tensors[f"model.layers.{l}.attn.wq.weight"] = conv(permute_reverse(wqkv[:n_heads*head_dim], n_heads, head_dim)) 447 | tensors[f"model.layers.{l}.attn.wk.weight"] = conv(permute_reverse(wqkv[n_heads*head_dim:(n_heads+n_kv_heads)*head_dim], n_kv_heads, head_dim)) 448 | tensors[f"model.layers.{l}.attn.wv.weight"] = conv(wqkv[(n_heads+n_kv_heads)*head_dim:]) 449 | tensors[f"model.layers.{l}.attn.wo.weight"] = conv(weights[f"transformer.blocks.{l}.norm_attn_norm.attn.out_proj.weight"]) 450 | 451 | tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"transformer.blocks.{l}.norm_attn_norm.norm_2.weight"].float() 452 | 453 | tensors[f"model.layers.{l}.moegate.weight"] = conv(weights[f"transformer.blocks.{l}.ffn.router.layer.weight"]) 454 | 455 | tensors[f"model.layers.{l}.mlp.w1.weight"] = conv(weights[f"transformer.blocks.{l}.ffn.experts.mlp.w1"].view(n_experts, hidden_dim, dim)) 456 | tensors[f"model.layers.{l}.mlp.w2.weight"] = conv(weights[f"transformer.blocks.{l}.ffn.experts.mlp.w2"].view(n_experts, hidden_dim, dim).transpose(1, 2).contiguous()) 457 | tensors[f"model.layers.{l}.mlp.w3.weight"] = conv(weights[f"transformer.blocks.{l}.ffn.experts.mlp.v1"].view(n_experts, hidden_dim, dim)) 458 | 459 | tensors["model.norm.weight"] = weights["transformer.norm_f.weight"].float() 460 | tensors["model.output.weight"] = conv(weights["lm_head.weight"]) 461 | elif arch == "phi3": 462 | tensors["model.embed.weight"] = conv(weights["model.embed_tokens.weight"]) 463 | 464 | for l in range(config["num_hidden_layers"]): 465 | tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"model.layers.{l}.input_layernorm.weight"].float() 466 | 467 | head_dim = config["hidden_size"] // config["num_attention_heads"] 468 | n_heads = config["num_attention_heads"] 469 | n_kv_heads = config.get("num_key_value_heads", n_heads) 470 | 471 | wqkv = weights[f"model.layers.{l}.self_attn.qkv_proj.weight"] 472 | 473 | tensors[f"model.layers.{l}.attn.wq.weight"] = conv(permute_reverse(wqkv[:n_heads*head_dim], n_heads, head_dim)) 474 | tensors[f"model.layers.{l}.attn.wk.weight"] = conv(permute_reverse(wqkv[n_heads*head_dim:(n_heads+n_kv_heads)*head_dim], n_kv_heads, head_dim)) 475 | 476 | tensors[f"model.layers.{l}.attn.wv.weight"] = conv(wqkv[(n_heads+n_kv_heads)*head_dim:]) 477 | tensors[f"model.layers.{l}.attn.wo.weight"] = conv(weights[f"model.layers.{l}.self_attn.o_proj.weight"]) 478 | 479 | tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"model.layers.{l}.post_attention_layernorm.weight"].float() 480 | 481 | hidden_dim = config["intermediate_size"] 482 | 483 | mlp_proj = weights[f"model.layers.{l}.mlp.gate_up_proj.weight"] 484 | 485 | tensors[f"model.layers.{l}.mlp.w1.weight"] = conv(mlp_proj[:hidden_dim]) 486 | tensors[f"model.layers.{l}.mlp.w2.weight"] = conv(weights[f"model.layers.{l}.mlp.down_proj.weight"]) 487 | tensors[f"model.layers.{l}.mlp.w3.weight"] = conv(mlp_proj[hidden_dim:]) 488 | 489 | tensors["model.norm.weight"] = weights["model.norm.weight"].float() 490 | tensors["model.output.weight"] = conv(weights["lm_head.weight"]) 491 | 492 | # add tokenizer tensors at the end (to maximize the chance of model tensor alignment) 493 | # note: we concatenate all bytes of all tokens into a single tensor 494 | tensors["tokenizer.tokens"] = torch.cat([torch.tensor([x for x in b] + [0], dtype=torch.uint8) for b in tokens]) 495 | tensors["tokenizer.scores"] = torch.tensor(scores, dtype=torch.float32) 496 | 497 | print(f"\rSaving {len(tensors)} tensors..." + " " * 40) 498 | 499 | # in a perfect world, we would just use HF safetensors.torch.save_file 500 | # however, not only does it not support fp8 (https://github.com/huggingface/safetensors/pull/404), it also copies every tensor 501 | # our models are large, so we'll implement a custom save function. could even materialize converted tensors lazily later. 502 | def save_file(tensors, filename, metadata=None): 503 | _TYPES = { 504 | torch.float32: "F32", 505 | torch.float16: "F16", 506 | torch.bfloat16: "BF16", 507 | getattr(torch, "float8_e5m2", None): "F8_E5M2", 508 | getattr(torch, "float8_e4m3fn", None): "F8_E4M3", 509 | torch.int32: "I32", 510 | torch.int16: "I16", 511 | torch.int8: "I8", 512 | torch.uint8: "U8", 513 | } 514 | _ALIGN = 256 515 | 516 | header = {} 517 | offset = 0 518 | if metadata: 519 | header["__metadata__"] = metadata 520 | for k, v in tensors.items(): 521 | size = v.numel() * v.element_size() 522 | header[k] = { "dtype": _TYPES[v.dtype], "shape": v.shape, "data_offsets": [offset, offset + size] } 523 | offset += size 524 | 525 | hjson = json.dumps(header).encode("utf-8") 526 | hjson += b" " * (-(len(hjson) + 8) % _ALIGN) 527 | 528 | with open(filename, "wb") as f: 529 | f.write(len(hjson).to_bytes(8, byteorder="little")) 530 | f.write(hjson) 531 | for k, v in tensors.items(): 532 | assert v.layout == torch.strided and v.is_contiguous() 533 | v.view(torch.uint8).numpy().tofile(f) 534 | 535 | # metadata values must be strings in safetensors 536 | save_file(tensors, args.output, {k: str(v) for k, v in metadata.items()}) 537 | --------------------------------------------------------------------------------