├── .github └── workflows │ ├── node-client-tests.yml │ ├── pylint.yml │ ├── python-client-tests.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── demo-1.jpg └── demo-2.jpg ├── batch_generate_example.py ├── clients ├── node │ ├── .gitignore │ ├── .npmignore │ ├── README.MD │ ├── jest.config.js │ ├── package-lock.json │ ├── package.json │ ├── src │ │ ├── __tests__ │ │ │ ├── moondream.integration.test.ts │ │ │ ├── moondream.test.ts │ │ │ └── setup.ts │ │ ├── moondream.ts │ │ └── types.ts │ └── tsconfig.json └── python │ ├── README.md │ ├── moondream │ ├── __init__.py │ ├── cli.py │ ├── cloud_vl.py │ ├── moonfile.py │ ├── onnx_vl.py │ ├── preprocess.py │ ├── server.py │ ├── torch_vl.py │ ├── types.py │ └── version.py │ ├── pyproject.toml │ ├── scripts │ ├── test.py │ ├── test_cloud_parity.py │ └── test_local_server.py │ └── tests │ ├── test_api_inference.py │ ├── test_local_inference.py │ └── test_local_torch_inference.py ├── gradio_demo.py ├── moondream ├── __init__.py ├── config │ ├── config_md05.json │ └── config_md2.json ├── eval │ ├── chartqa.py │ ├── coco_map.py │ ├── countbenchqa.py │ ├── docvqa.py │ ├── eval_all.py │ ├── gazefollow.py │ ├── mmstar.py │ ├── naturalbench.py │ ├── pope.py │ ├── realworldqa.py │ ├── tallyqa.py │ ├── textvqa.py │ ├── utils.py │ └── waste_detection.py ├── finetune │ ├── README.md │ ├── __init__.py │ ├── finetune_region.py │ └── finetune_text.py └── torch │ ├── config.py │ ├── hf_moondream.py │ ├── hf_release.py │ ├── image_crops.py │ ├── layers.py │ ├── moondream.py │ ├── region.py │ ├── rope.py │ ├── sample.py │ ├── text.py │ ├── utils.py │ ├── vision.py │ └── weights.py ├── notebooks └── RepEng.ipynb ├── recipes ├── gaze-detection-video │ ├── .gitignore │ ├── README.md │ ├── gaze-detection-video.py │ ├── input │ │ └── .gitkeep │ ├── output │ │ └── .gitkeep │ ├── requirements.txt │ └── temp │ │ └── .gitkeep ├── promptable-content-moderation │ ├── .gitignore │ ├── README.md │ ├── app.py │ ├── deep_sort_integration.py │ ├── main.py │ ├── packages.txt │ ├── persistence.py │ ├── requirements.txt │ ├── video_visualization.py │ └── visualization.py └── promptable-video-redaction │ ├── .gitignore │ ├── README.md │ ├── app.py │ ├── main.py │ ├── packages.txt │ └── requirements.txt ├── requirements.txt ├── sample.py ├── tests └── test_image_crops.py └── webcam_gradio_demo.py /.github/workflows/node-client-tests.yml: -------------------------------------------------------------------------------- 1 | name: Node.js Client Tests 2 | 3 | on: 4 | # Only run on PRs to avoid duplicate runs 5 | pull_request: 6 | paths: 7 | - 'clients/node/**' 8 | - '.github/workflows/node-client-tests.yml' 9 | 10 | permissions: 11 | contents: read 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | node-version: [18.x, 20.x] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Node.js ${{ matrix.node-version }} 23 | uses: actions/setup-node@v4 24 | with: 25 | node-version: ${{ matrix.node-version }} 26 | cache: 'npm' 27 | cache-dependency-path: clients/node/package-lock.json 28 | 29 | - name: Install dependencies 30 | working-directory: ./clients/node 31 | run: npm ci 32 | 33 | - name: Run unit tests 34 | working-directory: ./clients/node 35 | run: npm test -- src/__tests__/moondream.test.ts 36 | 37 | - name: Run integration tests 38 | working-directory: ./clients/node 39 | env: 40 | MOONDREAM_API_KEY: ${{ secrets.MOONDREAM_API_KEY }} 41 | run: npm test -- src/__tests__/moondream.integration.test.ts 42 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | permissions: 10 | contents: read 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | strategy: 17 | matrix: 18 | python-version: ["3.12"] # Run lint checks only on latest Python version 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install autoflake black 29 | - name: Checking for unused imports 30 | run: | 31 | autoflake -c -r . 32 | - name: Checking code style 33 | run: | 34 | black --check . 35 | -------------------------------------------------------------------------------- /.github/workflows/python-client-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python Client Tests 2 | 3 | on: 4 | # Only run on PRs to avoid duplicate runs 5 | pull_request: 6 | paths: 7 | - 'clients/python/**' 8 | - '.github/workflows/python-client-tests.yml' 9 | 10 | permissions: 11 | contents: read 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ['3.10', '3.11', '3.12'] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: 'pip' 27 | 28 | - name: Install Poetry 29 | run: | 30 | curl -sSL https://install.python-poetry.org | python3 - 31 | 32 | - name: Cache Poetry dependencies 33 | uses: actions/cache@v3 34 | with: 35 | path: ~/.cache/pypoetry 36 | key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} 37 | restore-keys: | 38 | ${{ runner.os }}-poetry-${{ matrix.python-version }}- 39 | 40 | - name: Install dependencies 41 | working-directory: ./clients/python 42 | run: | 43 | poetry install --all-extras 44 | 45 | - name: Format code 46 | working-directory: ./clients/python 47 | run: | 48 | poetry run pip install black 49 | poetry run black tests/test_local_inference.py --check 50 | 51 | - name: Run tests 52 | working-directory: ./clients/python 53 | env: 54 | MOONDREAM_API_KEY: ${{ secrets.MOONDREAM_API_KEY }} 55 | run: | 56 | poetry run pip install pytest pytest-asyncio 57 | poetry run pytest tests/test_api_inference.py -v 58 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install pytest 26 | pip install -r requirements.txt 27 | - name: Run tests 28 | run: | 29 | python -m pytest tests/test_image_crops.py -v -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__ 3 | checkpoints 4 | data 5 | /pyproject.toml 6 | poetry.lock 7 | dist 8 | clients/python/moondream/torch 9 | wandb/ 10 | moondream_finetune.safetensors 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌔 moondream 2 | 3 | a tiny vision language model that kicks ass and runs anywhere 4 | 5 | [Website](https://moondream.ai/) | [Demo](https://moondream.ai/playground) 6 | 7 | ## Examples 8 | 9 | | Image | Example | 10 | | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 11 | | ![](assets/demo-1.jpg) | **What is the girl doing?**
The girl is sitting at a table and eating a large hamburger.

**What color is the girl's hair?**
The girl's hair is white. | 12 | | ![](assets/demo-2.jpg) | **What is this?**
This is a computer server rack, which is a device used to store and manage multiple computer servers. The rack is filled with various computer servers, each with their own dedicated space and power supply. The servers are connected to the rack via multiple cables, indicating that they are part of a larger system. The rack is placed on a carpeted floor, and there is a couch nearby, suggesting that the setup is in a living or entertainment area.

