├── .github └── workflows │ └── build.yml ├── .gitignore ├── .npmignore ├── LICENSE ├── README.md ├── package.json ├── src ├── image-processor.ts ├── index.ts └── model.ts ├── test.ts └── tsconfig.json /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | runs-on: >- 8 | ${{ 9 | (matrix.os == 'mac' && matrix.arch == 'arm64') && 10 | 'macos-14' || 11 | (fromJson('{"linux":"ubuntu-22.04","mac":"macos-13","win":"windows-2022"}')[matrix.os]) 12 | }} 13 | continue-on-error: false 14 | 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [linux, mac] 19 | arch: [x64] 20 | include: 21 | - os: mac 22 | arch: arm64 23 | 24 | steps: 25 | - name: Install linux dependencies 26 | if: matrix.os == 'linux' && matrix.arch == runner.arch 27 | run: sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev 28 | 29 | - name: Install mac dependencies 30 | if: matrix.os == 'mac' && matrix.arch == 'x64' 31 | run: brew install openblas 32 | 33 | - name: Checkout 34 | uses: actions/checkout@v4 35 | 36 | - name: Download models 37 | run: | 38 | npm install -g @frost-beta/huggingface 39 | huggingface download --filter=*.json --filter=*.safetensors openai/clip-vit-large-patch14 40 | 41 | - name: Test 42 | run: | 43 | yarn 44 | yarn prepack 45 | yarn tsx test.ts clip-vit-large-patch14 46 | 47 | publish: 48 | if: startsWith(github.ref, 'refs/tags/') 49 | needs: [build] 50 | runs-on: ubuntu-latest 51 | 52 | steps: 53 | - name: Checkout 54 | uses: actions/checkout@v4 55 | 56 | - name: Get tag 57 | run: echo "VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 58 | 59 | - name: Set package version 60 | run: | 61 | npm config set git-tag-version=false 62 | npm version $VERSION 63 | 64 | - name: Install deps 65 | run: yarn 66 | 67 | - name: Publish npm package 68 | uses: JS-DevTools/npm-publish@v3 69 | with: 70 | token: ${{ secrets.NPM_TOKEN }} 71 | access: public 72 | ignore-scripts: false 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # TypeScript compiled files 2 | /dist/ 3 | 4 | # Everything else should keep same with .npmignore 5 | *.swp 6 | *.tgz 7 | 8 | yarn.lock 9 | package-lock.json 10 | npm-debug.log 11 | yarn-error.log 12 | /node_modules/ 13 | /clip-vit-*/ 14 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | # Unused source files 2 | /.github/ 3 | /src/ 4 | 5 | # Everything else should keep same with .gitignore 6 | *.swp 7 | *.tgz 8 | 9 | yarn.lock 10 | package-lock.json 11 | npm-debug.log 12 | yarn-error.log 13 | /node_modules/ 14 | /clip-vit-*/ 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | Copyright © 2024 zcbenz 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Clip 2 | 3 | Node.js module for the [CLIP model](https://openai.com/index/clip/). 4 | 5 | Powered by [node-mlx](https://github.com/frost-beta/node-mlx), a machine 6 | learning framework for Node.js. 7 | 8 | ## APIs 9 | 10 | ```typescript 11 | import { core as mx } from '@frost-beta/mlx'; 12 | 13 | export type ImageInputType = Buffer | ArrayBuffer | string; 14 | 15 | export interface ProcessedImage { 16 | data: Buffer; 17 | info: sharp.OutputInfo; 18 | } 19 | 20 | export interface ClipInput { 21 | labels?: string[]; 22 | images?: ProcessedImage[]; 23 | } 24 | 25 | export interface ClipOutput { 26 | labelEmbeddings?: mx.array; 27 | imageEmbeddings?: mx.array; 28 | } 29 | 30 | export class Clip { 31 | constructor(modelDir: string); 32 | processImages(images: ImageInputType[]): Promise; 33 | computeEmbeddings({ labels, images }: ClipInput): ClipOutput; 34 | /** 35 | * Short hands of computeEmbeddings to convert results to JavaScript numbers 36 | * and ensure the intermediate arrays are destroyed. 37 | */ 38 | computeLabelEmbeddingsJs(labels: string[]): number[][]; 39 | computeImageEmbeddingsJs(images: ProcessedImage[]): number[][]; 40 | /** 41 | * Compute the cosine similarity between 2 embeddings. 42 | */ 43 | static computeCosineSimilaritiy(a1: mx.array, a2: mx.array): mx.array; 44 | /** 45 | * Compute the cosine similarities between 2 arrays of embeddings. 46 | * 47 | * A tuple will be returned, with the first element being the cosine 48 | * similarity scores, and the second element being the indices sorted by 49 | * their scores from larger to smalller. 50 | */ 51 | static computeCosineSimilarities(x1: mx.array | number[][], x2: mx.array | number[][]): [mx.array, mx.array]; 52 | } 53 | ``` 54 | 55 | ## License 56 | 57 | MIT 58 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@frost-beta/clip", 3 | "version": "0.0.1-dev", 4 | "description": "Compute embeddings of text/images with CLIP model", 5 | "main": "dist/index.js", 6 | "types": "dist/index.d.ts", 7 | "scripts": { 8 | "prepack": "tsc" 9 | }, 10 | "author": "zcbenz", 11 | "license": "MIT", 12 | "keywords": [ "mlx", "embeddings", "clip" ], 13 | "repository": { 14 | "type": "git", 15 | "url": "git+https://github.com/frost-beta/clip.js.git" 16 | }, 17 | "bugs": { 18 | "url": "https://github.com/frost-beta/clip.js/issues" 19 | }, 20 | "devDependencies": { 21 | "@types/node": "22.5.4", 22 | "tsx": "4.19.1", 23 | "typescript": "5.6.2" 24 | }, 25 | "dependencies": { 26 | "@frost-beta/mlx": "0.0.19", 27 | "@lenml/tokenizers": "1.1.2", 28 | "sharp": "0.33.5" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/image-processor.ts: -------------------------------------------------------------------------------- 1 | import sharp from 'sharp'; 2 | import {core as mx} from '@frost-beta/mlx'; 3 | 4 | export type ImageInputType = Buffer | ArrayBuffer | string; 5 | 6 | export interface PreprocessorConfig { 7 | cropSize: number, 8 | doCenterCrop: boolean, 9 | doNormalize: boolean, 10 | doResize: boolean, 11 | imageMean: number[], 12 | imageStd: number[], 13 | size: number 14 | } 15 | 16 | export interface ProcessedImage { 17 | data: Buffer; 18 | info: sharp.OutputInfo; 19 | } 20 | 21 | export class ClipImageProcessor { 22 | constructor(private config: PreprocessorConfig) {} 23 | 24 | async processImage(input: ImageInputType): Promise { 25 | let image = sharp(input); 26 | if (this.config.doResize && this.config.doCenterCrop && this.config.size == this.config.cropSize) { 27 | // Fast path for resize and crop with same size. 28 | image = image.resize(this.config.size, this.config.size); 29 | } else { 30 | // Slow path for doing resize and crop in 2 separate steps. 31 | if (this.config.doResize) 32 | image = image.resize(this.config.size, this.config.size, {fit: 'outside'}); 33 | if (this.config.doCenterCrop) 34 | image = await centerCrop(image, this.config.cropSize); 35 | } 36 | // The model only works with RGB. 37 | image = image.removeAlpha(); 38 | // Extract size and data. 39 | return await image.raw().toBuffer({resolveWithObject: true}); 40 | } 41 | 42 | processImages(inputs: ImageInputType[]): Promise { 43 | return Promise.all(inputs.map(this.processImage.bind(this))); 44 | } 45 | 46 | normalizeImages(images: ProcessedImage[]) { 47 | const {info} = images[0]; 48 | // The model expects the data to be a nested array. 49 | let tensor = mx.stack(images.map(i => mx.array(Array.from(i.data)))); 50 | tensor = tensor.reshape([ images.length, info.width, info.height, 3 ]); 51 | // Normalize the tensor. 52 | tensor = rescale(tensor); 53 | if (this.config.doNormalize) 54 | tensor = normalize(tensor, this.config.imageMean, this.config.imageStd); 55 | return tensor; 56 | } 57 | } 58 | 59 | async function centerCrop(image: sharp.Sharp, cropSize: number) { 60 | // Have to call toBuffer to get the new size after resize. 61 | const {info} = await image.toBuffer({resolveWithObject: true}); 62 | return image.extract({ 63 | top: (info.height - cropSize) / 2, 64 | left: (info.width - cropSize) / 2, 65 | width: cropSize, 66 | height: cropSize, 67 | }); 68 | } 69 | 70 | function rescale(tensor: mx.array) { 71 | return mx.divide(tensor.astype(mx.float32), 255); 72 | } 73 | 74 | function normalize(tensor: mx.array, mean: number[], std: number[]) { 75 | return mx.divide(mx.subtract(tensor, mx.array(mean)), 76 | mx.array(std)); 77 | } 78 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | import {statSync, readFileSync} from 'node:fs' 2 | import {TokenizerLoader} from '@lenml/tokenizers'; 3 | import {core as mx, nn} from '@frost-beta/mlx'; 4 | 5 | import { 6 | ClipConfig, 7 | ClipModelInput, 8 | ClipModel, 9 | } from './model'; 10 | import { 11 | ImageInputType, 12 | ProcessedImage, 13 | PreprocessorConfig, 14 | ClipImageProcessor, 15 | } from './image-processor'; 16 | 17 | export * from './model'; 18 | export * from './image-processor'; 19 | 20 | export interface ClipInput { 21 | labels?: string[]; 22 | images?: ProcessedImage[]; 23 | } 24 | 25 | export interface ClipOutput { 26 | labelEmbeddings?: mx.array; 27 | imageEmbeddings?: mx.array; 28 | } 29 | 30 | /** 31 | * Provide APIs around the CLIP model. 32 | */ 33 | export class Clip { 34 | #tokenizer?: Tokenizer; 35 | #imageProcessor?: ClipImageProcessor; 36 | #model?: ClipModel; 37 | 38 | constructor(public modelDir: string, public batchSize?: number) {} 39 | 40 | get tokenizer() { 41 | if (!this.#tokenizer) 42 | this.#tokenizer = loadTokenizer(this.modelDir); 43 | return this.#tokenizer; 44 | } 45 | 46 | get imageProcessor() { 47 | if (!this.#imageProcessor) 48 | this.#imageProcessor = loadImageProcessor(this.modelDir); 49 | return this.#imageProcessor; 50 | } 51 | 52 | get model() { 53 | if (!this.#model) { 54 | if (this.batchSize) { 55 | // When batchSize is hinted, we will set a cache limit. This is needed 56 | // because the model can burst to use many RAM and MLX's cache memory 57 | // will leave app's RAM usage at the peak. We should eventually fix the 58 | // model but for now setting cache limit is enough. 59 | const {size} = statSync(`${this.modelDir}/model.safetensors`); 60 | mx.metal.setCacheLimit(size * (1 + this.batchSize)); 61 | } 62 | this.#model = loadModel(this.modelDir); 63 | } 64 | return this.#model; 65 | } 66 | 67 | processImages(images: ImageInputType[]): Promise { 68 | return this.imageProcessor.processImages(images); 69 | } 70 | 71 | computeEmbeddings({labels, images}: ClipInput): ClipOutput { 72 | const input: ClipModelInput = {}; 73 | if (labels) 74 | input.inputIds = this.tokenizer.encode(labels); 75 | if (images) 76 | input.pixelValues = this.imageProcessor.normalizeImages(images); 77 | const output = this.model.forward(input); 78 | return { 79 | labelEmbeddings: output.textEmbeds, 80 | imageEmbeddings: output.imageEmbeds, 81 | }; 82 | } 83 | 84 | /** 85 | * Short hands of computeEmbeddings to convert results to JavaScript numbers 86 | * and ensure the intermediate arrays are destroyed. 87 | */ 88 | computeLabelEmbeddingsJs(labels: string[]): number[][] { 89 | this.model; // initialize model before mx.tidy 90 | return mx.tidy(() => this.computeEmbeddings({labels}).labelEmbeddings.tolist() as number[][]); 91 | } 92 | 93 | computeImageEmbeddingsJs(images: ProcessedImage[]): number[][] { 94 | this.model; // initialize model before mx.tidy 95 | return mx.tidy(() => this.computeEmbeddings({images}).imageEmbeddings.tolist() as number[][]); 96 | } 97 | 98 | /** 99 | * Compute the cosine similarity between 2 embeddings. 100 | */ 101 | static computeCosineSimilaritiy(a1: mx.array, a2: mx.array): mx.array { 102 | return nn.losses.cosineSimilarityLoss(a1, a2, 0); 103 | } 104 | 105 | /** 106 | * Compute the cosine similarities between 2 arrays of embeddings. 107 | * 108 | * A tuple will be returned, with the first element being the cosine 109 | * similarity scores, and the second element being the indices sorted by 110 | * their scores from larger to smalller. 111 | */ 112 | static computeCosineSimilarities(x1: mx.array | number[][], 113 | x2: mx.array | number[][]): [ mx.array, mx.array ] { 114 | if (!(x1 instanceof mx.array)) 115 | x1 = mx.array(x1); 116 | if (!(x2 instanceof mx.array)) 117 | x2 = mx.array(x2); 118 | const scores = nn.losses.cosineSimilarityLoss(x1, x2, 1); 119 | const indices = mx.argsort(scores).index(mx.Slice(null, null, -1)); 120 | return [ scores, indices ]; 121 | } 122 | } 123 | 124 | // The tokenizer for encoding multiple strings. 125 | export interface Tokenizer { 126 | encode(text: string[]): mx.array; 127 | } 128 | 129 | // Return the tokenizer. 130 | export function loadTokenizer(dir: string): Tokenizer { 131 | const tokenizer = TokenizerLoader.fromPreTrained({ 132 | tokenizerJSON: readJson(`${dir}/tokenizer.json`), 133 | tokenizerConfig: readJson(`${dir}/tokenizer_config.json`), 134 | }); 135 | return { 136 | encode(text: string[]) { 137 | const {input_ids} = tokenizer._call(text, {padding: true}); 138 | return mx.stack(input_ids as number[][]); 139 | } 140 | }; 141 | } 142 | 143 | // Return the image processor. 144 | export function loadImageProcessor(dir: string) { 145 | const json = readJson(`${dir}/preprocessor_config.json`); 146 | return new ClipImageProcessor(modelArgs(json) as PreprocessorConfig); 147 | } 148 | 149 | // Create the CLIP model. 150 | export function loadModel(dir: string) { 151 | // Read config files. 152 | const configJson = readJson(`${dir}/config.json`); 153 | const clipConfig = modelArgs(configJson) as ClipConfig; 154 | // Create model. 155 | const model = new ClipModel(clipConfig); 156 | const weights = Object.entries(mx.load(`${dir}/model.safetensors`)); 157 | // Sanitize the weights for MLX. 158 | const sanitizedWeights = []; 159 | for (const [ key, value ] of weights) { 160 | // Remove unused position_ids. 161 | if (key.includes('position_ids')) 162 | continue; 163 | // PyTorch Conv2d expects the weight tensor to be of shape: 164 | // [out_channels, in_channels, kH, KW] 165 | // MLX Conv2d expects the weight tensor to be of shape: 166 | // [out_channels, kH, KW, in_channels] 167 | if (key.endsWith('patch_embedding.weight')) 168 | sanitizedWeights.push([ key, value.transpose(0, 2, 3, 1) ]); 169 | else 170 | sanitizedWeights.push([ key, value ]); 171 | } 172 | model.loadWeights(sanitizedWeights); 173 | return model; 174 | } 175 | 176 | // Convert snake_case args into camelCase args. 177 | function modelArgs(args: any): object{ 178 | if (Array.isArray(args)) 179 | return args.map(v => modelArgs(v)); 180 | if (typeof args != 'object') 181 | return args; 182 | const newArgs = {} 183 | for (const key in args) { 184 | const newKey = key.replace(/(\_\w)/g, (s) => s[1].toUpperCase()) 185 | newArgs[newKey] = modelArgs(args[key]); 186 | } 187 | return newArgs 188 | } 189 | 190 | // Helper for reading a .json file. 191 | function readJson(path: string) { 192 | return JSON.parse(String(readFileSync(path))); 193 | } 194 | -------------------------------------------------------------------------------- /src/model.ts: -------------------------------------------------------------------------------- 1 | import {core as mx, nn} from '@frost-beta/mlx'; 2 | 3 | export interface ClipConfig { 4 | textConfig: ClipTextConfig; 5 | visionConfig: ClipVisionConfig; 6 | projectionDim: number; 7 | } 8 | 9 | interface EncoderConfig { 10 | numHiddenLayers: number; 11 | hiddenSize: number; 12 | intermediateSize: number; 13 | numAttentionHeads: number; 14 | layerNormEps: number; 15 | } 16 | 17 | export interface ClipTextConfig extends EncoderConfig { 18 | maxPositionEmbeddings: number; 19 | vocabSize: number; 20 | } 21 | 22 | export interface ClipVisionConfig extends EncoderConfig { 23 | numChannels: number; 24 | imageSize: number; 25 | patchSize: number; 26 | } 27 | 28 | export interface ClipTextOutput { 29 | poolerOutput: mx.array; 30 | lastHiddenState: mx.array; 31 | } 32 | 33 | export interface ClipVisionOutput extends ClipTextOutput { 34 | hiddenStates?: mx.array; 35 | } 36 | 37 | export interface ClipModelInput { 38 | inputIds?: mx.array; 39 | pixelValues?: mx.array; 40 | returnLoss?: boolean; 41 | } 42 | 43 | export interface ClipModelOutput { 44 | loss?: mx.array; 45 | textEmbeds?: mx.array; 46 | imageEmbeds?: mx.array; 47 | textModelOutput?: ClipTextOutput; 48 | visionModelOutput?: ClipVisionOutput; 49 | } 50 | 51 | class Attention extends nn.Module { 52 | numHeads: number; 53 | qProj: nn.Linear; 54 | kProj: nn.Linear; 55 | vProj: nn.Linear; 56 | outProj: nn.Linear; 57 | 58 | constructor(dims: number, 59 | numHeads: number, 60 | queryInputDims: number | null = null, 61 | keyInputDims: number | null = null, 62 | valueInputDims: number | null = null, 63 | valueDims: number | null = null, 64 | valueOutputDims: number | null = null, 65 | bias: boolean = true) { 66 | if (dims % numHeads != 0) { 67 | throw new Error(`The input feature dimensions should be divisible by the ` + 68 | `number of heads (${dims} % ${numHeads}) != 0`); 69 | } 70 | 71 | super(); 72 | 73 | queryInputDims = queryInputDims || dims; 74 | keyInputDims = keyInputDims || dims; 75 | valueInputDims = valueInputDims || keyInputDims; 76 | valueDims = valueDims || dims; 77 | valueOutputDims = valueOutputDims || dims; 78 | 79 | this.numHeads = numHeads; 80 | this.qProj = new nn.Linear(queryInputDims, dims, bias); 81 | this.kProj = new nn.Linear(keyInputDims, dims, bias); 82 | this.vProj = new nn.Linear(valueInputDims, valueDims, bias); 83 | this.outProj = new nn.Linear(valueDims, valueOutputDims, bias); 84 | } 85 | 86 | forward(queries: mx.array, keys: mx.array, values: mx.array, mask?: mx.array) { 87 | queries = this.qProj.forward(queries); 88 | keys = this.kProj.forward(keys); 89 | values = this.vProj.forward(values); 90 | 91 | const numHeads = this.numHeads; 92 | const [ B, L, D ] = queries.shape; 93 | const [ , S, ] = keys.shape; 94 | queries = queries.reshape(B, L, numHeads, -1).transpose(0, 2, 1, 3); 95 | keys = keys.reshape(B, S, numHeads, -1).transpose(0, 2, 3, 1); 96 | values = values.reshape(B, S, numHeads, -1).transpose(0, 2, 1, 3); 97 | 98 | const scale = Math.sqrt(1 / queries.shape.at(-1)); 99 | let scores = mx.matmul(mx.multiply(queries, scale), keys); 100 | if (mask) 101 | scores = mx.add(scores, mask.astype(scores.dtype)); 102 | scores = mx.softmax(scores, -1); 103 | const valuesHat = mx.matmul(scores, values).transpose(0, 2, 1, 3) 104 | .reshape(B, L, -1); 105 | 106 | return this.outProj.forward(valuesHat); 107 | } 108 | } 109 | 110 | class MLP extends nn.Module { 111 | activationFn: (x: mx.array) => mx.array; 112 | fc1: nn.Linear; 113 | fc2: nn.Linear; 114 | 115 | constructor(config: EncoderConfig) { 116 | super(); 117 | this.activationFn = quickGelu; 118 | this.fc1 = new nn.Linear(config.hiddenSize, config.intermediateSize); 119 | this.fc2 = new nn.Linear(config.intermediateSize, config.hiddenSize); 120 | } 121 | 122 | forward(x: mx.array): mx.array { 123 | x = this.activationFn(this.fc1.forward(x)); 124 | x = this.fc2.forward(x); 125 | return x; 126 | } 127 | } 128 | 129 | class EncoderLayer extends nn.Module { 130 | embedDim: number; 131 | selfAttn: Attention; 132 | layerNorm1: nn.LayerNorm; 133 | mlp: MLP; 134 | layerNorm2: nn.LayerNorm; 135 | 136 | constructor(config: EncoderConfig) { 137 | super(); 138 | this.embedDim = config.hiddenSize; 139 | this.selfAttn = new Attention(config.hiddenSize, config.numAttentionHeads); 140 | this.layerNorm1 = new nn.LayerNorm(this.embedDim, config.layerNormEps); 141 | this.mlp = new MLP(config); 142 | this.layerNorm2 = new nn.LayerNorm(this.embedDim, config.layerNormEps); 143 | } 144 | 145 | forward(x: mx.array, mask?: mx.array): mx.array { 146 | let y = this.layerNorm1.forward(x); 147 | y = this.selfAttn.forward(y, y, y, mask); 148 | x = mx.add(x, y); 149 | y = this.layerNorm2.forward(x); 150 | y = this.mlp.forward(y); 151 | return mx.add(x, y); 152 | } 153 | } 154 | 155 | class Encoder extends nn.Module { 156 | layers: EncoderLayer[] = []; 157 | 158 | constructor(config: EncoderConfig) { 159 | super(); 160 | for (let i = 0; i < config.numHiddenLayers; ++i) 161 | this.layers.push(new EncoderLayer(config)) 162 | } 163 | 164 | forward(h: mx.array, mask?: mx.array) { 165 | for (const layer of this.layers) 166 | h = layer.forward(h, mask); 167 | return h; 168 | } 169 | } 170 | 171 | class TextEmbeddings extends nn.Module { 172 | tokenEmbedding: nn.Embedding; 173 | positionEmbedding: nn.Embedding; 174 | 175 | constructor(config: ClipTextConfig) { 176 | super(); 177 | const embedDim = config.hiddenSize; 178 | this.tokenEmbedding = new nn.Embedding(config.vocabSize, embedDim); 179 | this.positionEmbedding = new nn.Embedding(config.maxPositionEmbeddings, embedDim); 180 | } 181 | 182 | forward(x: mx.array): mx.array { 183 | const embeddings = this.tokenEmbedding.forward(x.astype(mx.int32)); 184 | return mx.add(embeddings, 185 | this.positionEmbedding.weight.index(mx.Slice(null, x.shape[1]))); 186 | } 187 | } 188 | 189 | /** 190 | * Implements the text encoder transformer from CLIP. 191 | */ 192 | export class ClipTextModel extends nn.Module { 193 | embeddings: TextEmbeddings; 194 | encoder: Encoder; 195 | finalLayerNorm: nn.LayerNorm; 196 | 197 | constructor(config: ClipTextConfig) { 198 | super(); 199 | this.embeddings = new TextEmbeddings(config); 200 | this.encoder = new Encoder(config); 201 | this.finalLayerNorm = new nn.LayerNorm(config.hiddenSize); 202 | } 203 | 204 | forward(x: mx.array): ClipTextOutput { 205 | const [ B, N ] = x.shape; 206 | const eotTokens = mx.argmax(x, -1); 207 | x = this.embeddings.forward(x); 208 | const mask = nn.MultiHeadAttention.createAdditiveCausalMask(N, x.dtype); 209 | x = this.encoder.forward(x, mask); 210 | const lastHiddenState = this.finalLayerNorm.forward(x); 211 | const poolerOutput = lastHiddenState.index(mx.arange(B, mx.int32), eotTokens); 212 | return {poolerOutput, lastHiddenState}; 213 | } 214 | } 215 | 216 | class VisionEmbeddings extends nn.Module { 217 | embedDim: number; 218 | imageSize: number; 219 | patchSize: number; 220 | classEmbedding: mx.array; 221 | patchEmbedding: nn.Conv2d; 222 | numPatches: number; 223 | numPositions: number; 224 | positionEmbedding: nn.Embedding; 225 | 226 | constructor(config: ClipVisionConfig) { 227 | super(); 228 | this.embedDim = config.hiddenSize; 229 | this.imageSize = config.imageSize; 230 | this.patchSize = config.patchSize; 231 | 232 | this.classEmbedding = mx.zeros(config.hiddenSize); 233 | 234 | this.patchEmbedding = new nn.Conv2d(3, 235 | this.embedDim, 236 | this.patchSize, 237 | this.patchSize, 238 | undefined, 239 | undefined, 240 | false); 241 | 242 | this.numPatches = Math.pow(this.imageSize / this.patchSize, 2); 243 | this.numPositions = this.numPatches + 1; 244 | this.positionEmbedding = new nn.Embedding(this.numPositions, this.embedDim); 245 | } 246 | 247 | forward(x: mx.array): mx.array { 248 | const batchSize = x.shape[0]; 249 | // Patchify using conv: 250 | // [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim] 251 | let patchEmbeddings = this.patchEmbedding.forward(x); 252 | // [batch_size, num_patches, embed_dim] 253 | patchEmbeddings = mx.flatten(patchEmbeddings, 1, 2); 254 | const embedDim = patchEmbeddings.shape.at(-1); 255 | // Prepend embeddings 256 | // [batch_size, 1, embed_dim] 257 | const clsEmbeddings = mx.broadcastTo(this.classEmbedding, 258 | [ batchSize, 1, embedDim ]); 259 | // [batch_size, num_patches + 1, embed_dim] 260 | let embeddings = mx.concatenate([ clsEmbeddings, patchEmbeddings ], 1); 261 | // Add positional encoding 262 | embeddings = mx.add(embeddings, this.positionEmbedding.weight); 263 | return embeddings; 264 | } 265 | } 266 | 267 | /** 268 | * Implements the vision encoder transformer from CLIP. 269 | */ 270 | export class ClipVisionModel extends nn.Module { 271 | embeddings: VisionEmbeddings; 272 | preLayrnorm: nn.LayerNorm; 273 | encoder: Encoder; 274 | postLayernorm: nn.LayerNorm; 275 | 276 | constructor(config: ClipVisionConfig) { 277 | super(); 278 | this.embeddings = new VisionEmbeddings(config); 279 | this.preLayrnorm = new nn.LayerNorm(config.hiddenSize); 280 | this.encoder = new Encoder(config); 281 | this.postLayernorm = new nn.LayerNorm(config.hiddenSize); 282 | } 283 | 284 | forward(x: mx.array, outputHiddenStates = false): ClipVisionOutput { 285 | x = this.embeddings.forward(x); 286 | x = this.preLayrnorm.forward(x); 287 | 288 | const encoderStates = [ x ]; 289 | for (const layer of this.encoder.layers) { 290 | x = layer.forward(x); 291 | encoderStates.push(x); 292 | } 293 | 294 | const poolerOutput = this.postLayernorm.forward(x.index(mx.Slice(), 0, mx.Slice())); 295 | return { 296 | poolerOutput, 297 | lastHiddenState: x, 298 | hiddenStates: outputHiddenStates ? mx.stack(encoderStates) : undefined, 299 | }; 300 | } 301 | } 302 | 303 | export class ClipModel extends nn.Module { 304 | textModel: ClipTextModel; 305 | visionModel: ClipVisionModel; 306 | textEmbedDim: number; 307 | visionEmbedDim: number; 308 | projectionDim: number; 309 | visualProjection: nn.Linear; 310 | textProjection: nn.Linear; 311 | logitScale: mx.array; 312 | 313 | constructor(config: ClipConfig) { 314 | super(); 315 | this.textModel = new ClipTextModel(config.textConfig); 316 | this.visionModel = new ClipVisionModel(config.visionConfig); 317 | 318 | this.textEmbedDim = config.textConfig.hiddenSize; 319 | this.visionEmbedDim = config.visionConfig.hiddenSize; 320 | this.projectionDim = config.projectionDim; 321 | 322 | this.visualProjection = new nn.Linear(this.visionEmbedDim, this.projectionDim, false); 323 | this.textProjection = new nn.Linear(this.textEmbedDim, this.projectionDim, false); 324 | this.logitScale = mx.array(0.); 325 | } 326 | 327 | forward({inputIds, pixelValues, returnLoss} : ClipModelInput): ClipModelOutput { 328 | let textEmbeds, textModelOutput, imageEmbeds, visionModelOutput; 329 | if (inputIds) { 330 | textModelOutput = this.textModel.forward(inputIds); 331 | textEmbeds = this.textProjection.forward(textModelOutput.poolerOutput); 332 | textEmbeds = mx.divide(textEmbeds, mx.linalg.norm(textEmbeds, undefined, -1, true)); 333 | } 334 | if (pixelValues) { 335 | visionModelOutput = this.visionModel.forward(pixelValues); 336 | imageEmbeds = this.visualProjection.forward(visionModelOutput.poolerOutput); 337 | imageEmbeds = mx.divide(imageEmbeds, mx.linalg.norm(imageEmbeds, undefined, -1, true)); 338 | } 339 | 340 | if (returnLoss && (!inputIds || !pixelValues)) { 341 | throw new Error("Must provide text and image inputs to compute loss."); 342 | } 343 | 344 | let loss; 345 | if (returnLoss) { 346 | const logits = mx.multiply(mx.matmul(textEmbeds, imageEmbeds.T), 347 | mx.exp(this.logitScale)); 348 | loss = clipLoss(logits); 349 | } 350 | 351 | return {loss, textEmbeds, imageEmbeds, visionModelOutput, textModelOutput}; 352 | } 353 | } 354 | 355 | // A fast GELU approximation https://github.com/hendrycks/GELUs 356 | function quickGelu(x: mx.array): mx.array { 357 | return mx.multiply(x, mx.sigmoid(mx.multiply(1.702, x))); 358 | } 359 | 360 | // Compute loss of CLIP model's output. 361 | function clipLoss(logits: mx.array): mx.array { 362 | const [ N, M ] = logits.shape; 363 | const captionLoss = nn.losses.crossEntropy(logits, mx.arange(N, mx.int32), undefined, undefined, undefined, 'mean'); 364 | const imageLoss = nn.losses.crossEntropy(logits.T, mx.arange(M, mx.int32), undefined, undefined, undefined, 'mean'); 365 | return mx.divide(mx.add(captionLoss, imageLoss), 366 | 2.0); 367 | } 368 | -------------------------------------------------------------------------------- /test.ts: -------------------------------------------------------------------------------- 1 | import {Clip} from './src/index.ts'; 2 | 3 | main(); 4 | 5 | async function main() { 6 | const clip = new Clip(process.argv[2] ?? 'clip-vit-large-patch14'); 7 | const labelEmbeddings = clip.computeLabelEmbeddingsJs([ 'seagull', 'lovely dog' ]); 8 | const imageEmbeddings = clip.computeImageEmbeddingsJs( 9 | await clip.processImages(await Promise.all([ 10 | download('https://d29fhpw069ctt2.cloudfront.net/photo/34910/preview/u3x7cekkS16ajjtJcb5L_DSC_5869_npreviews_9e55.jpg'), 11 | download('https://d29fhpw069ctt2.cloudfront.net/photo/35183/preview/UzWklzFdRBSbkRKhEnvc_1-6128_npreviews_79e3.jpg'), 12 | ]))); 13 | const [ scores, indices ] = Clip.computeCosineSimilarities(labelEmbeddings, 14 | imageEmbeddings); 15 | console.log('Cosine similarity:', scores.tolist()); 16 | } 17 | 18 | async function download(url) { 19 | const response = await fetch(url); 20 | return Buffer.from(await response.arrayBuffer()); 21 | } 22 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "dist", 4 | "rootDir": "src", 5 | "declaration": true, 6 | "module": "commonjs", 7 | "target": "es2023", 8 | "lib": [ "esnext" ], 9 | "esModuleInterop": true 10 | }, 11 | "include": [ "src" ], 12 | "exclude": [ "node_modules" ] 13 | } 14 | --------------------------------------------------------------------------------