├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── condensed
└── condensed.js
├── globals.js
├── index.html
├── instructions.js
├── model.js
├── other
├── conversion_scripts
│ ├── README.md
│ ├── ckpt.pt
│ ├── convert_checkpoint_pytorch.py
│ ├── convert_pretrained_pytorch.py
│ └── sample_shakespeare_ckpt.pt
├── int8-gemm.js
├── misc
│ ├── files.png
│ └── header.png
├── scratchpad.js
├── test.js
└── validation
│ ├── README.md
│ ├── test
│ ├── gpt2medium_validation.json
│ └── shakepeare_validation.json
│ └── validation.js
├── tokenizer.js
├── visuals.js
└── weights
├── better_shakespeare
├── lm_head.weight_gpt.bin
├── params_gpt.json
├── transformer.h.0.attn.c_attn.bias_gpt.bin
├── transformer.h.0.attn.c_attn.weight_gpt.bin
├── transformer.h.0.attn.c_proj.bias_gpt.bin
├── transformer.h.0.attn.c_proj.weight_gpt.bin
├── transformer.h.0.ln_1.bias_gpt.bin
├── transformer.h.0.ln_1.weight_gpt.bin
├── transformer.h.0.ln_2.bias_gpt.bin
├── transformer.h.0.ln_2.weight_gpt.bin
├── transformer.h.0.mlp.c_fc.bias_gpt.bin
├── transformer.h.0.mlp.c_fc.weight_gpt.bin
├── transformer.h.0.mlp.c_proj.bias_gpt.bin
├── transformer.h.0.mlp.c_proj.weight_gpt.bin
├── transformer.h.1.attn.c_attn.bias_gpt.bin
├── transformer.h.1.attn.c_attn.weight_gpt.bin
├── transformer.h.1.attn.c_proj.bias_gpt.bin
├── transformer.h.1.attn.c_proj.weight_gpt.bin
├── transformer.h.1.ln_1.bias_gpt.bin
├── transformer.h.1.ln_1.weight_gpt.bin
├── transformer.h.1.ln_2.bias_gpt.bin
├── transformer.h.1.ln_2.weight_gpt.bin
├── transformer.h.1.mlp.c_fc.bias_gpt.bin
├── transformer.h.1.mlp.c_fc.weight_gpt.bin
├── transformer.h.1.mlp.c_proj.bias_gpt.bin
├── transformer.h.1.mlp.c_proj.weight_gpt.bin
├── transformer.h.2.attn.c_attn.bias_gpt.bin
├── transformer.h.2.attn.c_attn.weight_gpt.bin
├── transformer.h.2.attn.c_proj.bias_gpt.bin
├── transformer.h.2.attn.c_proj.weight_gpt.bin
├── transformer.h.2.ln_1.bias_gpt.bin
├── transformer.h.2.ln_1.weight_gpt.bin
├── transformer.h.2.ln_2.bias_gpt.bin
├── transformer.h.2.ln_2.weight_gpt.bin
├── transformer.h.2.mlp.c_fc.bias_gpt.bin
├── transformer.h.2.mlp.c_fc.weight_gpt.bin
├── transformer.h.2.mlp.c_proj.bias_gpt.bin
├── transformer.h.2.mlp.c_proj.weight_gpt.bin
├── transformer.h.3.attn.c_attn.bias_gpt.bin
├── transformer.h.3.attn.c_attn.weight_gpt.bin
├── transformer.h.3.attn.c_proj.bias_gpt.bin
├── transformer.h.3.attn.c_proj.weight_gpt.bin
├── transformer.h.3.ln_1.bias_gpt.bin
├── transformer.h.3.ln_1.weight_gpt.bin
├── transformer.h.3.ln_2.bias_gpt.bin
├── transformer.h.3.ln_2.weight_gpt.bin
├── transformer.h.3.mlp.c_fc.bias_gpt.bin
├── transformer.h.3.mlp.c_fc.weight_gpt.bin
├── transformer.h.3.mlp.c_proj.bias_gpt.bin
├── transformer.h.3.mlp.c_proj.weight_gpt.bin
├── transformer.ln_f.bias_gpt.bin
├── transformer.ln_f.weight_gpt.bin
├── transformer.wpe.weight_gpt.bin
└── transformer.wte.weight_gpt.bin
├── gpt2
├── lm_head.weight_gpt.bin
├── params_gpt.json
├── transformer.h.0.attn.bias_gpt.bin
├── transformer.h.0.attn.c_attn.bias_gpt.bin
├── transformer.h.0.attn.c_attn.weight_gpt.bin
├── transformer.h.0.attn.c_proj.bias_gpt.bin
├── transformer.h.0.attn.c_proj.weight_gpt.bin
├── transformer.h.0.attn.masked_bias_gpt.bin
├── transformer.h.0.ln_1.bias_gpt.bin
├── transformer.h.0.ln_1.weight_gpt.bin
├── transformer.h.0.ln_2.bias_gpt.bin
├── transformer.h.0.ln_2.weight_gpt.bin
├── transformer.h.0.mlp.c_fc.bias_gpt.bin
├── transformer.h.0.mlp.c_fc.weight_gpt.bin
├── transformer.h.0.mlp.c_proj.bias_gpt.bin
├── transformer.h.0.mlp.c_proj.weight_gpt.bin
├── transformer.h.1.attn.bias_gpt.bin
├── transformer.h.1.attn.c_attn.bias_gpt.bin
├── transformer.h.1.attn.c_attn.weight_gpt.bin
├── transformer.h.1.attn.c_proj.bias_gpt.bin
├── transformer.h.1.attn.c_proj.weight_gpt.bin
├── transformer.h.1.attn.masked_bias_gpt.bin
├── transformer.h.1.ln_1.bias_gpt.bin
├── transformer.h.1.ln_1.weight_gpt.bin
├── transformer.h.1.ln_2.bias_gpt.bin
├── transformer.h.1.ln_2.weight_gpt.bin
├── transformer.h.1.mlp.c_fc.bias_gpt.bin
├── transformer.h.1.mlp.c_fc.weight_gpt.bin
├── transformer.h.1.mlp.c_proj.bias_gpt.bin
├── transformer.h.1.mlp.c_proj.weight_gpt.bin
├── transformer.h.10.attn.bias_gpt.bin
├── transformer.h.10.attn.c_attn.bias_gpt.bin
├── transformer.h.10.attn.c_attn.weight_gpt.bin
├── transformer.h.10.attn.c_proj.bias_gpt.bin
├── transformer.h.10.attn.c_proj.weight_gpt.bin
├── transformer.h.10.attn.masked_bias_gpt.bin
├── transformer.h.10.ln_1.bias_gpt.bin
├── transformer.h.10.ln_1.weight_gpt.bin
├── transformer.h.10.ln_2.bias_gpt.bin
├── transformer.h.10.ln_2.weight_gpt.bin
├── transformer.h.10.mlp.c_fc.bias_gpt.bin
├── transformer.h.10.mlp.c_fc.weight_gpt.bin
├── transformer.h.10.mlp.c_proj.bias_gpt.bin
├── transformer.h.10.mlp.c_proj.weight_gpt.bin
├── transformer.h.11.attn.bias_gpt.bin
├── transformer.h.11.attn.c_attn.bias_gpt.bin
├── transformer.h.11.attn.c_attn.weight_gpt.bin
├── transformer.h.11.attn.c_proj.bias_gpt.bin
├── transformer.h.11.attn.c_proj.weight_gpt.bin
├── transformer.h.11.attn.masked_bias_gpt.bin
├── transformer.h.11.ln_1.bias_gpt.bin
├── transformer.h.11.ln_1.weight_gpt.bin
├── transformer.h.11.ln_2.bias_gpt.bin
├── transformer.h.11.ln_2.weight_gpt.bin
├── transformer.h.11.mlp.c_fc.bias_gpt.bin
├── transformer.h.11.mlp.c_fc.weight_gpt.bin
├── transformer.h.11.mlp.c_proj.bias_gpt.bin
├── transformer.h.11.mlp.c_proj.weight_gpt.bin
├── transformer.h.2.attn.bias_gpt.bin
├── transformer.h.2.attn.c_attn.bias_gpt.bin
├── transformer.h.2.attn.c_attn.weight_gpt.bin
├── transformer.h.2.attn.c_proj.bias_gpt.bin
├── transformer.h.2.attn.c_proj.weight_gpt.bin
├── transformer.h.2.attn.masked_bias_gpt.bin
├── transformer.h.2.ln_1.bias_gpt.bin
├── transformer.h.2.ln_1.weight_gpt.bin
├── transformer.h.2.ln_2.bias_gpt.bin
├── transformer.h.2.ln_2.weight_gpt.bin
├── transformer.h.2.mlp.c_fc.bias_gpt.bin
├── transformer.h.2.mlp.c_fc.weight_gpt.bin
├── transformer.h.2.mlp.c_proj.bias_gpt.bin
├── transformer.h.2.mlp.c_proj.weight_gpt.bin
├── transformer.h.3.attn.bias_gpt.bin
├── transformer.h.3.attn.c_attn.bias_gpt.bin
├── transformer.h.3.attn.c_attn.weight_gpt.bin
├── transformer.h.3.attn.c_proj.bias_gpt.bin
├── transformer.h.3.attn.c_proj.weight_gpt.bin
├── transformer.h.3.attn.masked_bias_gpt.bin
├── transformer.h.3.ln_1.bias_gpt.bin
├── transformer.h.3.ln_1.weight_gpt.bin
├── transformer.h.3.ln_2.bias_gpt.bin
├── transformer.h.3.ln_2.weight_gpt.bin
├── transformer.h.3.mlp.c_fc.bias_gpt.bin
├── transformer.h.3.mlp.c_fc.weight_gpt.bin
├── transformer.h.3.mlp.c_proj.bias_gpt.bin
├── transformer.h.3.mlp.c_proj.weight_gpt.bin
├── transformer.h.4.attn.bias_gpt.bin
├── transformer.h.4.attn.c_attn.bias_gpt.bin
├── transformer.h.4.attn.c_attn.weight_gpt.bin
├── transformer.h.4.attn.c_proj.bias_gpt.bin
├── transformer.h.4.attn.c_proj.weight_gpt.bin
├── transformer.h.4.attn.masked_bias_gpt.bin
├── transformer.h.4.ln_1.bias_gpt.bin
├── transformer.h.4.ln_1.weight_gpt.bin
├── transformer.h.4.ln_2.bias_gpt.bin
├── transformer.h.4.ln_2.weight_gpt.bin
├── transformer.h.4.mlp.c_fc.bias_gpt.bin
├── transformer.h.4.mlp.c_fc.weight_gpt.bin
├── transformer.h.4.mlp.c_proj.bias_gpt.bin
├── transformer.h.4.mlp.c_proj.weight_gpt.bin
├── transformer.h.5.attn.bias_gpt.bin
├── transformer.h.5.attn.c_attn.bias_gpt.bin
├── transformer.h.5.attn.c_attn.weight_gpt.bin
├── transformer.h.5.attn.c_proj.bias_gpt.bin
├── transformer.h.5.attn.c_proj.weight_gpt.bin
├── transformer.h.5.attn.masked_bias_gpt.bin
├── transformer.h.5.ln_1.bias_gpt.bin
├── transformer.h.5.ln_1.weight_gpt.bin
├── transformer.h.5.ln_2.bias_gpt.bin
├── transformer.h.5.ln_2.weight_gpt.bin
├── transformer.h.5.mlp.c_fc.bias_gpt.bin
├── transformer.h.5.mlp.c_fc.weight_gpt.bin
├── transformer.h.5.mlp.c_proj.bias_gpt.bin
├── transformer.h.5.mlp.c_proj.weight_gpt.bin
├── transformer.h.6.attn.bias_gpt.bin
├── transformer.h.6.attn.c_attn.bias_gpt.bin
├── transformer.h.6.attn.c_attn.weight_gpt.bin
├── transformer.h.6.attn.c_proj.bias_gpt.bin
├── transformer.h.6.attn.c_proj.weight_gpt.bin
├── transformer.h.6.attn.masked_bias_gpt.bin
├── transformer.h.6.ln_1.bias_gpt.bin
├── transformer.h.6.ln_1.weight_gpt.bin
├── transformer.h.6.ln_2.bias_gpt.bin
├── transformer.h.6.ln_2.weight_gpt.bin
├── transformer.h.6.mlp.c_fc.bias_gpt.bin
├── transformer.h.6.mlp.c_fc.weight_gpt.bin
├── transformer.h.6.mlp.c_proj.bias_gpt.bin
├── transformer.h.6.mlp.c_proj.weight_gpt.bin
├── transformer.h.7.attn.bias_gpt.bin
├── transformer.h.7.attn.c_attn.bias_gpt.bin
├── transformer.h.7.attn.c_attn.weight_gpt.bin
├── transformer.h.7.attn.c_proj.bias_gpt.bin
├── transformer.h.7.attn.c_proj.weight_gpt.bin
├── transformer.h.7.attn.masked_bias_gpt.bin
├── transformer.h.7.ln_1.bias_gpt.bin
├── transformer.h.7.ln_1.weight_gpt.bin
├── transformer.h.7.ln_2.bias_gpt.bin
├── transformer.h.7.ln_2.weight_gpt.bin
├── transformer.h.7.mlp.c_fc.bias_gpt.bin
├── transformer.h.7.mlp.c_fc.weight_gpt.bin
├── transformer.h.7.mlp.c_proj.bias_gpt.bin
├── transformer.h.7.mlp.c_proj.weight_gpt.bin
├── transformer.h.8.attn.bias_gpt.bin
├── transformer.h.8.attn.c_attn.bias_gpt.bin
├── transformer.h.8.attn.c_attn.weight_gpt.bin
├── transformer.h.8.attn.c_proj.bias_gpt.bin
├── transformer.h.8.attn.c_proj.weight_gpt.bin
├── transformer.h.8.attn.masked_bias_gpt.bin
├── transformer.h.8.ln_1.bias_gpt.bin
├── transformer.h.8.ln_1.weight_gpt.bin
├── transformer.h.8.ln_2.bias_gpt.bin
├── transformer.h.8.ln_2.weight_gpt.bin
├── transformer.h.8.mlp.c_fc.bias_gpt.bin
├── transformer.h.8.mlp.c_fc.weight_gpt.bin
├── transformer.h.8.mlp.c_proj.bias_gpt.bin
├── transformer.h.8.mlp.c_proj.weight_gpt.bin
├── transformer.h.9.attn.bias_gpt.bin
├── transformer.h.9.attn.c_attn.bias_gpt.bin
├── transformer.h.9.attn.c_attn.weight_gpt.bin
├── transformer.h.9.attn.c_proj.bias_gpt.bin
├── transformer.h.9.attn.c_proj.weight_gpt.bin
├── transformer.h.9.attn.masked_bias_gpt.bin
├── transformer.h.9.ln_1.bias_gpt.bin
├── transformer.h.9.ln_1.weight_gpt.bin
├── transformer.h.9.ln_2.bias_gpt.bin
├── transformer.h.9.ln_2.weight_gpt.bin
├── transformer.h.9.mlp.c_fc.bias_gpt.bin
├── transformer.h.9.mlp.c_fc.weight_gpt.bin
├── transformer.h.9.mlp.c_proj.bias_gpt.bin
├── transformer.h.9.mlp.c_proj.weight_gpt.bin
├── transformer.ln_f.bias_gpt.bin
├── transformer.ln_f.weight_gpt.bin
├── transformer.wpe.weight_gpt.bin
└── transformer.wte.weight_gpt.bin
└── tokenization
├── gpt_tokens.json
├── simple_tokens.json
└── vocab.bpe
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.bin filter=lfs diff=lfs merge=lfs -text
2 | *.json filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | weights/large-models
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | IFCOOLTELLME License
2 |
3 | Copyright (c) 2023 Will DePue
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | If this software is used for any purpose that is substantially epic, awesome, or
16 | incredible, notice is required to the Author, reachable at will@depue.net.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WebGPT
2 |
3 | 
4 |
5 | After six years of development, WebGPU is about to launch across most major web browsers. This is massive: web applications now have near-native access to the GPU, with the added capacity of compute shaders.
6 |
7 | WebGPT is a vanilla JS and HTML implementation of a transformer model, intended as a proof-of-concept as well as educational resource. WebGPT has been tested to be working with models up to 500 M parameters, though could likely support far more with further testing/optimization.
8 |
9 | ### Current Stats
10 | 2020 M1 Mac: 3ms/token at 5M parameters with f32 precision.
11 | 2020 M1 Mac: 30ms/token at 117M parameters with f32 precision.
12 | 2020 M1 Mac: 70ms/token at 377M parameters with f32 precision.
13 | 2020 M1 Mac: 120ms/token at 775M parameters with f32 precision.
14 | 1.5B is working but unstable, sitting around 1000ms/token due to inefficiencies.
15 |
16 | ## Running WebGPT
17 |
18 | Running WebGPT is remarkably simple, as it's just a set of HTML + JS files. Since WebGPU is still in the process of being released, you'll need to open with a compatible browser. WebGPU is currently available on Chrome v113 but the most straightforward way to ensure proper functionality is to install [Chrome Canary](https://www.google.com/chrome/canary/) or Edge Canary.
19 |
20 | I've included two different models: a toy GPT-Shakespeare model (which is severly undertrained haha) and GPT-2 117M. See main.js for more information on how to run these models. If you want to import custom models, take a look at misc/conversion_scripts.
21 |
22 | If you want to try out WebGPT, visit the demo website here [KMeans.org](https://www.kmeans.org). I'd generally reccomend cloning the repo and running locally, just because loading the weights remotely is significantly slower.
23 | Note: **You'll need to use Git LFS** to download the model files, after cloning the repository.
24 |
25 | 
26 |
27 | ## Roadmap / Fixing Stupid Decisions
28 |
29 | - [x] Embeddings / de-embeddings on GPU.
30 | - [x] Initializing pipelines on every step is incredibly inefficient.
31 | - [x] Key-value caching.
32 | - [x] Reuse buffers.
33 | - [x] Kernel shared memory for matmul!
34 | - [x] Destroy buffers after use!
35 | - [x] Create kernel instruction classes + optimize pipeline creation.
36 | - [X] Fuse all kernels.
37 | - [X] Optimize all other kernels.
38 | - [X] Compute pass splitting for larger models _(maxStorageBufferBindingSize)_
39 | - [ ] Run selection ops on GPU (topk, selection softmax)
40 | - [ ] Attention kernel is optimized for small models, not for large models where each head having it's own matmul is more efficient.
41 | - [ ] Investigate why attention cache isn't giving proper speed-ups.
42 | - [ ] Make simple instructional version without special stuff.
43 | - [ ] Optimize workgroup sizes, specifically for single row/col operations.
44 | - [ ] Convert into a package.
45 | - [ ] Write better comments + make Youtube explainer.
46 |
47 | ## Acknowledgements
48 |
49 | When I started this project I had no idea how transformers worked or how to implement them (or GPUs or matmul kernels or WebGPU or tokenization for that matter), so Andrej Karpathy's series on neural networks and building GPT from scratch were invaluable: [Andrej's Youtube](https://www.youtube.com/@AndrejKarpathy). I've also used some code as well from the nanoGPT repository: [nanoGPT](https://github.com/karpathy/nanoGPT).
50 |
51 | I copied from LatitudeGames' implementation of OpenAI's GPT-3 tokenizer in Javascript: [GPT-3-Encoder](https://github.com/latitudegames/GPT-3-Encoder).
52 |
--------------------------------------------------------------------------------
/globals.js:
--------------------------------------------------------------------------------
1 | const FastMatMulBlock = new FastMatMulBlockClass();
2 | const AttentionBlock = new AttentionBlockClass();
3 | const ResidualBlock = new ResidualBlockClass();
4 | const EmbedBlock = new EmbedBlockClass();
5 | const DeEmbedBlock = new DeEmbedBlockClass();
6 | const GeluBlock = new GeluBlockClass();
7 | const LayerNormBlock = new LayerNormBlockClass();
8 | const SoftmaxBlock = new SoftmaxBlockClass();
9 |
10 | // Needed for deletion.
11 | let operations = [FastMatMulBlock, AttentionBlock, ResidualBlock, EmbedBlock, DeEmbedBlock, GeluBlock, LayerNormBlock, SoftmaxBlock];
12 |
13 | function initializeOperations(device) {
14 | for (const operation of operations) operation.initialize(device);
15 | }
16 |
17 | function destroyOperationBuffers() {
18 | for (const operation of operations) operation.destroyBuffers();
19 | }
20 |
21 | function clearOperationCache() {
22 | for (const operation of operations) operation.clearBufferCache();
23 | }
24 |
25 | function destroyOperations() {
26 | for (const operation of operations) operation.destroy();
27 | }
28 |
29 | const bufferUsageDict = {
30 | copy_from: GPUBufferUsage.COPY_SRC,
31 | copy_to: GPUBufferUsage.COPY_DST,
32 | storage: GPUBufferUsage.STORAGE,
33 | uniform: GPUBufferUsage.UNIFORM,
34 | map_read: GPUBufferUsage.MAP_READ,
35 | };
36 |
37 | // ---------------- Helper Functions ----------------
38 |
39 | async function fetchBin(url) {
40 | const response = await fetch(url);
41 | const buffer = await response.arrayBuffer();
42 | return new Float32Array(buffer);
43 | }
44 |
45 | const wgSize = (dim, size) => Math.min(Math.ceil(dim / size), Infinity);
46 |
47 | function sampleFromDistribution(probs) {
48 | const rand = Math.random();
49 | let cumulativeProb = 0;
50 | for (let i = 0; i < probs.length; i++) {
51 | cumulativeProb += probs[i];
52 | if (rand < cumulativeProb) {
53 | return i;
54 | }
55 | }
56 | return probs.length - 1;
57 | }
58 |
59 | function cpuSoftmax(logits, temperature = 1.0) {
60 | const maxLogit = Math.max(...logits);
61 | const expLogits = logits.map((logit) => Math.exp((logit - maxLogit) / temperature));
62 | const sumExpLogits = expLogits.reduce((a, b) => a + b, 0);
63 | return expLogits.map((expLogit) => expLogit / sumExpLogits);
64 | }
65 |
66 | function selectTopK(probs, top_k) {
67 | const sortedIndices = Array.from(probs)
68 | .map((value, index) => ({ value, index }))
69 | .sort((a, b) => b.value - a.value)
70 | .map(({ index }) => index);
71 | const topKIndices = sortedIndices.slice(0, top_k);
72 | const topKProbs = topKIndices.map((index) => probs[index]);
73 | return { topKIndices, topKProbs };
74 | }
75 |
76 | // ----------------------- Matrix Operations -----------------------
77 |
78 | const zeros = (dim) => new Float32Array(dim).fill(0);
79 |
80 | function transpose(array, input_rows, input_cols) {
81 | if (array.length !== input_rows * input_cols) {
82 | console.log(array.length, input_rows, input_cols);
83 | throw new Error("Transpose dims failed");
84 | }
85 |
86 | const transpose = [];
87 | for (let col = 0; col < input_cols; col++) {
88 | for (let row = 0; row < input_rows; row++) {
89 | transpose.push(array[row * input_cols + col]);
90 | }
91 | }
92 |
93 | return new Float32Array(transpose);
94 | }
95 |
96 | function leastPrimeFactor(n, start = 2) {
97 | for (let i = start; i <= Math.sqrt(n); i++) {
98 | if (n % i === 0) return i;
99 | }
100 | return n;
101 | }
102 |
103 | function formatAsMatrix(floatArray, dimA, dimB) {
104 | const resultMatrix = [];
105 | for (let i = 0; i < dimA; i++) {
106 | resultMatrix.push(floatArray.slice(i * dimB, (i + 1) * dimB));
107 | }
108 | return resultMatrix;
109 | }
110 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | WebGPU GPT Model Demo
5 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | WebGPU GPT Model Demo
17 | Checking WebGPU support...
18 |
19 | PS: Loading models is 5x slower on the web rather than running locally. Just clone the repo and open!
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | Special models (download required):
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
51 |
52 |
53 |
54 |
55 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/model.js:
--------------------------------------------------------------------------------
1 | class GPT {
2 | constructor(folder, type) {
3 | this.folder = folder;
4 | this.tokenizerType = type;
5 | this.initialized = false;
6 |
7 | this.device;
8 | this.model;
9 | this.tokenizer;
10 | this.params;
11 | this.minBufferOffset = 1;
12 |
13 | this.defaultPrompt;
14 | this.defaultTopK;
15 | this.defaultTemperature;
16 | this.defaultTokens;
17 |
18 | this.externalBuffer;
19 |
20 | this.unloadDeletionStack = [];
21 | }
22 |
23 | async initialize() {
24 | if (this.initialized) return console.error("Model already initialized");
25 | if (!navigator.gpu) throw new Error("WebGPU is not supported");
26 |
27 | const adapter = await navigator.gpu.requestAdapter();
28 | this.device = await adapter.requestDevice();
29 |
30 | initializeOperations(this.device);
31 |
32 | [this.model, this.params] = await this.loadModel(this.folder);
33 | this.tokenizer = this.tokenizerType == "bpe" ? new GPT2Tokenizer() : new SimpleTokenizer();
34 | await this.tokenizer.load();
35 |
36 | if (this.tokenizerType == "bpe") {
37 | this.defaultPrompt = `What is the answer to life, the universe, and everything?\n`;
38 | this.defaultTopK = 3;
39 | this.defaultTemperature = 1;
40 | this.defaultTokens = 30;
41 | } else {
42 | this.defaultPrompt = `WILL:\nAh, how dare you challenge me?\nHave you forgotten I built WebGPT?\n`;
43 | this.defaultTopK = 2;
44 | this.defaultTemperature = 1;
45 | this.defaultTokens = 80;
46 | }
47 |
48 | this.initialized = true;
49 |
50 | console.log("Model initialized");
51 | }
52 |
53 | async *generate(prompt, max_new_tokens, top_k, temperature) {
54 | if (!this.initialized) {
55 | console.error("Model not loaded yet");
56 | return;
57 | }
58 |
59 | // Buffer size (321644800) exceeds the max buffer size limit (268435456).
60 | // - While calling [Device].CreateBuffer([BufferDescriptor]).
61 |
62 | let history = this.tokenizer.encode(prompt);
63 | console.log(`Prompt (${history.length} tokens):\n${prompt}`);
64 |
65 | const warmupRuns = 3;
66 | let totalTime = 0;
67 |
68 | for (let i = 0; i < max_new_tokens; i++) {
69 | const idx_cond = history.slice(-this.params.n_ctx);
70 | const useAttCache = i !== 0 && history.length <= this.params.n_ctx;
71 |
72 | const startTime = performance.now();
73 | const logits = await this.run(idx_cond, useAttCache);
74 | const endTime = performance.now();
75 |
76 | // console.log(`\nIteration ${i + 1} of ${max_new_tokens}`);
77 | const lapsedTime = endTime - startTime;
78 | console.log(`Kernel execution time: ${lapsedTime} ms`);
79 | i >= warmupRuns && (totalTime += lapsedTime);
80 |
81 | const { topKIndices, topKProbs } = selectTopK(logits, top_k);
82 | const probs = cpuSoftmax(topKProbs, temperature);
83 | const idx_next = topKIndices[sampleFromDistribution(probs)];
84 |
85 | history = history.concat(idx_next);
86 |
87 | // console.log(`Output:\n${this.tokenizer.decode(history)}`);
88 |
89 | // const totalProbs = cpuSoftmax(logits, temperature);
90 | // const tokenProbsString = Array.from(totalProbs)
91 | // .map((value, index) => ({ value, index }))
92 | // .sort((a, b) => b.value - a.value)
93 | // .slice(0, 8)
94 | // .map((prob) => `{ ${this.tokenizer.decode([prob.index]).replace(/(\r\n|\n|\r)/gm, "newline")} } : ${prob.value.toPrecision(3)}`)
95 | // .join(" | ");
96 | // console.log("Top 8 token probs:", tokenProbsString);
97 |
98 | yield this.tokenizer.decode([idx_next]);
99 | }
100 |
101 | console.log(`Average kernel execution time: ${totalTime / (max_new_tokens - warmupRuns)} ms`);
102 | }
103 |
104 | async run(idx) {
105 | const { posEmbdBuffer, layer_buffers, normGammaBuffer, normBetaBuffer, embeddingsBuffers, deEmbeddingsBuffers } = this.model;
106 | const { attention_scale, n_embd, n_head, head_size, n_layer, vocab_size, hidden_size, vocab_chunk_size, vocab_chunk_instances } = this.params;
107 | const seq_length = idx.length;
108 |
109 | // ---------------- Create Passes ---------------- //
110 | // Note: These are re-initialized because everytime seq_length changes buffers are different sizes.
111 |
112 | // Pipeline creation is major bottleneck to spin up speed! Also buffer re-use.
113 |
114 | this.computePasses = [];
115 | let intermediateBuffer;
116 | let residualBuffer;
117 | {
118 | const { passes, resultBuffer } = EmbedBlock.newInstance(idx, seq_length, n_embd, vocab_chunk_size, embeddingsBuffers, posEmbdBuffer, ResidualBlock);
119 | intermediateBuffer = resultBuffer;
120 | residualBuffer = resultBuffer;
121 | this.computePasses.push(...passes);
122 | }
123 | for (let i = 0; i < n_layer; i++) {
124 | const buffers = layer_buffers[i];
125 | {
126 | const { passes, resultBuffer } = LayerNormBlock.newInstance(
127 | seq_length,
128 | n_embd,
129 | intermediateBuffer,
130 | buffers.normAttentionGammaBuffer,
131 | buffers.normAttentionBetaBuffer
132 | );
133 | intermediateBuffer = resultBuffer;
134 | this.computePasses.push(...passes);
135 | }
136 | {
137 | const { passes, resultBuffer } = AttentionBlock.newFusedInstance(
138 | seq_length,
139 | n_embd,
140 | attention_scale,
141 | n_head,
142 | head_size,
143 | intermediateBuffer,
144 | buffers.qkvWeightArray[0],
145 | buffers.qkvBiasArray[0],
146 | buffers.qkvWeightArray[1],
147 | buffers.qkvBiasArray[1],
148 | buffers.qkvWeightArray[2],
149 | buffers.qkvBiasArray[2],
150 | buffers.linearWeightsBuffer,
151 | buffers.linearBiasBuffer,
152 | FastMatMulBlock,
153 | SoftmaxBlock
154 | );
155 | intermediateBuffer = resultBuffer;
156 | this.computePasses.push(...passes);
157 | }
158 | {
159 | const { passes, resultBuffer } = ResidualBlock.newInstance(seq_length, n_embd, intermediateBuffer, residualBuffer);
160 | intermediateBuffer = resultBuffer;
161 | residualBuffer = resultBuffer;
162 | this.computePasses.push(...passes);
163 | }
164 | {
165 | const { passes, resultBuffer } = LayerNormBlock.newInstance(
166 | seq_length,
167 | n_embd,
168 | intermediateBuffer,
169 | buffers.normLinearGammaBuffer,
170 | buffers.normLinearBetaBuffer
171 | );
172 | intermediateBuffer = resultBuffer;
173 | this.computePasses.push(...passes);
174 | }
175 | {
176 | const { resultBuffer, passes } = FastMatMulBlock.newInstance(
177 | seq_length,
178 | hidden_size,
179 | n_embd,
180 | intermediateBuffer,
181 | buffers.firstLayerWeightsBuffer,
182 | buffers.firstLayerBiasBuffer
183 | );
184 | intermediateBuffer = resultBuffer;
185 | this.computePasses.push(...passes);
186 | }
187 | {
188 | const { resultBuffer, passes } = GeluBlock.newInstance(seq_length, hidden_size, intermediateBuffer);
189 | intermediateBuffer = resultBuffer;
190 | this.computePasses.push(...passes);
191 | }
192 | {
193 | const { resultBuffer, passes } = FastMatMulBlock.newInstance(
194 | seq_length,
195 | n_embd,
196 | hidden_size,
197 | intermediateBuffer,
198 | buffers.secondLayerWeightsBuffer,
199 | buffers.secondLayerBiasBuffer
200 | );
201 | intermediateBuffer = resultBuffer;
202 | this.computePasses.push(...passes);
203 | }
204 | {
205 | const { passes, resultBuffer } = ResidualBlock.newInstance(seq_length, n_embd, intermediateBuffer, residualBuffer);
206 | intermediateBuffer = resultBuffer;
207 | residualBuffer = resultBuffer;
208 | this.computePasses.push(...passes);
209 | }
210 | }
211 | {
212 | if (this.externalBuffer) {
213 | this.computePasses.push({
214 | flag: "copy",
215 | src: intermediateBuffer,
216 | srcOffset: 0,
217 | dst: this.externalBuffer,
218 | dstOffset: 0,
219 | size: this.bufferSize(seq_length, n_embd),
220 | });
221 | }
222 | }
223 | {
224 | const { passes, resultBuffer } = LayerNormBlock.newInstance(seq_length, n_embd, intermediateBuffer, normGammaBuffer, normBetaBuffer);
225 | intermediateBuffer = resultBuffer;
226 | this.computePasses.push(...passes);
227 | }
228 | {
229 | const { passes, resultBuffer } = DeEmbedBlock.newInstance(
230 | n_embd,
231 | vocab_size,
232 | vocab_chunk_size * vocab_chunk_instances,
233 | seq_length,
234 | vocab_chunk_size,
235 | intermediateBuffer,
236 | deEmbeddingsBuffers
237 | );
238 | intermediateBuffer = resultBuffer;
239 | this.computePasses.push(...passes);
240 | }
241 | const resultBuffer = intermediateBuffer;
242 |
243 | // ---------------- Compute Passes ----------------
244 |
245 | const commandEncoder = this.device.createCommandEncoder();
246 | for (const pass of this.computePasses) {
247 | if (pass.flag === "compute") {
248 | const passEncoder = commandEncoder.beginComputePass();
249 | passEncoder.setPipeline(pass.pipeline);
250 | for (let i = 0; i < pass.groups.length; i++) passEncoder.setBindGroup(i, pass.groups[i]);
251 | passEncoder.dispatchWorkgroups(pass.workgroups.x, pass.workgroups.y);
252 | passEncoder.end();
253 | } else if (pass.flag === "copy") {
254 | commandEncoder.copyBufferToBuffer(pass.src, pass.srcOffset, pass.dst, pass.dstOffset, pass.size);
255 | }
256 | }
257 | this.device.queue.submit([commandEncoder.finish()]);
258 |
259 | // ---------------- Read Results ----------------
260 |
261 | await resultBuffer.mapAsync(GPUMapMode.READ);
262 | const output = resultBuffer.getMappedRange();
263 | const outputArray = new Float32Array(output).slice(0); // Copy the array, otherwise it'll be destroyed.
264 |
265 | clearOperationCache();
266 |
267 | return outputArray;
268 | }
269 |
270 | async loadModel(folder) {
271 | if (this.initialized) return console.error("Model already loaded");
272 |
273 | console.log("Loading model from folder:", folder);
274 | const weightsFolder = `weights/${folder}/`;
275 |
276 | const params = await this.loadParameters(weightsFolder);
277 | const { embeddingsBuffers, deEmbeddingsBuffers } = await this.loadEmbeddings(params, weightsFolder);
278 | const { posEmbdBuffer } = await this.loadPositionalEmbeddings(params, weightsFolder);
279 | const layer_buffers = await this.loadLayers(params, weightsFolder);
280 |
281 | console.log("Loading final layer norm...");
282 | const { normGammaBuffer, normBetaBuffer } = await this.loadFinalLayerNorm(params, weightsFolder);
283 |
284 | const output = { layer_buffers, embeddingsBuffers, deEmbeddingsBuffers, posEmbdBuffer, normGammaBuffer, normBetaBuffer };
285 | console.log("Finished loading model.", output, params);
286 | return [output, params];
287 | }
288 |
289 | async loadParameters(weightsFolder) {
290 | console.log("Loading params...");
291 | const params = await (await fetch(`${weightsFolder}/params_gpt.json`)).json();
292 |
293 | // Did you enable GitHub LFS? Won't work without it.
294 | if (params.n_embd % 4 !== 0) throw new Error("Model load failed: n_embd must be divisible by 4.");
295 | if (params.n_embd % params.n_head !== 0) throw new Error("Model load failed: n_embd must be divisible by n_head.");
296 | // I'm unsure if this is a reasonable requirement here. At worst, I can figure out some padding method.
297 | if ((params.n_embd / params.n_head) % 4 !== 0) throw new Error("Model load failed: n_embd / n_head must be divisible by 4.");
298 | const tokenParam = this.bufferSize(params.vocab_size, params.n_embd);
299 | let minSplits = Math.ceil(tokenParam / this.device.limits.maxStorageBufferBindingSize);
300 | function vocabChunkSizeCalc(vocab_size, n_embd, splits, maxStorageBufferBindingSize) {
301 | // Possibly could be better? Needs actual benchmarking to know what approach is best.
302 | const optimisticSize = Math.ceil(vocab_size / splits / 4) * 4 * n_embd;
303 | const pessimiticSize = Math.floor(vocab_size / splits / 4) * 4 * n_embd;
304 | let vocab_chunk_size = optimisticSize;
305 | if (optimisticSize > maxStorageBufferBindingSize) {
306 | vocab_chunk_size = pessimiticSize;
307 | if (pessimiticSize * splits < tokenParam) {
308 | return vocabChunkSizeCalc(vocab_size, n_embd, splits + 1, maxStorageBufferBindingSize);
309 | }
310 | }
311 | return { vocab_chunk_size: vocab_chunk_size / n_embd, splits };
312 | }
313 | const { vocab_chunk_size, splits } = vocabChunkSizeCalc(params.vocab_size, params.n_embd, minSplits, this.device.limits.maxStorageBufferBindingSize);
314 | if (splits > minSplits) console.warn(`Non-optimal number of vocab splits. Optimal: ${minSplits}, Selected: ${splits}`);
315 |
316 | // Set derived parameters
317 | params.vocab_chunk_size = vocab_chunk_size;
318 | params.vocab_chunk_instances = splits;
319 | params.head_size = params.n_embd / params.n_head;
320 | params.hidden_size = params.n_embd * 4;
321 | params.attention_scale = 1 / Math.sqrt(params.n_embd / params.n_head);
322 | params.bias = params.bias == undefined ? true : params.bias;
323 |
324 | // Check for overflow in buffers larger than maxStorageBufferBindingSize
325 | const maxBufferSize = this.device.limits.maxStorageBufferBindingSize / 4;
326 | if (params.n_embd * params.n_ctx > maxBufferSize) console.warn("Model load failed: n_embd * n_ctx must be less than maxStorageBufferBindingSize.");
327 | if (params.n_embd * params.hidden_size > maxBufferSize)
328 | console.warn("Model load failed: n_embd * hidden_size must be less than maxStorageBufferBindingSize.");
329 | if (params.n_ctx * params.n_ctx * params.n_head > maxBufferSize)
330 | console.warn("Model load failed: n_ctx * n_ctx must be less than maxStorageBufferBindingSize.");
331 | if (params.n_embd * params.n_embd * 3 > maxBufferSize)
332 | console.warn("Model load failed: n_embd * n_embd * 3 must be less than maxStorageBufferBindingSize.");
333 |
334 | console.log("Params:", params);
335 |
336 | return params;
337 | }
338 |
339 | async loadEmbeddings(params, weightsFolder) {
340 | console.log("Loading token embeddings...");
341 | const embeddingWeights = await fetchBin(`${weightsFolder}/transformer.wte.weight_gpt.bin`);
342 |
343 | // Chunks are stored in row-major order and are of dimensions n_embd x vocab_chunk_size.
344 | // Embedding weights are imported in column-major order and are of dimensions vocab_size x n_embd.
345 | // We pre-transpose the chunk for the deEmbedding process for the matmul. Could do this on GPU later.
346 | const embeddingsBuffers = [];
347 | const deEmbeddingsBuffers = [];
348 | for (let i = 0; i < params.vocab_chunk_instances; i++) {
349 | console.log(`Loading deEmbedding chunk ${i + 1}/${params.vocab_chunk_instances}...`);
350 | const offset = i * params.vocab_chunk_size;
351 | let size = params.vocab_chunk_size;
352 |
353 | const paddedArray = new Float32Array(params.vocab_chunk_size * params.n_embd);
354 | if (i === params.vocab_chunk_instances - 1) {
355 | size = params.vocab_size - offset;
356 | paddedArray.set(size * params.n_embd, zeros((params.vocab_chunk_size * params.vocab_chunk_instances - params.vocab_size) * params.n_embd));
357 | }
358 | paddedArray.set(embeddingWeights.subarray(offset * params.n_embd, offset * params.n_embd + size * params.n_embd));
359 |
360 | embeddingsBuffers.push(this.initTensor(paddedArray, [params.vocab_chunk_size, params.n_embd], ["copy_from"]));
361 |
362 | const chunk = transpose(paddedArray, params.vocab_chunk_size, params.n_embd); // Use GPU perhaps?
363 | deEmbeddingsBuffers.push(this.initTensor(chunk, [params.n_embd, params.vocab_chunk_size], ["storage"]));
364 | }
365 |
366 | return { embeddingsBuffers, deEmbeddingsBuffers };
367 | }
368 |
369 | async loadPositionalEmbeddings(params, weightsFolder) {
370 | console.log("Loading positional embeddings...");
371 | const posEmbeddings = await fetchBin(`${weightsFolder}/transformer.wpe.weight_gpt.bin`);
372 | const posEmbdBuffer = this.initTensor(posEmbeddings, [params.n_ctx, params.n_embd], ["copy_from"]);
373 |
374 | return { posEmbdBuffer };
375 | }
376 |
377 | async loadFinalLayerNorm(params, weightsFolder) {
378 | console.log("Loading final norm...");
379 | const prefix = `${weightsFolder}/transformer.ln_f.`;
380 |
381 | const tensorPromises = [
382 | this.fetchAndInitTensor(`${prefix}weight_gpt.bin`, [params.n_embd], ["storage"]),
383 | this.fetchAndInitTensor(`${prefix}bias_gpt.bin`, [params.n_embd], ["storage"]),
384 | ];
385 |
386 | const [normGammaBuffer, normBetaBuffer] = await Promise.all(tensorPromises);
387 |
388 | return { normGammaBuffer, normBetaBuffer };
389 | }
390 |
391 | async loadLayers(params, weightsFolder) {
392 | console.log("Loading layers...");
393 | const layerPromises = [];
394 |
395 | for (let i = 0; i < params.n_layer; i++) {
396 | layerPromises.push(this.loadLayer(params, weightsFolder, i));
397 | }
398 |
399 | const layer_buffers = await Promise.all(layerPromises);
400 | return layer_buffers;
401 | }
402 |
403 | async loadLayer(params, weightsFolder, layerIndex) {
404 | console.log("Starting to load layer...", layerIndex);
405 | const prefix = `${weightsFolder}transformer.h.${layerIndex}.`;
406 |
407 | // Create an array of promises for fetching and initializing the tensors
408 | const tensorPromises = [
409 | this.fetchAndInitTensor(`${prefix}ln_1.weight_gpt.bin`, [params.n_embd], ["storage"]),
410 | this.fetchAndInitTensor(`${prefix}ln_1.bias_gpt.bin`, [params.n_embd], ["storage"]),
411 | this.fetchAndSplitQKVWeightTensors(`${prefix}attn.c_attn.weight_gpt.bin`, [params.n_embd, 3 * params.n_embd], ["storage"]),
412 | this.fetchAndSplitQKVBiasTensors(`${prefix}attn.c_attn.bias_gpt.bin`, [params.n_embd], ["storage"]),
413 | this.fetchAndInitTensor(`${prefix}attn.c_proj.weight_gpt.bin`, [params.n_embd, params.n_embd], ["storage"]),
414 | this.fetchAndInitTensor(`${prefix}attn.c_proj.bias_gpt.bin`, [params.n_embd], ["storage"]),
415 | this.fetchAndInitTensor(`${prefix}ln_2.weight_gpt.bin`, [params.n_embd], ["storage"]),
416 | this.fetchAndInitTensor(`${prefix}ln_2.bias_gpt.bin`, [params.n_embd], ["storage"]),
417 | this.fetchAndInitTensor(`${prefix}mlp.c_fc.weight_gpt.bin`, [params.n_embd, params.hidden_size], ["storage"]),
418 | this.fetchAndInitTensor(`${prefix}mlp.c_fc.bias_gpt.bin`, [params.hidden_size], ["storage"]),
419 | this.fetchAndInitTensor(`${prefix}mlp.c_proj.weight_gpt.bin`, [params.hidden_size, params.n_embd], ["storage"]),
420 | this.fetchAndInitTensor(`${prefix}mlp.c_proj.bias_gpt.bin`, [params.n_embd], ["storage"]),
421 | ];
422 |
423 | // Wait for all tensors to be fetched and initialized
424 | const [
425 | normAttentionGammaBuffer,
426 | normAttentionBetaBuffer,
427 | qkvWeightArray,
428 | qkvBiasArray,
429 | linearWeightsBuffer,
430 | linearBiasBuffer,
431 | normLinearGammaBuffer,
432 | normLinearBetaBuffer,
433 | firstLayerWeightsBuffer,
434 | firstLayerBiasBuffer,
435 | secondLayerWeightsBuffer,
436 | secondLayerBiasBuffer,
437 | ] = await Promise.all(tensorPromises);
438 |
439 | // Process the fetched data and return the layer buffers
440 | return {
441 | normAttentionGammaBuffer,
442 | normAttentionBetaBuffer,
443 | qkvWeightArray,
444 | qkvBiasArray,
445 | linearWeightsBuffer,
446 | linearBiasBuffer,
447 | normLinearGammaBuffer,
448 | normLinearBetaBuffer,
449 | firstLayerWeightsBuffer,
450 | firstLayerBiasBuffer,
451 | secondLayerWeightsBuffer,
452 | secondLayerBiasBuffer,
453 | };
454 | }
455 |
456 | async fetchAndSplitQKVWeightTensors(url, dims, ops) {
457 | const data = transpose(await fetchBin(url), dims[0], dims[1]);
458 |
459 | const qWeights = transpose(data.subarray(0, dims[0] * dims[0]), dims[0], dims[0]);
460 | const kWeights = transpose(data.subarray(dims[0] * dims[0], dims[0] * dims[0] * 2), dims[0], dims[0]);
461 | const vWeights = transpose(data.subarray(dims[0] * dims[0] * 2, dims[0] * dims[0] * 3), dims[0], dims[0]);
462 |
463 | const qWeightsBuffer = this.initTensor(qWeights, [dims[0], dims[0]], ops);
464 | const kWeightsBuffer = this.initTensor(kWeights, [dims[0], dims[0]], ops);
465 | const vWeightsBuffer = this.initTensor(vWeights, [dims[0], dims[0]], ops);
466 |
467 | return [qWeightsBuffer, kWeightsBuffer, vWeightsBuffer];
468 | }
469 |
470 | async fetchAndSplitQKVBiasTensors(url, dims, ops) {
471 | const data = await fetchBin(url);
472 |
473 | const qBias = data.subarray(0, dims[0]);
474 | const kBias = data.subarray(dims[0], dims[0] * 2);
475 | const vBias = data.subarray(dims[0] * 2, dims[0] * 3);
476 |
477 | const qBiasBuffer = this.initTensor(qBias, [dims[0]], ops);
478 | const kBiasBuffer = this.initTensor(kBias, [dims[0]], ops);
479 | const vBiasBuffer = this.initTensor(vBias, [dims[0]], ops);
480 |
481 | return [qBiasBuffer, kBiasBuffer, vBiasBuffer];
482 | }
483 |
484 | async fetchAndInitTensor(url, dims, ops) {
485 | console.log("Fetching and initializing tensor...", url);
486 | const data = await fetchBin(url);
487 | return this.initTensor(data, dims, ops);
488 | }
489 |
490 | initTensor(data, dims, ops) {
491 | const buffer = this.device.createBuffer({
492 | size: this.bufferSize(dims[0], dims[1] || 1, dims[2] || 1),
493 | usage: ops.map((u) => bufferUsageDict[u]).reduce((a, b) => a | b),
494 | mappedAtCreation: true,
495 | });
496 | new Float32Array(buffer.getMappedRange()).set(data);
497 | buffer.unmap();
498 | this.unloadDeletionStack.push(buffer);
499 | return buffer;
500 | }
501 |
502 | unloadBuffers() {
503 | this.unloadDeletionStack.map((buffer) => buffer.destroy());
504 | this.unloadDeletionStack = [];
505 | }
506 |
507 | bufferSize(dimX, dimY = 1, dimZ = 1) {
508 | const size = Math.ceil((dimX * dimY * dimZ * Float32Array.BYTES_PER_ELEMENT) / this.minBufferOffset) * this.minBufferOffset;
509 | if (size > this.device.limits.maxStorageBufferBindingSize)
510 | console.warn("Warning: Buffer size calc result exceeds GPU limit, are you using this value for a tensor size?", dimX, dimY, dimZ, size);
511 | return size;
512 | }
513 | }
514 |
--------------------------------------------------------------------------------
/other/conversion_scripts/README.md:
--------------------------------------------------------------------------------
1 | # Running custom models on WebGPU
2 |
3 | It's fairly easy to run custom models on WebGPU. At the moment, I only support PyTorch models via the scripts below but it should be fairly simple to export other model weights to work here.
4 |
5 | Importing weights requires you to export transformer weights as a series of individual .bin files. Pardon the somewhat inconvenient process as loading such significant file sizes into Javascript requires some clever engineering.
6 |
7 | An example structure with only two layers. Each matrix is collapes into a row-major 1-dimensional array.
8 |
9 | ```
10 | transformer.wte.weight.bin: [65, 128]
11 | transformer.wpe.weight.bin: [64, 128]
12 | transformer.h.0.ln_1.weight.bin: [128]
13 | transformer.h.0.ln_1.bias.bin: [128]
14 | transformer.h.0.attn.c_attn.weight.bin: [384, 128]
15 | transformer.h.0.attn.c_attn.bias.bin: [384]
16 | transformer.h.0.attn.c_proj.weight.bin: [128, 128]
17 | transformer.h.0.attn.c_proj.bias.bin: [128]
18 | transformer.h.0.ln_2.weight.bin: [128]
19 | transformer.h.0.ln_2.bias.bin: [128]
20 | transformer.h.0.mlp.c_fc.weight.bin: [512, 128]
21 | transformer.h.0.mlp.c_fc.bias.bin: [512]
22 | transformer.h.0.mlp.c_proj.weight.bin: [128, 512]
23 | transformer.h.0.mlp.c_proj.bias.bin: [128]
24 | transformer.h.1.ln_1.weight.bin: [128]
25 | transformer.h.1.ln_1.bias.bin: [128]
26 | transformer.h.1.attn.c_attn.weight.bin: [384, 128]
27 | transformer.h.1.attn.c_attn.bias.bin: [384]
28 | transformer.h.1.attn.c_proj.weight.bin: [128, 128]
29 | transformer.h.1.attn.c_proj.bias.bin: [128]
30 | transformer.h.1.ln_2.weight.bin: [128]
31 | transformer.h.1.ln_2.bias.bin: [128]
32 | transformer.h.1.mlp.c_fc.weight.bin: [512, 128]
33 | transformer.h.1.mlp.c_fc.bias.bin: [512]
34 | transformer.h.1.mlp.c_proj.weight.bin: [128, 512]
35 | transformer.h.1.mlp.c_proj.bias.bin: [128]
36 | transformer.ln_f.weight.bin: [128]
37 | transformer.ln_f.bias.bin: [128]
38 | lm_head.weight.bin: [65, 128]
39 | ```
40 |
41 | I've included a export script for PyTorch models. Quite simply, you must use the model.state_dict() and export into individual files. If you want to export pre-trained GPT models, you'll need to slightly format the parameters to work correctly.
42 |
--------------------------------------------------------------------------------
/other/conversion_scripts/ckpt.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/WebGPT/a83cdc8d46e8d140b55d87089482999580b64a3d/other/conversion_scripts/ckpt.pt
--------------------------------------------------------------------------------
/other/conversion_scripts/convert_checkpoint_pytorch.py:
--------------------------------------------------------------------------------
1 | import json
2 | import struct
3 | import torch
4 | import os
5 |
6 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight',
7 | 'mlp.c_fc.weight', 'mlp.c_proj.weight']
8 |
9 |
10 | def save_weights_to_bin_files(checkpoint, folder_name):
11 | for key, value in checkpoint['model'].items():
12 | print(f"{key}: {value.shape}")
13 | if key.startswith('_orig_mod.'):
14 | continue
15 | with open(os.path.join(folder_name, f"{key}_gpt.bin"), 'wb') as file:
16 | values = value.cpu().numpy()
17 | # Only use this if using old minGPT model.
18 | # if any(key.endswith(w) for w in transposed):
19 | # values = values.T
20 |
21 | for single_value in values.flatten():
22 | file.write(struct.pack('> 24) / 127.0) * absmax;
56 | matrix[i + 1] = (((packedValue << 16) >> 24) / 127.0) * absmax;
57 | matrix[i + 2] = (((packedValue << 8) >> 24) / 127.0) * absmax;
58 | matrix[i + 3] = ((packedValue >> 24) / 127.0) * absmax;
59 | }
60 |
61 | return matrix;
62 | }
63 |
64 | const qa = quantizeMatrix(A, M, K);
65 | const qb = quantizeMatrix(B, K, N);
66 |
67 | const quantizedA = qa.quantizedMatrix;
68 | const quantizedB = qb.quantizedMatrix;
69 |
70 | const dqB = dequantizeMatrix(quantizedB, qb.absmax, K, N);
71 |
72 | // for (let i = 0; i < 10; i++) {
73 | // console.log(B[i], dqB[i]);
74 | // }
75 |
76 | const absmax = Math.max(qa.absmax, qb.absmax);
77 |
78 | // Naive CPU implementation of matrix multiplication
79 | function multiplyMatrices(A, B, C, M, N, K) {
80 | for (let i = 0; i < M; i++) {
81 | for (let j = 0; j < N; j++) {
82 | let sum = 0;
83 | for (let k = 0; k < K; k++) {
84 | sum += A[i * K + k] * B[k * N + j];
85 | }
86 | C[i * N + j] = sum;
87 | }
88 | }
89 | }
90 |
91 | async function run() {
92 | // Create WebGPU device and queue
93 | const adapter = await navigator.gpu.requestAdapter();
94 | const device = await adapter.requestDevice();
95 | const queue = device.queue;
96 |
97 | // Create buffers for matrices A, B, and C
98 | const aBuffer = device.createBuffer({
99 | size: A.byteLength,
100 | usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
101 | });
102 | const bBuffer = device.createBuffer({
103 | size: quantizedB.byteLength,
104 | usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
105 | });
106 | const cBuffer = device.createBuffer({
107 | size: C.byteLength,
108 | usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
109 | });
110 |
111 | // Copy matrices A and B to their respective buffers
112 | queue.writeBuffer(aBuffer, 0, A);
113 | queue.writeBuffer(bBuffer, 0, quantizedB);
114 |
115 | // Create bind group layout and bind group
116 |
117 | const shaderCode = `
118 |
119 | @group(0) @binding(0) var array_a: array>;
120 | @group(0) @binding(1) var array_b: array;
121 |
122 | @group(0) @binding(2) var array_c: array>;
123 |
124 | const absmax = ${absmax};
125 |
126 | fn unpackInt8x4(value: i32) -> vec4 {
127 | let x = f32((value << 24) >> 24) / 127.0 * absmax;
128 | let y = f32(((value << 16) >> 24)) / 127.0 * absmax;
129 | let z = f32(((value << 8) >> 24)) / 127.0 * absmax;
130 | let w = f32(((value >> 24))) / 127.0 * absmax;
131 | return vec4(x, y, z, w);
132 | }
133 |
134 | @compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY})
135 | fn main(@builtin(global_invocation_id) global_id: vec3) {
136 | var M: u32 = ${M};
137 | var N: u32 = ${N};
138 | var ND4: u32 = ${Math.ceil(N / 4)};
139 | var KD4: u32 = ${Math.ceil(K / 4)};
140 | var x: u32 = global_id.x;
141 | var y: u32 = global_id.y;
142 |
143 | if (x * 8 >= N || y * 4 >= M) {
144 | return;
145 | }
146 |
147 | var sum00: vec4 = vec4();
148 | var sum01: vec4 = vec4();
149 | var sum02: vec4 = vec4();
150 | var sum03: vec4 = vec4();
151 | var sum10: vec4 = vec4();
152 | var sum11: vec4 = vec4();
153 | var sum12: vec4 = vec4();
154 | var sum13: vec4 = vec4();
155 |
156 | for(var k: u32 = 0u; k < KD4; k = k + 1u) {
157 | var arow0: vec4 = array_a[(y * 4u + 0u) * KD4 + k];
158 | var arow1: vec4 = array_a[(y * 4u + 1u) * KD4 + k];
159 | var arow2: vec4 = array_a[(y * 4u + 2u) * KD4 + k];
160 | var arow3: vec4 = array_a[(y * 4u + 3u) * KD4 + k];
161 | var brow: vec4;
162 |
163 | brow = unpackInt8x4(array_b[(k * 4u + 0u) * ND4 + x * 2u + 0u]);
164 | sum00 = vec4(arow0.x) * brow + sum00;
165 | sum01 = vec4(arow1.x) * brow + sum01;
166 | sum02 = vec4(arow2.x) * brow + sum02;
167 | sum03 = vec4(arow3.x) * brow + sum03;
168 |
169 | brow = unpackInt8x4(array_b[(k * 4u + 0u) * ND4 + x * 2u + 1u]);
170 | sum10 = vec4(arow0.x) * brow + sum10;
171 | sum11 = vec4(arow1.x) * brow + sum11;
172 | sum12 = vec4(arow2.x) * brow + sum12;
173 | sum13 = vec4(arow3.x) * brow + sum13;
174 |
175 | brow = unpackInt8x4(array_b[(k * 4u + 1u) * ND4 + x * 2u + 0u]);
176 | sum00 = vec4(arow0.y) * brow + sum00;
177 | sum01 = vec4(arow1.y) * brow + sum01;
178 | sum02 = vec4(arow2.y) * brow + sum02;
179 | sum03 = vec4(arow3.y) * brow + sum03;
180 |
181 | brow = unpackInt8x4(array_b[(k * 4u + 1u) * ND4 + x * 2u + 1u]);
182 | sum10 = vec4(arow0.y) * brow + sum10;
183 | sum11 = vec4(arow1.y) * brow + sum11;
184 | sum12 = vec4(arow2.y) * brow + sum12;
185 | sum13 = vec4(arow3.y) * brow + sum13;
186 |
187 | brow = unpackInt8x4(array_b[(k * 4u + 2u) * ND4 + x * 2u + 0u]);
188 | sum00 = vec4(arow0.z) * brow + sum00;
189 | sum01 = vec4(arow1.z) * brow + sum01;
190 | sum02 = vec4(arow2.z) * brow + sum02;
191 | sum03 = vec4(arow3.z) * brow + sum03;
192 |
193 | brow = unpackInt8x4(array_b[(k * 4u + 2u) * ND4 + x * 2u + 1u]);
194 | sum10 = vec4(arow0.z) * brow + sum10;
195 | sum11 = vec4(arow1.z) * brow + sum11;
196 | sum12 = vec4(arow2.z) * brow + sum12;
197 | sum13 = vec4(arow3.z) * brow + sum13;
198 |
199 | brow = unpackInt8x4(array_b[(k * 4u + 3u) * ND4 + x * 2u + 0u]);
200 | sum00 = vec4(arow0.w) * brow + sum00;
201 | sum01 = vec4(arow1.w) * brow + sum01;
202 | sum02 = vec4(arow2.w) * brow + sum02;
203 | sum03 = vec4(arow3.w) * brow + sum03;
204 |
205 | brow = unpackInt8x4(array_b[(k * 4u + 3u) * ND4 + x * 2u + 1u]);
206 | sum10 = vec4(arow0.w) * brow + sum10;
207 | sum11 = vec4(arow1.w) * brow + sum11;
208 | sum12 = vec4(arow2.w) * brow + sum12;
209 | sum13 = vec4(arow3.w) * brow + sum13;
210 | }
211 |
212 | if (y * 4u + 0u < M) {
213 | array_c[x * 2u + 0u + (y * 4u + 0u) * ND4] = sum00;
214 | array_c[x * 2u + 1u + (y * 4u + 0u) * ND4] = sum10;
215 | }
216 | if (y * 4u + 1u < M) {
217 | array_c[x * 2u + 0u + (y * 4u + 1u) * ND4] = sum01;
218 | array_c[x * 2u + 1u + (y * 4u + 1u) * ND4] = sum11;
219 | }
220 | if (y * 4u + 2u < M) {
221 | array_c[x * 2u + 0u + (y * 4u + 2u) * ND4] = sum02;
222 | array_c[x * 2u + 1u + (y * 4u + 2u) * ND4] = sum12;
223 | }
224 | if (y * 4u + 3u < M) {
225 | array_c[x * 2u + 0u + (y * 4u + 3u) * ND4] = sum03;
226 | array_c[x * 2u + 1u + (y * 4u + 3u) * ND4] = sum13;
227 | }
228 | }
229 | `;
230 |
231 | const shaderModule = device.createShaderModule({
232 | code: shaderCode,
233 | });
234 |
235 | const bindGroupLayout = device.createBindGroupLayout({
236 | entries: [
237 | {
238 | binding: 0,
239 | visibility: GPUShaderStage.COMPUTE,
240 | buffer: {
241 | type: "read-only-storage",
242 | },
243 | },
244 | {
245 | binding: 1,
246 | visibility: GPUShaderStage.COMPUTE,
247 | buffer: {
248 | type: "read-only-storage",
249 | },
250 | },
251 | {
252 | binding: 2,
253 | visibility: GPUShaderStage.COMPUTE,
254 | buffer: {
255 | type: "storage",
256 | },
257 | },
258 | ],
259 | });
260 |
261 | const bindGroup = device.createBindGroup({
262 | layout: bindGroupLayout,
263 | entries: [
264 | {
265 | binding: 0,
266 | resource: {
267 | buffer: aBuffer,
268 | },
269 | },
270 | {
271 | binding: 1,
272 | resource: {
273 | buffer: bBuffer,
274 | },
275 | },
276 | {
277 | binding: 2,
278 | resource: {
279 | buffer: cBuffer,
280 | },
281 | },
282 | ],
283 | });
284 |
285 | const pipelineLayout = device.createPipelineLayout({
286 | bindGroupLayouts: [bindGroupLayout],
287 | });
288 |
289 | const pipeline = device.createComputePipeline({
290 | layout: pipelineLayout,
291 | compute: {
292 | module: shaderModule,
293 | entryPoint: "main",
294 | },
295 | });
296 | const encoder = device.createCommandEncoder();
297 | const passEncoder = encoder.beginComputePass();
298 |
299 | // Dispatch the compute kernel
300 | passEncoder.setPipeline(pipeline);
301 | passEncoder.setBindGroup(0, bindGroup);
302 | passEncoder.dispatchWorkgroups(workgroupSizeX, workgroupSizeY, 1);
303 | passEncoder.end();
304 |
305 | const readBuffer = device.createBuffer({
306 | size: C.byteLength,
307 | usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
308 | });
309 |
310 | // Copy matrix C from the GPU to the CPU
311 | encoder.copyBufferToBuffer(cBuffer, 0, readBuffer, 0, C.byteLength);
312 |
313 | device.queue.submit([encoder.finish()]);
314 |
315 | await readBuffer.mapAsync(GPUMapMode.READ);
316 | const readBufferData = new Float32Array(readBuffer.getMappedRange());
317 |
318 | const C_cpu = new Float32Array(M * N);
319 | multiplyMatrices(A, B, C_cpu, M, N, K);
320 |
321 | for (let i = 0; i < M * N; i++) {
322 | if (Math.abs(C_cpu[i] - readBufferData[i]) > 0.1) {
323 | console.error("CPU and GPU results differ at index", i);
324 | console.error("CPU:", C_cpu[i], "GPU:", readBufferData[i]);
325 | break;
326 | }
327 | // } else {
328 | // console.log("CPU and GPU results are the same at index", i);
329 | // console.log("CPU:", C_cpu[i], "GPU:", readBufferData[i]);
330 | // }
331 | }
332 |
333 | let mae = 0;
334 | for (let i = 0; i < M * N; i++) {
335 | mae += Math.abs(C_cpu[i] - readBufferData[i]);
336 | }
337 | mae /= M * N;
338 | console.log("Mean Absolute Error:", mae);
339 |
340 | const NUM_RUNS = 100;
341 |
342 | //warmup
343 |
344 | for (let i = 0; i < NUM_RUNS; i++) {
345 | // Dispatch the compute kernel
346 | const encoder = device.createCommandEncoder();
347 | const passEncoder = encoder.beginComputePass();
348 |
349 | // Dispatch the compute kernel
350 | passEncoder.setPipeline(pipeline);
351 | passEncoder.setBindGroup(0, bindGroup);
352 | passEncoder.dispatchWorkgroups(workgroupSizeX, workgroupSizeY, 1);
353 |
354 | passEncoder.end();
355 |
356 | const readBuffer = device.createBuffer({
357 | size: C.byteLength,
358 | usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
359 | });
360 |
361 | // Copy matrix C from the GPU to the CPU
362 | encoder.copyBufferToBuffer(cBuffer, 0, readBuffer, 0, C.byteLength);
363 | }
364 |
365 | // Run GPU kernel NUM_RUNS times and measure time
366 | let totalTime = 0;
367 | for (let i = 0; i < NUM_RUNS; i++) {
368 | const start = performance.now();
369 |
370 | // Dispatch the compute kernel
371 | const encoder = device.createCommandEncoder();
372 | const passEncoder = encoder.beginComputePass();
373 |
374 | // Dispatch the compute kernel
375 | passEncoder.setPipeline(pipeline);
376 | passEncoder.setBindGroup(0, bindGroup);
377 | passEncoder.dispatchWorkgroups(M / workgroupSizeX, N / workgroupSizeY, 1);
378 |
379 | passEncoder.end();
380 |
381 | const readBuffer = device.createBuffer({
382 | size: C.byteLength,
383 | usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
384 | });
385 |
386 | // Copy matrix C from the GPU to the CPU
387 | encoder.copyBufferToBuffer(cBuffer, 0, readBuffer, 0, C.byteLength);
388 |
389 | const end = performance.now();
390 | totalTime += end - start;
391 | }
392 | const averageTime = totalTime / NUM_RUNS;
393 | console.log(`Average time per run: ${averageTime.toFixed(2)} ms`);
394 | // print flops
395 |
396 | const flops = (2 * M * N * K) / averageTime;
397 | console.log(`GFLOPS: ${flops / 1e9}`);
398 | }
399 |
400 | run();
401 |
--------------------------------------------------------------------------------
/other/misc/files.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/WebGPT/a83cdc8d46e8d140b55d87089482999580b64a3d/other/misc/files.png
--------------------------------------------------------------------------------
/other/misc/header.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/WebGPT/a83cdc8d46e8d140b55d87089482999580b64a3d/other/misc/header.png
--------------------------------------------------------------------------------
/other/scratchpad.js:
--------------------------------------------------------------------------------
1 | class Instruction {
2 | constructor(device) {
3 | this.device = device;
4 | this.bufferDeletionStack = [];
5 | this.unloadDeletionStack = [];
6 |
7 | this.initBindGroups();
8 | }
9 |
10 | initBindGroup(layout, buffers, label = "") {
11 | return this.device.createBindGroup({
12 | layout,
13 | entries: buffers.map((buffer, i) => ({
14 | binding: i,
15 | resource: { buffer },
16 | })),
17 | label,
18 | });
19 | }
20 |
21 | initBuffer(ops, row, col = 1, noDelete = false) {
22 | const buffer = this.device.createBuffer({
23 | size: this.bufferSize(row, col),
24 | usage: ops.map((u) => bufferUsageDict[u]).reduce((a, b) => a | b),
25 | });
26 | if (!noDelete) this.bufferDeletionStack.push(buffer);
27 | else this.unloadDeletionStack.push(buffer);
28 | return buffer;
29 | }
30 |
31 | bufferSize(dimA, dimB = 1) {
32 | return Math.ceil((dimA * dimB * Float32Array.BYTES_PER_ELEMENT) / 1) * 1;
33 | }
34 |
35 | initBindGroups() {
36 | const bg = (types) =>
37 | this.device.createBindGroupLayout({
38 | entries: types.map((entry, i) => ({
39 | binding: i,
40 | visibility: GPUShaderStage.COMPUTE,
41 | buffer: { type: entry },
42 | })),
43 | });
44 |
45 | this.r_r_r_Layout = bg(["read-only-storage", "read-only-storage", "read-only-storage"]);
46 | this.r_r_Layout = bg(["read-only-storage", "read-only-storage"]);
47 | this.r_Layout = bg(["read-only-storage"]);
48 | this.u_s_Layout = bg(["uniform", "storage"]);
49 | this.u_s_s_s_Layout = bg(["uniform", "storage", "storage", "storage"]);
50 | }
51 |
52 | initPipeline(code, bindGroupLayouts, label = "", constants = {}) {
53 | return this.device.createComputePipeline({
54 | layout: this.device.createPipelineLayout({ bindGroupLayouts }),
55 | compute: {
56 | module: this.device.createShaderModule({ code }),
57 | entryPoint: "main",
58 | constants,
59 | },
60 | label,
61 | });
62 | }
63 |
64 | unloadBuffers() {
65 | this.unloadDeletionStack.map((buffer) => buffer.destroy());
66 | this.unloadDeletionStack = [];
67 | }
68 |
69 | destroyBuffers() {
70 | this.bufferDeletionStack.map((buffer) => buffer.destroy());
71 | this.bufferDeletionStack = [];
72 | }
73 | }
74 |
75 | class FastMatMul extends Instruction {
76 | constructor(device) {
77 | super(device);
78 | this.name = "fastMatMul";
79 | this.pipelineCache = new Map();
80 | }
81 |
82 | getPipeline(rows) {
83 | const div4 = rows % 4 === 0;
84 | const pipelineCacheKey = div4 ? "fastMatMulNoCheck" : "fastMatMul";
85 | if (this.pipelineCache.has(pipelineCacheKey)) {
86 | return this.pipelineCache.get(pipelineCacheKey);
87 | }
88 | const kernel = div4 ? this.fastMatMulNoCheck : this.fastMatMul;
89 | const pipeline = this.initPipeline(kernel, [this.u_s_Layout, this.r_r_Layout], pipelineCacheKey);
90 | this.pipelineCache.set(pipelineCacheKey, pipeline);
91 | return pipeline;
92 | }
93 |
94 | newInstance(rows, cols, shared, bufA, bufB) {
95 | const pipeline = this.getPipeline(rows);
96 | const uniformBuffer = this.initBuffer(["uniform", "copy_to"], 4);
97 | const resultBuf = this.initBuffer(["storage", "copy_from"], rows, cols);
98 | const opBindGroup = this.initBindGroup(this.u_s_Layout, [uniformBuffer, resultBuf], "opBindGroup");
99 | const inputBindGroup = this.initBindGroup(this.r_r_Layout, [bufA, bufB], "inputBindGroup");
100 | const workgroups = { x: wgSize(cols, 64), y: wgSize(rows, 32) };
101 | this.device.queue.writeBuffer(uniformBuffer, 0, new Uint32Array([rows, cols, Math.ceil(cols / 4), Math.ceil(shared / 4)]));
102 |
103 | return {
104 | resultBuf,
105 | pass: {
106 | pipeline,
107 | groups: [opBindGroup, inputBindGroup],
108 | workgroups,
109 | },
110 | };
111 | }
112 |
113 | fastMatMul = `
114 | struct CMeta {
115 | M: u32,
116 | N: u32,
117 | ND4: u32,
118 | KD4: u32,
119 | }
120 |
121 | @group(1) @binding(0) var array_a: array>;
122 | @group(1) @binding(1) var array_b: array>;
123 |
124 | @group(0) @binding(0) var cmeta: CMeta;
125 | @group(0) @binding(1) var array_c: array>;
126 |
127 | @compute @workgroup_size(8, 8)
128 | fn main(@builtin(global_invocation_id) global_id: vec3) {
129 | var M: u32 = cmeta.M;
130 | var N: u32 = cmeta.N;
131 | var ND4: u32 = cmeta.ND4;
132 | var KD4: u32 = cmeta.KD4;
133 | var x: u32 = global_id.x;
134 | var y: u32 = global_id.y;
135 |
136 | if (x * 8 >= N || y * 4 >= M) {
137 | return;
138 | }
139 |
140 | var sum00: vec4 = vec4();
141 | var sum01: vec4 = vec4();
142 | var sum02: vec4 = vec4();
143 | var sum03: vec4 = vec4();
144 | var sum10: vec4 = vec4();
145 | var sum11: vec4 = vec4();
146 | var sum12: vec4 = vec4();
147 | var sum13: vec4 = vec4();
148 |
149 | for(var k: u32 = 0u; k < KD4; k = k + 1u) {
150 | var arow0: vec4 = array_a[(y * 4u + 0u) * KD4 + k];
151 | var arow1: vec4 = array_a[(y * 4u + 1u) * KD4 + k];
152 | var arow2: vec4 = array_a[(y * 4u + 2u) * KD4 + k];
153 | var arow3: vec4 = array_a[(y * 4u + 3u) * KD4 + k];
154 | var brow: vec4;
155 |
156 | brow = array_b[(k * 4u + 0u) * ND4 + x * 2u + 0u];
157 | sum00 = vec4(arow0.x) * brow + sum00;
158 | sum01 = vec4(arow1.x) * brow + sum01;
159 | sum02 = vec4(arow2.x) * brow + sum02;
160 | sum03 = vec4(arow3.x) * brow + sum03;
161 |
162 | brow = array_b[(k * 4u + 0u) * ND4 + x * 2u + 1u];
163 | sum10 = vec4(arow0.x) * brow + sum10;
164 | sum11 = vec4(arow1.x) * brow + sum11;
165 | sum12 = vec4(arow2.x) * brow + sum12;
166 | sum13 = vec4(arow3.x) * brow + sum13;
167 |
168 | brow = array_b[(k * 4u + 1u) * ND4 + x * 2u + 0u];
169 | sum00 = vec4(arow0.y) * brow + sum00;
170 | sum01 = vec4(arow1.y) * brow + sum01;
171 | sum02 = vec4(arow2.y) * brow + sum02;
172 | sum03 = vec4(arow3.y) * brow + sum03;
173 |
174 | brow = array_b[(k * 4u + 1u) * ND4 + x * 2u + 1u];
175 | sum10 = vec4(arow0.y) * brow + sum10;
176 | sum11 = vec4(arow1.y) * brow + sum11;
177 | sum12 = vec4(arow2.y) * brow + sum12;
178 | sum13 = vec4(arow3.y) * brow + sum13;
179 |
180 | brow = array_b[(k * 4u + 2u) * ND4 + x * 2u + 0u];
181 | sum00 = vec4(arow0.z) * brow + sum00;
182 | sum01 = vec4(arow1.z) * brow + sum01;
183 | sum02 = vec4(arow2.z) * brow + sum02;
184 | sum03 = vec4(arow3.z) * brow + sum03;
185 |
186 | brow = array_b[(k * 4u + 2u) * ND4 + x * 2u + 1u];
187 | sum10 = vec4(arow0.z) * brow + sum10;
188 | sum11 = vec4(arow1.z) * brow + sum11;
189 | sum12 = vec4(arow2.z) * brow + sum12;
190 | sum13 = vec4(arow3.z) * brow + sum13;
191 |
192 | brow = array_b[(k * 4u + 3u) * ND4 + x * 2u + 0u];
193 | sum00 = vec4(arow0.w) * brow + sum00;
194 | sum01 = vec4(arow1.w) * brow + sum01;
195 | sum02 = vec4(arow2.w) * brow + sum02;
196 | sum03 = vec4(arow3.w) * brow + sum03;
197 |
198 | brow = array_b[(k * 4u + 3u) * ND4 + x * 2u + 1u];
199 | sum10 = vec4(arow0.w) * brow + sum10;
200 | sum11 = vec4(arow1.w) * brow + sum11;
201 | sum12 = vec4(arow2.w) * brow + sum12;
202 | sum13 = vec4(arow3.w) * brow + sum13;
203 | }
204 |
205 | if (y * 4u + 0u < M) {
206 | array_c[x * 2u + 0u + (y * 4u + 0u) * ND4] = sum00;
207 | array_c[x * 2u + 1u + (y * 4u + 0u) * ND4] = sum10;
208 | }
209 | if (y * 4u + 1u < M) {
210 | array_c[x * 2u + 0u + (y * 4u + 1u) * ND4] = sum01;
211 | array_c[x * 2u + 1u + (y * 4u + 1u) * ND4] = sum11;
212 | }
213 | if (y * 4u + 2u < M) {
214 | array_c[x * 2u + 0u + (y * 4u + 2u) * ND4] = sum02;
215 | array_c[x * 2u + 1u + (y * 4u + 2u) * ND4] = sum12;
216 | }
217 | if (y * 4u + 3u < M) {
218 | array_c[x * 2u + 0u + (y * 4u + 3u) * ND4] = sum03;
219 | array_c[x * 2u + 1u + (y * 4u + 3u) * ND4] = sum13;
220 | }
221 | }
222 | `;
223 |
224 | fastMatMulNoCheck = `
225 | struct CMeta {
226 | M: u32,
227 | N: u32,
228 | ND4: u32,
229 | KD4: u32,
230 | }
231 |
232 | @group(1) @binding(0) var array_a: array>;
233 | @group(1) @binding(1) var array_b: array>;
234 |
235 | @group(0) @binding(0) var cmeta: CMeta;
236 | @group(0) @binding(1) var array_c: array>;
237 |
238 | @compute @workgroup_size(8, 8)
239 | fn main(@builtin(global_invocation_id) global_id: vec3) {
240 | var M: u32 = cmeta.M;
241 | var N: u32 = cmeta.N;
242 | var ND4: u32 = cmeta.ND4;
243 | var KD4: u32 = cmeta.KD4;
244 | var x: u32 = global_id.x;
245 | var y: u32 = global_id.y;
246 |
247 | if (x * 8 >= N || y * 4 >= M) {
248 | return;
249 | }
250 |
251 | var sum00: vec4 = vec4();
252 | var sum01: vec4 = vec4();
253 | var sum02: vec4 = vec4();
254 | var sum03: vec4 = vec4();
255 | var sum10: vec4 = vec4();
256 | var sum11: vec4 = vec4();
257 | var sum12: vec4 = vec4();
258 | var sum13: vec4 = vec4();
259 |
260 | for(var k: u32 = 0u; k < KD4; k = k + 1u) {
261 | var arow0: vec4 = array_a[(y * 4u + 0u) * KD4 + k];
262 | var arow1: vec4 = array_a[(y * 4u + 1u) * KD4 + k];
263 | var arow2: vec4 = array_a[(y * 4u + 2u) * KD4 + k];
264 | var arow3: vec4 = array_a[(y * 4u + 3u) * KD4 + k];
265 | var brow: vec4;
266 |
267 | brow = array_b[(k * 4u + 0u) * ND4 + x * 2u + 0u];
268 | sum00 = vec4(arow0.x) * brow + sum00;
269 | sum01 = vec4(arow1.x) * brow + sum01;
270 | sum02 = vec4(arow2.x) * brow + sum02;
271 | sum03 = vec4(arow3.x) * brow + sum03;
272 |
273 | brow = array_b[(k * 4u + 0u) * ND4 + x * 2u + 1u];
274 | sum10 = vec4(arow0.x) * brow + sum10;
275 | sum11 = vec4(arow1.x) * brow + sum11;
276 | sum12 = vec4(arow2.x) * brow + sum12;
277 | sum13 = vec4(arow3.x) * brow + sum13;
278 |
279 | brow = array_b[(k * 4u + 1u) * ND4 + x * 2u + 0u];
280 | sum00 = vec4(arow0.y) * brow + sum00;
281 | sum01 = vec4(arow1.y) * brow + sum01;
282 | sum02 = vec4(arow2.y) * brow + sum02;
283 | sum03 = vec4(arow3.y) * brow + sum03;
284 |
285 | brow = array_b[(k * 4u + 1u) * ND4 + x * 2u + 1u];
286 | sum10 = vec4(arow0.y) * brow + sum10;
287 | sum11 = vec4(arow1.y) * brow + sum11;
288 | sum12 = vec4(arow2.y) * brow + sum12;
289 | sum13 = vec4(arow3.y) * brow + sum13;
290 |
291 | brow = array_b[(k * 4u + 2u) * ND4 + x * 2u + 0u];
292 | sum00 = vec4(arow0.z) * brow + sum00;
293 | sum01 = vec4(arow1.z) * brow + sum01;
294 | sum02 = vec4(arow2.z) * brow + sum02;
295 | sum03 = vec4(arow3.z) * brow + sum03;
296 |
297 | brow = array_b[(k * 4u + 2u) * ND4 + x * 2u + 1u];
298 | sum10 = vec4(arow0.z) * brow + sum10;
299 | sum11 = vec4(arow1.z) * brow + sum11;
300 | sum12 = vec4(arow2.z) * brow + sum12;
301 | sum13 = vec4(arow3.z) * brow + sum13;
302 |
303 | brow = array_b[(k * 4u + 3u) * ND4 + x * 2u + 0u];
304 | sum00 = vec4(arow0.w) * brow + sum00;
305 | sum01 = vec4(arow1.w) * brow + sum01;
306 | sum02 = vec4(arow2.w) * brow + sum02;
307 | sum03 = vec4(arow3.w) * brow + sum03;
308 |
309 | brow = array_b[(k * 4u + 3u) * ND4 + x * 2u + 1u];
310 | sum10 = vec4(arow0.w) * brow + sum10;
311 | sum11 = vec4(arow1.w) * brow + sum11;
312 | sum12 = vec4(arow2.w) * brow + sum12;
313 | sum13 = vec4(arow3.w) * brow + sum13;
314 | }
315 |
316 | array_c[x * 2u + 0u + (y * 4u + 0u) * ND4] = sum00;
317 | array_c[x * 2u + 1u + (y * 4u + 0u) * ND4] = sum10;
318 | array_c[x * 2u + 0u + (y * 4u + 1u) * ND4] = sum01;
319 | array_c[x * 2u + 1u + (y * 4u + 1u) * ND4] = sum11;
320 | array_c[x * 2u + 0u + (y * 4u + 2u) * ND4] = sum02;
321 | array_c[x * 2u + 1u + (y * 4u + 2u) * ND4] = sum12;
322 | array_c[x * 2u + 0u + (y * 4u + 3u) * ND4] = sum03;
323 | array_c[x * 2u + 1u + (y * 4u + 3u) * ND4] = sum13;
324 | }
325 | `;
326 | }
327 |
328 | class TestGPT {
329 | constructor(folder, type, doAttentionCache = false) {
330 | this.folder = folder;
331 | this.tokenizerType = type;
332 | this.initialized = false;
333 |
334 | this.device;
335 | this.model;
336 | this.tokenizer;
337 | this.params;
338 | this.minBufferOffset = 1;
339 | this.doAttentionCache = doAttentionCache;
340 |
341 | this.defaultPrompt;
342 | this.defaultTopK;
343 | this.defaultTemperature;
344 | this.defaultTokens;
345 |
346 | this.bufferDeletionStack = [];
347 | this.unloadDeletionStack = [];
348 | }
349 |
350 | async initialize() {
351 | if (this.initialized) return console.error("Model already initialized");
352 | if (!navigator.gpu) throw new Error("WebGPU is not supported");
353 |
354 | const adapter = await navigator.gpu.requestAdapter();
355 | this.device = await adapter.requestDevice();
356 |
357 | this.matMulOperation = new FastMatMul(this.device);
358 |
359 | const dimM = 10;
360 | const dimN = 10;
361 | const demo = new Float32Array(dimM * dimN);
362 | for (let i = 0; i < dimM * dimN; i++) demo[i] = 1;
363 | const weights1 = this.initTensor(demo, [dimM, dimN], ["storage", "copy_from"]);
364 | // const weights2 = this.initTensor(demo, [dimM, dimN], ["storage", "copy_from"]);
365 | this.inputBuffer = this.initBuffer(["storage", "copy_from", "copy_to"], dimM, dimN);
366 |
367 | this.computePasses = [];
368 | let intermediateBuffer = this.inputBuffer;
369 | for (let i = 0; i < 10; i++) {
370 | let { pass, resultBuf } = this.matMulOperation.newInstance(10, 10, 10, intermediateBuffer, weights1);
371 | intermediateBuffer = resultBuf;
372 | this.computePasses.push(pass);
373 | }
374 | this.resultBuffer = intermediateBuffer;
375 | this.outputBuffer = this.initBuffer(["map_read", "copy_to"], dimM, dimN);
376 |
377 | this.initialized = true;
378 | }
379 |
380 | async test() {
381 | const dimM = 10;
382 | const dimN = 10;
383 | const matrixA = new Float32Array(dimM * dimN);
384 | for (let i = 0; i < dimM * dimN; i++) matrixA[i] = i * 0.1;
385 |
386 | this.device.queue.writeBuffer(this.inputBuffer, 0, matrixA);
387 |
388 | const commandEncoder = this.device.createCommandEncoder();
389 | for (const pass of this.computePasses) {
390 | const passEncoder = commandEncoder.beginComputePass();
391 | passEncoder.setPipeline(pass.pipeline);
392 | for (let i = 0; i < pass.groups.length; i++) passEncoder.setBindGroup(i, pass.groups[i]);
393 | passEncoder.dispatchWorkgroups(pass.workgroups.x, pass.workgroups.y);
394 | passEncoder.end();
395 | }
396 | commandEncoder.copyBufferToBuffer(this.resultBuffer, 0, this.outputBuffer, 0, this.bufferSize(dimM, dimN));
397 | this.device.queue.submit([commandEncoder.finish()]);
398 |
399 | await this.outputBuffer.mapAsync(GPUMapMode.READ);
400 | const output = this.outputBuffer.getMappedRange();
401 | const outputArray = new Float32Array(output).slice(0); // Prevent destruction.
402 | console.log(outputArray, formatAsMatrix(outputArray, dimM, dimN));
403 |
404 | this.destroyBuffers();
405 | }
406 |
407 | initBindGroup(layout, buffers) {
408 | return this.device.createBindGroup({
409 | layout,
410 | entries: buffers.map((buffer, i) => ({
411 | binding: i,
412 | resource: { buffer },
413 | })),
414 | });
415 | }
416 |
417 | initOutputBuffer(commandEncoder, buffer, row, col) {
418 | const outputBuffer = this.initBuffer(["map_read", "copy_to"], row, col);
419 | commandEncoder.copyBufferToBuffer(buffer, 0, outputBuffer, 0, this.bufferSize(row, col));
420 | return outputBuffer;
421 | }
422 |
423 | initBuffer(ops, row, col = 1, noDelete = false) {
424 | const buffer = this.device.createBuffer({
425 | size: this.bufferSize(row, col),
426 | usage: ops.map((u) => bufferUsageDict[u]).reduce((a, b) => a | b),
427 | });
428 | if (!noDelete) this.bufferDeletionStack.push(buffer);
429 | else this.unloadDeletionStack.push(buffer);
430 | return buffer;
431 | }
432 |
433 | initTensor(data, dims, ops) {
434 | const buffer = this.device.createBuffer({
435 | size: this.bufferSize(dims[0], dims[1], dims[2] || 1),
436 | usage: ops.map((u) => bufferUsageDict[u]).reduce((a, b) => a | b),
437 | mappedAtCreation: true,
438 | });
439 | const array = new Float32Array(buffer.getMappedRange());
440 | array.set(data);
441 | buffer.unmap();
442 | this.unloadDeletionStack.push(buffer);
443 | return buffer;
444 | }
445 |
446 | bufferSize(dimX, dimY = 1, dimZ = 1) {
447 | return Math.ceil((dimX * dimY * dimZ * Float32Array.BYTES_PER_ELEMENT) / this.minBufferOffset) * this.minBufferOffset;
448 | }
449 |
450 | unloadBuffers() {
451 | this.unloadDeletionStack.map((buffer) => buffer.destroy());
452 | this.unloadDeletionStack = [];
453 | }
454 |
455 | destroyBuffers() {
456 | this.bufferDeletionStack.map((buffer) => buffer.destroy());
457 | this.bufferDeletionStack = [];
458 | }
459 |
460 | initBindGroups() {
461 | const bg = (types) =>
462 | this.device.createBindGroupLayout({
463 | entries: types.map((entry, i) => ({
464 | binding: i,
465 | visibility: GPUShaderStage.COMPUTE,
466 | buffer: { type: entry },
467 | })),
468 | });
469 |
470 | this.r_r_r_Layout = bg(["read-only-storage", "read-only-storage", "read-only-storage"]);
471 | this.r_r_Layout = bg(["read-only-storage", "read-only-storage"]);
472 | this.r_Layout = bg(["read-only-storage"]);
473 | this.u_s_Layout = bg(["uniform", "storage"]);
474 | this.u_s_s_s_Layout = bg(["uniform", "storage", "storage", "storage"]);
475 | }
476 |
477 | async initPipelines() {
478 | const p = (code, bindGroupLayouts) => {
479 | return this.device.createComputePipelineAsync({
480 | layout: this.device.createPipelineLayout({ bindGroupLayouts }),
481 | compute: {
482 | module: this.device.createShaderModule({ code }),
483 | entryPoint: "main",
484 | },
485 | });
486 | };
487 | }
488 | }
489 |
490 | async function test() {
491 | const GPU = new TestGPT();
492 | await GPU.initialize();
493 | await GPU.test();
494 | }
495 |
496 | /*
497 |
498 |
499 | fast row add shader for reference
500 | struct BMeta {
501 | M: u32,
502 | N: u32,
503 | ND4: u32,
504 | }
505 |
506 | @group(1) @binding(0) var array_matrix: array>;
507 | @group(1) @binding(1) var array_bias: array>;
508 | @group(0) @binding(0) var bmeta: BMeta;
509 | @group(0) @binding(1) var array_output: array>;
510 |
511 | @compute @workgroup_size(8,8)
512 | fn main(@builtin(global_invocation_id) global_id: vec3) {
513 | var col: u32 = global_id.x;
514 | var row: u32 = global_id.y;
515 | var ND4: u32 = bmeta.ND4;
516 | var M: u32 = bmeta.M;
517 |
518 | if (row >= M || col >= ND4) {
519 | return;
520 | }
521 |
522 | array_output[row * ND4 + col] = array_matrix[row * ND4 + col] + array_bias[col];
523 | }
524 |
525 | class FastMatMulBlockClass extends Block {
526 | constructor() {
527 | super();
528 | this.name = "fastMatMul";
529 | this.pipelineCache = new Map();
530 | }
531 |
532 | getPipeline(rows) {
533 | const div4 = rows % 4 === 0;
534 | const pipelineCacheKey = div4 ? "fastMatMulNoCheck" : "fastMatMul";
535 | if (this.pipelineCache.has(pipelineCacheKey)) return this.pipelineCache.get(pipelineCacheKey);
536 | const kernel = div4 ? this.fastMatMulNoCheck : this.fastMatMul;
537 | const pipeline = this.initPipeline(kernel, [this.u_s_Layout, this.r_r_Layout], `${this.name}_Pipeline_${pipelineCacheKey}`);
538 | this.pipelineCache.set(pipelineCacheKey, pipeline);
539 | return pipeline;
540 | }
541 |
542 | newInstance(rows, cols, shared, bufA, bufB) {
543 | const pipeline = this.getPipeline(rows);
544 | const uniformBuffer = this.initBuffer(["uniform", "copy_to"], [4]);
545 | const resultBuffer = this.initBuffer(["storage", "copy_from"], [rows, cols]);
546 | const opBindGroup = this.initBindGroup(this.u_s_Layout, [uniformBuffer, resultBuffer], `${this.name}_OpG`);
547 | const inputBindGroup = this.initBindGroup(this.r_r_Layout, [bufA, bufB], `${this.name}_InputG`);
548 | const workgroups = { x: wgSize(cols, 64), y: wgSize(rows, 32) };
549 | this.device.queue.writeBuffer(uniformBuffer, 0, new Uint32Array([rows, cols, Math.ceil(cols / 4), Math.ceil(shared / 4)]));
550 |
551 | return {
552 | resultBuffer,
553 | passes: [
554 | {
555 | flag: "compute",
556 | pipeline,
557 | groups: [opBindGroup, inputBindGroup],
558 | workgroups,
559 | },
560 | ],
561 | };
562 | }
563 |
564 | fastMatMul = `
565 | struct CMeta {
566 | M: u32,
567 | N: u32,
568 | ND4: u32,
569 | KD4: u32,
570 | }
571 |
572 | @group(1) @binding(0) var array_a: array>;
573 | @group(1) @binding(1) var array_b: array>;
574 |
575 | @group(0) @binding(0) var cmeta: CMeta;
576 | @group(0) @binding(1) var array_c: array>;
577 |
578 | @compute @workgroup_size(8, 8)
579 | fn main(@builtin(global_invocation_id) global_id: vec3) {
580 | var M: u32 = cmeta.M;
581 | var N: u32 = cmeta.N;
582 | var ND4: u32 = cmeta.ND4;
583 | var KD4: u32 = cmeta.KD4;
584 | var x: u32 = global_id.x;
585 | var y: u32 = global_id.y;
586 |
587 | if (x * 8 >= N || y * 4 >= M) {
588 | return;
589 | }
590 |
591 | var sum00: vec4 = vec4();
592 | var sum01: vec4 = vec4();
593 | var sum02: vec4 = vec4();
594 | var sum03: vec4 = vec4();
595 | var sum10: vec4 = vec4();
596 | var sum11: vec4 = vec4();
597 | var sum12: vec4 = vec4();
598 | var sum13: vec4 = vec4();
599 |
600 | for(var k: u32 = 0u; k < KD4; k = k + 1u) {
601 | var arow0: vec4 = array_a[(y * 4u + 0u) * KD4 + k];
602 | var arow1: vec4 = array_a[(y * 4u + 1u) * KD4 + k];
603 | var arow2: vec4 = array_a[(y * 4u + 2u) * KD4 + k];
604 | var arow3: vec4 = array_a[(y * 4u + 3u) * KD4 + k];
605 | var brow: vec4;
606 |
607 | brow = array_b[(k * 4u + 0u) * ND4 + x * 2u + 0u];
608 | sum00 = vec4(arow0.x) * brow + sum00;
609 | sum01 = vec4(arow1.x) * brow + sum01;
610 | sum02 = vec4(arow2.x) * brow + sum02;
611 | sum03 = vec4(arow3.x) * brow + sum03;
612 |
613 | brow = array_b[(k * 4u + 0u) * ND4 + x * 2u + 1u];
614 | sum10 = vec4(arow0.x) * brow + sum10;
615 | sum11 = vec4(arow1.x) * brow + sum11;
616 | sum12 = vec4(arow2.x) * brow + sum12;
617 | sum13 = vec4(arow3.x) * brow + sum13;
618 |
619 | brow = array_b[(k * 4u + 1u) * ND4 + x * 2u + 0u];
620 | sum00 = vec4(arow0.y) * brow + sum00;
621 | sum01 = vec4(arow1.y) * brow + sum01;
622 | sum02 = vec4(arow2.y) * brow + sum02;
623 | sum03 = vec4(arow3.y) * brow + sum03;
624 |
625 | brow = array_b[(k * 4u + 1u) * ND4 + x * 2u + 1u];
626 | sum10 = vec4(arow0.y) * brow + sum10;
627 | sum11 = vec4(arow1.y) * brow + sum11;
628 | sum12 = vec4(arow2.y) * brow + sum12;
629 | sum13 = vec4(arow3.y) * brow + sum13;
630 |
631 | brow = array_b[(k * 4u + 2u) * ND4 + x * 2u + 0u];
632 | sum00 = vec4(arow0.z) * brow + sum00;
633 | sum01 = vec4(arow1.z) * brow + sum01;
634 | sum02 = vec4(arow2.z) * brow + sum02;
635 | sum03 = vec4(arow3.z) * brow + sum03;
636 |
637 | brow = array_b[(k * 4u + 2u) * ND4 + x * 2u + 1u];
638 | sum10 = vec4(arow0.z) * brow + sum10;
639 | sum11 = vec4(arow1.z) * brow + sum11;
640 | sum12 = vec4(arow2.z) * brow + sum12;
641 | sum13 = vec4(arow3.z) * brow + sum13;
642 |
643 | brow = array_b[(k * 4u + 3u) * ND4 + x * 2u + 0u];
644 | sum00 = vec4(arow0.w) * brow + sum00;
645 | sum01 = vec4(arow1.w) * brow + sum01;
646 | sum02 = vec4(arow2.w) * brow + sum02;
647 | sum03 = vec4(arow3.w) * brow + sum03;
648 |
649 | brow = array_b[(k * 4u + 3u) * ND4 + x * 2u + 1u];
650 | sum10 = vec4(arow0.w) * brow + sum10;
651 | sum11 = vec4(arow1.w) * brow + sum11;
652 | sum12 = vec4(arow2.w) * brow + sum12;
653 | sum13 = vec4