**What is behind the stand?**
Behind the stand, there is a brick wall. | 13 | 14 | ## About 15 | 16 | Moondream is a highly efficient open-source vision language model that combines powerful image understanding capabilities with a remarkably small footprint. It's designed to be versatile and accessible, capable of running on a wide range of devices and platforms. 17 | 18 | The project offers two model variants: 19 | 20 | - **Moondream 2B**: The primary model with 2 billion parameters, offering robust performance for general-purpose image understanding tasks including captioning, visual question answering, and object detection. 21 | - **Moondream 0.5B**: A compact 500 million parameter model specifically optimized as a distillation target for edge devices, enabling efficient deployment on resource-constrained hardware while maintaining impressive capabilities. 22 | 23 | ## How to use 24 | 25 | Moondream can be run locally, or in the cloud. Please refer to the [Getting Started](https://moondream.ai/c/docs/quickstart) page for details. 26 | 27 | ## Special thanks 28 | 29 | * [Modal](https://modal.com/?utm_source=github&utm_medium=github&utm_campaign=moondream) - Modal lets you run jobs in the cloud, by just writing a few lines of Python. Here's an [example of how to run Moondream on Modal](https://github.com/m87-labs/moondream-examples/tree/main/quickstart/modal). 30 | -------------------------------------------------------------------------------- /assets/demo-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/assets/demo-1.jpg -------------------------------------------------------------------------------- /assets/demo-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/assets/demo-2.jpg -------------------------------------------------------------------------------- /batch_generate_example.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from transformers import AutoTokenizer 3 | 4 | from moondream.hf import LATEST_REVISION, Moondream, detect_device 5 | 6 | device, dtype = detect_device() 7 | 8 | model_id = "vikhyatk/moondream2" 9 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) 10 | moondream = Moondream.from_pretrained( 11 | model_id, 12 | revision=LATEST_REVISION, 13 | torch_dtype=dtype, 14 | ).to(device=device) 15 | moondream.eval() 16 | 17 | image1 = Image.open("assets/demo-1.jpg") 18 | image2 = Image.open("assets/demo-2.jpg") 19 | prompts = [ 20 | "What is the girl doing?", 21 | "What color is the girl's hair?", 22 | "What is this?", 23 | "What is behind the stand?", 24 | ] 25 | 26 | answers = moondream.batch_answer( 27 | images=[image1, image1, image2, image2], 28 | prompts=prompts, 29 | tokenizer=tokenizer, 30 | ) 31 | 32 | for question, answer in zip(prompts, answers): 33 | print(f"Q: {question}") 34 | print(f"A: {answer}") 35 | print() 36 | -------------------------------------------------------------------------------- /clients/node/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | coverage 3 | .env 4 | .env.local 5 | .env.development 6 | .env.test 7 | .env.production 8 | dist/ 9 | build/ 10 | lib/ 11 | out/ 12 | *.log 13 | npm-debug.log* 14 | yarn-debug.log* 15 | yarn-error.log* 16 | .DS_Store 17 | .idea/ 18 | .vscode/ 19 | *.swp 20 | *.swo 21 | .next/ 22 | .cache/ 23 | .parcel-cache/ 24 | tsconfig.tsbuildinfo -------------------------------------------------------------------------------- /clients/node/.npmignore: -------------------------------------------------------------------------------- 1 | /src/ 2 | __tests__/ 3 | -------------------------------------------------------------------------------- /clients/node/README.MD: -------------------------------------------------------------------------------- 1 | # Moondream NodeJS Client Library 2 | 3 | Official NodeJS client library for Moondream, a tiny vision language model that can 4 | analyze images and answer questions about them. This client library provides easy 5 | access to Moondream's API endpoints for image analysis. 6 | 7 | ## Installation 8 | 9 | Install the package using npm: 10 | 11 | ```bash 12 | npm install moondream 13 | ``` 14 | 15 | Or using yarn: 16 | 17 | ```bash 18 | yarn add moondream 19 | ``` 20 | 21 | ## Quick Start 22 | 23 | Before using this client library, you'll need an API key to access Moondream's hosted service. 24 | You can get a free API key from [console.moondream.ai](https://console.moondream.ai). 25 | 26 | ### Cloud 27 | 28 | ```javascript 29 | import { vl } from "moondream"; 30 | import fs from "fs"; 31 | 32 | // Initialize the client 33 | const model = new vl({ 34 | apiKey: "your-api-key", 35 | }); 36 | 37 | // Read an image file 38 | const image = fs.readFileSync("path/to/image.jpg"); 39 | 40 | // Basic usage examples 41 | async function main() { 42 | // Generate a caption for the image 43 | const caption = await model.caption({ 44 | image: image, 45 | length: "normal", 46 | stream: false 47 | }); 48 | console.log("Caption:", caption); 49 | 50 | // Ask a question about the image 51 | const answer = await model.query({ 52 | image: image, 53 | question: "What's in this image?", 54 | stream: false 55 | }); 56 | console.log("Answer:", answer); 57 | 58 | // Stream the response 59 | const stream = await model.caption({ 60 | image: image, 61 | length: "normal", 62 | stream: true 63 | }); 64 | for await (const chunk of stream.caption) { 65 | process.stdout.write(chunk); 66 | } 67 | } 68 | 69 | main(); 70 | ``` 71 | 72 | ### Local Inference 73 | 74 | - Install the `moondream` CLI: `pip install moondream` 75 | - Run the local server: `moondream serve --model ` 76 | - Set the `apiUrl` parameter to the URL of the local server (the default is `http://localhost:3475`) 77 | 78 | ```javascript 79 | const model = new vl({ 80 | apiUrl: "http://localhost:3475", 81 | }); 82 | 83 | const image = fs.readFileSync("path/to/image.jpg"); 84 | 85 | // Basic usage examples 86 | async function main() { 87 | // Generate a caption for the image 88 | const caption = await model.caption({ 89 | image: image, 90 | length: "normal", 91 | stream: false 92 | }); 93 | console.log("Caption:", caption); 94 | 95 | // Ask a question about the image 96 | const answer = await model.query({ 97 | image: image, 98 | question: "What's in this image?", 99 | stream: false 100 | }); 101 | console.log("Answer:", answer); 102 | 103 | // Stream the response 104 | const stream = await model.caption({ 105 | image: image, 106 | length: "normal", 107 | stream: true 108 | }); 109 | for await (const chunk of stream.caption) { 110 | process.stdout.write(chunk); 111 | } 112 | } 113 | 114 | main(); 115 | ``` 116 | 117 | ## Features 118 | 119 | - **caption**: Generate descriptive captions for images 120 | - **query**: Ask questions about image content 121 | - **detect**: Find bounding boxes around objects in images 122 | - **point**: Identify the center location of specified objects in images 123 | 124 | ## API Reference 125 | 126 | ### Constructor 127 | 128 | ```javascript 129 | // for cloud inference 130 | const model = new vl({ 131 | apiKey: "your-api-key", 132 | }); 133 | 134 | // or for local inference 135 | const model = new vl({ 136 | apiUrl: "http://localhost:3475", 137 | }); 138 | ``` 139 | 140 | ### Methods 141 | 142 | #### caption({ image: string, length: string, stream?: boolean }) 143 | 144 | Generate a caption for an image. 145 | 146 | ```javascript 147 | const result = await model.caption({ 148 | image: image, 149 | length: "normal", 150 | stream: false 151 | }); 152 | 153 | // or with streaming 154 | const stream = await model.caption({ 155 | image: image, 156 | length: "normal", 157 | stream: true 158 | }); 159 | ``` 160 | 161 | #### query({ image: string, question: string, stream?: boolean }) 162 | 163 | Ask a question about an image. 164 | 165 | ```javascript 166 | const result = await model.query({ 167 | image: image, 168 | question: "What's in this image?", 169 | stream: false 170 | }); 171 | 172 | // or with streaming 173 | const stream = await model.query({ 174 | image: image, 175 | question: "What's in this image?", 176 | stream: true 177 | }); 178 | ``` 179 | 180 | #### detect({ image: string, object: string }) 181 | 182 | Detect specific objects in an image. 183 | 184 | ```javascript 185 | const result = await model.detect({ 186 | image: image, 187 | object: "car" 188 | }); 189 | ``` 190 | 191 | #### point({ image: string, object: string }) 192 | 193 | Get coordinates of specific objects in an image. 194 | 195 | ```javascript 196 | const result = await model.point({ 197 | image: image, 198 | object: "person" 199 | }); 200 | ``` 201 | 202 | ### Input Types 203 | 204 | - Images can be provided as: 205 | - Buffer: Raw image data 206 | - Base64EncodedImage: `{ imageUrl: string }` 207 | 208 | ### Response Types 209 | 210 | All methods return promises that resolve to typed responses: 211 | 212 | - CaptionOutput: `{ caption: string | AsyncGenerator }` 213 | - QueryOutput: `{ answer: string | AsyncGenerator }` 214 | - DetectOutput: `{ objects: Array }` 215 | - PointOutput: `{ points: Array }` 216 | 217 | ## Links 218 | 219 | - [Website](https://moondream.ai/) 220 | - [Demo](https://moondream.ai/playground) 221 | -------------------------------------------------------------------------------- /clients/node/jest.config.js: -------------------------------------------------------------------------------- 1 | // jest.config.js 2 | module.exports = { 3 | preset: 'ts-jest', 4 | testEnvironment: 'node', 5 | roots: ['/src'], 6 | testMatch: ['**/__tests__/**/*.test.ts'], 7 | moduleFileExtensions: ['ts', 'js', 'json', 'node'], 8 | collectCoverage: false, 9 | setupFiles: ['/src/__tests__/setup.ts'] 10 | }; -------------------------------------------------------------------------------- /clients/node/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "moondream", 3 | "version": "0.0.5", 4 | "description": "TypeScript client for the Moondream AI API", 5 | "main": "dist/src/moondream.js", 6 | "types": "dist/src/moondream.d.ts", 7 | "scripts": { 8 | "build": "tsc", 9 | "test": "jest", 10 | "prepare": "npm run build" 11 | }, 12 | "keywords": [ 13 | "moondream", 14 | "ai", 15 | "vision", 16 | "client" 17 | ], 18 | "author": "", 19 | "license": "", 20 | "dependencies": { 21 | "sharp": "^0.32.1" 22 | }, 23 | "devDependencies": { 24 | "@types/jest": "^29.0.0", 25 | "@types/node": "^18.0.0", 26 | "@types/node-fetch": "^2.6.1", 27 | "dotenv": "^16.4.7", 28 | "jest": "^29.0.0", 29 | "ts-jest": "^29.0.0", 30 | "typescript": "^4.9.5" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /clients/node/src/__tests__/setup.ts: -------------------------------------------------------------------------------- 1 | // Jest setup file - currently empty as we don't need any global setup -------------------------------------------------------------------------------- /clients/node/src/moondream.ts: -------------------------------------------------------------------------------- 1 | import { Buffer } from 'buffer'; 2 | import sharp from 'sharp'; 3 | import http from 'http'; 4 | import https from 'https'; 5 | import { version } from '../package.json'; 6 | import { 7 | Base64EncodedImage, 8 | CaptionOutput, 9 | QueryOutput, 10 | DetectOutput, 11 | PointOutput, 12 | CaptionRequest, 13 | QueryRequest, 14 | DetectRequest, 15 | PointRequest, 16 | } from './types'; 17 | 18 | export interface MoondreamVLConfig { 19 | apiKey?: string; 20 | apiUrl?: string; 21 | } 22 | const DEFAULT_API_URL = 'https://api.moondream.ai/v1'; 23 | 24 | export class vl { 25 | private apiKey: string; 26 | private apiUrl: string; 27 | 28 | constructor(config: MoondreamVLConfig) { 29 | this.apiKey = config.apiKey || ''; 30 | this.apiUrl = config.apiUrl || DEFAULT_API_URL; 31 | if (this.apiKey === '' && this.apiUrl === DEFAULT_API_URL) { 32 | throw new Error( 33 | 'An apiKey is required for cloud inference. ' 34 | ); 35 | } 36 | } 37 | 38 | private async encodeImage( 39 | image: Buffer | Base64EncodedImage 40 | ): Promise { 41 | if ('imageUrl' in image) { 42 | return image; 43 | } 44 | 45 | try { 46 | const MAX_SIZE = 768; 47 | 48 | // Process image with Sharp 49 | const metadata = await sharp(image).metadata(); 50 | 51 | if (!metadata.width || !metadata.height) { 52 | throw new Error('Unable to get image dimensions'); 53 | } 54 | 55 | const scale = MAX_SIZE / Math.max(metadata.width, metadata.height); 56 | let processedImage = sharp(image); 57 | 58 | if (scale < 1) { 59 | processedImage = processedImage.resize( 60 | Math.round(metadata.width * scale), 61 | Math.round(metadata.height * scale), 62 | { 63 | fit: 'inside', 64 | withoutEnlargement: true, 65 | } 66 | ); 67 | } 68 | 69 | const buffer = await processedImage 70 | .toFormat('jpeg', { quality: 95 }) 71 | .toBuffer(); 72 | 73 | const base64Image = buffer.toString('base64'); 74 | return { 75 | imageUrl: `data:image/jpeg;base64,${base64Image}`, 76 | }; 77 | } catch (error) { 78 | throw new Error( 79 | `Failed to convert image to JPEG: ${(error as Error).message}` 80 | ); 81 | } 82 | } 83 | 84 | private makeRequest(path: string, body: any, stream: boolean = false): Promise { 85 | return new Promise((resolve, reject) => { 86 | const url = new URL(this.apiUrl + path); 87 | const requestBody = JSON.stringify(body); 88 | 89 | const options = { 90 | method: 'POST', 91 | headers: { 92 | 'X-Moondream-Auth': this.apiKey, 93 | 'Content-Type': 'application/json', 94 | 'User-Agent': `moondream-node/${version}`, 95 | 'Content-Length': Buffer.byteLength(requestBody) 96 | } 97 | }; 98 | 99 | const client = url.protocol === 'https:' ? https : http; 100 | const req = client.request(url, options, (res) => { 101 | if (stream) { 102 | resolve(res); 103 | return; 104 | } 105 | 106 | let data = ''; 107 | res.on('data', chunk => data += chunk); 108 | res.on('end', () => { 109 | if (res.statusCode !== 200) { 110 | reject(new Error(`HTTP error! status: ${res.statusCode}`)); 111 | return; 112 | } 113 | try { 114 | resolve(JSON.parse(data)); 115 | } catch (error) { 116 | reject(new Error(`Failed to parse JSON response: ${(error as Error).message}`)); 117 | } 118 | }); 119 | }); 120 | 121 | req.on('error', (error) => { 122 | reject(error); 123 | }); 124 | 125 | req.write(requestBody); 126 | req.end(); 127 | }); 128 | } 129 | 130 | private async* streamResponse(response: any): AsyncGenerator { 131 | let buffer = ''; 132 | 133 | try { 134 | for await (const chunk of response) { 135 | buffer += chunk.toString(); 136 | const lines = buffer.split('\n'); 137 | buffer = lines.pop() || ''; 138 | 139 | for (const line of lines) { 140 | if (line.startsWith('data: ')) { 141 | try { 142 | const data = JSON.parse(line.slice(6)); 143 | if ('chunk' in data) { 144 | yield data.chunk; 145 | } 146 | if (data.completed) { 147 | return; 148 | } 149 | } catch (error) { 150 | throw new Error(`Failed to parse JSON response from server: ${(error as Error).message}`); 151 | } 152 | } 153 | } 154 | } 155 | 156 | // Handle any remaining data in the buffer 157 | if (buffer) { 158 | const lines = buffer.split('\n'); 159 | for (const line of lines) { 160 | if (line.startsWith('data: ')) { 161 | try { 162 | const data = JSON.parse(line.slice(6)); 163 | if ('chunk' in data) { 164 | yield data.chunk; 165 | } 166 | } catch (error) { 167 | throw new Error(`Failed to parse JSON response from server: ${(error as Error).message}`); 168 | } 169 | } 170 | } 171 | } 172 | } catch (error) { 173 | throw new Error(`Failed to stream response: ${(error as Error).message}`); 174 | } 175 | } 176 | 177 | public async caption( 178 | request: CaptionRequest 179 | ): Promise { 180 | const encodedImage = await this.encodeImage(request.image); 181 | 182 | const response = await this.makeRequest('/caption', { 183 | image_url: encodedImage.imageUrl, 184 | length: request.length, 185 | stream: request.stream, 186 | }, request.stream); 187 | 188 | if (request.stream) { 189 | return { caption: this.streamResponse(response) }; 190 | } 191 | 192 | return { caption: response.caption }; 193 | } 194 | 195 | public async query( 196 | request: QueryRequest 197 | ): Promise { 198 | const encodedImage = await this.encodeImage(request.image); 199 | 200 | const response = await this.makeRequest('/query', { 201 | image_url: encodedImage.imageUrl, 202 | question: request.question, 203 | stream: request.stream, 204 | }, request.stream); 205 | 206 | if (request.stream) { 207 | return { answer: this.streamResponse(response) }; 208 | } 209 | 210 | return { answer: response.answer }; 211 | } 212 | 213 | public async detect( 214 | request: DetectRequest 215 | ): Promise { 216 | const encodedImage = await this.encodeImage(request.image); 217 | 218 | const response = await this.makeRequest('/detect', { 219 | image_url: encodedImage.imageUrl, 220 | object: request.object, 221 | }); 222 | 223 | return { objects: response.objects }; 224 | } 225 | 226 | public async point( 227 | request: PointRequest 228 | ): Promise { 229 | const encodedImage = await this.encodeImage(request.image); 230 | 231 | const response = await this.makeRequest('/point', { 232 | image_url: encodedImage.imageUrl, 233 | object: request.object, 234 | }); 235 | 236 | return { points: response.points }; 237 | } 238 | } -------------------------------------------------------------------------------- /clients/node/src/types.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Base interface for encoded images 3 | */ 4 | export interface Base64EncodedImage { 5 | imageUrl: string; 6 | } 7 | 8 | /** 9 | * Length options for caption generation 10 | */ 11 | export type Length = 'normal' | 'short'; 12 | 13 | /** 14 | * Settings for controlling the model's generation behavior 15 | */ 16 | export interface SamplingSettings { 17 | maxTokens?: number; 18 | } 19 | 20 | /** 21 | * Request structure for image caption requests 22 | */ 23 | export interface CaptionRequest { 24 | image: Buffer | Base64EncodedImage; 25 | length?: Length; 26 | stream?: boolean; 27 | settings?: SamplingSettings; 28 | } 29 | /** 30 | * Response structure for image caption requests 31 | */ 32 | export interface CaptionOutput { 33 | caption: string | AsyncGenerator; 34 | } 35 | 36 | /** 37 | * Request structure for image query requests 38 | */ 39 | export interface QueryRequest { 40 | image: Buffer | Base64EncodedImage; 41 | question: string; 42 | stream?: boolean; 43 | settings?: SamplingSettings; 44 | } 45 | /** 46 | * Response structure for image query requests 47 | */ 48 | export interface QueryOutput { 49 | answer: string | AsyncGenerator; 50 | } 51 | 52 | /** 53 | * Request structure for object detection requests 54 | */ 55 | export interface DetectRequest { 56 | image: Buffer | Base64EncodedImage; 57 | object: string; 58 | } 59 | /** 60 | * Response structure for object detection requests 61 | */ 62 | export interface DetectOutput { 63 | objects: DetectedObject[]; 64 | } 65 | 66 | /** 67 | * Object detection result 68 | */ 69 | export interface DetectedObject { 70 | x_min: number; 71 | y_min: number; 72 | x_max: number; 73 | y_max: number; 74 | } 75 | 76 | /** 77 | * Response structure for object detection requests 78 | */ 79 | export interface DetectOutput { 80 | objects: DetectedObject[]; 81 | } 82 | 83 | /** 84 | * Error response from the API 85 | */ 86 | export interface ApiError { 87 | error: { 88 | message: string; 89 | code?: string; 90 | details?: unknown; 91 | }; 92 | } 93 | 94 | /** 95 | * Configuration options for the client 96 | */ 97 | export interface ClientConfig { 98 | apiKey: string; 99 | apiUrl?: string; 100 | timeout?: number; 101 | retries?: number; 102 | } 103 | 104 | /** 105 | * API response for streaming requests 106 | */ 107 | export interface StreamResponse { 108 | chunk?: string; 109 | completed?: boolean; 110 | error?: string; 111 | } 112 | 113 | /** 114 | * Options for image processing 115 | */ 116 | export interface ImageProcessingOptions { 117 | maxSize?: number; 118 | quality?: number; 119 | format?: 'jpeg' | 'png'; 120 | } 121 | 122 | /** 123 | * Common response type for all API endpoints 124 | */ 125 | export type ApiResponse = { 126 | success: boolean; 127 | data?: T; 128 | error?: ApiError; 129 | timestamp?: string; 130 | requestId?: string; 131 | } 132 | 133 | /** 134 | * Pointing request structure 135 | */ 136 | export interface PointRequest { 137 | image: Buffer | Base64EncodedImage; 138 | object: string; 139 | } 140 | /** 141 | * Point coordinates for object location 142 | */ 143 | export interface Point { 144 | x: number; 145 | y: number; 146 | } 147 | 148 | /** 149 | * Response structure for point requests 150 | */ 151 | export interface PointOutput { 152 | points: Point[]; 153 | } -------------------------------------------------------------------------------- /clients/node/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es2020", 4 | "module": "commonjs", 5 | "declaration": true, 6 | "outDir": "./dist", 7 | "strict": true, 8 | "esModuleInterop": true, 9 | "skipLibCheck": true, 10 | "forceConsistentCasingInFileNames": true, 11 | "moduleResolution": "node", 12 | "resolveJsonModule": true 13 | }, 14 | "include": ["src"], 15 | "exclude": ["node_modules", "dist", "test"] 16 | } -------------------------------------------------------------------------------- /clients/python/README.md: -------------------------------------------------------------------------------- 1 | # Moondream Python Client Library 2 | 3 | Official Python client library for Moondream, a tiny vision language model that 4 | can analyze images and answer questions about them. This library supports both 5 | local inference and cloud-based API access. 6 | 7 | ## Features 8 | 9 | - **Local Inference**: Run the model directly on your machine using CPU 10 | - **Cloud API**: Access Moondream's hosted service for faster inference 11 | - **Streaming**: Stream responses token by token for real-time output 12 | - **Multiple Model Sizes**: Choose between 0.5B and 2B parameter models 13 | - **Multiple Tasks**: Caption images, answer questions, detect objects, and locate points 14 | 15 | ## Installation 16 | 17 | Install the package from PyPI: 18 | 19 | ```bash 20 | pip install moondream==0.0.6 21 | ``` 22 | 23 | To install the CPU dependencies for local inference, run: 24 | 25 | ```bash 26 | pip install moondream[cpu] 27 | ``` 28 | 29 | To install the GPU dependencies for local inference, run: 30 | 31 | ```bash 32 | # Copy the torch implementation from the root moondream repo into the moondream/torch directory 33 | cp -r moondream/torch clients/python/moondream/torch 34 | 35 | # Install the GPU dependencies 36 | pip install moondream[gpu] 37 | ``` 38 | 39 | ## Quick Start 40 | 41 | ### Using Cloud API 42 | 43 | To use Moondream's cloud API, you'll first need an API key. Sign up for a free 44 | account at [console.moondream.ai](https://console.moondream.ai) to get your key. 45 | Once you have your key, you can use it to initialize the client as shown below. 46 | 47 | ```python 48 | import moondream as md 49 | from PIL import Image 50 | 51 | # Initialize with API key 52 | model = md.vl(api_key="your-api-key") 53 | 54 | # Load an image 55 | image = Image.open("path/to/image.jpg") 56 | 57 | # Generate a caption 58 | caption = model.caption(image)["caption"] 59 | print("Caption:", caption) 60 | 61 | # Ask a question 62 | answer = model.query(image, "What's in this image?")["answer"] 63 | print("Answer:", answer) 64 | 65 | # Stream the response 66 | for chunk in model.caption(image, stream=True)["caption"]: 67 | print(chunk, end="", flush=True) 68 | ``` 69 | 70 | ### Using Local Inference 71 | 72 | First, download the model weights. We recommend the int8 weights for most applications: 73 | 74 | | Model | Precision | Download Size | Memory Usage | Download Link | 75 | | -------------- | --------- | ------------- | ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------- | 76 | | Moondream 2B | int8 | 1,733 MiB | 2,624 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-2b-int8.mf.gz?download=true) | 77 | | Moondream 2B | int4 | 1,167 MiB | 2,002 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-2b-int4.mf.gz?download=true) | 78 | | Moondream 0.5B | int8 | 593 MiB | 996 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-0_5b-int8.mf.gz?download=true) | 79 | | Moondream 0.5B | int4 | 422 MiB | 816 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-0_5b-int4.mf.gz?download=true) | 80 | 81 | Then use the model locally: 82 | 83 | ```python 84 | import moondream as md 85 | from PIL import Image 86 | 87 | # Initialize with local model path 88 | model = md.vl(model="path/to/moondream-2b-int8.bin") 89 | 90 | # Load and encode image 91 | image = Image.open("path/to/image.jpg") 92 | 93 | # Since encoding an image is computationally expensive, you can encode it once 94 | # and reuse the encoded version for multiple queries/captions/etc. This avoids 95 | # having to re-encode the same image multiple times. 96 | encoded_image = model.encode_image(image) 97 | 98 | # Generate caption 99 | caption = model.caption(encoded_image)["caption"] 100 | print("Caption:", caption) 101 | 102 | # Ask questions 103 | answer = model.query(encoded_image, "What's in this image?")["answer"] 104 | print("Answer:", answer) 105 | ``` 106 | 107 | ## API Reference 108 | 109 | ### Constructor 110 | 111 | ```python 112 | model = md.vl( 113 | model="path/to/model.bin", # For local inference 114 | api_key="your-api-key" # For cloud API access 115 | ) 116 | ``` 117 | 118 | ### Methods 119 | 120 | #### caption(image, length="normal", stream=False, settings=None) 121 | 122 | Generate a caption for an image. 123 | 124 | ```python 125 | result = model.caption(image) 126 | # or with streaming 127 | for chunk in model.caption(image, stream=True)["caption"]: 128 | print(chunk, end="") 129 | ``` 130 | 131 | #### query(image, question, stream=False, settings=None) 132 | 133 | Ask a question about an image. 134 | 135 | ```python 136 | result = model.query(image, "What's in this image?") 137 | # or with streaming 138 | for chunk in model.query(image, "What's in this image?", stream=True)["answer"]: 139 | print(chunk, end="") 140 | ``` 141 | 142 | #### detect(image, object) 143 | 144 | Detect and locate specific objects in an image. 145 | 146 | ```python 147 | result = model.detect(image, "car") 148 | ``` 149 | 150 | #### point(image, object) 151 | 152 | Get coordinates of specific objects in an image. 153 | 154 | ```python 155 | result = model.point(image, "person") 156 | ``` 157 | 158 | ### Input Types 159 | 160 | - Images can be provided as: 161 | - PIL.Image.Image objects 162 | - Encoded image objects (from model.encode_image()) 163 | 164 | ### Response Types 165 | 166 | All methods return typed dictionaries: 167 | 168 | - CaptionOutput: `{"caption": str | Generator}` 169 | - QueryOutput: `{"answer": str | Generator}` 170 | - DetectOutput: `{"objects": List[Region]}` 171 | - PointOutput: `{"points": List[Point]}` 172 | 173 | ## Performance Notes 174 | 175 | - Local inference currently only supports CPU execution 176 | - CUDA (GPU) and MPS (Apple Silicon) support coming soon 177 | - For optimal performance with GPU/MPS, use the PyTorch implementation for now 178 | 179 | ## Development Notes 180 | 181 | - Copy the torch implementation from the root moondream repo into the `torch` directory 182 | - Run `poetry install --extras "gpu"` to install the GPU dependencies 183 | - Run `poetry install --extras "cpu"` to install the CPU dependencies 184 | 185 | ## Links 186 | 187 | - [Website](https://moondream.ai/) 188 | - [Demo](https://moondream.ai/playground) -------------------------------------------------------------------------------- /clients/python/moondream/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .cloud_vl import CloudVL 4 | from .types import VLM 5 | 6 | DEFAULT_API_URL = "https://api.moondream.ai/v1" 7 | 8 | 9 | def vl( 10 | *, 11 | model: Optional[str] = None, 12 | api_key: Optional[str] = None, 13 | api_url: Optional[str] = None, 14 | ) -> VLM: 15 | if model: 16 | model_filetype = model.split(".")[-1] 17 | if model_filetype == "safetensors": 18 | from .torch_vl import TorchVL 19 | 20 | return TorchVL(model=model) 21 | elif model_filetype == "mf": 22 | from .onnx_vl import OnnxVL 23 | 24 | return OnnxVL.from_path(model) 25 | 26 | raise ValueError( 27 | "Unsupported model filetype. Please use a .safetensors model for GPU use or .mf model for CPU use." 28 | ) 29 | 30 | if api_key: 31 | if not api_url: 32 | api_url = DEFAULT_API_URL 33 | 34 | return CloudVL(api_key=api_key, api_url=api_url) 35 | 36 | if api_url and api_url == DEFAULT_API_URL: 37 | if not api_key: 38 | raise ValueError("An api_key is required for cloud inference.") 39 | 40 | return CloudVL(api_url=api_url) 41 | 42 | raise ValueError("At least one of `model`, `api_key`, or `api_url` is required.") 43 | -------------------------------------------------------------------------------- /clients/python/moondream/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from http import server 4 | 5 | from .onnx_vl import OnnxVL 6 | from .server import MoondreamHandler 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="Moondream CLI") 11 | subparsers = parser.add_subparsers(dest="command", help="Command to run") 12 | 13 | # Server command 14 | server_parser = subparsers.add_parser("serve", help="Start the Moondream server") 15 | server_parser.add_argument("--model", type=str, help="Path to the model file") 16 | server_parser.add_argument( 17 | "--host", type=str, default="localhost", help="Host to bind to" 18 | ) 19 | server_parser.add_argument( 20 | "--port", type=int, default=3475, help="Port to listen on" 21 | ) 22 | 23 | args = parser.parse_args() 24 | 25 | if args.command == "serve": 26 | if args.model: 27 | model = OnnxVL.from_path(args.model) 28 | else: 29 | parser.error("Model path is required") 30 | 31 | MoondreamHandler.model = model 32 | server_address = (args.host, args.port) 33 | try: 34 | httpd = server.HTTPServer(server_address, MoondreamHandler) 35 | print(f"Starting Moondream server on http://{args.host}:{args.port}") 36 | httpd.serve_forever() 37 | except KeyboardInterrupt: 38 | print("\nShutting down server...") 39 | httpd.server_close() 40 | except Exception as e: 41 | print(f"Error: {e}", file=sys.stderr) 42 | sys.exit(1) 43 | else: 44 | parser.print_help() 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /clients/python/moondream/cloud_vl.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import urllib.request 4 | from io import BytesIO 5 | from typing import Literal, Optional, Union 6 | 7 | from PIL import Image 8 | 9 | from .types import ( 10 | VLM, 11 | Base64EncodedImage, 12 | CaptionOutput, 13 | DetectOutput, 14 | EncodedImage, 15 | PointOutput, 16 | QueryOutput, 17 | SamplingSettings, 18 | ) 19 | from .version import __version__ 20 | 21 | 22 | class CloudVL(VLM): 23 | def __init__( 24 | self, 25 | *, 26 | api_url: str = "https://api.moondream.ai/v1", 27 | api_key: Optional[str] = None, 28 | ): 29 | self.api_key = api_key 30 | self.api_url = api_url 31 | 32 | def encode_image( 33 | self, image: Union[Image.Image, EncodedImage] 34 | ) -> Base64EncodedImage: 35 | if isinstance(image, EncodedImage): 36 | assert type(image) == Base64EncodedImage 37 | return image 38 | try: 39 | width, height = image.size 40 | max_size = 768 41 | scale = max_size / max(width, height) 42 | if scale < 1: 43 | new_size = (int(width * scale), int(height * scale)) 44 | image = image.resize(new_size, Image.Resampling.LANCZOS) 45 | 46 | if image.mode != "RGB": 47 | image = image.convert("RGB") 48 | buffered = BytesIO() 49 | image.save(buffered, format="JPEG", quality=95) 50 | img_str = base64.b64encode(buffered.getvalue()).decode() 51 | return Base64EncodedImage(image_url=f"data:image/jpeg;base64,{img_str}") 52 | except Exception as e: 53 | raise ValueError("Failed to convert image to JPEG.") from e 54 | 55 | def _stream_response(self, req): 56 | """Helper function to stream response chunks from the API.""" 57 | with urllib.request.urlopen(req) as response: 58 | for line in response: 59 | if not line: 60 | continue 61 | line = line.decode("utf-8") 62 | if line.startswith("data: "): 63 | try: 64 | data = json.loads(line[6:]) 65 | if "chunk" in data: 66 | yield data["chunk"] 67 | if data.get("completed"): 68 | break 69 | except json.JSONDecodeError as e: 70 | raise ValueError( 71 | "Failed to parse JSON response from server." 72 | ) from e 73 | 74 | def caption( 75 | self, 76 | image: Union[Image.Image, EncodedImage], 77 | length: Literal["normal", "short"] = "normal", 78 | stream: bool = False, 79 | settings: Optional[SamplingSettings] = None, 80 | ) -> CaptionOutput: 81 | encoded_image = self.encode_image(image) 82 | payload = { 83 | "image_url": encoded_image.image_url, 84 | "length": length, 85 | "stream": stream, 86 | } 87 | 88 | data = json.dumps(payload).encode("utf-8") 89 | headers = { 90 | "Content-Type": "application/json", 91 | "User-Agent": f"moondream-python/{__version__}", 92 | } 93 | if self.api_key: 94 | headers["X-Moondream-Auth"] = self.api_key 95 | req = urllib.request.Request( 96 | f"{self.api_url}/caption", 97 | data=data, 98 | headers=headers, 99 | ) 100 | 101 | def generator(): 102 | for chunk in self._stream_response(req): 103 | yield chunk 104 | 105 | if stream: 106 | return {"caption": generator()} 107 | 108 | with urllib.request.urlopen(req) as response: 109 | result = json.loads(response.read().decode("utf-8")) 110 | return {"caption": result["caption"]} 111 | 112 | def query( 113 | self, 114 | image: Union[Image.Image, EncodedImage], 115 | question: str, 116 | stream: bool = False, 117 | settings: Optional[SamplingSettings] = None, 118 | ) -> QueryOutput: 119 | encoded_image = self.encode_image(image) 120 | payload = { 121 | "image_url": encoded_image.image_url, 122 | "question": question, 123 | "stream": stream, 124 | # TODO: Pass sampling settings like max_tokens to the API. 125 | } 126 | 127 | data = json.dumps(payload).encode("utf-8") 128 | headers = { 129 | "Content-Type": "application/json", 130 | "User-Agent": f"moondream-python/{__version__}", 131 | } 132 | if self.api_key: 133 | headers["X-Moondream-Auth"] = self.api_key 134 | req = urllib.request.Request( 135 | f"{self.api_url}/query", 136 | data=data, 137 | headers=headers, 138 | ) 139 | 140 | if stream: 141 | return {"answer": self._stream_response(req)} 142 | 143 | with urllib.request.urlopen(req) as response: 144 | result = json.loads(response.read().decode("utf-8")) 145 | return {"answer": result["answer"]} 146 | 147 | def detect( 148 | self, 149 | image: Union[Image.Image, EncodedImage], 150 | object: str, 151 | ) -> DetectOutput: 152 | encoded_image = self.encode_image(image) 153 | payload = {"image_url": encoded_image.image_url, "object": object} 154 | 155 | data = json.dumps(payload).encode("utf-8") 156 | headers = { 157 | "Content-Type": "application/json", 158 | "User-Agent": f"moondream-python/{__version__}", 159 | } 160 | if self.api_key: 161 | headers["X-Moondream-Auth"] = self.api_key 162 | req = urllib.request.Request( 163 | f"{self.api_url}/detect", 164 | data=data, 165 | headers=headers, 166 | ) 167 | 168 | with urllib.request.urlopen(req) as response: 169 | result = json.loads(response.read().decode("utf-8")) 170 | return {"objects": result["objects"]} 171 | 172 | def point( 173 | self, 174 | image: Union[Image.Image, EncodedImage], 175 | object: str, 176 | ) -> PointOutput: 177 | encoded_image = self.encode_image(image) 178 | payload = {"image_url": encoded_image.image_url, "object": object} 179 | 180 | data = json.dumps(payload).encode("utf-8") 181 | headers = { 182 | "Content-Type": "application/json", 183 | "User-Agent": f"moondream-python/{__version__}", 184 | } 185 | if self.api_key: 186 | headers["X-Moondream-Auth"] = self.api_key 187 | req = urllib.request.Request( 188 | f"{self.api_url}/point", 189 | data=data, 190 | headers=headers, 191 | ) 192 | 193 | with urllib.request.urlopen(req) as response: 194 | result = json.loads(response.read().decode("utf-8")) 195 | return {"points": result["points"]} 196 | -------------------------------------------------------------------------------- /clients/python/moondream/moonfile.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import gzip 3 | from typing import BinaryIO, Tuple, Iterator, Union 4 | 5 | MOON_MAGIC = b"MOON" 6 | MOON_VERSION = 1 7 | 8 | 9 | class MoonReader: 10 | def __init__(self, input_path: str): 11 | self.input_path = input_path 12 | 13 | def _get_file_handle(self) -> Union[BinaryIO, gzip.GzipFile]: 14 | """Returns appropriate file handle based on extension""" 15 | if self.input_path.endswith(".gz"): 16 | return gzip.open(self.input_path, "rb") 17 | return open(self.input_path, "rb") 18 | 19 | def _validate_header(self, f: Union[BinaryIO, gzip.GzipFile]) -> None: 20 | """Validate magic bytes and version""" 21 | magic = f.read(4) 22 | if magic != MOON_MAGIC: 23 | raise ValueError(f"Invalid magic bytes: {magic}") 24 | 25 | version = struct.unpack("!B", f.read(1))[0] 26 | if version != MOON_VERSION: 27 | raise ValueError(f"Unsupported version: {version}") 28 | 29 | def read_files(self) -> Iterator[Tuple[str, bytes]]: 30 | """Read and yield (filename, content) pairs from the archive""" 31 | with self._get_file_handle() as f: 32 | self._validate_header(f) 33 | 34 | while True: 35 | # Try to read filename length 36 | filename_len_bytes = f.read(4) 37 | if not filename_len_bytes: 38 | break # End of file 39 | 40 | filename_len = struct.unpack("!I", filename_len_bytes)[0] 41 | 42 | # Read filename 43 | filename = f.read(filename_len).decode("utf-8") 44 | 45 | # Read content length and content 46 | content_len = struct.unpack("!Q", f.read(8))[0] 47 | content = f.read(content_len) 48 | 49 | yield filename, content 50 | 51 | 52 | def unpack(input_path: str) -> Iterator[Tuple[str, bytes]]: 53 | """Unpack a .mf file""" 54 | reader = MoonReader(input_path) 55 | for filename, content in reader.read_files(): 56 | yield filename, content 57 | -------------------------------------------------------------------------------- /clients/python/moondream/preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def im_resize( 8 | image: Image.Image, 9 | size: Tuple[int, int], 10 | resample: int = Image.Resampling.BICUBIC, 11 | ) -> Image.Image: 12 | return image.resize(size, resample=resample) 13 | 14 | 15 | def adaptive_avg_pool2d(x, output_size): 16 | """Applies 2D adaptive average pooling over an input signal. 17 | 18 | Resizes input to a target size by averaging values in local neighborhoods. 19 | The neighborhoods are computed to evenly cover the input image while 20 | maintaining approximately equal size. Similar to PyTorch's 21 | adaptive_avg_pool2d but expects input shape (H,W,C) rather than (N,C,H,W). 22 | 23 | Args: 24 | x: Input tensor of shape (height, width, channels) 25 | output_size: Target output size. Can be: 26 | - Single integer for square output (size, size) 27 | - Tuple of two ints (out_height, out_width) 28 | 29 | Returns: 30 | Tensor of shape (out_height, out_width, channels) 31 | 32 | Example: 33 | >>> img = np.random.randn(32, 32, 3) # 32x32 RGB image 34 | >>> pooled = adaptive_avg_pool2d(img, (7, 7)) # Resize to 7x7 35 | >>> pooled.shape 36 | (7, 7, 3) 37 | """ 38 | height, width, channels = x.shape 39 | 40 | if isinstance(output_size, int): 41 | output_size = (output_size, output_size) 42 | out_height, out_width = output_size 43 | 44 | stride_h = height // out_height 45 | stride_w = width // out_width 46 | kernel_h = height - (out_height - 1) * stride_h 47 | kernel_w = width - (out_width - 1) * stride_w 48 | 49 | output = np.zeros((out_height, out_width, channels), dtype=x.dtype) 50 | 51 | for i in range(out_height): 52 | for j in range(out_width): 53 | h_start = i * stride_h 54 | h_end = h_start + kernel_h 55 | w_start = j * stride_w 56 | w_end = w_start + kernel_w 57 | output[i, j, :] = x[h_start:h_end, w_start:w_end, :].mean(axis=(0, 1)) 58 | 59 | return output 60 | 61 | 62 | def normalize( 63 | image: np.ndarray, 64 | mean: List[float] = [0.5, 0.5, 0.5], 65 | std: List[float] = [0.5, 0.5, 0.5], 66 | ) -> np.ndarray: 67 | """ 68 | Normalize an image array. 69 | """ 70 | return (image - np.array(mean)) / np.array(std) 71 | 72 | 73 | def create_patches( 74 | image: Image.Image, image_patch_size=378 75 | ) -> Tuple[np.ndarray, Tuple[int, int]]: 76 | """ 77 | Split the given image into a variable number of patches depending upon its 78 | resolution. Returns the patches as a numpy array, and the selected patching 79 | template as a tuple of (rows, cols). 80 | """ 81 | image = image.convert("RGB") 82 | 83 | # Start off with the global patch. 84 | patches = [im_resize(image, (image_patch_size, image_patch_size))] 85 | 86 | # Find the closest resolution template. 87 | # 88 | # (1, 2) (2, 1) (2, 2) 89 | # +-------+-------+ +-----------+ +-------+-------+ 90 | # | 1 | 2 | | 1 | | 1 | 2 | 91 | # +-------+-------+ +-----------+ +-------+-------+ 92 | # | 2 | | 3 | 4 | 93 | # +-----------+ +-------+-------+ 94 | res_templates = [(1, 2), (2, 1), (2, 2)] 95 | 96 | im_width, im_height = image.size 97 | max_dim = max(im_width, im_height) 98 | if max_dim < image_patch_size * 1.4: 99 | # If the image is already small, we avoid adding an extra patch 100 | # here to avoid redundant computation in the vision encoder, and 101 | # instead copy the global patch features after running the vision 102 | # encoder, before passing it through the vision projection. 103 | selected_template = (1, 1) 104 | else: 105 | aspect_ratio = im_width / im_height 106 | selected_template = min( 107 | res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio) 108 | ) 109 | 110 | patch_width = im_width // selected_template[1] 111 | patch_height = im_height // selected_template[0] 112 | 113 | for row in range(selected_template[0]): 114 | for col in range(selected_template[1]): 115 | x_min = col * patch_width 116 | y_min = row * patch_height 117 | x_max = x_min + patch_width 118 | y_max = y_min + patch_height 119 | patches.append( 120 | im_resize( 121 | image.crop((x_min, y_min, x_max, y_max)), 122 | (image_patch_size, image_patch_size), 123 | ) 124 | ) 125 | 126 | return ( 127 | np.stack( 128 | [ 129 | normalize( 130 | (np.array(patch_img) / 255.0), 131 | mean=[0.5, 0.5, 0.5], 132 | std=[0.5, 0.5, 0.5], 133 | ).transpose(2, 0, 1) 134 | for patch_img in patches 135 | ], 136 | dtype=np.float16, 137 | ), 138 | selected_template, 139 | ) 140 | -------------------------------------------------------------------------------- /clients/python/moondream/torch_vl.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union 2 | 3 | import torch 4 | from PIL import Image 5 | 6 | from .torch.moondream import MoondreamConfig, MoondreamModel 7 | from .torch.weights import load_weights_into_model 8 | from .types import ( 9 | VLM, 10 | Base64EncodedImage, 11 | CaptionOutput, 12 | DetectOutput, 13 | EncodedImage, 14 | PointOutput, 15 | QueryOutput, 16 | SamplingSettings, 17 | ) 18 | from .version import __version__ 19 | 20 | 21 | class TorchVL(VLM): 22 | def __init__( 23 | self, 24 | *, 25 | model: str, 26 | ): 27 | config = MoondreamConfig() 28 | self.model = MoondreamModel(config) 29 | load_weights_into_model(model, self.model) 30 | self.model.eval() 31 | # Move model to the appropriate device 32 | if torch.cuda.is_available(): 33 | self.device = "cuda" 34 | elif torch.backends.mps.is_available(): 35 | self.device = "mps" 36 | else: 37 | self.device = "cpu" 38 | self.model.to(self.device) 39 | 40 | def encode_image( 41 | self, image: Union[Image.Image, EncodedImage] 42 | ) -> Base64EncodedImage: 43 | if isinstance(image, EncodedImage): 44 | assert type(image) == Base64EncodedImage 45 | return image 46 | 47 | if not self.model: 48 | raise ValueError("No local model loaded") 49 | 50 | return self.model.encode_image(image) 51 | 52 | def caption( 53 | self, 54 | image: Union[Image.Image, EncodedImage], 55 | length: Literal["normal", "short"] = "normal", 56 | stream: bool = False, 57 | settings: Optional[SamplingSettings] = None, 58 | ) -> CaptionOutput: 59 | if not self.model: 60 | raise ValueError("No local model loaded") 61 | 62 | encoded_image = ( 63 | self.model.encode_image(image) if isinstance(image, Image.Image) else image 64 | ) 65 | return self.model.caption( 66 | encoded_image, length=length, stream=stream, settings=settings 67 | ) 68 | 69 | def query( 70 | self, 71 | image: Union[Image.Image, EncodedImage], 72 | question: str, 73 | stream: bool = False, 74 | settings: Optional[SamplingSettings] = None, 75 | ) -> QueryOutput: 76 | if not self.model: 77 | raise ValueError("No local model loaded") 78 | 79 | encoded_image = ( 80 | self.model.encode_image(image) if isinstance(image, Image.Image) else image 81 | ) 82 | return self.model.query( 83 | encoded_image, question, stream=stream, settings=settings 84 | ) 85 | 86 | def detect( 87 | self, 88 | image: Union[Image.Image, EncodedImage], 89 | object: str, 90 | ) -> DetectOutput: 91 | if not self.model: 92 | raise ValueError("No local model loaded") 93 | 94 | encoded_image = ( 95 | self.model.encode_image(image) if isinstance(image, Image.Image) else image 96 | ) 97 | return self.model.detect(encoded_image, object) 98 | 99 | def point( 100 | self, 101 | image: Union[Image.Image, EncodedImage], 102 | object: str, 103 | ) -> PointOutput: 104 | if not self.model: 105 | raise ValueError("No local model loaded") 106 | 107 | encoded_image = ( 108 | self.model.encode_image(image) if isinstance(image, Image.Image) else image 109 | ) 110 | return self.model.point(encoded_image, object) 111 | -------------------------------------------------------------------------------- /clients/python/moondream/types.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | from PIL import Image 4 | from dataclasses import dataclass 5 | from typing import Generator, List, TypedDict, Union, Optional, Literal 6 | 7 | 8 | @dataclass 9 | class EncodedImage(ABC): 10 | pass 11 | 12 | 13 | @dataclass 14 | class OnnxEncodedImage(EncodedImage): 15 | pos: int 16 | kv_cache: np.ndarray 17 | 18 | 19 | @dataclass 20 | class Base64EncodedImage(EncodedImage): 21 | image_url: str 22 | 23 | 24 | SamplingSettings = TypedDict( 25 | "SamplingSettings", 26 | {"max_tokens": int}, 27 | total=False, 28 | ) 29 | 30 | CaptionOutput = TypedDict( 31 | "CaptionOutput", {"caption": Union[str, Generator[str, None, None]]} 32 | ) 33 | 34 | QueryOutput = TypedDict( 35 | "QueryOutput", {"answer": Union[str, Generator[str, None, None]]} 36 | ) 37 | 38 | Region = TypedDict( 39 | "Region", {"x_min": float, "y_min": float, "x_max": int, "y_max": float} 40 | ) 41 | DetectOutput = TypedDict("DetectOutput", {"objects": List[Region]}) 42 | 43 | Point = TypedDict("Point", {"x": float, "y": float}) 44 | PointOutput = TypedDict("PointOutput", {"points": List[Point]}) 45 | 46 | 47 | class VLM(ABC): 48 | @abstractmethod 49 | def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage: 50 | """ 51 | Preprocess the image by running it through the model. Only supported for local 52 | inference. 53 | 54 | This method is useful if the user wants to make multiple queries with the same image. 55 | The output is not guaranteed to be backward-compatible across version updates, 56 | and should not be persisted out of band. 57 | 58 | Args: 59 | image (Image.Image): The input image to be encoded. 60 | 61 | Returns: 62 | The encoded representation of the image. 63 | """ 64 | 65 | @abstractmethod 66 | def caption( 67 | self, 68 | image: Union[Image.Image, EncodedImage], 69 | length: Literal["normal", "short"] = "normal", 70 | stream: bool = False, 71 | settings: Optional[SamplingSettings] = None, 72 | ) -> CaptionOutput: 73 | """ 74 | Generate a caption for the input image. 75 | 76 | Args: 77 | image (Union[Image.Image, EncodedImage]): The input image to be captioned. 78 | length (str): Length of caption to generate. Can be "normal" or "short". 79 | Defaults to "normal". 80 | stream (bool): If True, returns a generator that streams the output tokens. 81 | Defaults to False. 82 | settings (Optional[SamplingSettings]): Optional settings for the caption 83 | generation. If not provided, default settings will be used. 84 | 85 | Returns: 86 | CaptionOutput: A dictionary containing the 'caption' field with either a string 87 | or generator that yields strings for the caption. 88 | """ 89 | 90 | @abstractmethod 91 | def query( 92 | self, 93 | image: Union[Image.Image, EncodedImage], 94 | question: str, 95 | stream: bool = False, 96 | settings: Optional[SamplingSettings] = None, 97 | ) -> QueryOutput: 98 | """ 99 | Generate an answer to the input question about the input image. 100 | 101 | Args: 102 | image (Union[Image.Image, EncodedImage]): The input image to be queried. 103 | question (str): The question to be answered. 104 | stream (bool): If True, returns a generator that streams the output tokens. 105 | (default: False) 106 | settings (Optional[SamplingSettings]): Optional settings for the query 107 | generation. 108 | 109 | Returns: 110 | QueryOutput: A dictionary containing the 'answer' field with either a string 111 | or generator that yields strings for the response. 112 | """ 113 | 114 | @abstractmethod 115 | def detect( 116 | self, 117 | image: Union[Image.Image, EncodedImage], 118 | object: str, 119 | ) -> DetectOutput: 120 | """ 121 | Detect and localize the specified object in the input image. 122 | 123 | Args: 124 | image (Union[Image.Image, EncodedImage]): The input image to be analyzed. 125 | object (str): The object to be detected in the image. 126 | 127 | Returns: 128 | DetectOutput: A dictionary containing: 129 | 'objects' (List[Region]): List of detected object regions, where each 130 | Region has: 131 | - x_min (float): Left boundary of detection box 132 | - y_min (float): Top boundary of detection box 133 | - x_max (float): Right boundary of detection box 134 | - y_max (float): Bottom boundary of detection box 135 | """ 136 | 137 | @abstractmethod 138 | def point( 139 | self, 140 | image: Union[Image.Image, EncodedImage], 141 | object: str, 142 | ) -> PointOutput: 143 | """ 144 | Points out all instances of the given object in the input image. 145 | 146 | Args: 147 | image (Union[Image.Image, EncodedImage]): The input image to be analyzed for 148 | pointing out objects. 149 | object (str): The object type to be pointed out in the image. 150 | 151 | Returns: 152 | PointOutput: A dictionary containing: 153 | 'points' (List[Point]): List of detected points, where each Point has: 154 | - x (float): X coordinate of the point marking the object 155 | - y (float): Y coordinate of the point marking the object 156 | 157 | This method identifies instances of the specified object in the image and returns 158 | a list of coordinates marking the location of each instance found. Each point 159 | indicates the approximate center or most relevant position for that object 160 | instance. 161 | """ 162 | -------------------------------------------------------------------------------- /clients/python/moondream/version.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | 3 | try: 4 | __version__ = metadata.version("moondream") 5 | except Exception: 6 | __version__ = "unknown" 7 | -------------------------------------------------------------------------------- /clients/python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ "poetry-core",] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "moondream" 7 | version = "0.0.2" 8 | description = "Python client library for moondream" 9 | authors = [ "M87 Labs ",] 10 | readme = "README.md" 11 | [[tool.poetry.packages]] 12 | include = "moondream" 13 | from = "." 14 | 15 | [tool.pyright] 16 | venvPath = "." 17 | venv = ".venv" 18 | reportMissingParameterType = false 19 | 20 | [tool.poetry.dependencies] 21 | python = "^3.10" 22 | pillow = "^10.4.0" 23 | numpy = "^2.1.2" 24 | onnxruntime = { version = ">=1.19.2", optional = true } 25 | tokenizers = { version = ">=0.20.1", optional = true } 26 | torch = { version = ">=2.5.0", optional = true } 27 | safetensors = { version = ">=0.4.2", optional = true } 28 | einops = { version = ">=0.7.0", optional = true } 29 | pyvips-binary = { version = ">=8.16.0", optional = true } 30 | pyvips = { version = ">=2.2.1", optional = true } 31 | 32 | [tool.poetry.extras] 33 | cpu = [ 34 | "onnxruntime", 35 | "tokenizers" 36 | ] 37 | gpu = [ 38 | "torch", 39 | "safetensors", 40 | "einops", 41 | "pyvips-binary", 42 | "pyvips", 43 | "tokenizers" 44 | ] 45 | 46 | [tool.poetry.scripts] 47 | moondream = "moondream.cli:main" 48 | -------------------------------------------------------------------------------- /clients/python/scripts/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import tracemalloc 4 | 5 | from PIL import Image, ImageDraw 6 | 7 | import moondream as md 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model-path", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | 14 | class Colors: 15 | HEADER = "\033[95m" # Purple 16 | BLUE = "\033[94m" 17 | GREEN = "\033[92m" 18 | YELLOW = "\033[93m" 19 | RED = "\033[91m" 20 | ENDC = "\033[0m" 21 | BOLD = "\033[1m" 22 | 23 | 24 | def format_memory(memory_mb): 25 | """Format memory size with appropriate unit""" 26 | return f"{memory_mb:.2f} MiB" 27 | 28 | 29 | def print_section(title): 30 | """Print a section header with dynamic padding to center the text""" 31 | total_width = 65 32 | text_length = len(title) + 2 # Add 2 for spaces around title 33 | total_padding = total_width - text_length 34 | left_padding = total_padding // 2 35 | right_padding = total_padding - left_padding 36 | print( 37 | f"\n{Colors.HEADER}{Colors.BOLD}{'-'*left_padding} {title} {'-'*right_padding}{Colors.ENDC}" 38 | ) 39 | 40 | 41 | def print_metric(label, value, color=Colors.BLUE): 42 | """Print a metric with consistent formatting""" 43 | print(f"| {color}{label}{Colors.ENDC}: {value}") 44 | 45 | 46 | def log_memory_and_time(operation_name, start_time, start_memory): 47 | """Log memory and time differences for an operation""" 48 | end_time = time.time() 49 | current_memory = get_memory_usage() 50 | time_diff = end_time - start_time 51 | memory_diff = current_memory - start_memory 52 | 53 | print("\nStats") 54 | print_metric("Time", f"{time_diff:.2f} seconds") 55 | print_metric("Memory usage", format_memory(current_memory)) 56 | 57 | # Color-code memory increase based on significance 58 | color = ( 59 | Colors.GREEN 60 | if memory_diff < 10 61 | else Colors.YELLOW if memory_diff < 100 else Colors.RED 62 | ) 63 | print_metric("Memory increase", format_memory(memory_diff), color) 64 | 65 | return end_time, current_memory 66 | 67 | 68 | def get_memory_usage(): 69 | """Get current memory usage in MiB""" 70 | current, peak = tracemalloc.get_traced_memory() 71 | return current / 1024 / 1024 72 | 73 | 74 | # Start tracking memory 75 | tracemalloc.start() 76 | 77 | # Initial memory measurement 78 | initial_memory = get_memory_usage() 79 | print_section("Initial State") 80 | print_metric("Initial memory usage", format_memory(initial_memory)) 81 | 82 | # Load image 83 | print_section("Image Loading") 84 | start_time = time.time() 85 | start_memory = get_memory_usage() 86 | image = Image.open("../../assets/demo-1.jpg") 87 | log_memory_and_time("Image Loading", start_time, start_memory) 88 | 89 | # Initialize model 90 | print_section("Model Initialization") 91 | start_time = time.time() 92 | start_memory = get_memory_usage() 93 | model = md.vl(model=args.model_path) 94 | log_memory_and_time("Model Initialization", start_time, start_memory) 95 | 96 | # Encode image 97 | print_section("Image Encoding") 98 | start_time = time.time() 99 | start_memory = get_memory_usage() 100 | encoded_image = model.encode_image(image) 101 | log_memory_and_time("Image Encoding", start_time, start_memory) 102 | 103 | # Generate caption 104 | print_section("Caption Generation") 105 | print(f"{Colors.BOLD}Caption:{Colors.ENDC}", end="", flush=True) 106 | start_time = time.time() 107 | start_memory = get_memory_usage() 108 | tokens = 0 109 | for tok in model.caption(encoded_image, stream=True)["caption"]: 110 | print(tok, end="", flush=True) 111 | tokens += 1 112 | print() 113 | end_time, end_memory = log_memory_and_time("Caption Stats", start_time, start_memory) 114 | print_metric("Token generation speed", f"{tokens / (end_time - start_time):.2f} tok/s") 115 | 116 | # Generate answer to question 117 | question = "How many people are in this image? Answer briefly." 118 | print_section("Question Answering") 119 | print(f"{Colors.BOLD}Question:{Colors.ENDC} {question}") 120 | print(f"{Colors.BOLD}Answer:{Colors.ENDC}", end="", flush=True) 121 | start_time = time.time() 122 | start_memory = get_memory_usage() 123 | tokens = 0 124 | for tok in model.query(encoded_image, question, stream=True)["answer"]: 125 | print(tok, end="", flush=True) 126 | tokens += 1 127 | print() 128 | end_time, end_memory = log_memory_and_time( 129 | "Question Answering Stats", start_time, start_memory 130 | ) 131 | print_metric("Token generation speed", f"{tokens / (end_time - start_time):.2f} tok/s") 132 | 133 | # Object detection 134 | object = "burger" 135 | print_section("Object Detection") 136 | print(f"{Colors.BOLD}Detect:{Colors.ENDC} {object}") 137 | start_time = time.time() 138 | start_memory = get_memory_usage() 139 | objects = model.detect(encoded_image, object)["objects"] 140 | print(len(objects), "detected") 141 | 142 | # Draw rectangles for each detected object 143 | width, height = image.size 144 | draw = ImageDraw.Draw(image) 145 | for obj in objects: 146 | x_min = int(obj["x_min"] * width) 147 | x_max = int(obj["x_max"] * width) 148 | y_min = int(obj["y_min"] * height) 149 | y_max = int(obj["y_max"] * height) 150 | draw.rectangle([(x_min, y_min), (x_max, y_max)], outline="green", width=2) 151 | image.save("detection_output.jpg") 152 | 153 | end_time, end_memory = log_memory_and_time( 154 | "Object Detection Stats", start_time, start_memory 155 | ) 156 | 157 | # Final summary 158 | print_section("Final Summary") 159 | final_memory = get_memory_usage() 160 | current, peak = tracemalloc.get_traced_memory() 161 | 162 | print_metric("Final memory usage", format_memory(final_memory)) 163 | print_metric("Total memory increase", format_memory(final_memory - initial_memory)) 164 | print_metric("Peak memory usage", format_memory(peak / 1024 / 1024)) 165 | 166 | # Stop tracking memory 167 | tracemalloc.stop() 168 | -------------------------------------------------------------------------------- /clients/python/scripts/test_cloud_parity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import moondream as md 4 | 5 | from PIL import Image 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--model-path", type=str, required=True) 9 | args = parser.parse_args() 10 | 11 | local = md.vl(model_path=args.model_path) 12 | cloud = md.vl(api_key=os.environ["MOONDREAM_API_KEY"]) 13 | 14 | image_path = "../../assets/demo-1.jpg" 15 | image = Image.open(image_path) 16 | 17 | print("# Pointing") 18 | object = "person" 19 | print("Local:", local.point(image, object)) 20 | print("Cloud:", cloud.point(image, object)) 21 | 22 | print("# Captioning") 23 | print("Local:", local.caption(image)) 24 | print("Cloud:", cloud.caption(image)) 25 | 26 | print("# Querying") 27 | question = "What is the character eating?" 28 | print("Local:", local.query(image, question)) 29 | print("Cloud:", cloud.query(image, question)) 30 | 31 | print("# Detecting") 32 | object_to_detect = "burger" 33 | print("Local:", local.detect(image, object_to_detect)) 34 | print("Cloud:", cloud.detect(image, object_to_detect)) 35 | 36 | print("# Streaming Caption") 37 | print("Local:") 38 | for tok in local.caption(image, stream=True)["caption"]: 39 | print(tok, end="", flush=True) 40 | print() 41 | print("Cloud:") 42 | for tok in cloud.caption(image, stream=True)["caption"]: 43 | print(tok, end="", flush=True) 44 | print() 45 | 46 | print("# Streaming Query") 47 | print("Local:") 48 | for tok in local.query(image, question, stream=True)["answer"]: 49 | print(tok, end="", flush=True) 50 | print() 51 | print("Cloud:") 52 | for tok in cloud.query(image, question, stream=True)["answer"]: 53 | print(tok, end="", flush=True) 54 | print() 55 | -------------------------------------------------------------------------------- /clients/python/scripts/test_local_server.py: -------------------------------------------------------------------------------- 1 | import moondream as md 2 | from PIL import Image 3 | 4 | base_url = "http://localhost:3475" 5 | local_server_client = md.vl(api_url=base_url) 6 | 7 | image_path = "../../assets/demo-1.jpg" 8 | image = Image.open(image_path) 9 | 10 | print("# Pointing") 11 | object = "person" 12 | print("Local Server:", local_server_client.point(image, object)) 13 | 14 | print("# Captioning") 15 | print("Local Server:", local_server_client.caption(image)) 16 | 17 | print("# Querying") 18 | question = "What is the character eating?" 19 | print("Local Server:", local_server_client.query(image, question)) 20 | 21 | print("# Detecting") 22 | object_to_detect = "burger" 23 | print("Local Server:", local_server_client.detect(image, object_to_detect)) 24 | 25 | print("# Captioning Stream") 26 | print("Local Server:") 27 | for tok in local_server_client.caption(image, stream=True)["caption"]: 28 | print(tok, end="", flush=True) 29 | print() 30 | 31 | print("# Querying Stream") 32 | print("Local Server:") 33 | for tok in local_server_client.query(image, question, stream=True)["answer"]: 34 | print(tok, end="", flush=True) 35 | print() 36 | -------------------------------------------------------------------------------- /clients/python/tests/test_api_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from PIL import Image 4 | import moondream as md 5 | 6 | TEST_IMAGE_PATH = os.path.join( 7 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), 8 | "assets", 9 | "demo-1.jpg", 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def model(): 15 | api_key = os.getenv("MOONDREAM_API_KEY") 16 | if not api_key: 17 | pytest.skip("MOONDREAM_API_KEY environment variable not set") 18 | return md.vl(api_key=api_key) 19 | 20 | 21 | @pytest.fixture 22 | def test_image(): 23 | return Image.open(TEST_IMAGE_PATH) 24 | 25 | 26 | def test_api_initialization(model): 27 | assert model is not None 28 | assert isinstance(model, md.cloud_vl.CloudVL) 29 | 30 | 31 | def test_image_captioning(model, test_image): 32 | # Test normal length caption 33 | result = model.caption(test_image, length="normal") 34 | assert "caption" in result 35 | assert isinstance(result["caption"], str) 36 | assert len(result["caption"]) > 0 37 | 38 | # Test short length caption 39 | result = model.caption(test_image, length="short") 40 | assert "caption" in result 41 | assert isinstance(result["caption"], str) 42 | assert len(result["caption"]) > 0 43 | 44 | 45 | def test_streaming_caption(model, test_image): 46 | result = model.caption(test_image, stream=True) 47 | assert "caption" in result 48 | 49 | # Test that we can iterate over the stream 50 | caption = "" 51 | for chunk in result["caption"]: 52 | assert isinstance(chunk, str) 53 | caption += chunk 54 | 55 | assert len(caption) > 0 56 | 57 | 58 | def test_query_answering(model, test_image): 59 | # Test basic question answering 60 | result = model.query(test_image, "What is in this image?") 61 | assert "answer" in result 62 | assert isinstance(result["answer"], str) 63 | assert len(result["answer"]) > 0 64 | 65 | 66 | def test_streaming_query(model, test_image): 67 | result = model.query(test_image, "What is in this image?", stream=True) 68 | assert "answer" in result 69 | 70 | # Test that we can iterate over the stream 71 | answer = "" 72 | for chunk in result["answer"]: 73 | assert isinstance(chunk, str) 74 | answer += chunk 75 | 76 | assert len(answer) > 0 77 | 78 | 79 | @pytest.mark.skip( 80 | reason="API handles invalid caption lengths differently than local model" 81 | ) 82 | def test_invalid_caption_length(model, test_image): 83 | with pytest.raises(ValueError, match="Model does not support caption length"): 84 | model.caption(test_image, length="invalid") 85 | 86 | 87 | def test_missing_api_key(): 88 | with pytest.raises(ValueError, match="An api_key is required for cloud inference"): 89 | md.vl(api_url="https://api.moondream.ai/v1") 90 | -------------------------------------------------------------------------------- /clients/python/tests/test_local_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from PIL import Image 4 | import moondream as md 5 | 6 | MODEL_URL = "https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-0_5b-int8.mf.gz" 7 | MODEL_PATH = os.path.join( 8 | os.path.dirname(__file__), "test_data", "moondream-0_5b-int8.mf" 9 | ) 10 | TEST_IMAGE_PATH = os.path.join( 11 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), 12 | "assets", 13 | "demo-1.jpg", 14 | ) 15 | 16 | 17 | @pytest.fixture(scope="session", autouse=True) 18 | def download_model(): 19 | if not os.path.exists(os.path.dirname(MODEL_PATH)): 20 | os.makedirs(os.path.dirname(MODEL_PATH)) 21 | 22 | if not os.path.exists(MODEL_PATH): 23 | import requests 24 | import gzip 25 | import io 26 | 27 | # Download the model file 28 | print("Downloading model file...") 29 | response = requests.get(MODEL_URL, stream=True) 30 | response.raise_for_status() 31 | 32 | # Read the gzipped content into memory 33 | content = response.content 34 | 35 | # Decompress and save 36 | print("Decompressing model file...") 37 | with gzip.open(io.BytesIO(content), "rb") as f_in: 38 | with open(MODEL_PATH, "wb") as f_out: 39 | f_out.write(f_in.read()) 40 | print("Model file ready.") 41 | 42 | 43 | @pytest.fixture 44 | def model(): 45 | return md.vl(model=MODEL_PATH) 46 | 47 | 48 | @pytest.fixture 49 | def test_image(): 50 | return Image.open(TEST_IMAGE_PATH) 51 | 52 | 53 | def test_model_initialization(model): 54 | assert model is not None 55 | assert hasattr(model, "vision_encoder") 56 | assert hasattr(model, "text_decoder") 57 | assert hasattr(model, "tokenizer") 58 | 59 | 60 | def test_image_encoding(model, test_image): 61 | encoded_image = model.encode_image(test_image) 62 | assert encoded_image is not None 63 | assert hasattr(encoded_image, "pos") 64 | assert hasattr(encoded_image, "kv_cache") 65 | 66 | 67 | def test_image_captioning(model, test_image): 68 | # Test normal length caption 69 | result = model.caption(test_image, length="normal") 70 | assert "caption" in result 71 | assert isinstance(result["caption"], str) 72 | assert len(result["caption"]) > 0 73 | 74 | # Test short length caption 75 | result = model.caption(test_image, length="short") 76 | assert "caption" in result 77 | assert isinstance(result["caption"], str) 78 | assert len(result["caption"]) > 0 79 | 80 | 81 | def test_streaming_caption(model, test_image): 82 | result = model.caption(test_image, stream=True) 83 | assert "caption" in result 84 | 85 | # Test that we can iterate over the stream 86 | caption = "" 87 | for chunk in result["caption"]: 88 | assert isinstance(chunk, str) 89 | caption += chunk 90 | 91 | assert len(caption) > 0 92 | 93 | 94 | def test_reuse_encoded_image(model, test_image): 95 | # Test that we can reuse an encoded image for multiple operations 96 | encoded_image = model.encode_image(test_image) 97 | 98 | # Generate two captions using the same encoded image 99 | result1 = model.caption(encoded_image) 100 | result2 = model.caption(encoded_image) 101 | 102 | assert result1["caption"] == result2["caption"] 103 | 104 | 105 | def test_invalid_caption_length(model, test_image): 106 | with pytest.raises(ValueError, match="Model does not support caption length"): 107 | model.caption(test_image, length="invalid") 108 | 109 | 110 | def test_invalid_model_path(): 111 | with pytest.raises( 112 | ValueError, 113 | match="Unsupported model filetype. Please use a .safetensors for GPU use or .mf for CPU use.", 114 | ): 115 | md.vl(model="invalid/path/to/model.bin") 116 | -------------------------------------------------------------------------------- /clients/python/tests/test_local_torch_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from PIL import Image 4 | import moondream as md 5 | 6 | MODEL_PATH = "/Users/caleb/Projects/moondream/moondream-playground-inf/src/ai-models/05/moondream-01-08-2025.safetensors" 7 | TEST_IMAGE_PATH = os.path.join( 8 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), 9 | "assets", 10 | "demo-1.jpg", 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def model(): 16 | return md.vl(model=MODEL_PATH) 17 | 18 | 19 | @pytest.fixture 20 | def test_image(): 21 | return Image.open(TEST_IMAGE_PATH) 22 | 23 | 24 | def test_image_captioning(model, test_image): 25 | # Test normal length caption 26 | result = model.caption(test_image, length="normal") 27 | assert "caption" in result 28 | assert isinstance(result["caption"], str) 29 | assert len(result["caption"]) > 0 30 | 31 | # Test short length caption 32 | result = model.caption(test_image, length="short") 33 | assert "caption" in result 34 | assert isinstance(result["caption"], str) 35 | assert len(result["caption"]) > 0 36 | 37 | # Test streaming caption 38 | result = model.caption(test_image, stream=True) 39 | assert "caption" in result 40 | 41 | # Test that we can iterate over the stream 42 | num_chunks = 0 43 | caption = "" 44 | for chunk in result["caption"]: 45 | assert isinstance(chunk, str) 46 | caption += chunk 47 | num_chunks += 1 48 | 49 | assert len(caption) > 0 50 | assert num_chunks > 1 51 | 52 | 53 | def test_query(model, test_image): 54 | result = model.query(test_image, "What is in this image?") 55 | assert "answer" in result 56 | assert isinstance(result["answer"], str) 57 | assert len(result["answer"]) > 0 58 | 59 | # Test streaming query 60 | result = model.query(test_image, "What is in this image?", stream=True) 61 | assert "answer" in result 62 | 63 | # Test that we can iterate over the stream 64 | num_chunks = 0 65 | answer = "" 66 | for chunk in result["answer"]: 67 | assert isinstance(chunk, str) 68 | answer += chunk 69 | num_chunks += 1 70 | 71 | assert num_chunks > 1 72 | assert len(answer) > 0 73 | 74 | 75 | def test_detect(model, test_image): 76 | result = model.detect(test_image, "person") 77 | assert "objects" in result 78 | assert isinstance(result["objects"], list) 79 | 80 | 81 | def test_point(model, test_image): 82 | result = model.point(test_image, "face") 83 | assert "points" in result 84 | assert isinstance(result["points"], list) 85 | -------------------------------------------------------------------------------- /gradio_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from threading import Thread 4 | 5 | import gradio as gr 6 | import torch 7 | from PIL import ImageDraw 8 | from torchvision.transforms.v2 import Resize 9 | from transformers import AutoTokenizer, TextIteratorStreamer 10 | 11 | from moondream.hf import LATEST_REVISION, Moondream, detect_device 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--cpu", action="store_true") 15 | args = parser.parse_args() 16 | 17 | if args.cpu: 18 | device = torch.device("cpu") 19 | dtype = torch.float32 20 | else: 21 | device, dtype = detect_device() 22 | if device != torch.device("cpu"): 23 | print("Using device:", device) 24 | print("If you run into issues, pass the `--cpu` flag to this script.") 25 | print() 26 | 27 | model_id = "vikhyatk/moondream2" 28 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) 29 | moondream = Moondream.from_pretrained( 30 | model_id, revision=LATEST_REVISION, torch_dtype=dtype 31 | ).to(device=device) 32 | moondream.eval() 33 | 34 | 35 | def answer_question(img, prompt): 36 | image_embeds = moondream.encode_image(img) 37 | streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) 38 | thread = Thread( 39 | target=moondream.answer_question, 40 | kwargs={ 41 | "image_embeds": image_embeds, 42 | "question": prompt, 43 | "tokenizer": tokenizer, 44 | "streamer": streamer, 45 | }, 46 | ) 47 | thread.start() 48 | 49 | buffer = "" 50 | for new_text in streamer: 51 | buffer += new_text 52 | yield buffer 53 | 54 | 55 | def extract_floats(text): 56 | # Regular expression to match an array of four floating point numbers 57 | pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]" 58 | match = re.search(pattern, text) 59 | if match: 60 | # Extract the numbers and convert them to floats 61 | return [float(num) for num in match.groups()] 62 | return None # Return None if no match is found 63 | 64 | 65 | def extract_bbox(text): 66 | bbox = None 67 | if extract_floats(text) is not None: 68 | x1, y1, x2, y2 = extract_floats(text) 69 | bbox = (x1, y1, x2, y2) 70 | return bbox 71 | 72 | 73 | def process_answer(img, answer): 74 | if extract_bbox(answer) is not None: 75 | x1, y1, x2, y2 = extract_bbox(answer) 76 | draw_image = Resize(768)(img) 77 | width, height = draw_image.size 78 | x1, x2 = int(x1 * width), int(x2 * width) 79 | y1, y2 = int(y1 * height), int(y2 * height) 80 | bbox = (x1, y1, x2, y2) 81 | ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3) 82 | return gr.update(visible=True, value=draw_image) 83 | 84 | return gr.update(visible=False, value=None) 85 | 86 | 87 | with gr.Blocks() as demo: 88 | gr.Markdown( 89 | """ 90 | # 🌔 moondream 91 | """ 92 | ) 93 | with gr.Row(): 94 | prompt = gr.Textbox(label="Input Prompt", value="Describe this image.", scale=4) 95 | submit = gr.Button("Submit") 96 | with gr.Row(): 97 | img = gr.Image(type="pil", label="Upload an Image") 98 | with gr.Column(): 99 | output = gr.Markdown(label="Response") 100 | ann = gr.Image(visible=False, label="Annotated Image") 101 | 102 | submit.click(answer_question, [img, prompt], output) 103 | prompt.submit(answer_question, [img, prompt], output) 104 | output.change(process_answer, [img, output], ann, show_progress=False) 105 | 106 | demo.queue().launch(debug=True) 107 | -------------------------------------------------------------------------------- /moondream/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/moondream/__init__.py -------------------------------------------------------------------------------- /moondream/config/config_md05.json: -------------------------------------------------------------------------------- 1 | { 2 | "text": { 3 | "dim": 1024, 4 | "ff_dim": 4096, 5 | "n_layers": 24, 6 | "vocab_size": 51200, 7 | "max_context": 2048, 8 | "n_heads": 16, 9 | "prefix_attn": 730 10 | }, 11 | "vision": { 12 | "enc_dim": 720, 13 | "enc_patch_size": 14, 14 | "enc_n_layers": 27, 15 | "enc_ff_dim": 2690, 16 | "enc_n_heads": 10, 17 | "proj_out_dim": 1024, 18 | "crop_size": 378, 19 | "in_channels": 3, 20 | "max_crops": 12, 21 | "overlap_margin": 4, 22 | "proj_inner_dim": 8192 23 | }, 24 | "region": { 25 | "dim": 1024, 26 | "coord_feat_dim": 256, 27 | "coord_out_dim": 1024, 28 | "size_feat_dim": 512, 29 | "size_out_dim": 2048, 30 | "inner_dim": 8192 31 | }, 32 | "tokenizer": { 33 | "bos_id": 50256, 34 | "eos_id": 50256, 35 | "templates": { 36 | "caption": { 37 | "short": [ 38 | 198, 39 | 198, 40 | 16438, 41 | 8305, 42 | 25 43 | ], 44 | "normal": [ 45 | 198, 46 | 198, 47 | 24334, 48 | 1159, 49 | 25 50 | ] 51 | }, 52 | "query": { 53 | "prefix": [ 54 | 198, 55 | 198, 56 | 24361, 57 | 25 58 | ], 59 | "suffix": [ 60 | 198, 61 | 198, 62 | 33706, 63 | 25 64 | ] 65 | }, 66 | "detect": { 67 | "prefix": [ 68 | 198, 69 | 198, 70 | 47504, 71 | 25 72 | ], 73 | "suffix": [ 74 | 628 75 | ] 76 | }, 77 | "point": { 78 | "prefix": [ 79 | 198, 80 | 198, 81 | 12727, 82 | 25 83 | ], 84 | "suffix": [ 85 | 628 86 | ] 87 | } 88 | } 89 | } 90 | } -------------------------------------------------------------------------------- /moondream/config/config_md2.json: -------------------------------------------------------------------------------- 1 | { 2 | "text": { 3 | "dim": 2048, 4 | "ff_dim": 8192, 5 | "n_layers": 24, 6 | "vocab_size": 51200, 7 | "max_context": 2048, 8 | "n_heads": 32, 9 | "prefix_attn": 730 10 | }, 11 | "vision": { 12 | "enc_dim": 1152, 13 | "enc_patch_size": 14, 14 | "enc_n_layers": 27, 15 | "enc_ff_dim": 4304, 16 | "enc_n_heads": 16, 17 | "proj_out_dim": 2048, 18 | "crop_size": 378, 19 | "in_channels": 3, 20 | "max_crops": 12, 21 | "overlap_margin": 4, 22 | "proj_inner_dim": 8192 23 | }, 24 | "region": { 25 | "dim": 2048, 26 | "coord_feat_dim": 256, 27 | "coord_out_dim": 1024, 28 | "size_feat_dim": 512, 29 | "size_out_dim": 2048, 30 | "inner_dim": 8192 31 | }, 32 | "tokenizer": { 33 | "bos_id": 50256, 34 | "eos_id": 50256, 35 | "templates": { 36 | "caption": { 37 | "short": [ 38 | 198, 39 | 198, 40 | 16438, 41 | 8305, 42 | 25 43 | ], 44 | "normal": [ 45 | 198, 46 | 198, 47 | 24334, 48 | 1159, 49 | 25 50 | ] 51 | }, 52 | "query": { 53 | "prefix": [ 54 | 198, 55 | 198, 56 | 24361, 57 | 25 58 | ], 59 | "suffix": [ 60 | 198, 61 | 198, 62 | 33706, 63 | 25 64 | ] 65 | }, 66 | "detect": { 67 | "prefix": [ 68 | 198, 69 | 198, 70 | 47504, 71 | 25 72 | ], 73 | "suffix": [ 74 | 628 75 | ] 76 | }, 77 | "point": { 78 | "prefix": [ 79 | 198, 80 | 198, 81 | 12727, 82 | 25 83 | ], 84 | "suffix": [ 85 | 628 86 | ] 87 | } 88 | } 89 | } 90 | } -------------------------------------------------------------------------------- /moondream/eval/chartqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import torch 4 | 5 | from tqdm import tqdm 6 | import json 7 | 8 | from ..torch.config import MoondreamConfig 9 | from ..torch.moondream import MoondreamModel 10 | from ..torch.weights import load_weights_into_model 11 | 12 | PREFIX = "Analyze the chart carefully, consider both visual features and data values, and provide a precise answer without any additional explanation or formatting. " 13 | 14 | 15 | # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81 16 | def relaxed_correctness( 17 | target: str, prediction: str, max_relative_change: float = 0.05 18 | ) -> bool: 19 | """Calculates relaxed correctness. 20 | 21 | The correctness tolerates certain error ratio defined by max_relative_change. 22 | See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: 23 | “Following Methani et al. (2020), we use a relaxed accuracy measure for the 24 | numeric answers to allow a minor inaccuracy that may result from the automatic 25 | data extraction process. We consider an answer to be correct if it is within 26 | 5% of the gold answer. For non-numeric answers, we still need an exact match 27 | to consider an answer to be correct.” 28 | 29 | Args: 30 | target: Target string. 31 | prediction: Predicted string. 32 | max_relative_change: Maximum relative change. 33 | 34 | Returns: 35 | Whether the prediction was correct given the specified tolerance. 36 | """ 37 | 38 | def _to_float(text): 39 | try: 40 | if text.endswith("%"): 41 | # Convert percentages to floats. 42 | return float(text.rstrip("%")) / 100.0 43 | else: 44 | return float(text) 45 | except ValueError: 46 | return None 47 | 48 | prediction = str(prediction) 49 | target = str(target) 50 | prediction_float = _to_float(prediction) 51 | target_float = _to_float(target) 52 | if prediction_float is not None and target_float: 53 | relative_change = abs(prediction_float - target_float) / abs(target_float) 54 | return relative_change <= max_relative_change 55 | else: 56 | return prediction == target 57 | 58 | 59 | def eval_chartqa(model, debug=False): 60 | dataset = datasets.load_dataset("vikhyatk/chartqa", split="test") 61 | 62 | correct = 0 63 | total = 0 64 | human_correct = 0 65 | human_total = 0 66 | results = [] 67 | 68 | for row in tqdm(dataset, disable=debug, desc="ChartQA"): 69 | image = row["image"] 70 | encoded_image = model.encode_image(image) 71 | 72 | result = [] 73 | for qa in row["qa"]: 74 | question = PREFIX + qa["question"] 75 | answer = qa["answer"] 76 | model_answer = model.query(encoded_image, question)["answer"] 77 | 78 | # Attempt to parse both answers into lists, otherwise 79 | try: 80 | answer_list = json.loads(answer) 81 | model_answer_list = json.loads(model_answer) 82 | if not ( 83 | isinstance(answer_list, list) 84 | and isinstance(model_answer_list, list) 85 | and len(answer_list) == len(model_answer_list) 86 | ): 87 | raise ValueError 88 | except: 89 | # If parsing fails or lengths are not equal, compare the strings directly instead 90 | answer_list = [answer] 91 | model_answer_list = [model_answer] 92 | 93 | total += 1 94 | if qa["source"] == "human": 95 | human_total += 1 96 | 97 | is_correct = False 98 | if all( 99 | relaxed_correctness( 100 | str(cur_answer).strip().lower(), 101 | str(cur_model_answer).strip().lower(), 102 | ) 103 | for cur_answer, cur_model_answer in zip(answer_list, model_answer_list) 104 | ): 105 | correct += 1 106 | if qa["source"] == "human": 107 | human_correct += 1 108 | is_correct = True 109 | if debug: 110 | print( 111 | f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}" 112 | ) 113 | print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}") 114 | print(f"Total Accuracy: {correct * 100 / total:.2f}") 115 | print("---------") 116 | result.append( 117 | { 118 | "question": question, 119 | "ground_truth": answer_list, 120 | "model_answer": model_answer_list, 121 | "is_correct": is_correct, 122 | "source": qa["source"], 123 | } 124 | ) 125 | results.append(result) 126 | 127 | return { 128 | "human_acc": human_correct * 100 / human_total, 129 | "total_acc": correct * 100 / total, 130 | "results": results, 131 | } 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("--model", type=str, required=True) 137 | parser.add_argument("--debug", action="store_true") 138 | args = parser.parse_args() 139 | 140 | if torch.cuda.is_available(): 141 | torch.set_default_device("cuda") 142 | elif torch.backends.mps.is_available(): 143 | torch.set_default_device("mps") 144 | 145 | config = MoondreamConfig() 146 | model = MoondreamModel(config) 147 | load_weights_into_model(args.model, model) 148 | model.compile() 149 | 150 | results = eval_chartqa(model, args.debug) 151 | print(f"Human Accuracy: {results['human_acc']:.2f}") 152 | print(f"Total Accuracy: {results['total_acc']:.2f}") 153 | -------------------------------------------------------------------------------- /moondream/eval/countbenchqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import torch 4 | 5 | from tqdm import tqdm 6 | 7 | from ..torch.config import MoondreamConfig 8 | from ..torch.moondream import MoondreamModel 9 | from ..torch.weights import load_weights_into_model 10 | 11 | PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. " 12 | 13 | 14 | def eval_countbenchqa(model, debug=False): 15 | dataset = datasets.load_dataset("vikhyatk/CountBenchQA", split="test") 16 | 17 | correct = 0 18 | total = 0 19 | results = [] 20 | 21 | for row in tqdm(dataset, disable=debug, desc="CountBenchQA"): 22 | image = row["image"] 23 | encoded_image = model.encode_image(image) 24 | 25 | question = PREFIX + row["question"] 26 | answer = str(row["number"]) 27 | model_answer = model.query(encoded_image, question)["answer"] 28 | is_correct = model_answer.strip().lower() == answer.strip().lower() 29 | 30 | results.append( 31 | { 32 | "question": question, 33 | "ground_truth": answer, 34 | "model_answer": model_answer, 35 | "is_correct": is_correct, 36 | } 37 | ) 38 | 39 | total += 1 40 | if is_correct: 41 | correct += 1 42 | elif debug: 43 | print(f"Question: {row['question']}") 44 | print(f"Answer: {answer}") 45 | print(f"Model Answer: {model_answer}") 46 | if debug: 47 | print(f"Correct: {correct}, Total: {total}") 48 | print(f"Accuracy: {correct * 100 / total:.2f}") 49 | print("---------") 50 | 51 | return { 52 | "acc": correct * 100 / total, 53 | "correct_count": correct, 54 | "total_count": total, 55 | "results": results, 56 | } 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--model", type=str, required=True) 62 | parser.add_argument("--debug", action="store_true") 63 | args = parser.parse_args() 64 | 65 | if torch.cuda.is_available(): 66 | torch.set_default_device("cuda") 67 | elif torch.backends.mps.is_available(): 68 | torch.set_default_device("mps") 69 | 70 | config = MoondreamConfig() 71 | model = MoondreamModel(config) 72 | load_weights_into_model(args.model, model) 73 | 74 | result = eval_countbenchqa(model, args.debug) 75 | 76 | print(f"Accuracy: {result['acc']:.2f}") 77 | print(f"Correct: {result['correct_count']}, Total: {result['total_count']}") 78 | -------------------------------------------------------------------------------- /moondream/eval/docvqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import editdistance 3 | from datasets import load_dataset 4 | from tqdm import tqdm 5 | import torch 6 | 7 | from ..torch.config import MoondreamConfig 8 | from ..torch.moondream import MoondreamModel 9 | from ..torch.weights import load_weights_into_model 10 | 11 | SUFFIX = " The answer should be a short text span taken verbatim from the document." 12 | 13 | 14 | def get_anls(s1, s2): 15 | s1 = s1.lower().strip() 16 | s2 = s2.lower().strip() 17 | iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2)) 18 | anls = iou if iou >= 0.5 else 0.0 19 | return anls 20 | 21 | 22 | def eval_docvqa(model, debug=False): 23 | docvqa_val = load_dataset("vikhyatk/docvqa-val", split="validation") 24 | 25 | scores = [] 26 | results = [] 27 | 28 | for row in tqdm(docvqa_val, disable=debug, desc="DocVQA"): 29 | image = row["image"] 30 | encoded_image = model.encode_image(image) 31 | 32 | result = [] 33 | for qa in row["qa"]: 34 | question = qa["question"] 35 | answers = qa["answers"] 36 | prompt = question + SUFFIX 37 | 38 | model_answer = model.query(encoded_image, prompt)["answer"] 39 | anls = max(get_anls(model_answer, gt) for gt in answers) 40 | scores.append(anls) 41 | result.append( 42 | { 43 | "question": question, 44 | "ground_truth": answers, 45 | "model_answer": model_answer, 46 | "anls": anls, 47 | } 48 | ) 49 | 50 | if debug: 51 | print(f"Question: {question}") 52 | print(f"Ground Truth: {answers}") 53 | print(f"Model Answer: {model_answer}") 54 | print(f"ANLS: {anls}") 55 | print(f"Current Average ANLS: {sum(scores) / len(scores):.4f}") 56 | print("---------") 57 | results.append(result) 58 | 59 | return { 60 | "anls": sum(scores) / len(scores), 61 | "results": results, 62 | } 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--model", type=str, required=True) 68 | parser.add_argument("--debug", action="store_true") 69 | args = parser.parse_args() 70 | 71 | if torch.cuda.is_available(): 72 | torch.set_default_device("cuda") 73 | elif torch.backends.mps.is_available(): 74 | torch.set_default_device("mps") 75 | 76 | config = MoondreamConfig() 77 | model = MoondreamModel(config) 78 | load_weights_into_model(args.model, model) 79 | model.compile() 80 | 81 | result = eval_docvqa(model, args.debug) 82 | 83 | print(f"ANLS: {result['anls']:.4f}") 84 | -------------------------------------------------------------------------------- /moondream/eval/eval_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from pprint import pprint 5 | 6 | from ..torch.config import MoondreamConfig 7 | from ..torch.moondream import MoondreamModel 8 | from ..torch.weights import load_weights_into_model 9 | 10 | from .countbenchqa import eval_countbenchqa 11 | from .pope import evaluate_pope 12 | from .realworldqa import eval_realworldqa 13 | from .chartqa import eval_chartqa 14 | from .textvqa import eval_textvqa 15 | from .docvqa import eval_docvqa 16 | from .mmstar import eval_mmstar 17 | from .coco_map import eval_coco_map 18 | from .naturalbench import eval_naturalbench 19 | from .tallyqa import eval_tallyqa 20 | 21 | 22 | def create_model(ckpt_path): 23 | config = MoondreamConfig() 24 | model = MoondreamModel(config) 25 | load_weights_into_model(ckpt_path, model) 26 | model.compile() 27 | return model 28 | 29 | 30 | def eval_all(model, skip=[]): 31 | evals = { 32 | "countbenchqa": eval_countbenchqa, 33 | "pope": evaluate_pope, 34 | "realworldqa": eval_realworldqa, 35 | "chartqa": eval_chartqa, 36 | "mmstar": eval_mmstar, 37 | "docvqa": eval_docvqa, 38 | "coco_map": eval_coco_map, 39 | "textvqa": eval_textvqa, 40 | "naturalbench": eval_naturalbench, 41 | "tallyqa": eval_tallyqa, 42 | } 43 | 44 | for b in skip: 45 | del evals[b] 46 | 47 | results = {} 48 | for name, eval_fn in evals.items(): 49 | results[name] = eval_fn(model) 50 | pprint({k: v for k, v in results[name].items() if k != "results"}) 51 | 52 | return results 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--model", type=str, required=True) 58 | args = parser.parse_args() 59 | 60 | if torch.cuda.is_available(): 61 | torch.set_default_device("cuda") 62 | elif torch.backends.mps.is_available(): 63 | torch.set_default_device("mps") 64 | 65 | model = create_model(args.model) 66 | eval_all(model) 67 | -------------------------------------------------------------------------------- /moondream/eval/gazefollow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | import math 4 | 5 | from tqdm import tqdm 6 | 7 | from ..torch.config import MoondreamConfig 8 | from ..torch.moondream import MoondreamModel 9 | from ..torch.weights import load_weights_into_model 10 | 11 | 12 | def eval_gazefollow(model, debug=False): 13 | dataset = datasets.load_dataset("vikhyatk/gazefollow", split="test") 14 | 15 | mean_l2_error = [] 16 | min_l2_error = [] 17 | total = 0 18 | 19 | for i, row in tqdm(enumerate(dataset), total=len(dataset)): 20 | heads = [] 21 | 22 | for gaze in row["gazes"]: 23 | head_bbox = gaze["head_bbox"] # xmin, ymin, xmax, ymax 24 | eye_coord = (gaze["eye"]["x"], gaze["eye"]["y"]) 25 | mean_target_gaze = (gaze["gaze"]["x"], gaze["gaze"]["y"]) 26 | 27 | # Check if a head already exists with the same approximate bbox. 28 | # If so, use that head instead of creating a new one. 29 | for head in heads: 30 | if ( 31 | abs(head["head_bbox"]["xmin"] - head_bbox["xmin"]) < 0.001 32 | and abs(head["head_bbox"]["xmax"] - head_bbox["xmax"]) < 0.001 33 | and abs(head["head_bbox"]["ymin"] - head_bbox["ymin"]) < 0.001 34 | and abs(head["head_bbox"]["ymax"] - head_bbox["ymax"]) < 0.001 35 | ): 36 | head["gazes"].append(mean_target_gaze) 37 | break 38 | else: 39 | heads.append( 40 | { 41 | "head_bbox": head_bbox, 42 | "eye_coord": eye_coord, 43 | "gazes": [mean_target_gaze], 44 | } 45 | ) 46 | 47 | for head in heads: 48 | pred_gaze = model.detect_gaze( 49 | row["image"], 50 | eye=head["eye_coord"], 51 | face={ 52 | "x_min": head["head_bbox"]["xmin"], 53 | "y_min": head["head_bbox"]["ymin"], 54 | "x_max": head["head_bbox"]["xmax"], 55 | "y_max": head["head_bbox"]["ymax"], 56 | }, 57 | unstable_settings={"force_detect": True}, 58 | )["gaze"] 59 | 60 | mean_target_gaze = ( 61 | sum(gaze[0] for gaze in head["gazes"]) / len(head["gazes"]), 62 | sum(gaze[1] for gaze in head["gazes"]) / len(head["gazes"]), 63 | ) 64 | mean_l2 = math.sqrt( 65 | (mean_target_gaze[0] - pred_gaze["x"]) ** 2 66 | + (mean_target_gaze[1] - pred_gaze["y"]) ** 2 67 | ) 68 | min_l2 = min( 69 | math.sqrt( 70 | (target_gaze[0] - pred_gaze["x"]) ** 2 71 | + (target_gaze[1] - pred_gaze["y"]) ** 2 72 | ) 73 | for target_gaze in head["gazes"] 74 | ) 75 | 76 | mean_l2_error.append(mean_l2) 77 | min_l2_error.append(min_l2) 78 | total += 1 79 | 80 | if i % 100 == 0 and debug: 81 | print("Mean L2 error:", sum(mean_l2_error) / total) 82 | print("Min L2 error:", sum(min_l2_error) / total) 83 | 84 | return { 85 | "mean_l2": sum(mean_l2_error) / total, 86 | "min_l2": sum(min_l2_error) / total, 87 | } 88 | 89 | 90 | if __name__ == "__main__": 91 | import argparse 92 | 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--model", type=str, required=True) 95 | 96 | parser.add_argument("--debug", action="store_true") 97 | args = parser.parse_args() 98 | 99 | if torch.cuda.is_available(): 100 | torch.set_default_device("cuda") 101 | elif torch.backends.mps.is_available(): 102 | torch.set_default_device("mps") 103 | 104 | config = MoondreamConfig() 105 | model = MoondreamModel(config) 106 | load_weights_into_model(args.model, model) 107 | 108 | results = eval_gazefollow(model, debug=args.debug) 109 | 110 | print(f"Mean L2 error: {results['mean_l2']:.4f}") 111 | print(f"Min L2 error: {results['min_l2']:.4f}") 112 | -------------------------------------------------------------------------------- /moondream/eval/mmstar.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | 4 | from tqdm import tqdm 5 | 6 | from ..torch.config import MoondreamConfig 7 | from ..torch.moondream import MoondreamModel 8 | from ..torch.weights import load_weights_into_model 9 | 10 | SUFFIX = " Please answer directly with only the letter of the correct option and nothing else." 11 | 12 | 13 | def eval_mmstar(model, debug=False): 14 | dataset = datasets.load_dataset("Lin-Chen/MMStar", split="val") 15 | 16 | correct = 0 17 | total = 0 18 | category_stats = {} 19 | results = [] 20 | 21 | for row in tqdm(dataset, disable=debug, desc="MMStar"): 22 | image = row["image"] 23 | question = row["question"] + SUFFIX 24 | answer = row["answer"] 25 | model_answer = model.query(image, question)["answer"] 26 | is_correct = model_answer.strip().lower() == answer.strip().lower() 27 | 28 | category = f"{row['category']} / {row['l2_category']}" 29 | if category not in category_stats: 30 | category_stats[category] = {"correct": 0, "total": 0} 31 | 32 | total += 1 33 | category_stats[category]["total"] += 1 34 | 35 | results.append( 36 | { 37 | "question": question, 38 | "ground_truth": answer, 39 | "model_answer": model_answer, 40 | "is_correct": is_correct, 41 | "category": category, 42 | } 43 | ) 44 | 45 | if is_correct: 46 | correct += 1 47 | category_stats[category]["correct"] += 1 48 | elif debug: 49 | print(f"Index: {row['index']}") 50 | print(f"Question: {row['question']}") 51 | print(f"Answer: {answer}") 52 | print(f"Model Answer: {model_answer}") 53 | if debug: 54 | print(f"Correct: {correct}, Total: {total}") 55 | print(f"Accuracy: {correct * 100 / total:.2f}") 56 | print("Results by category:") 57 | for category, stats in category_stats.items(): 58 | acc = stats["correct"] * 100 / stats["total"] 59 | print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%") 60 | print("---------") 61 | 62 | return { 63 | "acc": correct * 100 / total, 64 | "correct_count": correct, 65 | "total_count": total, 66 | "category_stats": category_stats, 67 | "results": results, 68 | } 69 | 70 | 71 | if __name__ == "__main__": 72 | import argparse 73 | 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--model", type=str, required=True) 76 | parser.add_argument("--debug", action="store_true") 77 | args = parser.parse_args() 78 | 79 | if torch.cuda.is_available(): 80 | torch.set_default_device("cuda") 81 | elif torch.backends.mps.is_available(): 82 | torch.set_default_device("mps") 83 | 84 | config = MoondreamConfig() 85 | model = MoondreamModel(config) 86 | load_weights_into_model(args.model, model) 87 | model.compile() 88 | 89 | result = eval_mmstar(model, args.debug) 90 | 91 | print(f"Correct: {result['correct_count']}, Total: {result['total_count']}") 92 | print(f"Accuracy: {result['acc']:.2f}") 93 | 94 | print("\nResults by category:") 95 | for category, stats in result["category_stats"].items(): 96 | acc = stats["correct"] * 100 / stats["total"] 97 | print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%") 98 | -------------------------------------------------------------------------------- /moondream/eval/naturalbench.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from tqdm import tqdm 3 | import torch 4 | 5 | from ..torch.config import MoondreamConfig 6 | from ..torch.moondream import MoondreamModel 7 | from ..torch.weights import load_weights_into_model 8 | 9 | 10 | def eval_naturalbench(model, debug=False): 11 | # Yes, the benchmark test set is stored in the 'train' split... 12 | dataset = load_dataset("BaiqiL/NaturalBench", split="train") 13 | 14 | acc = [] 15 | q_acc = [] 16 | i_acc = [] 17 | g_acc = [] 18 | 19 | for row in tqdm(dataset, disable=debug, desc="NaturalBench"): 20 | if row["Question_Type"] == "yes_no": 21 | suffix = " Answer yes or no." 22 | else: 23 | suffix = "" 24 | 25 | images = [row["Image_0"], row["Image_1"], row["Image_0"], row["Image_1"]] 26 | prompts = [ 27 | row["Question_0"] + suffix, 28 | row["Question_0"] + suffix, 29 | row["Question_1"] + suffix, 30 | row["Question_1"] + suffix, 31 | ] 32 | expected = [ 33 | row["Image_0_Question_0"].strip().lower(), 34 | row["Image_1_Question_0"].strip().lower(), 35 | row["Image_0_Question_1"].strip().lower(), 36 | row["Image_0_Question_1"].strip().lower(), 37 | ] 38 | 39 | answers = [] 40 | for img, prompt in zip(images, prompts): 41 | encoded_image = model.encode_image(img) 42 | answer = model.query(encoded_image, prompt)["answer"] 43 | answers.append(answer.strip().lower()) 44 | 45 | if debug: 46 | for i, (q, a, e) in enumerate(zip(prompts, answers, expected)): 47 | print(f"Q{i}: {q}") 48 | print(f"Model: {a}") 49 | print(f"Expected: {e}") 50 | print(f"Correct: {a == e}") 51 | print("---") 52 | 53 | acc.append(answers[0] == expected[0]) 54 | acc.append(answers[1] == expected[1]) 55 | acc.append(answers[2] == expected[2]) 56 | acc.append(answers[3] == expected[3]) 57 | 58 | i_acc.append(answers[0] == expected[0] and answers[2] == expected[2]) 59 | i_acc.append(answers[1] == expected[1] and answers[3] == expected[3]) 60 | 61 | q_acc.append(answers[0] == expected[0] and answers[1] == expected[1]) 62 | q_acc.append(answers[2] == expected[2] and answers[3] == expected[3]) 63 | 64 | g_acc.append( 65 | answers[0] == expected[0] 66 | and answers[1] == expected[1] 67 | and answers[2] == expected[2] 68 | and answers[3] == expected[3] 69 | ) 70 | 71 | if debug: 72 | print(f"Current Overall Accuracy: {sum(acc) / len(acc):.4f}") 73 | print(f"Current Image Accuracy: {sum(i_acc) / len(i_acc):.4f}") 74 | print(f"Current Question Accuracy: {sum(q_acc) / len(q_acc):.4f}") 75 | print(f"Current Group Accuracy: {sum(g_acc) / len(g_acc):.4f}") 76 | print("=========") 77 | 78 | return { 79 | "overall_acc": sum(acc) / len(acc), 80 | "image_acc": sum(i_acc) / len(i_acc), 81 | "question_acc": sum(q_acc) / len(q_acc), 82 | "group_acc": sum(g_acc) / len(g_acc), 83 | } 84 | 85 | 86 | if __name__ == "__main__": 87 | import argparse 88 | 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model", type=str, required=True) 91 | parser.add_argument("--debug", action="store_true") 92 | args = parser.parse_args() 93 | 94 | if torch.cuda.is_available(): 95 | torch.set_default_device("cuda") 96 | elif torch.backends.mps.is_available(): 97 | torch.set_default_device("mps") 98 | 99 | config = MoondreamConfig() 100 | model = MoondreamModel(config) 101 | load_weights_into_model(args.model, model) 102 | model.compile() 103 | 104 | results = eval_naturalbench(model, debug=args.debug) 105 | 106 | print(f"Overall Accuracy: {results['overall_acc']:.4f}") 107 | print(f"Image Accuracy: {results['image_acc']:.4f}") 108 | print(f"Question Accuracy: {results['question_acc']:.4f}") 109 | print(f"Group Accuracy: {results['group_acc']:.4f}") 110 | -------------------------------------------------------------------------------- /moondream/eval/pope.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | from tqdm import tqdm 4 | import torch 5 | 6 | from ..torch.config import MoondreamConfig 7 | from ..torch.moondream import MoondreamModel 8 | from ..torch.weights import load_weights_into_model 9 | 10 | 11 | def evaluate_pope(model, debug=False): 12 | pope_dataset = load_dataset("vikhyatk/POPE", split="test") 13 | 14 | stats = { 15 | "random": (0, 0), 16 | "popular": (0, 0), 17 | "adversarial": (0, 0), 18 | } 19 | 20 | for row in tqdm(pope_dataset, disable=debug, desc="POPE"): 21 | image = row["image"] 22 | encoded_image = model.encode_image(image) 23 | for split in ["adversarial", "popular", "random"]: 24 | for qa in row[split]: 25 | question = qa["question"] 26 | answer = qa["answer"] 27 | prompt = f"{question}\nAnswer yes or no." 28 | model_answer = model.query(encoded_image, prompt)["answer"].strip() 29 | 30 | if debug: 31 | print(f"Split: {split}") 32 | print(f"Question: {question}") 33 | print(f"Model: {model_answer}") 34 | print(f"Expected: {answer}") 35 | print(f"Correct: {model_answer.lower() == answer.lower()}") 36 | print("---") 37 | 38 | if model_answer.lower() == answer.lower(): 39 | stats[split] = (stats[split][0] + 1, stats[split][1] + 1) 40 | else: 41 | stats[split] = (stats[split][0], stats[split][1] + 1) 42 | 43 | if debug: 44 | for s in stats: 45 | if stats[s][1] > 0: 46 | print( 47 | f"{s.capitalize()}: {stats[s][0]}/{stats[s][1]} = {stats[s][0] * 100.0 / stats[s][1]:.2f}%" 48 | ) 49 | print("=========") 50 | 51 | return { 52 | "random": stats["random"][0] * 100.0 / stats["random"][1], 53 | "popular": stats["popular"][0] * 100.0 / stats["popular"][1], 54 | "adversarial": stats["adversarial"][0] * 100.0 / stats["adversarial"][1], 55 | } 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--model", type=str, required=True) 61 | parser.add_argument("--debug", action="store_true") 62 | args = parser.parse_args() 63 | 64 | if torch.cuda.is_available(): 65 | torch.set_default_device("cuda") 66 | elif torch.backends.mps.is_available(): 67 | torch.set_default_device("mps") 68 | 69 | config = MoondreamConfig() 70 | model = MoondreamModel(config) 71 | load_weights_into_model(args.model, model) 72 | 73 | result = evaluate_pope(model, args.debug) 74 | 75 | print(f"Random Accuracy: {result['random']:.2f}") 76 | print(f"Popular Accuracy: {result['popular']:.2f}") 77 | print(f"Adversarial Accuracy: {result['adversarial']:.2f}") 78 | -------------------------------------------------------------------------------- /moondream/eval/realworldqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import torch 4 | 5 | from tqdm import tqdm 6 | 7 | from ..torch.config import MoondreamConfig 8 | from ..torch.moondream import MoondreamModel 9 | from ..torch.weights import load_weights_into_model 10 | 11 | 12 | def eval_realworldqa(model, debug=False): 13 | dataset = datasets.load_dataset("lmms-lab/RealWorldQA", split="test") 14 | 15 | correct = 0 16 | total = 0 17 | results = [] 18 | 19 | for row in tqdm(dataset, disable=debug, desc="RealWorldQA"): 20 | image = row["image"] 21 | question = row["question"] 22 | answer = row["answer"] 23 | model_answer = model.query(image, question)["answer"] 24 | is_correct = model_answer.strip().lower() == answer.strip().lower() 25 | 26 | results.append( 27 | { 28 | "question": question, 29 | "ground_truth": answer, 30 | "model_answer": model_answer, 31 | "is_correct": is_correct, 32 | } 33 | ) 34 | 35 | total += 1 36 | if is_correct: 37 | correct += 1 38 | elif debug: 39 | print(f"Image: {row['image_path']}") 40 | print(f"Question: {question}") 41 | print(f"Answer: {answer}") 42 | print(f"Model Answer: {model_answer}") 43 | if debug: 44 | print(f"Correct: {correct}, Total: {total}") 45 | print(f"Accuracy: {correct * 100 / total:.2f}") 46 | print("---------") 47 | 48 | return { 49 | "acc": correct * 100 / total, 50 | "correct_count": correct, 51 | "total_count": total, 52 | "results": results, 53 | } 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--model", type=str, required=True) 59 | parser.add_argument("--debug", action="store_true") 60 | args = parser.parse_args() 61 | 62 | if torch.cuda.is_available(): 63 | torch.set_default_device("cuda") 64 | elif torch.backends.mps.is_available(): 65 | torch.set_default_device("mps") 66 | 67 | config = MoondreamConfig() 68 | model = MoondreamModel(config) 69 | load_weights_into_model(args.model, model) 70 | model.compile() 71 | 72 | result = eval_realworldqa(model, args.debug) 73 | 74 | print(f"Accuracy: {result['acc']:.2f}") 75 | print(f"Correct: {result['correct_count']} / {result['total_count']}") 76 | -------------------------------------------------------------------------------- /moondream/eval/tallyqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import torch 4 | 5 | from tqdm import tqdm 6 | 7 | from ..torch.config import MoondreamConfig 8 | from ..torch.moondream import MoondreamModel 9 | from ..torch.weights import load_weights_into_model 10 | 11 | PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. " 12 | 13 | 14 | def eval_tallyqa(model, debug=False): 15 | dataset = datasets.load_dataset( 16 | "vikhyatk/tallyqa-test", 17 | split="test", 18 | download_config=datasets.DownloadConfig(num_proc=16), 19 | ) 20 | 21 | total = 0 22 | total_simple = 0 23 | correct = 0 24 | correct_simple = 0 25 | 26 | for row in tqdm(dataset, disable=args.debug): 27 | image = row["image"] 28 | encoded_image = model.encode_image(image) 29 | 30 | for qa in row["qa"]: 31 | question = PREFIX + qa["question"] 32 | answer = str(qa["answer"]) 33 | is_simple = qa["is_simple"] 34 | 35 | model_answer = model.query(encoded_image, question)["answer"] 36 | 37 | total += 1 38 | if model_answer.strip().lower() == answer.strip().lower(): 39 | correct += 1 40 | elif args.debug: 41 | print(f"Question: {qa['question']}") 42 | print(f"Answer: {answer}") 43 | print(f"Model Answer: {model_answer}") 44 | 45 | if is_simple: 46 | total_simple += 1 47 | if model_answer.strip().lower() == answer.strip().lower(): 48 | correct_simple += 1 49 | 50 | if args.debug: 51 | print(f"Simple - Correct: {correct_simple}, Total: {total_simple}") 52 | print(f"Simple Accuracy: {correct_simple * 100 / total_simple:.2f}") 53 | print(f"All - Correct: {correct}, Total: {total}") 54 | print(f"All Accuracy: {correct * 100 / total:.2f}") 55 | print("---------") 56 | 57 | return { 58 | "simple_acc": correct_simple * 100 / total_simple, 59 | "full_acc": correct * 100 / total, 60 | } 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--model", type=str, required=True) 66 | parser.add_argument("--debug", action="store_true") 67 | args = parser.parse_args() 68 | 69 | if torch.cuda.is_available(): 70 | torch.set_default_device("cuda") 71 | elif torch.backends.mps.is_available(): 72 | torch.set_default_device("mps") 73 | 74 | config = MoondreamConfig() 75 | model = MoondreamModel(config) 76 | load_weights_into_model(args.model, model) 77 | model.compile() 78 | 79 | result = eval_tallyqa(model, args.debug) 80 | 81 | print(f"Simple acc: {result['simple_acc']:.2f}") 82 | print(f"Full acc: {result['full_acc']:.2f}") 83 | -------------------------------------------------------------------------------- /moondream/eval/textvqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import torch 4 | 5 | from tqdm import tqdm 6 | 7 | from ..torch.config import MoondreamConfig 8 | from ..torch.moondream import MoondreamModel 9 | from ..torch.weights import load_weights_into_model 10 | from .utils import VQAScorer 11 | 12 | PREFIX_TEXTVQA = "Read the text in the image and provide a brief lowercase answer. Respond 'unanswerable' only if there is no plausible answer. " 13 | 14 | 15 | def eval_textvqa(model, debug=False): 16 | dataset = datasets.load_dataset("vikhyatk/textvqa_val", split="validation") 17 | 18 | scorer = VQAScorer() 19 | 20 | total_score = 0 21 | total_samples = 0 22 | results = [] 23 | 24 | for row in tqdm(dataset, disable=debug, desc="TextVQA"): 25 | image = row["image"] 26 | encoded_image = model.encode_image(image) 27 | question = PREFIX_TEXTVQA + row["question"] 28 | model_answer = model.query(encoded_image, question)["answer"] 29 | 30 | score = scorer.compute_score(model_answer, row["answers"]) 31 | total_score += score 32 | total_samples += 1 33 | 34 | results.append( 35 | { 36 | "question": question, 37 | "ground_truth": row["answers"], 38 | "model_answer": model_answer, 39 | "score": score, 40 | } 41 | ) 42 | 43 | if debug: 44 | print(f"Question: {row['question']}") 45 | print(f"Ground Truth Answers: {row['answers']}") 46 | print(f"Model Answer: {model_answer}") 47 | print(f"Score: {score}") 48 | print(f"Running Average Score: {total_score * 100 / total_samples:.2f}") 49 | print("---------") 50 | 51 | return {"score": total_score * 100 / total_samples, "results": results} 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--model", type=str, required=True) 57 | parser.add_argument("--debug", action="store_true") 58 | args = parser.parse_args() 59 | 60 | if torch.cuda.is_available(): 61 | torch.set_default_device("cuda") 62 | elif torch.backends.mps.is_available(): 63 | torch.set_default_device("mps") 64 | 65 | config = MoondreamConfig() 66 | model = MoondreamModel(config) 67 | load_weights_into_model(args.model, model) 68 | model.compile() 69 | 70 | result = eval_textvqa(model, args.debug) 71 | 72 | print(f"Score: {result['score']}") 73 | -------------------------------------------------------------------------------- /moondream/eval/waste_detection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from typing import Dict, List, Tuple 4 | 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from datasets import load_dataset 9 | 10 | from ..torch.config import MoondreamConfig 11 | from ..torch.moondream import MoondreamModel 12 | from ..torch.weights import load_weights_into_model 13 | 14 | 15 | Box = Tuple[float, float, float, float] # (x1, y1, x2, y2) – in proportion form 16 | 17 | 18 | def iou(a: Box, b: Box) -> float: 19 | """Corner-format IoU. Returns 0 when either box has zero area.""" 20 | x1, y1 = max(a[0], b[0]), max(a[1], b[1]) 21 | x2, y2 = min(a[2], b[2]), min(a[3], b[3]) 22 | inter = max(0.0, x2 - x1) * max(0.0, y2 - y1) 23 | 24 | union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter 25 | return inter / union if union else 0.0 26 | 27 | 28 | def match(gt: List[Box], pr: List[Box], iou_thr: float) -> Tuple[int, int, int]: 29 | """ 30 | Greedy one-to-one matching with no confidences. 31 | Predictions are taken in the order produced by the model. 32 | """ 33 | tp = fp = 0 34 | seen = [False] * len(gt) 35 | 36 | for p in pr: 37 | best, best_i = 0.0, -1 38 | for i, g in enumerate(gt): 39 | if seen[i]: 40 | continue 41 | iou_ = iou(p, g) 42 | if iou_ > best: 43 | best, best_i = iou_, i 44 | if best >= iou_thr: 45 | tp += 1 46 | seen[best_i] = True 47 | else: 48 | fp += 1 49 | 50 | fn = len(gt) - tp 51 | return tp, fp, fn 52 | 53 | 54 | class WasteDetection(torch.utils.data.Dataset): 55 | def __init__(self, name: str = "moondream/waste_detection", split: str = "test"): 56 | self.ds = load_dataset(name, split=split) 57 | 58 | def __len__(self): 59 | return len(self.ds) 60 | 61 | def __getitem__(self, idx: int) -> Dict: 62 | s = self.ds[idx] 63 | img = ( 64 | s["image"] 65 | if isinstance(s["image"], Image.Image) 66 | else Image.fromarray(s["image"]) 67 | ) 68 | W, H = float(s.get("width", img.width)), float(s.get("height", img.height)) 69 | 70 | lbl_to_boxes = defaultdict(list) 71 | for (xc, yc, bw, bh), lbl in zip(s["boxes"], s["labels"]): 72 | x1 = xc - bw / 2 73 | y1 = yc - bh / 2 74 | x2 = xc + bw / 2 75 | y2 = yc + bh / 2 76 | lbl_to_boxes[lbl].append((x1, y1, x2, y2)) 77 | 78 | return {"image": img, "gt": lbl_to_boxes, "W": W, "H": H} 79 | 80 | 81 | def evaluate( 82 | model: MoondreamModel, 83 | iou_thr: float, 84 | debug: bool, 85 | ): 86 | ds = WasteDetection(split="test") 87 | TP = FP = FN = 0 88 | 89 | for s in tqdm(ds, disable=debug, desc="Waste"): 90 | img, gts = s["image"], s["gt"] 91 | enc = model.encode_image(img) 92 | 93 | for lbl, gt_boxes in gts.items(): 94 | preds: List[Box] = [ 95 | ( 96 | o["x_min"], 97 | o["y_min"], 98 | o["x_max"], 99 | o["y_max"], 100 | ) 101 | for o in model.detect(enc, lbl)["objects"] 102 | ] 103 | tp, fp, fn = match(gt_boxes, preds, iou_thr) 104 | TP += tp 105 | FP += fp 106 | FN += fn 107 | 108 | prec = TP / (TP + FP) if TP + FP else 0.0 109 | rec = TP / (TP + FN) if TP + FN else 0.0 110 | f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0.0 111 | return dict(precision=prec, recall=rec, f1=f1, tp=TP, fp=FP, fn=FN) 112 | 113 | 114 | def load_model(path: str, device: torch.device) -> MoondreamModel: 115 | cfg = MoondreamConfig() 116 | model = MoondreamModel(cfg) 117 | load_weights_into_model(path, model) 118 | model.compile() 119 | model.to(device) 120 | return model 121 | 122 | 123 | def main(): 124 | p = argparse.ArgumentParser() 125 | p.add_argument("--model", required=True) 126 | p.add_argument("--iou_thr", type=float, default=0.5) 127 | p.add_argument("--gpu", type=int, default=0) 128 | p.add_argument("--debug", action="store_true") 129 | args = p.parse_args() 130 | 131 | if torch.cuda.is_available(): 132 | torch.cuda.set_device(args.gpu) 133 | device = torch.device(f"cuda:{args.gpu}") 134 | elif torch.backends.mps.is_available(): 135 | device = torch.device("mps") 136 | else: 137 | device = torch.device("cpu") 138 | 139 | model = load_model(args.model, device) 140 | res = evaluate(model, args.iou_thr, args.debug) 141 | 142 | print(f"Precision: {res['precision']*100:.2f}%") 143 | print(f"Recall: {res['recall']*100:.2f}%") 144 | print(f"F1 Score: {res['f1']*100:.2f}%") 145 | print(f"TP: {res['tp']} FP: {res['fp']} FN: {res['fn']}") 146 | 147 | 148 | if __name__ == "__main__": 149 | """ 150 | Eval to accompany finetune_region.py. 151 | """ 152 | main() 153 | -------------------------------------------------------------------------------- /moondream/finetune/README.md: -------------------------------------------------------------------------------- 1 | # Finetuning Moondream 2B 2 | 3 | This readme will walk you through the process of finetuning the text and region encoders of the Moondream 2B model. 4 | 5 | > Make sure to run all commands from the root directory of the project. 6 | 7 | ## Initial Setup 8 | 9 | ### Clone and Setup Environment 10 | ```bash 11 | git clone https://github.com/vikhyat/moondream 12 | cd moondream 13 | python -m venv .venv 14 | source .venv/bin/activate 15 | ``` 16 | 17 | ### Install Dependencies 18 | ```bash 19 | # Install base requirements 20 | pip install -r requirements.txt 21 | 22 | # Install finetuning specific dependencies 23 | pip install safetensors datasets bitsandbytes tqdm wandb einops 24 | ``` 25 | 26 | ## Downloading the Base Model 27 | 28 | Download `model.safetensors` from the [Hugging Face repository](https://huggingface.co/vikhyatk/moondream2/tree/main) and place it in the `models` directory as `moondream_base.safetensors`. 29 | 30 | ```bash 31 | # Create models directory 32 | mkdir -p models 33 | 34 | # Download it using curl (run from root moondream directory) 35 | wget https://huggingface.co/vikhyatk/moondream2/resolve/main/model.safetensors 36 | ``` 37 | 38 | ## Weights & Biases 39 | 40 | We use Weights & Biases (wandb) to track finetuning progress. 41 | 42 | To set it up to track your runs, use `wandb login`. 43 | 44 | This will take you through creating an account if you don't have one setup already. Enter your API key and you're ready to go. 45 | 46 | ## Finetuning the Text Encoder 47 | 48 | For this example, we will be teaching Moondream to describe images. 49 | 50 | Given the prompt: 51 | `\n\nQuestion: Describe this image.\n\nAnswer:` 52 | 53 | We return a more detailed caption of the image then you would get from the base model. 54 | 55 | 1. Double check that you've updated MODEL_PATH to point to the base moondream model in `moondream/finetune/finetune_text.py` 56 | 2. Double check that the save path ends in `.safetensors`, otherwise the run will fail. 57 | > Navigate to line 150 in `moondream/finetune/finetune_text.py`, 58 | ``` # Add save path 59 | save_file( 60 | model.state_dict(), 61 | "moondream_finetune.safetensors", // update this line ex: "models/moondream_text_finetuned.safetensors" 62 | ) 63 | ``` 64 | 65 | ### Start Text Finetuning 66 | ```bash 67 | python -m moondream.finetune.finetune_text 68 | ``` 69 | 70 | The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_text_finetuned.safetensors`. 71 | 72 | ### Test the Finetuned Text Encoder 73 | 74 | You can test the finetuned models performance with the following command (run from root moondream directory). 75 | 76 | This will return the caption of the image. 77 | 78 | ```bash 79 | # Remember to update the paths 80 | python -m moondream.torch.sample --model [FINETUNED_MODEL_PATH] --image "[DATASET_DIRECTORY]/test/[IMAGE_NAME]" --prompt "\n\nQuestion: Describe this image.\n\nAnswer:" 81 | ``` 82 | 83 | ## Finetuning the Region Encoder 84 | 85 | For this example, we will be teaching Moondream to detect railroad cracks in images of a railway. 86 | 87 | Our dataset trains our model such that, 88 | 89 | Given the prompt: 90 | `\n\nDetect: \n\n` 91 | 92 | We are returned the coordinates of a detected crack in the following format: 93 | ```{'objects': [{'x_min': [X_MIN], 'y_min': [Y_MIN], 'x_max': [X_MAX], 'y_max': [Y_MAX]}]}``` 94 | 95 | ### Setup Dataset Dependencies 96 | 97 | 1. Update MODEL_PATH to point to the base moondream model. 98 | 5. Double check that the save path ends in `.safetensors`, otherwise the run will fail. 99 | > Navigate to line 244 in `moondream/finetune/finetune_region.py`. 100 | ``` # Add save path 101 | save_file( 102 | model.state_dict(), 103 | "moondream_finetune.safetensors", // update this line 104 | ) 105 | ``` 106 | 107 | ### Start Region Finetuning 108 | ```bash 109 | python -m moondream.finetune.finetune_region 110 | ``` 111 | 112 | The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_region_finetuned.safetensors`. -------------------------------------------------------------------------------- /moondream/finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/moondream/finetune/__init__.py -------------------------------------------------------------------------------- /moondream/finetune/finetune_text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset 4 | import math 5 | from safetensors.torch import save_file 6 | 7 | from tqdm import tqdm 8 | from datasets import load_dataset 9 | from bitsandbytes.optim import AdamW8bit 10 | import wandb 11 | 12 | from ..torch.weights import load_weights_into_model 13 | from ..torch.moondream import MoondreamModel, MoondreamConfig, text_encoder 14 | from ..torch.text import _produce_hidden, _lm_head, TextConfig 15 | 16 | # This is a intended to be a basic starting point for fine-tuning the text encoder. 17 | # Your optimal hyperparams and data may be different. 18 | MODEL_PATH = "" 19 | # Your data should end with the eos token. Here is the textual representation. 20 | ANSWER_EOS = "<|endoftext|>" 21 | LR = 3e-6 22 | EPOCHS = 3 23 | GRAD_ACCUM_STEPS = 128 24 | 25 | 26 | def lr_schedule(step, max_steps): 27 | x = step / max_steps 28 | if x < 0.1: 29 | return 0.1 * LR + 0.9 * LR * x / 0.1 30 | else: 31 | return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2 32 | 33 | 34 | def text_loss( 35 | inputs_embeds: torch.Tensor, w: nn.Module, labels: torch.Tensor, config: TextConfig 36 | ): 37 | _, q_len, _ = inputs_embeds.shape 38 | hidden_BTC = _produce_hidden(inputs_embeds, w, config) 39 | lm_logits = _lm_head(hidden_BTC, w) 40 | 41 | loss = None 42 | if labels is not None: 43 | _, _, l_len = labels.shape 44 | shift_index = (q_len - l_len) - 1 45 | shifted_logits = lm_logits[..., shift_index:-1, :].contiguous() 46 | shifted_labels = labels.contiguous() 47 | loss = nn.CrossEntropyLoss()( 48 | shifted_logits.view(-1, shifted_logits.size(-1)), 49 | shifted_labels.view(-1), 50 | ) 51 | return loss 52 | 53 | 54 | class DocciDataset(Dataset): 55 | def __init__(self, split="train"): 56 | self.data = load_dataset("google/docci", trust_remote_code=True)[split] 57 | 58 | def __len__(self): 59 | return len(self.data) 60 | 61 | def __getitem__(self, idx): 62 | sample = self.data[idx] 63 | description = sample["description"] 64 | return { 65 | "image": sample["image"], 66 | "qa": { 67 | "question": "\n\nQuestion: Describe this image.\n\nAnswer:", 68 | "answer": f"{description}{ANSWER_EOS}", 69 | }, 70 | } 71 | 72 | 73 | def main(): 74 | if torch.cuda.is_available(): 75 | torch.set_default_device("cuda") 76 | elif torch.backends.mps.is_available(): 77 | torch.set_default_device("mps") 78 | 79 | wandb.init( 80 | project="moondream-ft", 81 | config={ 82 | "EPOCHS": EPOCHS, 83 | "GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS, 84 | "LR": LR, 85 | }, 86 | ) 87 | 88 | config = MoondreamConfig() 89 | model = MoondreamModel(config) 90 | load_weights_into_model(MODEL_PATH, model) 91 | 92 | optimizer = AdamW8bit( 93 | [ 94 | {"params": model.text.parameters()}, 95 | ], 96 | lr=LR, 97 | betas=(0.9, 0.95), 98 | eps=1e-6, 99 | ) 100 | 101 | dataset = DocciDataset("train") 102 | 103 | total_steps = EPOCHS * len(dataset) // GRAD_ACCUM_STEPS 104 | pbar = tqdm(total=total_steps) 105 | 106 | i = 0 107 | for epoch in range(EPOCHS): 108 | for sample in dataset: 109 | i += 1 110 | with torch.no_grad(): 111 | img_emb = model._run_vision_encoder(sample["image"]) 112 | bos_emb = text_encoder( 113 | torch.tensor([[model.config.tokenizer.bos_id]], device=model.device), 114 | model.text, 115 | ) 116 | question_tokens = model.tokenizer.encode(sample["qa"]["question"]).ids 117 | question_emb = text_encoder( 118 | torch.tensor([[question_tokens]], device=model.device), 119 | model.text, 120 | ).squeeze(0) 121 | answer_tokens = model.tokenizer.encode(sample["qa"]["answer"]).ids 122 | answer_emb = text_encoder( 123 | torch.tensor([[answer_tokens]], device=model.device), 124 | model.text, 125 | ).squeeze(0) 126 | inputs_embeds = torch.cat( 127 | [bos_emb, img_emb[None], question_emb, answer_emb], dim=1 128 | ) 129 | loss = text_loss( 130 | inputs_embeds=inputs_embeds, 131 | w=model.text, 132 | labels=torch.tensor([[answer_tokens]], device=model.device), 133 | config=config.text, 134 | ) 135 | 136 | loss.backward() 137 | 138 | if i % GRAD_ACCUM_STEPS == 0: 139 | optimizer.step() 140 | optimizer.zero_grad() 141 | 142 | lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps) 143 | for param_group in optimizer.param_groups: 144 | param_group["lr"] = lr 145 | pbar.set_postfix({"step": i // GRAD_ACCUM_STEPS, "loss": loss.item()}) 146 | pbar.update(1) 147 | wandb.log( 148 | {"loss/train": loss.item(), "lr": optimizer.param_groups[0]["lr"]} 149 | ) 150 | wandb.finish() 151 | # Add save path: ex. home/model.safetensors 152 | save_file( 153 | model.state_dict(), 154 | "moondream_finetune.safetensors", 155 | ) 156 | 157 | 158 | if __name__ == "__main__": 159 | """ 160 | Replace paths with your appropriate paths. 161 | To run: python -m moondream.finetune.finetune_text 162 | """ 163 | main() 164 | -------------------------------------------------------------------------------- /moondream/torch/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, List, Optional 3 | 4 | 5 | @dataclass(frozen=True) 6 | class TextConfig: 7 | dim: int = 2048 8 | ff_dim: int = 8192 9 | n_layers: int = 24 10 | vocab_size: int = 51200 11 | max_context: int = 2048 12 | n_heads: int = 32 13 | n_kv_heads: int = 32 14 | prefix_attn: int = 730 15 | group_size: Optional[int] = None 16 | 17 | 18 | @dataclass(frozen=True) 19 | class VisionConfig: 20 | enc_dim: int = 1152 21 | enc_patch_size: int = 14 22 | enc_n_layers: int = 27 23 | enc_ff_dim: int = 4304 24 | enc_n_heads: int = 16 25 | proj_out_dim: int = 2048 26 | crop_size: int = 378 27 | in_channels: int = 3 28 | max_crops: int = 12 29 | overlap_margin: int = 4 30 | proj_inner_dim: int = 8192 31 | 32 | 33 | @dataclass(frozen=True) 34 | class RegionConfig: 35 | dim: int = 2048 36 | coord_feat_dim: int = 256 37 | coord_out_dim: int = 1024 38 | size_feat_dim: int = 512 39 | size_out_dim: int = 2048 40 | inner_dim: int = 8192 41 | group_size: Optional[int] = None 42 | 43 | 44 | @dataclass(frozen=True) 45 | class TokenizerConfig: 46 | bos_id: int = 0 47 | eos_id: int = 0 48 | coord_id: int = 5 49 | size_id: int = 6 50 | templates: Dict[str, Optional[Dict[str, List[int]]]] = field( 51 | default_factory=lambda: { 52 | "caption": { 53 | "short": [1, 32708, 2, 12492, 3], 54 | "normal": [1, 32708, 2, 6382, 3], 55 | "long": [1, 32708, 2, 4059, 3], 56 | }, 57 | "query": {"prefix": [1, 15381, 2], "suffix": [3]}, 58 | "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]}, 59 | "point": {"prefix": [1, 2581, 2], "suffix": [3]}, 60 | } 61 | ) 62 | 63 | 64 | @dataclass(frozen=True) 65 | class MoondreamConfig: 66 | text: TextConfig = TextConfig() 67 | vision: VisionConfig = VisionConfig() 68 | region: RegionConfig = RegionConfig() 69 | tokenizer: TokenizerConfig = TokenizerConfig() 70 | 71 | @classmethod 72 | def from_dict(cls, config_dict: dict): 73 | text_config = TextConfig(**config_dict.get("text", {})) 74 | vision_config = VisionConfig(**config_dict.get("vision", {})) 75 | region_config = RegionConfig(**config_dict.get("region", {})) 76 | tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {})) 77 | return cls( 78 | text=text_config, 79 | vision=vision_config, 80 | region=region_config, 81 | tokenizer=tokenizer_config, 82 | ) 83 | 84 | def to_dict(self): 85 | return { 86 | "text": self.text.__dict__, 87 | "vision": self.vision.__dict__, 88 | "region": self.region.__dict__, 89 | "tokenizer": self.tokenizer.__dict__, 90 | } 91 | -------------------------------------------------------------------------------- /moondream/torch/hf_moondream.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, PretrainedConfig 2 | 3 | from .config import MoondreamConfig 4 | from .moondream import MoondreamModel 5 | 6 | # Files sometimes don't get loaded without these... 7 | from .image_crops import * 8 | from .vision import * 9 | from .text import * 10 | from .region import * 11 | from .utils import * 12 | 13 | 14 | def extract_question(text): 15 | prefix = "\n\nQuestion: " 16 | suffix = "\n\nAnswer:" 17 | 18 | if text.startswith(prefix) and text.endswith(suffix): 19 | return text[len(prefix) : -len(suffix)] 20 | else: 21 | return None 22 | 23 | 24 | class HfConfig(PretrainedConfig): 25 | _auto_class = "AutoConfig" 26 | model_type = "moondream1" 27 | 28 | def __init__(self, **kwargs): 29 | super().__init__(**kwargs) 30 | self.config = {} 31 | 32 | 33 | class HfMoondream(PreTrainedModel): 34 | _auto_class = "AutoModelForCausalLM" 35 | config_class = HfConfig 36 | 37 | def __init__(self, config): 38 | super().__init__(config) 39 | self.model = MoondreamModel( 40 | MoondreamConfig.from_dict(config.config), setup_caches=False 41 | ) 42 | self._is_kv_cache_setup = False 43 | 44 | def _setup_caches(self): 45 | if not self._is_kv_cache_setup: 46 | self.model._setup_caches() 47 | self._is_kv_cache_setup = True 48 | 49 | @property 50 | def encode_image(self): 51 | self._setup_caches() 52 | return self.model.encode_image 53 | 54 | @property 55 | def query(self): 56 | self._setup_caches() 57 | return self.model.query 58 | 59 | @property 60 | def caption(self): 61 | self._setup_caches() 62 | return self.model.caption 63 | 64 | @property 65 | def detect(self): 66 | self._setup_caches() 67 | return self.model.detect 68 | 69 | @property 70 | def point(self): 71 | self._setup_caches() 72 | return self.model.point 73 | 74 | @property 75 | def detect_gaze(self): 76 | self._setup_caches() 77 | return self.model.detect_gaze 78 | 79 | def answer_question( 80 | self, 81 | image_embeds, 82 | question, 83 | tokenizer=None, 84 | chat_history="", 85 | result_queue=None, 86 | max_new_tokens=256, 87 | **kwargs 88 | ): 89 | answer = self.query(image_embeds, question)["answer"].strip() 90 | 91 | if result_queue is not None: 92 | result_queue.put(answer) 93 | return answer 94 | 95 | def batch_answer(self, images, prompts, tokenizer=None, **kwargs): 96 | answers = [] 97 | for image, prompt in zip(images, prompts): 98 | answers.append(self.query(image, prompt)["answer"].strip()) 99 | return answers 100 | 101 | def _unsupported_exception(self): 102 | raise NotImplementedError( 103 | "This method is not supported in the latest version of moondream. " 104 | "Consider upgrading to the updated API spec, or alternately pin " 105 | "to 'revision=2024-08-26'." 106 | ) 107 | 108 | def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs): 109 | """ 110 | Function definition remains unchanged for backwards compatibility. 111 | Be aware that tokenizer, max_new_takens, and kwargs are ignored. 112 | """ 113 | prompt_extracted = extract_question(prompt) 114 | if prompt_extracted is not None: 115 | answer = self.model.query( 116 | image=image_embeds, question=prompt_extracted, stream=False 117 | )["answer"] 118 | else: 119 | image_embeds = self.encode_image(image_embeds) 120 | prompt_tokens = torch.tensor( 121 | [self.model.tokenizer.encode(prompt).ids], 122 | device=self.device, 123 | ) 124 | 125 | def generator(): 126 | for token in self.model._generate_text( 127 | prompt_tokens, 128 | image_embeds.kv_cache, 129 | image_embeds.pos, 130 | max_new_tokens, 131 | ): 132 | yield token 133 | 134 | answer = "".join(list(generator())) 135 | 136 | return [answer] 137 | 138 | def get_input_embeddings(self) -> nn.Embedding: 139 | """ 140 | Lazily wrap the raw parameter `self.model.text.wte` in a real 141 | `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper 142 | **shares** the weight tensor—no copy is made. 143 | """ 144 | if not hasattr(self, "_input_embeddings"): 145 | self._input_embeddings = nn.Embedding.from_pretrained( 146 | self.model.text.wte, # tensor created in text.py 147 | freeze=True, # set to False if you need it trainable 148 | ) 149 | return self._input_embeddings 150 | 151 | def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None: 152 | """ 153 | Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the 154 | embeddings and keeps everything tied to `self.model.text.wte`. 155 | """ 156 | # 1. point the low-level parameter to the new weight matrix 157 | self.model.text.wte = value.weight 158 | # 2. keep a reference for get_input_embeddings() 159 | self._input_embeddings = value 160 | 161 | def input_embeds( 162 | self, 163 | input_ids: Union[torch.LongTensor, list, tuple], 164 | *, 165 | device: torch.device | None = None 166 | ) -> torch.FloatTensor: 167 | """ 168 | Back-compat wrapper that turns token IDs into embeddings. 169 | 170 | Example: 171 | ids = torch.tensor([[1, 2, 3]]) 172 | embeds = model.input_embeds(ids) # (1, 3, hidden_dim) 173 | """ 174 | if not torch.is_tensor(input_ids): 175 | input_ids = torch.as_tensor(input_ids) 176 | if device is not None: 177 | input_ids = input_ids.to(device) 178 | 179 | return self.get_input_embeddings()(input_ids) 180 | -------------------------------------------------------------------------------- /moondream/torch/hf_release.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from .weights import load_weights_into_model 5 | from .hf_moondream import HfConfig, HfMoondream 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--model-name", type=str, default="vikhyatk/moondream-next") 10 | parser.add_argument("--ckpt", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | config = HfConfig() 14 | model = HfMoondream(config) 15 | load_weights_into_model(args.ckpt, model.model) 16 | 17 | model.push_to_hub(args.model_name, config=config) 18 | -------------------------------------------------------------------------------- /moondream/torch/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from dataclasses import dataclass 6 | from typing import Literal 7 | 8 | try: 9 | from torchao import quantize_ 10 | from torchao.quantization import int4_weight_only 11 | except ImportError: 12 | 13 | def quantize_(model, quant_mode): 14 | raise ImportError( 15 | "torchao is not installed. Please install it with `pip install torchao`." 16 | ) 17 | 18 | def int4_weight_only(group_size): 19 | raise ImportError( 20 | "torchao is not installed. Please install it with `pip install torchao`." 21 | ) 22 | 23 | 24 | def gelu_approx(x): 25 | return F.gelu(x, approximate="tanh") 26 | 27 | 28 | @dataclass 29 | class LinearWeights: 30 | weight: torch.Tensor 31 | bias: torch.Tensor 32 | 33 | 34 | def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: 35 | return F.linear(x, w.weight, w.bias) 36 | 37 | 38 | def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16): 39 | _step = W_q.shape[0] 40 | W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device) 41 | W_r[:_step] = (W_q & 0b11110000) >> 4 42 | W_r[_step:] = W_q & 0b00001111 43 | W_r.sub_(zero).mul_(scale) 44 | return W_r.reshape(orig_shape) 45 | 46 | 47 | class QuantizedLinear(nn.Module): 48 | def __init__( 49 | self, 50 | in_features: int, 51 | out_features: int, 52 | dtype: torch.dtype, 53 | ): 54 | # TODO: Take group_size as an input instead of hardcoding it here. 55 | super().__init__() 56 | self.in_features = in_features 57 | self.out_features = out_features 58 | self.weight = nn.ParameterDict( 59 | { 60 | "packed": nn.Parameter( 61 | torch.empty( 62 | out_features * in_features // (128 * 2), 128, dtype=torch.uint8 63 | ), 64 | requires_grad=False, 65 | ), 66 | "scale": nn.Parameter( 67 | torch.empty(out_features * in_features // 128, 1), 68 | requires_grad=False, 69 | ), 70 | "zero_point": nn.Parameter( 71 | torch.empty(out_features * in_features // 128, 1), 72 | requires_grad=False, 73 | ), 74 | } 75 | ) 76 | self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False) 77 | self.unpacked = False 78 | 79 | def unpack(self): 80 | if self.unpacked: 81 | return 82 | 83 | self.weight = nn.Parameter( 84 | dequantize_tensor( 85 | self.weight["packed"], 86 | self.weight["scale"], 87 | self.weight["zero_point"], 88 | (self.out_features, self.in_features), 89 | torch.bfloat16, 90 | ) 91 | ) 92 | with torch.device("meta"): 93 | self.linear = nn.Linear( 94 | self.in_features, self.out_features, dtype=torch.bfloat16 95 | ) 96 | self.linear.weight = self.weight 97 | self.linear.bias = nn.Parameter( 98 | self.bias.to(torch.bfloat16), requires_grad=False 99 | ) 100 | 101 | del self.weight, self.bias 102 | quantize_(self, int4_weight_only(group_size=128)) 103 | self.unpacked = True 104 | torch.cuda.empty_cache() 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | if not self.unpacked: 108 | self.unpack() 109 | return self.linear(x) 110 | 111 | 112 | @dataclass 113 | class LayerNormWeights: 114 | weight: torch.Tensor 115 | bias: torch.Tensor 116 | 117 | 118 | def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: 119 | return F.layer_norm(x, w.bias.shape, w.weight, w.bias) 120 | 121 | 122 | @dataclass 123 | class MLPWeights: 124 | fc1: LinearWeights 125 | fc2: LinearWeights 126 | act: Literal["gelu_approx"] = "gelu_approx" 127 | 128 | 129 | def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: 130 | x = w.fc1(x) 131 | x = gelu_approx(x) 132 | x = w.fc2(x) 133 | return x 134 | 135 | 136 | @dataclass 137 | class AttentionWeights: 138 | qkv: LinearWeights 139 | proj: LinearWeights 140 | 141 | 142 | def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: 143 | bsz, q_len, d_model = x.shape 144 | head_dim = d_model // n_heads 145 | 146 | q, k, v = [ 147 | t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) 148 | for t in linear(x, w.qkv).chunk(3, dim=-1) 149 | ] 150 | out = F.scaled_dot_product_attention(q, k, v) 151 | out = out.transpose(1, 2).reshape(bsz, q_len, d_model) 152 | out = linear(out, w.proj) 153 | return out 154 | -------------------------------------------------------------------------------- /moondream/torch/region.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from typing import List, Tuple, Union 6 | 7 | from .layers import mlp 8 | 9 | SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]] 10 | 11 | 12 | def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: 13 | """ 14 | Applies Fourier feature mapping to input tensor x using frequency matrix w. This 15 | projects inputs through sinusoidal functions to create higher dimensional features 16 | that help mitigate spectral bias - the tendency of neural networks to learn 17 | low-frequency functions more easily than high-frequency ones. By explicitly 18 | mapping inputs to higher frequencies through sin/cos transformations, we enable 19 | better learning of fine details and higher frequency patterns. 20 | 21 | Args: 22 | x: Input tensor to transform 23 | w: Matrix of frequencies for the Fourier features transformation 24 | 25 | Returns: 26 | Concatenated cosine and sine transformed features as a tensor 27 | """ 28 | f = 2 * math.pi * x @ w 29 | return torch.cat([f.cos(), f.sin()], dim=-1) 30 | 31 | 32 | def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor: 33 | """ 34 | Takes as input a tensor containing a single float coordinate value (x or y) 35 | and encodes it into hidden states for input to the text model. 36 | 37 | Args: 38 | coord: Tensor with single float coordinate value 39 | 40 | Returns: 41 | Encoded hidden states tensor for input to text model 42 | """ 43 | return w.coord_encoder(fourier_features(coord, w.coord_features)) 44 | 45 | 46 | def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor: 47 | """ 48 | Takes as input the last hidden state from the text model and outputs a single logit 49 | representing either an x or y coordinate prediction. 50 | 51 | Args: 52 | hidden_state: The final hidden state tensor from the text model. 53 | 54 | Returns: 55 | A single logit representing the predicted coordinate value (x or y) 56 | """ 57 | return mlp(hidden_state, w.coord_decoder) 58 | 59 | 60 | def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor: 61 | """ 62 | Takes a tensor containing width and height values and encodes them into 63 | hidden states for input to the text model. 64 | 65 | Args: 66 | size: Tensor with two floats for width and height 67 | 68 | Returns: 69 | Encoded hidden states tensor for input to text model 70 | """ 71 | return w.size_encoder(fourier_features(size, w.size_features)) 72 | 73 | 74 | def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor: 75 | """ 76 | Takes as input the last hidden state from the text model and outputs logits 77 | for 1024 bins representing width and height in log-scale. 78 | 79 | The bins are distributed according to the formula: 80 | bin = (log2(size) + 10.0) / 10.0 * 1023.0 81 | where size values are clamped to be at least 1/1024. 82 | 83 | To convert from bin back to size: 84 | size = 2^((bin / 1023.0) * 10.0 - 10.0) 85 | 86 | Args: 87 | hidden_state: The final hidden state tensor from the text model. 88 | 89 | Returns: 90 | A tensor containing logits for 1024 bins for width and height. 91 | Shape is (2, 1024) where the first dimension corresponds to width and height. 92 | """ 93 | return mlp(hidden_state, w.size_decoder).view(2, -1) 94 | 95 | 96 | def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor: 97 | """ 98 | Takes a list of spatial references (points or regions) and encodes them into 99 | hidden states for input to the text model. 100 | 101 | Args: 102 | spatial_refs: List of spatial references (points or boxes) 103 | - Points are represented as normalized (x, y) tuples 104 | - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples 105 | 106 | Returns: 107 | {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]} 108 | """ 109 | coords, sizes = [], [] 110 | for ref in spatial_refs: 111 | if len(ref) == 2: 112 | coords.append(ref[0]) 113 | coords.append(ref[1]) 114 | else: 115 | x_c = (ref[0] + ref[2]) / 2 116 | y_c = (ref[1] + ref[3]) / 2 117 | width = ref[2] - ref[0] 118 | height = ref[3] - ref[1] 119 | coords.append(x_c) 120 | coords.append(y_c) 121 | sizes.append([width, height]) 122 | 123 | coords = torch.tensor( 124 | coords, device=w.coord_features.device, dtype=w.coord_features.dtype 125 | ).view(-1, 1) 126 | coords = encode_coordinate(coords, w) 127 | 128 | if sizes: 129 | sizes = torch.tensor( 130 | sizes, device=w.size_features.device, dtype=w.size_features.dtype 131 | ) 132 | sizes = encode_size(sizes, w) 133 | else: 134 | sizes = None 135 | 136 | return {"coords": coords, "sizes": sizes} 137 | -------------------------------------------------------------------------------- /moondream/torch/rope.py: -------------------------------------------------------------------------------- 1 | # Ethically sourced from https://github.com/xjdr-alt/entropix 2 | 3 | import torch 4 | 5 | 6 | def precompute_freqs_cis( 7 | dim: int, 8 | end: int, 9 | theta: float = 10000.0, 10 | use_scaled: bool = False, 11 | dtype: torch.dtype = torch.float32, 12 | ) -> torch.Tensor: 13 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim)) 14 | t = torch.arange(end, dtype=dtype).unsqueeze(1) 15 | freqs = t * freqs.unsqueeze(0) 16 | freqs = torch.exp(1j * freqs) 17 | return torch.stack([freqs.real, freqs.imag], dim=-1) 18 | 19 | 20 | def apply_rotary_emb( 21 | x: torch.Tensor, 22 | freqs_cis: torch.Tensor, 23 | position_ids: torch.Tensor, 24 | num_heads: int, 25 | rot_dim: int = 32, 26 | interleave: bool = False, 27 | ) -> torch.Tensor: 28 | assert rot_dim == freqs_cis.shape[-2] * 2 29 | assert num_heads == x.shape[1] 30 | 31 | x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:] 32 | 33 | if interleave: 34 | xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0] 35 | xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1] 36 | else: 37 | d_q = x_rot.shape[-1] // 2 38 | xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:] 39 | 40 | freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0) 41 | freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0) 42 | 43 | # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i 44 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 45 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 46 | xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2) 47 | 48 | return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1) 49 | -------------------------------------------------------------------------------- /moondream/torch/text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import functional as F 5 | 6 | from .layers import layer_norm, mlp, QuantizedLinear 7 | from .rope import apply_rotary_emb, precompute_freqs_cis 8 | from .config import TextConfig 9 | 10 | 11 | def text_encoder(input_ids: torch.Tensor, w: nn.Module): 12 | return F.embedding(input_ids, w.wte) 13 | 14 | 15 | def attn( 16 | x: torch.Tensor, 17 | w: nn.Module, 18 | freqs_cis: torch.Tensor, 19 | kv_cache: nn.Module, 20 | attn_mask: torch.Tensor, 21 | n_heads: int, 22 | n_kv_heads: int, 23 | position_ids: torch.Tensor, 24 | ): 25 | bsz, q_len, d_model = x.shape 26 | head_dim = d_model // n_heads 27 | 28 | qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim) 29 | q_dim = n_heads * head_dim 30 | kv_dim = n_kv_heads * head_dim 31 | q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1) 32 | del qkv_out 33 | 34 | q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) 35 | k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) 36 | v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) 37 | 38 | q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) 39 | k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads) 40 | 41 | if kv_cache is not None: 42 | k, v = kv_cache.update(position_ids, k, v) 43 | 44 | out = F.scaled_dot_product_attention( 45 | q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads 46 | ) 47 | out = out.transpose(1, 2).reshape(bsz, q_len, d_model) 48 | out = w.proj(out) 49 | return out 50 | 51 | 52 | def _attn( 53 | x: torch.Tensor, 54 | w: torch.Tensor, 55 | freqs_cis: torch.Tensor, 56 | attn_mask: torch.Tensor, 57 | n_heads: int, 58 | n_kv_heads: int, 59 | ): 60 | bsz, q_len, d_model = x.shape 61 | head_dim = d_model // n_heads 62 | pos = 0 63 | 64 | qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim) 65 | q_dim = n_heads * head_dim 66 | kv_dim = n_kv_heads * head_dim 67 | 68 | q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2) 69 | k = ( 70 | qkv_out[..., q_dim : q_dim + kv_dim] 71 | .view(bsz, q_len, n_kv_heads, head_dim) 72 | .transpose(1, 2) 73 | ) 74 | v = ( 75 | qkv_out[..., q_dim + kv_dim :] 76 | .view(bsz, q_len, n_kv_heads, head_dim) 77 | .transpose(1, 2) 78 | ) 79 | 80 | position_ids = torch.arange(pos, pos + q_len, dtype=torch.long) 81 | q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) 82 | k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads) 83 | out = F.scaled_dot_product_attention( 84 | q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads 85 | ) 86 | out = out.transpose(1, 2).reshape(bsz, q_len, d_model) 87 | out = w.proj(out) 88 | return out 89 | 90 | 91 | def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig): 92 | hidden_BTC = inputs_embeds 93 | 94 | bsz, q_len, d_model = inputs_embeds.shape 95 | attn_mask = torch.zeros(q_len, q_len) 96 | attn_mask[:730, :730] = 1 97 | for i in range(730, q_len): 98 | attn_mask[i, : i + 1] = 1 99 | attn_mask = attn_mask.to(dtype=torch.bool) 100 | 101 | for i, block in enumerate(w.blocks): 102 | l_in = layer_norm(hidden_BTC, block.ln) 103 | l_attn = _attn( 104 | x=l_in, 105 | w=block.attn, 106 | freqs_cis=w.freqs_cis, 107 | attn_mask=attn_mask, 108 | n_heads=config.n_heads, 109 | n_kv_heads=config.n_kv_heads, 110 | ) 111 | l_mlp = mlp(l_in, block.mlp) 112 | hidden_BTC = hidden_BTC + l_attn + l_mlp 113 | 114 | return hidden_BTC 115 | 116 | 117 | def text_decoder( 118 | x: torch.Tensor, 119 | w: nn.Module, 120 | attn_mask: torch.Tensor, 121 | position_ids: torch.Tensor, 122 | config: TextConfig, 123 | ): 124 | for block in w.blocks: 125 | l_in = layer_norm(x, block.ln) 126 | l_attn = attn( 127 | l_in, 128 | block.attn, 129 | freqs_cis=w.freqs_cis, 130 | kv_cache=block.kv_cache, 131 | attn_mask=attn_mask, 132 | n_heads=config.n_heads, 133 | n_kv_heads=config.n_kv_heads, 134 | position_ids=position_ids, 135 | ) 136 | l_mlp = mlp(l_in, block.mlp) 137 | x = x + l_attn + l_mlp 138 | 139 | return x 140 | 141 | 142 | def lm_head(hidden_BTC: torch.Tensor, w: nn.Module): 143 | hidden_BC = hidden_BTC[:, -1, :] 144 | hidden_BC = layer_norm(hidden_BC, w.post_ln) 145 | logits = w.lm_head(hidden_BC) 146 | return logits 147 | 148 | 149 | def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): 150 | hidden_BTC = layer_norm(hidden_BTC, w.post_ln) 151 | logits = w.lm_head(hidden_BTC) 152 | return logits 153 | 154 | 155 | def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: 156 | qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) 157 | linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear 158 | 159 | text = nn.ModuleDict( 160 | { 161 | "blocks": nn.ModuleList( 162 | [ 163 | nn.ModuleDict( 164 | { 165 | "ln": nn.LayerNorm(config.dim, dtype=dtype), 166 | "attn": nn.ModuleDict( 167 | { 168 | "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype), 169 | "proj": linear_cls( 170 | config.dim, config.dim, dtype=dtype 171 | ), 172 | } 173 | ), 174 | "mlp": nn.ModuleDict( 175 | { 176 | "fc1": linear_cls( 177 | config.dim, config.ff_dim, dtype=dtype 178 | ), 179 | "fc2": linear_cls( 180 | config.ff_dim, config.dim, dtype=dtype 181 | ), 182 | } 183 | ), 184 | } 185 | ) 186 | for _ in range(config.n_layers) 187 | ] 188 | ), 189 | "post_ln": nn.LayerNorm(config.dim, dtype=dtype), 190 | "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype), 191 | } 192 | ) 193 | text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype)) 194 | text.register_buffer( 195 | "freqs_cis", 196 | precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context), 197 | persistent=False, 198 | ) 199 | 200 | return text 201 | -------------------------------------------------------------------------------- /moondream/torch/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0): 5 | """ 6 | Robust outlier detection for list of (x,y) tuples. 7 | Only requires numpy. 8 | 9 | Args: 10 | points_tuples: list of (x,y) tuples 11 | k_nearest: number of neighbors to consider 12 | threshold: multiplier for median distance 13 | 14 | Returns: 15 | list: filtered list of (x,y) tuples with outliers removed 16 | list: list of booleans indicating which points were kept (True = kept) 17 | """ 18 | points = np.array(points_tuples) 19 | n_points = len(points) 20 | 21 | # Calculate pairwise distances manually 22 | dist_matrix = np.zeros((n_points, n_points)) 23 | for i in range(n_points): 24 | for j in range(i + 1, n_points): 25 | # Euclidean distance between points i and j 26 | dist = np.sqrt(np.sum((points[i] - points[j]) ** 2)) 27 | dist_matrix[i, j] = dist 28 | dist_matrix[j, i] = dist 29 | 30 | # Get k nearest neighbors' distances 31 | k = min(k_nearest, n_points - 1) 32 | neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k] 33 | avg_neighbor_dist = np.mean(neighbor_distances, axis=1) 34 | 35 | # Calculate mask using median distance 36 | median_dist = np.median(avg_neighbor_dist) 37 | mask = avg_neighbor_dist <= threshold * median_dist 38 | 39 | # Return filtered tuples and mask 40 | filtered_tuples = [t for t, m in zip(points_tuples, mask) if m] 41 | return filtered_tuples 42 | -------------------------------------------------------------------------------- /moondream/torch/vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from typing import Union, Tuple 7 | from PIL import Image 8 | 9 | from .layers import attn, layer_norm, mlp 10 | from .image_crops import overlap_crop_image 11 | from .config import VisionConfig 12 | 13 | if torch.backends.mps.is_available(): 14 | # Non-divisible input sizes are not implemented on MPS device yet. 15 | # https://github.com/pytorch/pytorch/issues/96056 16 | def adaptive_avg_pool2d(input, output_size): 17 | return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps") 18 | 19 | else: 20 | adaptive_avg_pool2d = F.adaptive_avg_pool2d 21 | 22 | DeviceLike = Union[str, torch.device, int] 23 | 24 | 25 | def prepare_crops( 26 | image: Image.Image, config: VisionConfig, device: DeviceLike 27 | ) -> Tuple[torch.Tensor, Tuple[int, int]]: 28 | np_image = np.array(image.convert("RGB")) 29 | overlap_crops = overlap_crop_image( 30 | np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin 31 | ) 32 | all_crops = overlap_crops["crops"] 33 | all_crops = np.transpose(all_crops, (0, 3, 1, 2)) 34 | all_crops = ( 35 | torch.from_numpy(all_crops) 36 | .to(device=device, dtype=torch.bfloat16) 37 | .div_(255.0) 38 | .sub_(0.5) 39 | .div_(0.5) 40 | ) 41 | return all_crops, overlap_crops["tiling"] 42 | 43 | 44 | def create_patches(x, patch_size): 45 | # Original shape: [B, C, H, W] 46 | B, C, H, W = x.shape 47 | P1 = P2 = patch_size 48 | 49 | # Step 1: Split H and W dimensions into patches 50 | # [B, C, H/P1, P1, W/P2, P2] 51 | x = x.reshape(B, C, H // P1, P1, W // P2, P2) 52 | 53 | # Step 2: Rearrange dimensions to match target shape 54 | # [B, H/P1, W/P2, C, P1, P2] 55 | x = x.permute(0, 2, 4, 1, 3, 5) 56 | 57 | # Step 3: Combine dimensions to get final shape 58 | # [B, (H/P1)*(W/P2), C*P1*P2] 59 | x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2) 60 | 61 | return x 62 | 63 | 64 | def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig): 65 | x = create_patches(input_BCHW, config.enc_patch_size) 66 | 67 | x = w.patch_emb(x) 68 | x = x + w.pos_emb 69 | for block in w.blocks: 70 | x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads) 71 | x = x + mlp(layer_norm(x, block.ln2), block.mlp) 72 | x = layer_norm(x, w.post_ln) 73 | 74 | return x 75 | 76 | 77 | def vision_projection( 78 | global_features: torch.Tensor, 79 | reconstructed: torch.Tensor, 80 | w: nn.Module, 81 | config: VisionConfig, 82 | ): 83 | reconstructed = reconstructed.permute(2, 0, 1) 84 | reconstructed = adaptive_avg_pool2d( 85 | reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers) 86 | ) 87 | reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim) 88 | final_features = torch.cat([global_features, reconstructed], dim=-1) 89 | return mlp(final_features, w.proj_mlp) 90 | 91 | 92 | def build_vision_model(config: VisionConfig, dtype: torch.dtype): 93 | patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels 94 | grid_size = config.crop_size // config.enc_patch_size 95 | num_patches = grid_size * grid_size 96 | 97 | vision = nn.ModuleDict( 98 | { 99 | "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype), 100 | "blocks": nn.ModuleList( 101 | [ 102 | nn.ModuleDict( 103 | { 104 | "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype), 105 | "attn": nn.ModuleDict( 106 | { 107 | "qkv": nn.Linear( 108 | config.enc_dim, 3 * config.enc_dim, dtype=dtype 109 | ), 110 | "proj": nn.Linear( 111 | config.enc_dim, config.enc_dim, dtype=dtype 112 | ), 113 | } 114 | ), 115 | "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype), 116 | "mlp": nn.ModuleDict( 117 | { 118 | "fc1": nn.Linear( 119 | config.enc_dim, config.enc_ff_dim, dtype=dtype 120 | ), 121 | "fc2": nn.Linear( 122 | config.enc_ff_dim, config.enc_dim, dtype=dtype 123 | ), 124 | } 125 | ), 126 | } 127 | ) 128 | for _ in range(config.enc_n_layers) 129 | ] 130 | ), 131 | "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype), 132 | "proj_mlp": nn.ModuleDict( 133 | { 134 | "fc1": nn.Linear( 135 | config.enc_dim * 2, config.proj_inner_dim, dtype=dtype 136 | ), 137 | "fc2": nn.Linear( 138 | config.proj_inner_dim, config.proj_out_dim, dtype=dtype 139 | ), 140 | } 141 | ), 142 | } 143 | ) 144 | vision.pos_emb = nn.Parameter( 145 | torch.zeros(1, num_patches, config.enc_dim, dtype=dtype) 146 | ) 147 | return vision 148 | -------------------------------------------------------------------------------- /recipes/gaze-detection-video/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | env/ 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # Virtual Environment 25 | venv/ 26 | ENV/ 27 | 28 | # IDE 29 | .idea/ 30 | .vscode/ 31 | *.swp 32 | *.swo 33 | 34 | # Project specific 35 | # input/* 36 | # !input/.gitkeep 37 | # output/* 38 | # !output/.gitkeep 39 | # temp/* 40 | # !temp/.gitkeep 41 | 42 | # Model files 43 | *.pt 44 | *.pth 45 | *.ckpt 46 | 47 | # Logs 48 | *.log 49 | 50 | # OS specific 51 | .DS_Store 52 | Thumbs.db 53 | -------------------------------------------------------------------------------- /recipes/gaze-detection-video/README.md: -------------------------------------------------------------------------------- 1 | # Gaze Detection Video Processor 2 | 3 | > **⚠️ IMPORTANT:** This project currently uses Moondream 2B (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client 4 | > libraries once they become available for this version. 5 | 6 | ## Table of Contents 7 | 8 | - [Overview](#overview) 9 | - [Sample Output](#sample-output) 10 | - [Features](#features) 11 | - [Prerequisites](#prerequisites) 12 | - [Installation](#installation) 13 | - [Linux/macOS Installation](#linuxmacos-installation) 14 | - [Windows Installation](#windows-installation) 15 | - [Usage](#usage) 16 | - [Output](#output) 17 | - [Troubleshooting](#troubleshooting) 18 | - [Performance Notes](#performance-notes) 19 | - [Dependencies](#dependencies) 20 | - [Model Details](#model-details) 21 | - [License](#license) 22 | 23 | ## Overview 24 | 25 | This project uses the Moondream 2B model to detect faces and their gaze directions in videos. It processes videos frame by frame, visualizing face detections and gaze directions. 26 | 27 | ## Sample Output 28 | 29 | | Input Video | Processed Output | 30 | | :-----------------------------------: | :-----------------------------------------: | 31 | | ![Input Video](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-input-sample.gif?raw=true) | ![Processed Output](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-output-sample.gif?raw=true) | 32 | 33 | ## Features 34 | 35 | - Face detection in video frames 36 | - Gaze direction tracking 37 | - Real-time visualization with: 38 | - Colored bounding boxes for faces 39 | - Gradient lines showing gaze direction 40 | - Gaze target points 41 | - Supports multiple faces per frame 42 | - Processes all common video formats (.mp4, .avi, .mov, .mkv) 43 | - Uses Moondream 2 (2025-01-09 release) via Hugging Face Transformers 44 | - Note: Will be migrated to official client libraries in future updates 45 | - No authentication required 46 | 47 | ## Prerequisites 48 | 49 | 1. Python 3.8 or later 50 | 2. CUDA-capable GPU recommended (but CPU mode works too) 51 | 3. FFmpeg installed on your system 52 | 53 | ## Installation 54 | 55 | ### Linux/macOS Installation 56 | 57 | 1. Install system dependencies: 58 | 59 | ```bash 60 | # Ubuntu/Debian 61 | sudo apt-get update && sudo apt-get install -y libvips42 libvips-dev ffmpeg 62 | 63 | # CentOS/RHEL 64 | sudo yum install vips vips-devel ffmpeg 65 | 66 | # macOS 67 | brew install vips ffmpeg 68 | ``` 69 | 70 | 2. Clone and setup the project: 71 | ```bash 72 | git clone https://github.com/vikhyat/moondream.git 73 | cd moondream/recipes/gaze-detection-video 74 | python3 -m venv venv 75 | source venv/bin/activate 76 | pip install -r requirements.txt 77 | ``` 78 | 79 | ### Windows Installation 80 | 81 | Windows setup requires a few additional steps for proper GPU support and libvips installation. 82 | 83 | 1. Clone the repository: 84 | 85 | ```bash 86 | git clone [repository-url] 87 | cd moondream/recipes/gaze-detection-video 88 | ``` 89 | 90 | 2. Create and activate virtual environment: 91 | 92 | ```bash 93 | python -m venv venv 94 | .\venv\Scripts\activate 95 | ``` 96 | 97 | 3. Install PyTorch with CUDA support: 98 | 99 | ```bash 100 | # For NVIDIA GPUs 101 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 102 | ``` 103 | 104 | 4. Install libvips: Download the appropriate version based on your system architecture: 105 | 106 | | Architecture | VIPS Version to Download | 107 | | ------------ | ------------------------ | 108 | | 32-bit x86 | vips-dev-w32-all-8.16.0.zip | 109 | | 64-bit x64 | vips-dev-w64-all-8.16.0.zip | 110 | 111 | - Extract the ZIP file 112 | - Copy all DLL files from `vips-dev-8.16\bin` to either: 113 | - Your project's root directory (easier) OR 114 | - `C:\Windows\System32` (requires admin privileges) 115 | - Add to PATH: 116 | 1. Open System Properties → Advanced → Environment Variables 117 | 2. Under System Variables, find PATH 118 | 3. Add the full path to the `vips-dev-8.16\bin` directory 119 | 120 | 5. Install FFmpeg: 121 | 122 | - Download from https://ffmpeg.org/download.html#build-windows 123 | - Extract and add the `bin` folder to your system PATH (similar to step 4) or to the project root directory 124 | 125 | 6. Install other dependencies: 126 | ```bash 127 | pip install -r requirements.txt 128 | ``` 129 | 130 | ## Usage 131 | 132 | 1. Place your input videos in the `input` directory 133 | 134 | - Supported formats: .mp4, .avi, .mov, .mkv 135 | - The directory will be created automatically if it doesn't exist 136 | 137 | 2. Run the script: 138 | 139 | ```bash 140 | python gaze-detection-video.py 141 | ``` 142 | 143 | 3. The script will: 144 | - Process all videos in the input directory 145 | - Show progress bars for each video 146 | - Save processed videos to the `output` directory with prefix 'processed\_' 147 | 148 | ## Output 149 | 150 | - Processed videos are saved as `output/processed_[original_name].[ext]` 151 | - Each frame in the output video shows: 152 | - Colored boxes around detected faces 153 | - Lines indicating gaze direction 154 | - Points showing where each person is looking 155 | 156 | ## Troubleshooting 157 | 158 | 1. CUDA/GPU Issues: 159 | 160 | - Ensure you have CUDA installed for GPU support 161 | - The script will automatically fall back to CPU if no GPU is available 162 | 163 | 2. Memory Issues: 164 | 165 | - If processing large videos, ensure you have enough RAM 166 | - Consider reducing video resolution if needed 167 | 168 | 3. libvips Errors: 169 | 170 | - Make sure libvips is properly installed for your OS 171 | - Check system PATH includes libvips 172 | 173 | 4. Video Format Issues: 174 | - Ensure FFmpeg is installed and in your system PATH 175 | - Try converting problematic videos to MP4 format 176 | 177 | ## Performance Notes 178 | 179 | - GPU processing is significantly faster than CPU 180 | - Processing time depends on: 181 | - Video resolution 182 | - Number of faces per frame 183 | - Frame rate 184 | - Available computing power 185 | 186 | ## Dependencies 187 | 188 | - transformers (for Moondream 2 model access) 189 | - torch 190 | - opencv-python 191 | - pillow 192 | - matplotlib 193 | - numpy 194 | - tqdm 195 | - pyvips 196 | - accelerate 197 | - einops 198 | 199 | ## Model Details 200 | 201 | > **⚠️ IMPORTANT:** This project currently uses Moondream 2 (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client 202 | > libraries once they become available for this version. 203 | 204 | The model is loaded using: 205 | -------------------------------------------------------------------------------- /recipes/gaze-detection-video/input/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/recipes/gaze-detection-video/input/.gitkeep -------------------------------------------------------------------------------- /recipes/gaze-detection-video/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/recipes/gaze-detection-video/output/.gitkeep -------------------------------------------------------------------------------- /recipes/gaze-detection-video/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | transformers>=4.36.0 3 | opencv-python>=4.8.0 4 | pillow>=10.0.0 5 | matplotlib>=3.7.0 6 | numpy>=1.24.0 7 | tqdm>=4.65.0 8 | pyvips 9 | accelerate>=0.26.0 10 | einops -------------------------------------------------------------------------------- /recipes/gaze-detection-video/temp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/recipes/gaze-detection-video/temp/.gitkeep -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | *.dll 23 | 24 | # Virtual Environment 25 | venv/ 26 | env/ 27 | ENV/ 28 | .venv/ 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | *.swo 35 | 36 | # Project specific 37 | inputs/* 38 | outputs/* 39 | !inputs/.gitkeep 40 | !outputs/.gitkeep 41 | inputs/ 42 | outputs/ 43 | 44 | # Model files 45 | *.pth 46 | *.onnx 47 | *.pt 48 | 49 | # Logs 50 | *.log 51 | 52 | certificate.pem -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/README.md: -------------------------------------------------------------------------------- 1 | # Promptable Content Moderation with Moondream 2 | 3 | Welcome to the future of content moderation with Moondream 2B, a powerful and lightweight vision-language model that enables detection and moderation of video content using natural language prompts. 4 | 5 | [Try it now.](https://huggingface.co/spaces/moondream/content-moderation) 6 | 7 | ## Features 8 | 9 | - Content moderation through natural language prompts 10 | - Multiple visualization styles 11 | - Intelligent scene detection and tracking: 12 | - DeepSORT tracking with scene-aware reset 13 | - Persistent moderation across frames 14 | - Smart tracker reset at scene boundaries 15 | - Optional grid-based detection for improved accuracy on complex scenes 16 | - Frame-by-frame processing with IoU-based merging 17 | - Web-compatible output format 18 | - Test mode (process only first X seconds) 19 | - Advanced moderation analysis with multiple visualization plots 20 | 21 | ## Examples 22 | 23 | | Prompt | Output | 24 | |--------|-----------------| 25 | | "white cigarette" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-cig.gif) | 26 | | "gun" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-gu.gif) | 27 | | "confederate flag" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-conflag.gif) | 28 | 29 | ## Requirements 30 | 31 | ### Python Dependencies 32 | 33 | For Windows users, before installing other requirements, first install PyTorch with CUDA support: 34 | 35 | ```bash 36 | pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 37 | ``` 38 | 39 | Then install the remaining dependencies: 40 | 41 | ```bash 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ### System Requirements 46 | 47 | - FFmpeg (required for video processing) 48 | - libvips (required for image processing) 49 | 50 | Installation by platform: 51 | 52 | - Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` 53 | - macOS: `brew install ffmpeg libvips` 54 | - Windows: 55 | - Download FFmpeg from [ffmpeg.org](https://ffmpeg.org/download.html) 56 | - Follow [libvips Windows installation guide](https://docs.moondream.ai/quick-start) 57 | 58 | ## Installation 59 | 60 | 1. Clone this repository and create a new virtual environment: 61 | 62 | ```bash 63 | git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction 64 | python -m venv .venv 65 | source .venv/bin/activate # On Windows: .venv\Scripts\activate 66 | ``` 67 | 68 | 2. Install Python dependencies: 69 | 70 | ```bash 71 | pip install -r requirements.txt 72 | ``` 73 | 74 | 3. Install ffmpeg and libvips: 75 | - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` 76 | - On macOS: `brew install ffmpeg` 77 | - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html) 78 | 79 | > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start) 80 | 81 | ## Usage 82 | 83 | The easiest way to use this tool is through its web interface, which provides a user-friendly experience for video content moderation. 84 | 85 | ### Web Interface 86 | 87 | 1. Start the web interface: 88 | 89 | ```bash 90 | python app.py 91 | ``` 92 | 93 | 2. Open the provided URL in your browser (typically ) 94 | 95 | 3. Use the interface to: 96 | - Upload your video file 97 | - Specify content to moderate (e.g., "face", "cigarette", "gun") 98 | - Choose redaction style (default: obfuscated-pixel) 99 | - OPTIONAL: Configure advanced settings 100 | - Processing speed/quality 101 | - Grid size for detection 102 | - Test mode for quick validation (default: on, 3 seconds) 103 | - Process the video and download results 104 | - Analyze detection patterns with visualization tools 105 | 106 | ## Output Files 107 | 108 | The tool generates two types of output files in the `outputs` directory: 109 | 110 | 1. Processed Videos: 111 | - Format: `[style]_[content_type]_[original_filename].mp4` 112 | - Example: `censor_inappropriate_video.mp4` 113 | 114 | 2. Detection Data: 115 | - Format: `[style]_[content_type]_[original_filename]_detections.json` 116 | - Contains frame-by-frame detection information 117 | - Used for visualization and analysis 118 | 119 | ## Technical Details 120 | 121 | ### Scene Detection and Tracking 122 | 123 | The tool uses advanced scene detection and object tracking: 124 | 125 | 1. Scene Detection: 126 | - Powered by PySceneDetect's ContentDetector 127 | - Automatically identifies scene changes in videos 128 | - Configurable detection threshold (default: 30.0) 129 | - Helps maintain tracking accuracy across scene boundaries 130 | 131 | 2. Object Tracking: 132 | - DeepSORT tracking for consistent object identification 133 | - Automatic tracker reset at scene changes 134 | - Maintains object identity within scenes 135 | - Prevents tracking errors across scene boundaries 136 | 137 | 3. Integration Benefits: 138 | - More accurate object tracking 139 | - Better handling of scene transitions 140 | - Reduced false positives in tracking 141 | - Improved tracking consistency 142 | 143 | ## Best Practices 144 | 145 | - Use test mode for initial configuration 146 | - Enable grid-based detection for complex scenes 147 | - Choose appropriate redaction style based on content type: 148 | - Censor: Complete content blocking 149 | - Blur styles: Less intrusive moderation 150 | - Bounding Box: Content review and analysis 151 | - Monitor system resources during processing 152 | - Use appropriate processing quality settings based on your needs 153 | 154 | ## Notes 155 | 156 | - Processing time depends on video length, resolution, GPU availability, and chosen settings 157 | - GPU is strongly recommended for faster processing 158 | - Grid-based detection increases accuracy but requires more processing time (each grid cell is processed independently) 159 | - Test mode processes only first X seconds (default: 3 seconds) for quick validation 160 | -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/deep_sort_integration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from deep_sort_realtime.deepsort_tracker import DeepSort 4 | from datetime import datetime 5 | 6 | 7 | class DeepSORTTracker: 8 | def __init__(self, max_age=5): 9 | """Initialize DeepSORT tracker.""" 10 | self.max_age = max_age 11 | self.tracker = self._create_tracker() 12 | 13 | def _create_tracker(self): 14 | """Create a new instance of DeepSort tracker.""" 15 | return DeepSort( 16 | max_age=self.max_age, 17 | embedder="mobilenet", # Using default MobileNetV2 embedder 18 | today=datetime.now().date(), # For track naming and daily ID reset 19 | ) 20 | 21 | def reset(self): 22 | """Reset the tracker state by creating a new instance.""" 23 | print("Resetting DeepSORT tracker...") 24 | self.tracker = self._create_tracker() 25 | 26 | def update(self, frame, detections): 27 | """Update tracking with new detections. 28 | 29 | Args: 30 | frame: Current video frame (numpy array) 31 | detections: List of (box, keyword) tuples where box is [x1, y1, x2, y2] normalized 32 | 33 | Returns: 34 | List of (box, keyword, track_id) tuples 35 | """ 36 | if not detections: 37 | return [] 38 | 39 | height, width = frame.shape[:2] 40 | 41 | # Convert normalized coordinates to absolute and format detections 42 | detection_list = [] 43 | for box, keyword in detections: 44 | x1 = int(box[0] * width) 45 | y1 = int(box[1] * height) 46 | x2 = int(box[2] * width) 47 | y2 = int(box[3] * height) 48 | w = x2 - x1 49 | h = y2 - y1 50 | 51 | # Format: ([left,top,w,h], confidence, detection_class) 52 | detection_list.append(([x1, y1, w, h], 1.0, keyword)) 53 | 54 | # Update tracker 55 | tracks = self.tracker.update_tracks(detection_list, frame=frame) 56 | 57 | # Convert back to normalized coordinates with track IDs 58 | tracked_objects = [] 59 | for track in tracks: 60 | if not track.is_confirmed(): 61 | continue 62 | 63 | ltrb = track.to_ltrb() # Get [left,top,right,bottom] format 64 | x1, y1, x2, y2 = ltrb 65 | 66 | # Normalize coordinates 67 | x1 = max(0.0, min(1.0, x1 / width)) 68 | y1 = max(0.0, min(1.0, y1 / height)) 69 | x2 = max(0.0, min(1.0, x2 / width)) 70 | y2 = max(0.0, min(1.0, y2 / height)) 71 | 72 | tracked_objects.append(([x1, y1, x2, y2], track.det_class, track.track_id)) 73 | 74 | return tracked_objects 75 | -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/packages.txt: -------------------------------------------------------------------------------- 1 | libvips 2 | ffmpeg -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/persistence.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def save_detection_data(data, output_file): 6 | """ 7 | Saves the detection data to a JSON file. 8 | 9 | Args: 10 | data (dict): The complete detection data structure. 11 | output_file (str): Path to the output JSON file. 12 | """ 13 | try: 14 | # Create directory if it doesn't exist 15 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 16 | 17 | with open(output_file, "w") as f: 18 | json.dump(data, f, indent=4) 19 | print(f"Detection data saved to {output_file}") 20 | return True 21 | except Exception as e: 22 | print(f"Error saving data: {str(e)}") 23 | return False 24 | 25 | 26 | def load_detection_data(input_file): 27 | """ 28 | Loads the detection data from a JSON file. 29 | 30 | Args: 31 | input_file (str): Path to the JSON file. 32 | 33 | Returns: 34 | dict: The loaded detection data, or None if there was an error. 35 | """ 36 | try: 37 | with open(input_file, "r") as f: 38 | return json.load(f) 39 | except Exception as e: 40 | print(f"Error loading data: {str(e)}") 41 | return None 42 | -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/requirements.txt: -------------------------------------------------------------------------------- 1 | gradio>=4.0.0 2 | torch>=2.0.0 3 | # if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 4 | transformers>=4.36.0 5 | opencv-python>=4.8.0 6 | pillow>=10.0.0 7 | numpy>=1.24.0 8 | tqdm>=4.66.0 9 | ffmpeg-python 10 | einops 11 | pyvips-binary 12 | pyvips 13 | accelerate 14 | # for spaces 15 | --extra-index-url https://download.pytorch.org/whl/cu113 16 | spaces 17 | # SAM dependencies 18 | torchvision>=0.20.1 19 | matplotlib>=3.7.0 20 | pandas>=2.0.0 21 | plotly 22 | # DeepSORT dependencies 23 | deep-sort-realtime>=1.3.2 24 | scikit-learn # Required for deep-sort-realtime 25 | # Scene detection dependencies (for intelligent scene-aware tracking) 26 | scenedetect[opencv]>=0.6.2 # Provides scene change detection capabilities -------------------------------------------------------------------------------- /recipes/promptable-content-moderation/visualization.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | from persistence import load_detection_data 4 | import argparse 5 | 6 | 7 | def visualize_detections(json_path): 8 | """ 9 | Visualize detection data from a JSON file. 10 | 11 | Args: 12 | json_path (str): Path to the JSON file containing detection data. 13 | """ 14 | # Load the persisted JSON data 15 | data = load_detection_data(json_path) 16 | if not data: 17 | return 18 | 19 | # Convert the frame detections to a DataFrame 20 | rows = [] 21 | for frame_data in data["frame_detections"]: 22 | frame = frame_data["frame"] 23 | timestamp = frame_data["timestamp"] 24 | for obj in frame_data["objects"]: 25 | rows.append( 26 | { 27 | "frame": frame, 28 | "timestamp": timestamp, 29 | "keyword": obj["keyword"], 30 | "x1": obj["bbox"][0], 31 | "y1": obj["bbox"][1], 32 | "x2": obj["bbox"][2], 33 | "y2": obj["bbox"][3], 34 | "area": (obj["bbox"][2] - obj["bbox"][0]) 35 | * (obj["bbox"][3] - obj["bbox"][1]), 36 | } 37 | ) 38 | 39 | if not rows: 40 | print("No detections found in the data") 41 | return 42 | 43 | df = pd.DataFrame(rows) 44 | 45 | # Create a figure with multiple subplots 46 | fig = plt.figure(figsize=(15, 10)) 47 | 48 | # Plot 1: Number of detections per frame 49 | plt.subplot(2, 2, 1) 50 | detections_per_frame = df.groupby("frame").size() 51 | plt.plot(detections_per_frame.index, detections_per_frame.values) 52 | plt.xlabel("Frame") 53 | plt.ylabel("Number of Detections") 54 | plt.title("Detections Per Frame") 55 | 56 | # Plot 2: Distribution of detection areas 57 | plt.subplot(2, 2, 2) 58 | df["area"].hist(bins=30) 59 | plt.xlabel("Detection Area (normalized)") 60 | plt.ylabel("Count") 61 | plt.title("Distribution of Detection Areas") 62 | 63 | # Plot 3: Average detection area over time 64 | plt.subplot(2, 2, 3) 65 | avg_area = df.groupby("frame")["area"].mean() 66 | plt.plot(avg_area.index, avg_area.values) 67 | plt.xlabel("Frame") 68 | plt.ylabel("Average Detection Area") 69 | plt.title("Average Detection Area Over Time") 70 | 71 | # Plot 4: Heatmap of detection centers 72 | plt.subplot(2, 2, 4) 73 | df["center_x"] = (df["x1"] + df["x2"]) / 2 74 | df["center_y"] = (df["y1"] + df["y2"]) / 2 75 | plt.hist2d(df["center_x"], df["center_y"], bins=30) 76 | plt.colorbar() 77 | plt.xlabel("X Position") 78 | plt.ylabel("Y Position") 79 | plt.title("Detection Center Heatmap") 80 | 81 | # Adjust layout and display 82 | plt.tight_layout() 83 | plt.show() 84 | 85 | # Print summary statistics 86 | print("\nSummary Statistics:") 87 | print(f"Total frames analyzed: {len(data['frame_detections'])}") 88 | print(f"Total detections: {len(df)}") 89 | print( 90 | f"Average detections per frame: {len(df) / len(data['frame_detections']):.2f}" 91 | ) 92 | print(f"\nVideo metadata:") 93 | for key, value in data["video_metadata"].items(): 94 | print(f"{key}: {value}") 95 | 96 | 97 | def main(): 98 | parser = argparse.ArgumentParser(description="Visualize object detection data") 99 | parser.add_argument( 100 | "json_file", help="Path to the JSON file containing detection data" 101 | ) 102 | args = parser.parse_args() 103 | 104 | visualize_detections(args.json_file) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /recipes/promptable-video-redaction/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environment 24 | venv/ 25 | env/ 26 | ENV/ 27 | .venv/ 28 | 29 | # IDE 30 | .idea/ 31 | .vscode/ 32 | *.swp 33 | *.swo 34 | 35 | # Project specific 36 | inputs/* 37 | outputs/* 38 | !inputs/.gitkeep 39 | !outputs/.gitkeep 40 | inputs/ 41 | outputs/ 42 | 43 | # Model files 44 | *.pth 45 | *.onnx 46 | *.pt 47 | 48 | # Logs 49 | *.log 50 | 51 | certificate.pem -------------------------------------------------------------------------------- /recipes/promptable-video-redaction/README.md: -------------------------------------------------------------------------------- 1 | # Promptable Video Redaction with Moondream 2 | 3 | This tool uses Moondream 2B, a powerful yet lightweight vision-language model, to detect and redact objects from videos. Moondream can recognize a wide variety of objects, people, 4 | text, and more with high accuracy while being much smaller than traditional models. 5 | 6 | [Try it now.](https://huggingface.co/spaces/moondream/promptable-video-redaction) 7 | 8 | ## About Moondream 9 | 10 | Moondream is a tiny yet powerful vision-language model that can analyze images and answer questions about them. It's designed to be lightweight and efficient while maintaining high 11 | accuracy. Some key features: 12 | 13 | - Only 2B parameters 14 | - Fast inference with minimal resource requirements 15 | - Supports CPU and GPU execution 16 | - Open source and free to use 17 | - Can detect almost anything you can describe in natural language 18 | 19 | Links: 20 | 21 | - [GitHub Repository](https://github.com/vikhyat/moondream) 22 | - [Hugging Face](https://huggingface.co/vikhyatk/moondream2) 23 | - [Build with Moondream](http://docs.moondream.ai/) 24 | 25 | ## Features 26 | 27 | - Real-time object detection in videos using Moondream 28 | - Multiple visualization styles: 29 | - Censor: Black boxes over detected objects 30 | - Bounding Box: Traditional bounding boxes with labels 31 | - Hitmarker: Call of Duty style crosshair markers 32 | - Optional grid-based detection for improved accuracy 33 | - Flexible object type detection using natural language 34 | - Frame-by-frame processing with IoU-based merging 35 | - Batch processing of multiple videos 36 | - Web-compatible output format 37 | - User-friendly web interface 38 | - Command-line interface for automation 39 | 40 | ## Requirements 41 | 42 | - Python 3.8+ 43 | - OpenCV (cv2) 44 | - PyTorch 45 | - Transformers 46 | - Pillow (PIL) 47 | - tqdm 48 | - ffmpeg 49 | - numpy 50 | - gradio (for web interface) 51 | 52 | ## Installation 53 | 54 | 1. Clone this repository and create a new virtual environment 55 | 56 | ```bash 57 | git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction 58 | python -m venv .venv 59 | source .venv/bin/activate 60 | ``` 61 | 62 | 2. Install the required packages: 63 | 64 | ```bash 65 | pip install -r requirements.txt 66 | ``` 67 | 68 | 3. Install ffmpeg: 69 | - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` 70 | - On macOS: `brew install ffmpeg` 71 | - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html) 72 | > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start) 73 | 74 | ## Usage 75 | 76 | ### Web Interface 77 | 78 | 1. Start the web interface: 79 | 80 | ```bash 81 | python app.py 82 | ``` 83 | 84 | 2. Open the provided URL in your browser 85 | 86 | 3. Use the interface to: 87 | - Upload your video 88 | - Specify what to censor (e.g., face, logo, text) 89 | - Adjust processing speed and quality 90 | - Configure grid size for detection 91 | - Process and download the censored video 92 | 93 | ### Command Line Interface 94 | 95 | 1. Create an `inputs` directory in the same folder as the script: 96 | 97 | ```bash 98 | mkdir inputs 99 | ``` 100 | 101 | 2. Place your video files in the `inputs` directory. Supported formats: 102 | 103 | - .mp4 104 | - .avi 105 | - .mov 106 | - .mkv 107 | - .webm 108 | 109 | 3. Run the script: 110 | 111 | ```bash 112 | python main.py 113 | ``` 114 | 115 | ### Optional Arguments: 116 | 117 | - `--test`: Process only first 3 seconds of each video (useful for testing detection settings) 118 | 119 | ```bash 120 | python main.py --test 121 | ``` 122 | 123 | - `--preset`: Choose FFmpeg encoding preset (affects output quality vs. speed) 124 | 125 | ```bash 126 | python main.py --preset ultrafast # Fastest, lower quality 127 | python main.py --preset veryslow # Slowest, highest quality 128 | ``` 129 | 130 | - `--detect`: Specify what object type to detect (using natural language) 131 | 132 | ```bash 133 | python main.py --detect person # Detect people 134 | python main.py --detect "red car" # Detect red cars 135 | python main.py --detect "person wearing a hat" # Detect people with hats 136 | ``` 137 | 138 | - `--box-style`: Choose visualization style 139 | 140 | ```bash 141 | python main.py --box-style censor # Black boxes (default) 142 | python main.py --box-style bounding-box # Bounding box-style boxes with labels 143 | python main.py --box-style hitmarker # COD-style hitmarkers 144 | ``` 145 | 146 | - `--rows` and `--cols`: Enable grid-based detection by splitting frames 147 | 148 | ```bash 149 | python main.py --rows 2 --cols 2 # Split each frame into 2x2 grid 150 | python main.py --rows 3 --cols 3 # Split each frame into 3x3 grid 151 | ``` 152 | 153 | You can combine arguments: 154 | 155 | ```bash 156 | python main.py --detect "person wearing sunglasses" --box-style bounding-box --test --preset "fast" --rows 2 --cols 2 157 | ``` 158 | 159 | ### Visualization Styles 160 | 161 | The tool supports three different visualization styles for detected objects: 162 | 163 | 1. **Censor** (default) 164 | 165 | - Places solid black rectangles over detected objects 166 | - Best for privacy and content moderation 167 | - Completely obscures the detected region 168 | 169 | 2. **Bounding Box** 170 | 171 | - Traditional object detection style 172 | - Red bounding box around detected objects 173 | - Label showing object type above the box 174 | - Good for analysis and debugging 175 | 176 | 3. **Hitmarker** 177 | - Call of Duty inspired visualization 178 | - White crosshair marker at center of detected objects 179 | - Small label above the marker 180 | - Stylistic choice for gaming-inspired visualization 181 | 182 | Choose the style that best fits your use case using the `--box-style` argument. 183 | 184 | ## Output 185 | 186 | Processed videos will be saved in the `outputs` directory with the format: `[style]_[object_type]_[original_filename].mp4` 187 | 188 | For example: 189 | 190 | - `censor_face_video.mp4` 191 | - `bounding-box_person_video.mp4` 192 | - `hitmarker_car_video.mp4` 193 | 194 | The output videos will include: 195 | 196 | - Original video content 197 | - Selected visualization style for detected objects 198 | - Web-compatible H.264 encoding 199 | 200 | ## Notes 201 | 202 | - Processing time depends on video length, grid size, and GPU availability 203 | - GPU is strongly recommended for faster processing 204 | - Requires sufficient disk space for temporary files 205 | - Detection quality varies based on video quality and Moondream's ability to recognize the specified object 206 | - Grid-based detection impacts performance significantly - use only when needed 207 | - Web interface shows progress updates and errors 208 | - Choose visualization style based on your use case 209 | - Moondream can detect almost anything you can describe in natural language 210 | -------------------------------------------------------------------------------- /recipes/promptable-video-redaction/packages.txt: -------------------------------------------------------------------------------- 1 | libvips 2 | ffmpeg -------------------------------------------------------------------------------- /recipes/promptable-video-redaction/requirements.txt: -------------------------------------------------------------------------------- 1 | gradio>=4.0.0 2 | torch 3 | transformers 4 | opencv-python 5 | pillow 6 | numpy 7 | tqdm 8 | ffmpeg-python 9 | einops 10 | pyvips 11 | accelerate -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.32.1 2 | huggingface-hub==0.24.0 3 | Pillow==10.4.0 4 | pyvips-binary==8.16.0 5 | pyvips==2.2.3 6 | torch==2.5.1 7 | transformers==4.44.0 8 | gradio==4.38.1 9 | 10 | # Needed for running evals 11 | datasets==3.2.0 12 | editdistance==0.8.1 13 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from queue import Queue 3 | from threading import Thread 4 | 5 | import torch 6 | from PIL import Image 7 | from transformers import AutoTokenizer, TextIteratorStreamer 8 | 9 | from moondream.hf import LATEST_REVISION, Moondream, detect_device 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--image", type=str, required=True) 14 | parser.add_argument("--prompt", type=str, required=False) 15 | parser.add_argument("--caption", action="store_true") 16 | parser.add_argument("--cpu", action="store_true") 17 | args = parser.parse_args() 18 | 19 | if args.cpu: 20 | device = torch.device("cpu") 21 | dtype = torch.float32 22 | else: 23 | device, dtype = detect_device() 24 | if device != torch.device("cpu"): 25 | print("Using device:", device) 26 | print("If you run into issues, pass the `--cpu` flag to this script.") 27 | print() 28 | 29 | image_path = args.image 30 | prompt = args.prompt 31 | 32 | model_id = "vikhyatk/moondream2" 33 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) 34 | moondream = Moondream.from_pretrained( 35 | model_id, 36 | revision=LATEST_REVISION, 37 | torch_dtype=dtype, 38 | ).to(device=device) 39 | moondream.eval() 40 | 41 | image = Image.open(image_path) 42 | 43 | if args.caption: 44 | print(moondream.caption(images=[image], tokenizer=tokenizer)[0]) 45 | else: 46 | image_embeds = moondream.encode_image(image) 47 | 48 | if prompt is None: 49 | chat_history = "" 50 | 51 | while True: 52 | question = input("> ") 53 | 54 | result_queue = Queue() 55 | 56 | streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) 57 | 58 | # Separate direct arguments from keyword arguments 59 | thread_args = (image_embeds, question, tokenizer, chat_history) 60 | thread_kwargs = {"streamer": streamer, "result_queue": result_queue} 61 | 62 | thread = Thread( 63 | target=moondream.answer_question, 64 | args=thread_args, 65 | kwargs=thread_kwargs, 66 | ) 67 | thread.start() 68 | 69 | buffer = "" 70 | for new_text in streamer: 71 | buffer += new_text 72 | if not new_text.endswith("<") and not new_text.endswith("END"): 73 | print(buffer, end="", flush=True) 74 | buffer = "" 75 | print(buffer) 76 | 77 | thread.join() 78 | 79 | answer = result_queue.get() 80 | chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n" 81 | else: 82 | print(">", prompt) 83 | answer = moondream.answer_question(image_embeds, prompt, tokenizer) 84 | print(answer) 85 | -------------------------------------------------------------------------------- /tests/test_image_crops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from moondream.torch.image_crops import overlap_crop_image, reconstruct_from_crops 4 | 5 | 6 | def test_overlap_crop_basic(): 7 | # Create a test image 8 | test_image = np.zeros((800, 600, 3), dtype=np.uint8) 9 | # Add a recognizable pattern - white rectangle in the middle 10 | test_image[300:500, 200:400] = 255 11 | 12 | result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12) 13 | 14 | # Check basic properties 15 | assert result["crops"][0].shape == (378, 378, 3) 16 | assert len(result["crops"]) > 1 17 | assert all(crop.shape == (378, 378, 3) for crop in result["crops"]) 18 | assert len(result["tiling"]) == 2 19 | 20 | 21 | def test_overlap_crop_small_image(): 22 | # Test with image smaller than crop size 23 | test_image = np.zeros((300, 200, 3), dtype=np.uint8) 24 | result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12) 25 | 26 | # Should still produce valid output 27 | assert result["crops"][0].shape == (378, 378, 3) 28 | assert len(result["crops"]) == 2 29 | assert result["tiling"] == (1, 1) 30 | 31 | 32 | def test_reconstruction(): 33 | # Create a test image 34 | test_image = np.zeros((800, 600, 3), dtype=np.uint8) 35 | # Add a recognizable pattern 36 | test_image[300:500, 200:400] = 255 37 | 38 | # Crop and reconstruct 39 | result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12) 40 | crops_tensor = [torch.from_numpy(crop) for crop in result["crops"][1:]] 41 | reconstructed = reconstruct_from_crops( 42 | crops_tensor, result["tiling"], overlap_margin=4 43 | ) 44 | 45 | # Convert back to numpy for comparison 46 | reconstructed_np = reconstructed.numpy() 47 | 48 | # The reconstructed image should be similar to the input 49 | # We can't expect exact equality due to resizing operations 50 | # but the white rectangle should still be visible in the middle 51 | center_reconstructed = reconstructed_np[ 52 | reconstructed_np.shape[0] // 2 - 100 : reconstructed_np.shape[0] // 2 + 100, 53 | reconstructed_np.shape[1] // 2 - 100 : reconstructed_np.shape[1] // 2 + 100, 54 | ].mean() 55 | 56 | # The center region should be significantly brighter than the edges 57 | assert center_reconstructed > reconstructed_np[:100, :100].mean() + 100 58 | -------------------------------------------------------------------------------- /webcam_gradio_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from threading import Thread 4 | 5 | import gradio as gr 6 | import torch 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer 8 | 9 | from moondream.hf import LATEST_REVISION, detect_device 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--cpu", action="store_true") 13 | args = parser.parse_args() 14 | 15 | if args.cpu: 16 | device = torch.device("cpu") 17 | dtype = torch.float32 18 | else: 19 | device, dtype = detect_device() 20 | if device != torch.device("cpu"): 21 | print("Using device:", device) 22 | print("If you run into issues, pass the `--cpu` flag to this script.") 23 | print() 24 | 25 | model_id = "vikhyatk/moondream2" 26 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) 27 | moondream = AutoModelForCausalLM.from_pretrained( 28 | model_id, trust_remote_code=True, revision=LATEST_REVISION 29 | ).to(device=device, dtype=dtype) 30 | moondream.eval() 31 | 32 | 33 | def answer_question(img, prompt): 34 | image_embeds = moondream.encode_image(img) 35 | streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) 36 | thread = Thread( 37 | target=moondream.answer_question, 38 | kwargs={ 39 | "image_embeds": image_embeds, 40 | "question": prompt, 41 | "tokenizer": tokenizer, 42 | "streamer": streamer, 43 | }, 44 | ) 45 | thread.start() 46 | 47 | buffer = "" 48 | for new_text in streamer: 49 | buffer += new_text 50 | yield buffer 51 | 52 | 53 | with gr.Blocks() as demo: 54 | gr.Markdown("# 🌔 moondream") 55 | 56 | gr.HTML( 57 | """ 58 | 64 | """ 65 | ) 66 | 67 | with gr.Row(): 68 | prompt = gr.Textbox( 69 | label="Prompt", 70 | value="What's going on? Respond with a single sentence.", 71 | interactive=True, 72 | ) 73 | with gr.Row(): 74 | img = gr.Image(type="pil", label="Upload an Image", streaming=True) 75 | output = gr.Markdown(elem_classes=["md_output"]) 76 | 77 | latest_img = None 78 | latest_prompt = prompt.value 79 | 80 | @img.change(inputs=[img]) 81 | def img_change(img): 82 | global latest_img 83 | latest_img = img 84 | 85 | @prompt.change(inputs=[prompt]) 86 | def prompt_change(prompt): 87 | global latest_prompt 88 | latest_prompt = prompt 89 | 90 | @demo.load(outputs=[output]) 91 | def live_video(): 92 | while True: 93 | if latest_img is None: 94 | time.sleep(0.1) 95 | else: 96 | for text in answer_question(latest_img, latest_prompt): 97 | if len(text) > 0: 98 | yield text 99 | 100 | 101 | demo.queue().launch(debug=True) 102 | --------------------------------------------------------------------------------