├── .gitignore ├── README.md ├── generate.js ├── input.txt ├── model.js ├── package.json └── train.js /.gitignore: -------------------------------------------------------------------------------- 1 | /weights.safetensors 2 | /tokenizer.json 3 | *.swp 4 | 5 | yarn.lock 6 | package-lock.json 7 | npm-debug.log 8 | yarn-error.log 9 | /node_modules/ 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Train text generation model with JavaScript 2 | 3 | This repo hosts some educational scripts for doing basic training on a 4 | decoder-only transformer, using [node-mlx](https://github.com/frost-beta/node-mlx) 5 | with Node.js. 6 | 7 | Files: 8 | 9 | * `model.js` - defines the model. 10 | * `input.txt` - text file used for training the model. 11 | * `train.js` - script for traning. 12 | * `generate.js` - script for generating text using the trained model. 13 | 14 | ## Platform 15 | 16 | Only Macs with Apple Silicon are supported. 17 | 18 | ## How to use 19 | 20 | Download dependencies and run the training script, which generates 21 | `tokenizer.json` and `weights.safetensors`: 22 | 23 | ```bash 24 | npm install 25 | node train.js 26 | ``` 27 | 28 | Then use the generate script to actually generate some text from the weights: 29 | 30 | ```bash 31 | node generate.js 32 | ``` 33 | 34 | ## What's next 35 | 36 | After understanding the basics of model training, you can check the 37 | [train-llama3-js](https://github.com/frost-beta/train-llama3-js) repo on how 38 | to train a Llama3 model with large datasets. 39 | 40 | ## License 41 | 42 | Public domain. 43 | -------------------------------------------------------------------------------- /generate.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | import fs from 'node:fs' 4 | import nextTick from 'tick-promise' 5 | import {core as mx} from '@frost-beta/mlx' 6 | 7 | import Model from './model.js' 8 | import {contextSize, hiddenSize, numHiddenLayers, numAttentionHeads, getTokenizer} from './train.js' 9 | 10 | main() 11 | 12 | async function main() { 13 | // Create tokenizer. 14 | const {vocab} = JSON.parse(fs.readFileSync('tokenizer.json')) 15 | const {vocabSize, encode, decode} = getTokenizer(vocab) 16 | 17 | // Create model. 18 | const model = new Model({vocabSize, hiddenSize, numHiddenLayers, numAttentionHeads}) 19 | model.loadWeights('weights.safetensors') 20 | 21 | // Generate. 22 | const prompt = 'MIRANDA:\n' 23 | process.stdout.write(prompt) 24 | for await (const token of step(encode(prompt), model)) { 25 | const char = decode([token]) 26 | process.stdout.write(char) 27 | } 28 | process.stdout.end('\n') 29 | } 30 | 31 | // Generate tokens from prompt. 32 | async function* step(prompt, model, maxTokens = 512, temperature = 0.8) { 33 | // Pass the tokens to the model and get the next token. 34 | const forward = (tokens) => { 35 | const inputs = mx.array([ tokens.slice(-contextSize) ], mx.int32) 36 | const logits = model.forward(inputs) 37 | return sample(logits.index(0, -1), temperature) 38 | } 39 | 40 | let tokens = prompt 41 | while (true) { 42 | const token = mx.tidy(() => forward(tokens).item()) 43 | tokens.push(token) 44 | // Yield the result in the next tick of loop, so GC can get a chance to run. 45 | await nextTick() 46 | yield token 47 | // Stop when hit maxTokens limit. 48 | if (tokens.length > maxTokens) 49 | return 50 | } 51 | } 52 | 53 | // Pick the best token from logits. 54 | function sample(logits, temperature) { 55 | if (temperature == 0) 56 | return mx.argmax(logits, -1) 57 | else 58 | return mx.random.categorical(mx.multiply(logits, 1 / temperature)) 59 | } 60 | -------------------------------------------------------------------------------- /model.js: -------------------------------------------------------------------------------- 1 | import {core as mx, nn} from '@frost-beta/mlx' 2 | 3 | // A decoder-only Transformer. 4 | export default class Model extends nn.Module { 5 | constructor({vocabSize, hiddenSize, numHiddenLayers, numAttentionHeads}) { 6 | super() 7 | this.embedding = new nn.Embedding(vocabSize, hiddenSize) 8 | this.pe = new nn.SinusoidalPositionalEncoding(hiddenSize) 9 | this.transformer = new nn.TransformerEncoder(numHiddenLayers, 10 | hiddenSize, 11 | numAttentionHeads) 12 | this.outProj = new nn.Linear(hiddenSize, vocabSize) 13 | } 14 | 15 | forward(x) { 16 | const L = x.shape[1] 17 | const mask = nn.MultiHeadAttention.createAdditiveCausalMask(L) 18 | x = this.embedding.forward(x) 19 | x = mx.add(x, this.pe.forward(mx.arange(L))) 20 | x = this.transformer.forward(x, mask) 21 | return this.outProj.forward(x) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "module", 3 | "license": "Unlicense", 4 | "dependencies": { 5 | "@frost-beta/mlx": "0.0.12", 6 | "tick-promise": "1.0.0" 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /train.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | import fs from 'node:fs' 4 | import path from 'node:path' 5 | import nextTick from 'tick-promise' 6 | import {core as mx, optimizers as optim, nn, utils} from '@frost-beta/mlx' 7 | 8 | import Model from './model.js' 9 | 10 | // Hyperparameters. 11 | export const contextSize = 128 12 | export const hiddenSize = 128 13 | export const numHiddenLayers = 8 14 | export const numAttentionHeads = 4 15 | 16 | // Traning configs. 17 | const batchSize = 32 18 | const epochs = 24 19 | const learningRate = 1e-3 20 | 21 | if (process.argv[1].endsWith('train.js')) 22 | main() 23 | 24 | async function main() { 25 | // Use the text file for training. 26 | const filename = 'input.txt' 27 | const text = fs.readFileSync(filename).toString() 28 | 29 | // Create tokenizer. 30 | const vocab = getVocabulary(text) 31 | const {vocabSize, encode, decode} = getTokenizer(vocab) 32 | console.log('Vocabulary size is', vocabSize) 33 | 34 | // Encode the text to tokens, and split them for traning and validating. 35 | const data = encode(text) 36 | const dataTrain = data.slice(0, Math.floor(data.length * 0.9)) 37 | const dataValid = data.slice(dataTrain.length) 38 | 39 | // Convert the tokens into features and labels. 40 | const {x: xTrain, y: yTrain} = loadData(dataTrain, contextSize) 41 | const {x: xValid, y: yValid} = loadData(dataValid, contextSize) 42 | console.log('Traning dataset\'s shape is', xTrain.shape) 43 | console.log('Validating dataset\'s shape is', xValid.shape) 44 | 45 | // Create model. 46 | const model = new Model({vocabSize, hiddenSize, numHiddenLayers, numAttentionHeads}) 47 | 48 | // Calculate how many parameters the model has. 49 | let nparams = 0 50 | for (const [k, x] of utils.treeFlatten(model.parameters())) { 51 | if (!k.includes('embedding')) 52 | nparams += x.size 53 | } 54 | console.log(`Model has ${(nparams / 1024 ** 2).toFixed(1)}M parameters.`) 55 | 56 | // Preprare utils for doing gradient descent. 57 | const lossAndGradFunction = nn.valueAndGrad(model, lossFunction) 58 | const optimizer = new optim.AdamW(learningRate) 59 | 60 | const reportPerIter = 100 61 | const totalIterations = epochs * Math.floor(xTrain.shape[0] / batchSize) 62 | 63 | // Train the model with training datasets. 64 | let losses = [] 65 | for (let e = 0, iterations = 1, start = Date.now(); e < epochs; ++e) { 66 | for await (const [x, y] of iterateBatches(xTrain, yTrain, batchSize)) { 67 | // Use mx.tidy to free all the intermediate tensors immediately. 68 | mx.tidy(() => { 69 | // Compute loss and gradients, then update the model. 70 | const [loss, grads] = lossAndGradFunction(model, x, y) 71 | optimizer.update(model, grads) 72 | mx.eval(model.state, optimizer.state) 73 | losses.push(loss.item()) 74 | // Keep the states of model and optimizer from getting freed. 75 | return [model.state, optimizer.state] 76 | }) 77 | mx.dispose([x, y]) 78 | // Report updates. 79 | if (++iterations % reportPerIter === 0) { 80 | const stop = Date.now() 81 | const trainLoss = mean(losses) 82 | console.log(`Iter ${iterations} (${(iterations / totalIterations * 100).toFixed(1)}%):`, 83 | `Train loss ${trainLoss.toFixed(3)},`, 84 | `It/sec ${(reportPerIter / (stop - start) * 1000).toFixed(3)}.`) 85 | start = Date.now() 86 | losses = [] 87 | } 88 | } 89 | } 90 | const trainLoss = mean(losses) 91 | 92 | // Evaluate the model by running it with the validation dataset. 93 | model.eval() 94 | losses = [] 95 | for await (const [x, y] of iterateBatches(xValid, yValid, batchSize)) { 96 | mx.tidy(() => { 97 | const loss = lossFunction(model, x, y) 98 | losses.push(loss.item()) 99 | return [model.state, optimizer.state] 100 | }) 101 | } 102 | const validLoss = mean(losses) 103 | console.log('Train Loss:', trainLoss.toFixed(3), 104 | 'Valid Loss:', validLoss.toFixed(3)) 105 | 106 | console.log('Save weights to weights.safetensors') 107 | model.saveWeights('weights.safetensors') 108 | console.log('Save tokenizer to tokenizer.json') 109 | fs.writeFileSync('tokenizer.json', JSON.stringify({vocab})) 110 | } 111 | 112 | // Analyze the text into a vocabulary. 113 | function getVocabulary(text) { 114 | return Array.from(new Set(text.split(''))).sort() 115 | } 116 | 117 | // Create a simple character mapped tokenizer. 118 | export function getTokenizer(vocab) { 119 | const vocabSize = vocab.length 120 | 121 | const itos = {} 122 | const stoi = {} 123 | vocab.forEach((c, i) => { 124 | itos[i] = c 125 | stoi[c] = i 126 | }) 127 | 128 | const encode = (x) => x.split('').map(c => stoi[c]) 129 | const decode = (x) => x.map(i => itos[i]).join('') 130 | 131 | return {vocabSize, encode, decode} 132 | } 133 | 134 | // Take tokens and split them into features and labels. 135 | function loadData(tokens, contextSize) { 136 | let x = [] 137 | let y = [] 138 | for (let i = 0; i < tokens.length - contextSize - 1; i += contextSize) { 139 | x.push(tokens.slice(i, i + contextSize)) 140 | y.push(tokens.slice(i + 1, i + contextSize + 1)) 141 | } 142 | return {x: mx.array(x, mx.uint32), y: mx.array(y, mx.uint32)} 143 | } 144 | 145 | // Iterate the dataset in batches. 146 | async function* iterateBatches(x, y, batchSize) { 147 | for (let i = 0; i < x.shape[0]; i += batchSize) { 148 | const slice = mx.Slice(i, i + batchSize) 149 | // Yield the result in the next tick of loop, so GC can get a chance to run. 150 | await nextTick() 151 | yield [x.index(slice), y.index(slice)] 152 | } 153 | } 154 | 155 | // Calculate the loss by 1) running the model with the inputs, and 2) then using 156 | // cross entropy function to get the loss between the results and targets. 157 | function lossFunction(model, x, y) { 158 | const logits = model.forward(x) 159 | const losses = nn.losses.crossEntropy(logits, y) 160 | return mx.mean(losses) 161 | } 162 | 163 | // Compute the mean value of an array. 164 | function mean(array) { 165 | if (array.length == 0) 166 | return 0 167 | return array.reduce((a, b) => a + b) / array.length 168 | } 169 | --------------------------------------------------------------------------------