├── .github
└── workflows
│ ├── node-client-tests.yml
│ ├── pylint.yml
│ ├── python-client-tests.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── demo-1.jpg
└── demo-2.jpg
├── batch_generate_example.py
├── clients
├── node
│ ├── .gitignore
│ ├── .npmignore
│ ├── README.MD
│ ├── jest.config.js
│ ├── package-lock.json
│ ├── package.json
│ ├── src
│ │ ├── __tests__
│ │ │ ├── moondream.integration.test.ts
│ │ │ ├── moondream.test.ts
│ │ │ └── setup.ts
│ │ ├── moondream.ts
│ │ └── types.ts
│ └── tsconfig.json
└── python
│ ├── README.md
│ ├── moondream
│ ├── __init__.py
│ ├── cli.py
│ ├── cloud_vl.py
│ ├── moonfile.py
│ ├── onnx_vl.py
│ ├── preprocess.py
│ ├── server.py
│ ├── torch_vl.py
│ ├── types.py
│ └── version.py
│ ├── pyproject.toml
│ ├── scripts
│ ├── test.py
│ ├── test_cloud_parity.py
│ └── test_local_server.py
│ └── tests
│ ├── test_api_inference.py
│ ├── test_local_inference.py
│ └── test_local_torch_inference.py
├── gradio_demo.py
├── moondream
├── __init__.py
├── config
│ ├── config_md05.json
│ └── config_md2.json
├── eval
│ ├── chartqa.py
│ ├── coco_map.py
│ ├── countbenchqa.py
│ ├── docvqa.py
│ ├── eval_all.py
│ ├── gazefollow.py
│ ├── mmstar.py
│ ├── naturalbench.py
│ ├── pope.py
│ ├── realworldqa.py
│ ├── tallyqa.py
│ ├── textvqa.py
│ ├── utils.py
│ └── waste_detection.py
├── finetune
│ ├── README.md
│ ├── __init__.py
│ ├── finetune_region.py
│ └── finetune_text.py
└── torch
│ ├── config.py
│ ├── hf_moondream.py
│ ├── hf_release.py
│ ├── image_crops.py
│ ├── layers.py
│ ├── moondream.py
│ ├── region.py
│ ├── rope.py
│ ├── sample.py
│ ├── text.py
│ ├── utils.py
│ ├── vision.py
│ └── weights.py
├── notebooks
└── RepEng.ipynb
├── recipes
├── gaze-detection-video
│ ├── .gitignore
│ ├── README.md
│ ├── gaze-detection-video.py
│ ├── input
│ │ └── .gitkeep
│ ├── output
│ │ └── .gitkeep
│ ├── requirements.txt
│ └── temp
│ │ └── .gitkeep
├── promptable-content-moderation
│ ├── .gitignore
│ ├── README.md
│ ├── app.py
│ ├── deep_sort_integration.py
│ ├── main.py
│ ├── packages.txt
│ ├── persistence.py
│ ├── requirements.txt
│ ├── video_visualization.py
│ └── visualization.py
└── promptable-video-redaction
│ ├── .gitignore
│ ├── README.md
│ ├── app.py
│ ├── main.py
│ ├── packages.txt
│ └── requirements.txt
├── requirements.txt
├── sample.py
├── tests
└── test_image_crops.py
└── webcam_gradio_demo.py
/.github/workflows/node-client-tests.yml:
--------------------------------------------------------------------------------
1 | name: Node.js Client Tests
2 |
3 | on:
4 | # Only run on PRs to avoid duplicate runs
5 | pull_request:
6 | paths:
7 | - 'clients/node/**'
8 | - '.github/workflows/node-client-tests.yml'
9 |
10 | permissions:
11 | contents: read
12 | jobs:
13 | test:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | node-version: [18.x, 20.x]
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 |
22 | - name: Set up Node.js ${{ matrix.node-version }}
23 | uses: actions/setup-node@v4
24 | with:
25 | node-version: ${{ matrix.node-version }}
26 | cache: 'npm'
27 | cache-dependency-path: clients/node/package-lock.json
28 |
29 | - name: Install dependencies
30 | working-directory: ./clients/node
31 | run: npm ci
32 |
33 | - name: Run unit tests
34 | working-directory: ./clients/node
35 | run: npm test -- src/__tests__/moondream.test.ts
36 |
37 | - name: Run integration tests
38 | working-directory: ./clients/node
39 | env:
40 | MOONDREAM_API_KEY: ${{ secrets.MOONDREAM_API_KEY }}
41 | run: npm test -- src/__tests__/moondream.integration.test.ts
42 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | permissions:
10 | contents: read
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 | permissions:
15 | contents: read
16 | strategy:
17 | matrix:
18 | python-version: ["3.12"] # Run lint checks only on latest Python version
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install autoflake black
29 | - name: Checking for unused imports
30 | run: |
31 | autoflake -c -r .
32 | - name: Checking code style
33 | run: |
34 | black --check .
35 |
--------------------------------------------------------------------------------
/.github/workflows/python-client-tests.yml:
--------------------------------------------------------------------------------
1 | name: Python Client Tests
2 |
3 | on:
4 | # Only run on PRs to avoid duplicate runs
5 | pull_request:
6 | paths:
7 | - 'clients/python/**'
8 | - '.github/workflows/python-client-tests.yml'
9 |
10 | permissions:
11 | contents: read
12 | jobs:
13 | test:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ['3.10', '3.11', '3.12']
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 |
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | cache: 'pip'
27 |
28 | - name: Install Poetry
29 | run: |
30 | curl -sSL https://install.python-poetry.org | python3 -
31 |
32 | - name: Cache Poetry dependencies
33 | uses: actions/cache@v3
34 | with:
35 | path: ~/.cache/pypoetry
36 | key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
37 | restore-keys: |
38 | ${{ runner.os }}-poetry-${{ matrix.python-version }}-
39 |
40 | - name: Install dependencies
41 | working-directory: ./clients/python
42 | run: |
43 | poetry install --all-extras
44 |
45 | - name: Format code
46 | working-directory: ./clients/python
47 | run: |
48 | poetry run pip install black
49 | poetry run black tests/test_local_inference.py --check
50 |
51 | - name: Run tests
52 | working-directory: ./clients/python
53 | env:
54 | MOONDREAM_API_KEY: ${{ secrets.MOONDREAM_API_KEY }}
55 | run: |
56 | poetry run pip install pytest pytest-asyncio
57 | poetry run pytest tests/test_api_inference.py -v
58 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.9", "3.10", "3.11", "3.12"]
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install pytest
26 | pip install -r requirements.txt
27 | - name: Run tests
28 | run: |
29 | python -m pytest tests/test_image_crops.py -v
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | __pycache__
3 | checkpoints
4 | data
5 | /pyproject.toml
6 | poetry.lock
7 | dist
8 | clients/python/moondream/torch
9 | wandb/
10 | moondream_finetune.safetensors
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🌔 moondream
2 |
3 | a tiny vision language model that kicks ass and runs anywhere
4 |
5 | [Website](https://moondream.ai/) | [Demo](https://moondream.ai/playground)
6 |
7 | ## Examples
8 |
9 | | Image | Example |
10 | | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
11 | |  | **What is the girl doing?** The girl is sitting at a table and eating a large hamburger. **What color is the girl's hair?** The girl's hair is white. |
12 | |  | **What is this?** This is a computer server rack, which is a device used to store and manage multiple computer servers. The rack is filled with various computer servers, each with their own dedicated space and power supply. The servers are connected to the rack via multiple cables, indicating that they are part of a larger system. The rack is placed on a carpeted floor, and there is a couch nearby, suggesting that the setup is in a living or entertainment area. **What is behind the stand?** Behind the stand, there is a brick wall. |
13 |
14 | ## About
15 |
16 | Moondream is a highly efficient open-source vision language model that combines powerful image understanding capabilities with a remarkably small footprint. It's designed to be versatile and accessible, capable of running on a wide range of devices and platforms.
17 |
18 | The project offers two model variants:
19 |
20 | - **Moondream 2B**: The primary model with 2 billion parameters, offering robust performance for general-purpose image understanding tasks including captioning, visual question answering, and object detection.
21 | - **Moondream 0.5B**: A compact 500 million parameter model specifically optimized as a distillation target for edge devices, enabling efficient deployment on resource-constrained hardware while maintaining impressive capabilities.
22 |
23 | ## How to use
24 |
25 | Moondream can be run locally, or in the cloud. Please refer to the [Getting Started](https://moondream.ai/c/docs/quickstart) page for details.
26 |
27 | ## Special thanks
28 |
29 | * [Modal](https://modal.com/?utm_source=github&utm_medium=github&utm_campaign=moondream) - Modal lets you run jobs in the cloud, by just writing a few lines of Python. Here's an [example of how to run Moondream on Modal](https://github.com/m87-labs/moondream-examples/tree/main/quickstart/modal).
30 |
--------------------------------------------------------------------------------
/assets/demo-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/assets/demo-1.jpg
--------------------------------------------------------------------------------
/assets/demo-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/assets/demo-2.jpg
--------------------------------------------------------------------------------
/batch_generate_example.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from transformers import AutoTokenizer
3 |
4 | from moondream.hf import LATEST_REVISION, Moondream, detect_device
5 |
6 | device, dtype = detect_device()
7 |
8 | model_id = "vikhyatk/moondream2"
9 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
10 | moondream = Moondream.from_pretrained(
11 | model_id,
12 | revision=LATEST_REVISION,
13 | torch_dtype=dtype,
14 | ).to(device=device)
15 | moondream.eval()
16 |
17 | image1 = Image.open("assets/demo-1.jpg")
18 | image2 = Image.open("assets/demo-2.jpg")
19 | prompts = [
20 | "What is the girl doing?",
21 | "What color is the girl's hair?",
22 | "What is this?",
23 | "What is behind the stand?",
24 | ]
25 |
26 | answers = moondream.batch_answer(
27 | images=[image1, image1, image2, image2],
28 | prompts=prompts,
29 | tokenizer=tokenizer,
30 | )
31 |
32 | for question, answer in zip(prompts, answers):
33 | print(f"Q: {question}")
34 | print(f"A: {answer}")
35 | print()
36 |
--------------------------------------------------------------------------------
/clients/node/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | coverage
3 | .env
4 | .env.local
5 | .env.development
6 | .env.test
7 | .env.production
8 | dist/
9 | build/
10 | lib/
11 | out/
12 | *.log
13 | npm-debug.log*
14 | yarn-debug.log*
15 | yarn-error.log*
16 | .DS_Store
17 | .idea/
18 | .vscode/
19 | *.swp
20 | *.swo
21 | .next/
22 | .cache/
23 | .parcel-cache/
24 | tsconfig.tsbuildinfo
--------------------------------------------------------------------------------
/clients/node/.npmignore:
--------------------------------------------------------------------------------
1 | /src/
2 | __tests__/
3 |
--------------------------------------------------------------------------------
/clients/node/README.MD:
--------------------------------------------------------------------------------
1 | # Moondream NodeJS Client Library
2 |
3 | Official NodeJS client library for Moondream, a tiny vision language model that can
4 | analyze images and answer questions about them. This client library provides easy
5 | access to Moondream's API endpoints for image analysis.
6 |
7 | ## Installation
8 |
9 | Install the package using npm:
10 |
11 | ```bash
12 | npm install moondream
13 | ```
14 |
15 | Or using yarn:
16 |
17 | ```bash
18 | yarn add moondream
19 | ```
20 |
21 | ## Quick Start
22 |
23 | Before using this client library, you'll need an API key to access Moondream's hosted service.
24 | You can get a free API key from [console.moondream.ai](https://console.moondream.ai).
25 |
26 | ### Cloud
27 |
28 | ```javascript
29 | import { vl } from "moondream";
30 | import fs from "fs";
31 |
32 | // Initialize the client
33 | const model = new vl({
34 | apiKey: "your-api-key",
35 | });
36 |
37 | // Read an image file
38 | const image = fs.readFileSync("path/to/image.jpg");
39 |
40 | // Basic usage examples
41 | async function main() {
42 | // Generate a caption for the image
43 | const caption = await model.caption({
44 | image: image,
45 | length: "normal",
46 | stream: false
47 | });
48 | console.log("Caption:", caption);
49 |
50 | // Ask a question about the image
51 | const answer = await model.query({
52 | image: image,
53 | question: "What's in this image?",
54 | stream: false
55 | });
56 | console.log("Answer:", answer);
57 |
58 | // Stream the response
59 | const stream = await model.caption({
60 | image: image,
61 | length: "normal",
62 | stream: true
63 | });
64 | for await (const chunk of stream.caption) {
65 | process.stdout.write(chunk);
66 | }
67 | }
68 |
69 | main();
70 | ```
71 |
72 | ### Local Inference
73 |
74 | - Install the `moondream` CLI: `pip install moondream`
75 | - Run the local server: `moondream serve --model `
76 | - Set the `apiUrl` parameter to the URL of the local server (the default is `http://localhost:3475`)
77 |
78 | ```javascript
79 | const model = new vl({
80 | apiUrl: "http://localhost:3475",
81 | });
82 |
83 | const image = fs.readFileSync("path/to/image.jpg");
84 |
85 | // Basic usage examples
86 | async function main() {
87 | // Generate a caption for the image
88 | const caption = await model.caption({
89 | image: image,
90 | length: "normal",
91 | stream: false
92 | });
93 | console.log("Caption:", caption);
94 |
95 | // Ask a question about the image
96 | const answer = await model.query({
97 | image: image,
98 | question: "What's in this image?",
99 | stream: false
100 | });
101 | console.log("Answer:", answer);
102 |
103 | // Stream the response
104 | const stream = await model.caption({
105 | image: image,
106 | length: "normal",
107 | stream: true
108 | });
109 | for await (const chunk of stream.caption) {
110 | process.stdout.write(chunk);
111 | }
112 | }
113 |
114 | main();
115 | ```
116 |
117 | ## Features
118 |
119 | - **caption**: Generate descriptive captions for images
120 | - **query**: Ask questions about image content
121 | - **detect**: Find bounding boxes around objects in images
122 | - **point**: Identify the center location of specified objects in images
123 |
124 | ## API Reference
125 |
126 | ### Constructor
127 |
128 | ```javascript
129 | // for cloud inference
130 | const model = new vl({
131 | apiKey: "your-api-key",
132 | });
133 |
134 | // or for local inference
135 | const model = new vl({
136 | apiUrl: "http://localhost:3475",
137 | });
138 | ```
139 |
140 | ### Methods
141 |
142 | #### caption({ image: string, length: string, stream?: boolean })
143 |
144 | Generate a caption for an image.
145 |
146 | ```javascript
147 | const result = await model.caption({
148 | image: image,
149 | length: "normal",
150 | stream: false
151 | });
152 |
153 | // or with streaming
154 | const stream = await model.caption({
155 | image: image,
156 | length: "normal",
157 | stream: true
158 | });
159 | ```
160 |
161 | #### query({ image: string, question: string, stream?: boolean })
162 |
163 | Ask a question about an image.
164 |
165 | ```javascript
166 | const result = await model.query({
167 | image: image,
168 | question: "What's in this image?",
169 | stream: false
170 | });
171 |
172 | // or with streaming
173 | const stream = await model.query({
174 | image: image,
175 | question: "What's in this image?",
176 | stream: true
177 | });
178 | ```
179 |
180 | #### detect({ image: string, object: string })
181 |
182 | Detect specific objects in an image.
183 |
184 | ```javascript
185 | const result = await model.detect({
186 | image: image,
187 | object: "car"
188 | });
189 | ```
190 |
191 | #### point({ image: string, object: string })
192 |
193 | Get coordinates of specific objects in an image.
194 |
195 | ```javascript
196 | const result = await model.point({
197 | image: image,
198 | object: "person"
199 | });
200 | ```
201 |
202 | ### Input Types
203 |
204 | - Images can be provided as:
205 | - Buffer: Raw image data
206 | - Base64EncodedImage: `{ imageUrl: string }`
207 |
208 | ### Response Types
209 |
210 | All methods return promises that resolve to typed responses:
211 |
212 | - CaptionOutput: `{ caption: string | AsyncGenerator }`
213 | - QueryOutput: `{ answer: string | AsyncGenerator }`
214 | - DetectOutput: `{ objects: Array }`
215 | - PointOutput: `{ points: Array }`
216 |
217 | ## Links
218 |
219 | - [Website](https://moondream.ai/)
220 | - [Demo](https://moondream.ai/playground)
221 |
--------------------------------------------------------------------------------
/clients/node/jest.config.js:
--------------------------------------------------------------------------------
1 | // jest.config.js
2 | module.exports = {
3 | preset: 'ts-jest',
4 | testEnvironment: 'node',
5 | roots: ['/src'],
6 | testMatch: ['**/__tests__/**/*.test.ts'],
7 | moduleFileExtensions: ['ts', 'js', 'json', 'node'],
8 | collectCoverage: false,
9 | setupFiles: ['/src/__tests__/setup.ts']
10 | };
--------------------------------------------------------------------------------
/clients/node/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "moondream",
3 | "version": "0.0.5",
4 | "description": "TypeScript client for the Moondream AI API",
5 | "main": "dist/src/moondream.js",
6 | "types": "dist/src/moondream.d.ts",
7 | "scripts": {
8 | "build": "tsc",
9 | "test": "jest",
10 | "prepare": "npm run build"
11 | },
12 | "keywords": [
13 | "moondream",
14 | "ai",
15 | "vision",
16 | "client"
17 | ],
18 | "author": "",
19 | "license": "",
20 | "dependencies": {
21 | "sharp": "^0.32.1"
22 | },
23 | "devDependencies": {
24 | "@types/jest": "^29.0.0",
25 | "@types/node": "^18.0.0",
26 | "@types/node-fetch": "^2.6.1",
27 | "dotenv": "^16.4.7",
28 | "jest": "^29.0.0",
29 | "ts-jest": "^29.0.0",
30 | "typescript": "^4.9.5"
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/clients/node/src/__tests__/setup.ts:
--------------------------------------------------------------------------------
1 | // Jest setup file - currently empty as we don't need any global setup
--------------------------------------------------------------------------------
/clients/node/src/moondream.ts:
--------------------------------------------------------------------------------
1 | import { Buffer } from 'buffer';
2 | import sharp from 'sharp';
3 | import http from 'http';
4 | import https from 'https';
5 | import { version } from '../package.json';
6 | import {
7 | Base64EncodedImage,
8 | CaptionOutput,
9 | QueryOutput,
10 | DetectOutput,
11 | PointOutput,
12 | CaptionRequest,
13 | QueryRequest,
14 | DetectRequest,
15 | PointRequest,
16 | } from './types';
17 |
18 | export interface MoondreamVLConfig {
19 | apiKey?: string;
20 | apiUrl?: string;
21 | }
22 | const DEFAULT_API_URL = 'https://api.moondream.ai/v1';
23 |
24 | export class vl {
25 | private apiKey: string;
26 | private apiUrl: string;
27 |
28 | constructor(config: MoondreamVLConfig) {
29 | this.apiKey = config.apiKey || '';
30 | this.apiUrl = config.apiUrl || DEFAULT_API_URL;
31 | if (this.apiKey === '' && this.apiUrl === DEFAULT_API_URL) {
32 | throw new Error(
33 | 'An apiKey is required for cloud inference. '
34 | );
35 | }
36 | }
37 |
38 | private async encodeImage(
39 | image: Buffer | Base64EncodedImage
40 | ): Promise {
41 | if ('imageUrl' in image) {
42 | return image;
43 | }
44 |
45 | try {
46 | const MAX_SIZE = 768;
47 |
48 | // Process image with Sharp
49 | const metadata = await sharp(image).metadata();
50 |
51 | if (!metadata.width || !metadata.height) {
52 | throw new Error('Unable to get image dimensions');
53 | }
54 |
55 | const scale = MAX_SIZE / Math.max(metadata.width, metadata.height);
56 | let processedImage = sharp(image);
57 |
58 | if (scale < 1) {
59 | processedImage = processedImage.resize(
60 | Math.round(metadata.width * scale),
61 | Math.round(metadata.height * scale),
62 | {
63 | fit: 'inside',
64 | withoutEnlargement: true,
65 | }
66 | );
67 | }
68 |
69 | const buffer = await processedImage
70 | .toFormat('jpeg', { quality: 95 })
71 | .toBuffer();
72 |
73 | const base64Image = buffer.toString('base64');
74 | return {
75 | imageUrl: `data:image/jpeg;base64,${base64Image}`,
76 | };
77 | } catch (error) {
78 | throw new Error(
79 | `Failed to convert image to JPEG: ${(error as Error).message}`
80 | );
81 | }
82 | }
83 |
84 | private makeRequest(path: string, body: any, stream: boolean = false): Promise {
85 | return new Promise((resolve, reject) => {
86 | const url = new URL(this.apiUrl + path);
87 | const requestBody = JSON.stringify(body);
88 |
89 | const options = {
90 | method: 'POST',
91 | headers: {
92 | 'X-Moondream-Auth': this.apiKey,
93 | 'Content-Type': 'application/json',
94 | 'User-Agent': `moondream-node/${version}`,
95 | 'Content-Length': Buffer.byteLength(requestBody)
96 | }
97 | };
98 |
99 | const client = url.protocol === 'https:' ? https : http;
100 | const req = client.request(url, options, (res) => {
101 | if (stream) {
102 | resolve(res);
103 | return;
104 | }
105 |
106 | let data = '';
107 | res.on('data', chunk => data += chunk);
108 | res.on('end', () => {
109 | if (res.statusCode !== 200) {
110 | reject(new Error(`HTTP error! status: ${res.statusCode}`));
111 | return;
112 | }
113 | try {
114 | resolve(JSON.parse(data));
115 | } catch (error) {
116 | reject(new Error(`Failed to parse JSON response: ${(error as Error).message}`));
117 | }
118 | });
119 | });
120 |
121 | req.on('error', (error) => {
122 | reject(error);
123 | });
124 |
125 | req.write(requestBody);
126 | req.end();
127 | });
128 | }
129 |
130 | private async* streamResponse(response: any): AsyncGenerator {
131 | let buffer = '';
132 |
133 | try {
134 | for await (const chunk of response) {
135 | buffer += chunk.toString();
136 | const lines = buffer.split('\n');
137 | buffer = lines.pop() || '';
138 |
139 | for (const line of lines) {
140 | if (line.startsWith('data: ')) {
141 | try {
142 | const data = JSON.parse(line.slice(6));
143 | if ('chunk' in data) {
144 | yield data.chunk;
145 | }
146 | if (data.completed) {
147 | return;
148 | }
149 | } catch (error) {
150 | throw new Error(`Failed to parse JSON response from server: ${(error as Error).message}`);
151 | }
152 | }
153 | }
154 | }
155 |
156 | // Handle any remaining data in the buffer
157 | if (buffer) {
158 | const lines = buffer.split('\n');
159 | for (const line of lines) {
160 | if (line.startsWith('data: ')) {
161 | try {
162 | const data = JSON.parse(line.slice(6));
163 | if ('chunk' in data) {
164 | yield data.chunk;
165 | }
166 | } catch (error) {
167 | throw new Error(`Failed to parse JSON response from server: ${(error as Error).message}`);
168 | }
169 | }
170 | }
171 | }
172 | } catch (error) {
173 | throw new Error(`Failed to stream response: ${(error as Error).message}`);
174 | }
175 | }
176 |
177 | public async caption(
178 | request: CaptionRequest
179 | ): Promise {
180 | const encodedImage = await this.encodeImage(request.image);
181 |
182 | const response = await this.makeRequest('/caption', {
183 | image_url: encodedImage.imageUrl,
184 | length: request.length,
185 | stream: request.stream,
186 | }, request.stream);
187 |
188 | if (request.stream) {
189 | return { caption: this.streamResponse(response) };
190 | }
191 |
192 | return { caption: response.caption };
193 | }
194 |
195 | public async query(
196 | request: QueryRequest
197 | ): Promise {
198 | const encodedImage = await this.encodeImage(request.image);
199 |
200 | const response = await this.makeRequest('/query', {
201 | image_url: encodedImage.imageUrl,
202 | question: request.question,
203 | stream: request.stream,
204 | }, request.stream);
205 |
206 | if (request.stream) {
207 | return { answer: this.streamResponse(response) };
208 | }
209 |
210 | return { answer: response.answer };
211 | }
212 |
213 | public async detect(
214 | request: DetectRequest
215 | ): Promise {
216 | const encodedImage = await this.encodeImage(request.image);
217 |
218 | const response = await this.makeRequest('/detect', {
219 | image_url: encodedImage.imageUrl,
220 | object: request.object,
221 | });
222 |
223 | return { objects: response.objects };
224 | }
225 |
226 | public async point(
227 | request: PointRequest
228 | ): Promise {
229 | const encodedImage = await this.encodeImage(request.image);
230 |
231 | const response = await this.makeRequest('/point', {
232 | image_url: encodedImage.imageUrl,
233 | object: request.object,
234 | });
235 |
236 | return { points: response.points };
237 | }
238 | }
--------------------------------------------------------------------------------
/clients/node/src/types.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Base interface for encoded images
3 | */
4 | export interface Base64EncodedImage {
5 | imageUrl: string;
6 | }
7 |
8 | /**
9 | * Length options for caption generation
10 | */
11 | export type Length = 'normal' | 'short';
12 |
13 | /**
14 | * Settings for controlling the model's generation behavior
15 | */
16 | export interface SamplingSettings {
17 | maxTokens?: number;
18 | }
19 |
20 | /**
21 | * Request structure for image caption requests
22 | */
23 | export interface CaptionRequest {
24 | image: Buffer | Base64EncodedImage;
25 | length?: Length;
26 | stream?: boolean;
27 | settings?: SamplingSettings;
28 | }
29 | /**
30 | * Response structure for image caption requests
31 | */
32 | export interface CaptionOutput {
33 | caption: string | AsyncGenerator;
34 | }
35 |
36 | /**
37 | * Request structure for image query requests
38 | */
39 | export interface QueryRequest {
40 | image: Buffer | Base64EncodedImage;
41 | question: string;
42 | stream?: boolean;
43 | settings?: SamplingSettings;
44 | }
45 | /**
46 | * Response structure for image query requests
47 | */
48 | export interface QueryOutput {
49 | answer: string | AsyncGenerator;
50 | }
51 |
52 | /**
53 | * Request structure for object detection requests
54 | */
55 | export interface DetectRequest {
56 | image: Buffer | Base64EncodedImage;
57 | object: string;
58 | }
59 | /**
60 | * Response structure for object detection requests
61 | */
62 | export interface DetectOutput {
63 | objects: DetectedObject[];
64 | }
65 |
66 | /**
67 | * Object detection result
68 | */
69 | export interface DetectedObject {
70 | x_min: number;
71 | y_min: number;
72 | x_max: number;
73 | y_max: number;
74 | }
75 |
76 | /**
77 | * Response structure for object detection requests
78 | */
79 | export interface DetectOutput {
80 | objects: DetectedObject[];
81 | }
82 |
83 | /**
84 | * Error response from the API
85 | */
86 | export interface ApiError {
87 | error: {
88 | message: string;
89 | code?: string;
90 | details?: unknown;
91 | };
92 | }
93 |
94 | /**
95 | * Configuration options for the client
96 | */
97 | export interface ClientConfig {
98 | apiKey: string;
99 | apiUrl?: string;
100 | timeout?: number;
101 | retries?: number;
102 | }
103 |
104 | /**
105 | * API response for streaming requests
106 | */
107 | export interface StreamResponse {
108 | chunk?: string;
109 | completed?: boolean;
110 | error?: string;
111 | }
112 |
113 | /**
114 | * Options for image processing
115 | */
116 | export interface ImageProcessingOptions {
117 | maxSize?: number;
118 | quality?: number;
119 | format?: 'jpeg' | 'png';
120 | }
121 |
122 | /**
123 | * Common response type for all API endpoints
124 | */
125 | export type ApiResponse = {
126 | success: boolean;
127 | data?: T;
128 | error?: ApiError;
129 | timestamp?: string;
130 | requestId?: string;
131 | }
132 |
133 | /**
134 | * Pointing request structure
135 | */
136 | export interface PointRequest {
137 | image: Buffer | Base64EncodedImage;
138 | object: string;
139 | }
140 | /**
141 | * Point coordinates for object location
142 | */
143 | export interface Point {
144 | x: number;
145 | y: number;
146 | }
147 |
148 | /**
149 | * Response structure for point requests
150 | */
151 | export interface PointOutput {
152 | points: Point[];
153 | }
--------------------------------------------------------------------------------
/clients/node/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es2020",
4 | "module": "commonjs",
5 | "declaration": true,
6 | "outDir": "./dist",
7 | "strict": true,
8 | "esModuleInterop": true,
9 | "skipLibCheck": true,
10 | "forceConsistentCasingInFileNames": true,
11 | "moduleResolution": "node",
12 | "resolveJsonModule": true
13 | },
14 | "include": ["src"],
15 | "exclude": ["node_modules", "dist", "test"]
16 | }
--------------------------------------------------------------------------------
/clients/python/README.md:
--------------------------------------------------------------------------------
1 | # Moondream Python Client Library
2 |
3 | Official Python client library for Moondream, a tiny vision language model that
4 | can analyze images and answer questions about them. This library supports both
5 | local inference and cloud-based API access.
6 |
7 | ## Features
8 |
9 | - **Local Inference**: Run the model directly on your machine using CPU
10 | - **Cloud API**: Access Moondream's hosted service for faster inference
11 | - **Streaming**: Stream responses token by token for real-time output
12 | - **Multiple Model Sizes**: Choose between 0.5B and 2B parameter models
13 | - **Multiple Tasks**: Caption images, answer questions, detect objects, and locate points
14 |
15 | ## Installation
16 |
17 | Install the package from PyPI:
18 |
19 | ```bash
20 | pip install moondream==0.0.6
21 | ```
22 |
23 | To install the CPU dependencies for local inference, run:
24 |
25 | ```bash
26 | pip install moondream[cpu]
27 | ```
28 |
29 | To install the GPU dependencies for local inference, run:
30 |
31 | ```bash
32 | # Copy the torch implementation from the root moondream repo into the moondream/torch directory
33 | cp -r moondream/torch clients/python/moondream/torch
34 |
35 | # Install the GPU dependencies
36 | pip install moondream[gpu]
37 | ```
38 |
39 | ## Quick Start
40 |
41 | ### Using Cloud API
42 |
43 | To use Moondream's cloud API, you'll first need an API key. Sign up for a free
44 | account at [console.moondream.ai](https://console.moondream.ai) to get your key.
45 | Once you have your key, you can use it to initialize the client as shown below.
46 |
47 | ```python
48 | import moondream as md
49 | from PIL import Image
50 |
51 | # Initialize with API key
52 | model = md.vl(api_key="your-api-key")
53 |
54 | # Load an image
55 | image = Image.open("path/to/image.jpg")
56 |
57 | # Generate a caption
58 | caption = model.caption(image)["caption"]
59 | print("Caption:", caption)
60 |
61 | # Ask a question
62 | answer = model.query(image, "What's in this image?")["answer"]
63 | print("Answer:", answer)
64 |
65 | # Stream the response
66 | for chunk in model.caption(image, stream=True)["caption"]:
67 | print(chunk, end="", flush=True)
68 | ```
69 |
70 | ### Using Local Inference
71 |
72 | First, download the model weights. We recommend the int8 weights for most applications:
73 |
74 | | Model | Precision | Download Size | Memory Usage | Download Link |
75 | | -------------- | --------- | ------------- | ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------- |
76 | | Moondream 2B | int8 | 1,733 MiB | 2,624 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-2b-int8.mf.gz?download=true) |
77 | | Moondream 2B | int4 | 1,167 MiB | 2,002 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-2b-int4.mf.gz?download=true) |
78 | | Moondream 0.5B | int8 | 593 MiB | 996 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-0_5b-int8.mf.gz?download=true) |
79 | | Moondream 0.5B | int4 | 422 MiB | 816 MiB | [Download](https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-0_5b-int4.mf.gz?download=true) |
80 |
81 | Then use the model locally:
82 |
83 | ```python
84 | import moondream as md
85 | from PIL import Image
86 |
87 | # Initialize with local model path
88 | model = md.vl(model="path/to/moondream-2b-int8.bin")
89 |
90 | # Load and encode image
91 | image = Image.open("path/to/image.jpg")
92 |
93 | # Since encoding an image is computationally expensive, you can encode it once
94 | # and reuse the encoded version for multiple queries/captions/etc. This avoids
95 | # having to re-encode the same image multiple times.
96 | encoded_image = model.encode_image(image)
97 |
98 | # Generate caption
99 | caption = model.caption(encoded_image)["caption"]
100 | print("Caption:", caption)
101 |
102 | # Ask questions
103 | answer = model.query(encoded_image, "What's in this image?")["answer"]
104 | print("Answer:", answer)
105 | ```
106 |
107 | ## API Reference
108 |
109 | ### Constructor
110 |
111 | ```python
112 | model = md.vl(
113 | model="path/to/model.bin", # For local inference
114 | api_key="your-api-key" # For cloud API access
115 | )
116 | ```
117 |
118 | ### Methods
119 |
120 | #### caption(image, length="normal", stream=False, settings=None)
121 |
122 | Generate a caption for an image.
123 |
124 | ```python
125 | result = model.caption(image)
126 | # or with streaming
127 | for chunk in model.caption(image, stream=True)["caption"]:
128 | print(chunk, end="")
129 | ```
130 |
131 | #### query(image, question, stream=False, settings=None)
132 |
133 | Ask a question about an image.
134 |
135 | ```python
136 | result = model.query(image, "What's in this image?")
137 | # or with streaming
138 | for chunk in model.query(image, "What's in this image?", stream=True)["answer"]:
139 | print(chunk, end="")
140 | ```
141 |
142 | #### detect(image, object)
143 |
144 | Detect and locate specific objects in an image.
145 |
146 | ```python
147 | result = model.detect(image, "car")
148 | ```
149 |
150 | #### point(image, object)
151 |
152 | Get coordinates of specific objects in an image.
153 |
154 | ```python
155 | result = model.point(image, "person")
156 | ```
157 |
158 | ### Input Types
159 |
160 | - Images can be provided as:
161 | - PIL.Image.Image objects
162 | - Encoded image objects (from model.encode_image())
163 |
164 | ### Response Types
165 |
166 | All methods return typed dictionaries:
167 |
168 | - CaptionOutput: `{"caption": str | Generator}`
169 | - QueryOutput: `{"answer": str | Generator}`
170 | - DetectOutput: `{"objects": List[Region]}`
171 | - PointOutput: `{"points": List[Point]}`
172 |
173 | ## Performance Notes
174 |
175 | - Local inference currently only supports CPU execution
176 | - CUDA (GPU) and MPS (Apple Silicon) support coming soon
177 | - For optimal performance with GPU/MPS, use the PyTorch implementation for now
178 |
179 | ## Development Notes
180 |
181 | - Copy the torch implementation from the root moondream repo into the `torch` directory
182 | - Run `poetry install --extras "gpu"` to install the GPU dependencies
183 | - Run `poetry install --extras "cpu"` to install the CPU dependencies
184 |
185 | ## Links
186 |
187 | - [Website](https://moondream.ai/)
188 | - [Demo](https://moondream.ai/playground)
--------------------------------------------------------------------------------
/clients/python/moondream/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from .cloud_vl import CloudVL
4 | from .types import VLM
5 |
6 | DEFAULT_API_URL = "https://api.moondream.ai/v1"
7 |
8 |
9 | def vl(
10 | *,
11 | model: Optional[str] = None,
12 | api_key: Optional[str] = None,
13 | api_url: Optional[str] = None,
14 | ) -> VLM:
15 | if model:
16 | model_filetype = model.split(".")[-1]
17 | if model_filetype == "safetensors":
18 | from .torch_vl import TorchVL
19 |
20 | return TorchVL(model=model)
21 | elif model_filetype == "mf":
22 | from .onnx_vl import OnnxVL
23 |
24 | return OnnxVL.from_path(model)
25 |
26 | raise ValueError(
27 | "Unsupported model filetype. Please use a .safetensors model for GPU use or .mf model for CPU use."
28 | )
29 |
30 | if api_key:
31 | if not api_url:
32 | api_url = DEFAULT_API_URL
33 |
34 | return CloudVL(api_key=api_key, api_url=api_url)
35 |
36 | if api_url and api_url == DEFAULT_API_URL:
37 | if not api_key:
38 | raise ValueError("An api_key is required for cloud inference.")
39 |
40 | return CloudVL(api_url=api_url)
41 |
42 | raise ValueError("At least one of `model`, `api_key`, or `api_url` is required.")
43 |
--------------------------------------------------------------------------------
/clients/python/moondream/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from http import server
4 |
5 | from .onnx_vl import OnnxVL
6 | from .server import MoondreamHandler
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser(description="Moondream CLI")
11 | subparsers = parser.add_subparsers(dest="command", help="Command to run")
12 |
13 | # Server command
14 | server_parser = subparsers.add_parser("serve", help="Start the Moondream server")
15 | server_parser.add_argument("--model", type=str, help="Path to the model file")
16 | server_parser.add_argument(
17 | "--host", type=str, default="localhost", help="Host to bind to"
18 | )
19 | server_parser.add_argument(
20 | "--port", type=int, default=3475, help="Port to listen on"
21 | )
22 |
23 | args = parser.parse_args()
24 |
25 | if args.command == "serve":
26 | if args.model:
27 | model = OnnxVL.from_path(args.model)
28 | else:
29 | parser.error("Model path is required")
30 |
31 | MoondreamHandler.model = model
32 | server_address = (args.host, args.port)
33 | try:
34 | httpd = server.HTTPServer(server_address, MoondreamHandler)
35 | print(f"Starting Moondream server on http://{args.host}:{args.port}")
36 | httpd.serve_forever()
37 | except KeyboardInterrupt:
38 | print("\nShutting down server...")
39 | httpd.server_close()
40 | except Exception as e:
41 | print(f"Error: {e}", file=sys.stderr)
42 | sys.exit(1)
43 | else:
44 | parser.print_help()
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/clients/python/moondream/cloud_vl.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import urllib.request
4 | from io import BytesIO
5 | from typing import Literal, Optional, Union
6 |
7 | from PIL import Image
8 |
9 | from .types import (
10 | VLM,
11 | Base64EncodedImage,
12 | CaptionOutput,
13 | DetectOutput,
14 | EncodedImage,
15 | PointOutput,
16 | QueryOutput,
17 | SamplingSettings,
18 | )
19 | from .version import __version__
20 |
21 |
22 | class CloudVL(VLM):
23 | def __init__(
24 | self,
25 | *,
26 | api_url: str = "https://api.moondream.ai/v1",
27 | api_key: Optional[str] = None,
28 | ):
29 | self.api_key = api_key
30 | self.api_url = api_url
31 |
32 | def encode_image(
33 | self, image: Union[Image.Image, EncodedImage]
34 | ) -> Base64EncodedImage:
35 | if isinstance(image, EncodedImage):
36 | assert type(image) == Base64EncodedImage
37 | return image
38 | try:
39 | width, height = image.size
40 | max_size = 768
41 | scale = max_size / max(width, height)
42 | if scale < 1:
43 | new_size = (int(width * scale), int(height * scale))
44 | image = image.resize(new_size, Image.Resampling.LANCZOS)
45 |
46 | if image.mode != "RGB":
47 | image = image.convert("RGB")
48 | buffered = BytesIO()
49 | image.save(buffered, format="JPEG", quality=95)
50 | img_str = base64.b64encode(buffered.getvalue()).decode()
51 | return Base64EncodedImage(image_url=f"data:image/jpeg;base64,{img_str}")
52 | except Exception as e:
53 | raise ValueError("Failed to convert image to JPEG.") from e
54 |
55 | def _stream_response(self, req):
56 | """Helper function to stream response chunks from the API."""
57 | with urllib.request.urlopen(req) as response:
58 | for line in response:
59 | if not line:
60 | continue
61 | line = line.decode("utf-8")
62 | if line.startswith("data: "):
63 | try:
64 | data = json.loads(line[6:])
65 | if "chunk" in data:
66 | yield data["chunk"]
67 | if data.get("completed"):
68 | break
69 | except json.JSONDecodeError as e:
70 | raise ValueError(
71 | "Failed to parse JSON response from server."
72 | ) from e
73 |
74 | def caption(
75 | self,
76 | image: Union[Image.Image, EncodedImage],
77 | length: Literal["normal", "short"] = "normal",
78 | stream: bool = False,
79 | settings: Optional[SamplingSettings] = None,
80 | ) -> CaptionOutput:
81 | encoded_image = self.encode_image(image)
82 | payload = {
83 | "image_url": encoded_image.image_url,
84 | "length": length,
85 | "stream": stream,
86 | }
87 |
88 | data = json.dumps(payload).encode("utf-8")
89 | headers = {
90 | "Content-Type": "application/json",
91 | "User-Agent": f"moondream-python/{__version__}",
92 | }
93 | if self.api_key:
94 | headers["X-Moondream-Auth"] = self.api_key
95 | req = urllib.request.Request(
96 | f"{self.api_url}/caption",
97 | data=data,
98 | headers=headers,
99 | )
100 |
101 | def generator():
102 | for chunk in self._stream_response(req):
103 | yield chunk
104 |
105 | if stream:
106 | return {"caption": generator()}
107 |
108 | with urllib.request.urlopen(req) as response:
109 | result = json.loads(response.read().decode("utf-8"))
110 | return {"caption": result["caption"]}
111 |
112 | def query(
113 | self,
114 | image: Union[Image.Image, EncodedImage],
115 | question: str,
116 | stream: bool = False,
117 | settings: Optional[SamplingSettings] = None,
118 | ) -> QueryOutput:
119 | encoded_image = self.encode_image(image)
120 | payload = {
121 | "image_url": encoded_image.image_url,
122 | "question": question,
123 | "stream": stream,
124 | # TODO: Pass sampling settings like max_tokens to the API.
125 | }
126 |
127 | data = json.dumps(payload).encode("utf-8")
128 | headers = {
129 | "Content-Type": "application/json",
130 | "User-Agent": f"moondream-python/{__version__}",
131 | }
132 | if self.api_key:
133 | headers["X-Moondream-Auth"] = self.api_key
134 | req = urllib.request.Request(
135 | f"{self.api_url}/query",
136 | data=data,
137 | headers=headers,
138 | )
139 |
140 | if stream:
141 | return {"answer": self._stream_response(req)}
142 |
143 | with urllib.request.urlopen(req) as response:
144 | result = json.loads(response.read().decode("utf-8"))
145 | return {"answer": result["answer"]}
146 |
147 | def detect(
148 | self,
149 | image: Union[Image.Image, EncodedImage],
150 | object: str,
151 | ) -> DetectOutput:
152 | encoded_image = self.encode_image(image)
153 | payload = {"image_url": encoded_image.image_url, "object": object}
154 |
155 | data = json.dumps(payload).encode("utf-8")
156 | headers = {
157 | "Content-Type": "application/json",
158 | "User-Agent": f"moondream-python/{__version__}",
159 | }
160 | if self.api_key:
161 | headers["X-Moondream-Auth"] = self.api_key
162 | req = urllib.request.Request(
163 | f"{self.api_url}/detect",
164 | data=data,
165 | headers=headers,
166 | )
167 |
168 | with urllib.request.urlopen(req) as response:
169 | result = json.loads(response.read().decode("utf-8"))
170 | return {"objects": result["objects"]}
171 |
172 | def point(
173 | self,
174 | image: Union[Image.Image, EncodedImage],
175 | object: str,
176 | ) -> PointOutput:
177 | encoded_image = self.encode_image(image)
178 | payload = {"image_url": encoded_image.image_url, "object": object}
179 |
180 | data = json.dumps(payload).encode("utf-8")
181 | headers = {
182 | "Content-Type": "application/json",
183 | "User-Agent": f"moondream-python/{__version__}",
184 | }
185 | if self.api_key:
186 | headers["X-Moondream-Auth"] = self.api_key
187 | req = urllib.request.Request(
188 | f"{self.api_url}/point",
189 | data=data,
190 | headers=headers,
191 | )
192 |
193 | with urllib.request.urlopen(req) as response:
194 | result = json.loads(response.read().decode("utf-8"))
195 | return {"points": result["points"]}
196 |
--------------------------------------------------------------------------------
/clients/python/moondream/moonfile.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import gzip
3 | from typing import BinaryIO, Tuple, Iterator, Union
4 |
5 | MOON_MAGIC = b"MOON"
6 | MOON_VERSION = 1
7 |
8 |
9 | class MoonReader:
10 | def __init__(self, input_path: str):
11 | self.input_path = input_path
12 |
13 | def _get_file_handle(self) -> Union[BinaryIO, gzip.GzipFile]:
14 | """Returns appropriate file handle based on extension"""
15 | if self.input_path.endswith(".gz"):
16 | return gzip.open(self.input_path, "rb")
17 | return open(self.input_path, "rb")
18 |
19 | def _validate_header(self, f: Union[BinaryIO, gzip.GzipFile]) -> None:
20 | """Validate magic bytes and version"""
21 | magic = f.read(4)
22 | if magic != MOON_MAGIC:
23 | raise ValueError(f"Invalid magic bytes: {magic}")
24 |
25 | version = struct.unpack("!B", f.read(1))[0]
26 | if version != MOON_VERSION:
27 | raise ValueError(f"Unsupported version: {version}")
28 |
29 | def read_files(self) -> Iterator[Tuple[str, bytes]]:
30 | """Read and yield (filename, content) pairs from the archive"""
31 | with self._get_file_handle() as f:
32 | self._validate_header(f)
33 |
34 | while True:
35 | # Try to read filename length
36 | filename_len_bytes = f.read(4)
37 | if not filename_len_bytes:
38 | break # End of file
39 |
40 | filename_len = struct.unpack("!I", filename_len_bytes)[0]
41 |
42 | # Read filename
43 | filename = f.read(filename_len).decode("utf-8")
44 |
45 | # Read content length and content
46 | content_len = struct.unpack("!Q", f.read(8))[0]
47 | content = f.read(content_len)
48 |
49 | yield filename, content
50 |
51 |
52 | def unpack(input_path: str) -> Iterator[Tuple[str, bytes]]:
53 | """Unpack a .mf file"""
54 | reader = MoonReader(input_path)
55 | for filename, content in reader.read_files():
56 | yield filename, content
57 |
--------------------------------------------------------------------------------
/clients/python/moondream/preprocess.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | def im_resize(
8 | image: Image.Image,
9 | size: Tuple[int, int],
10 | resample: int = Image.Resampling.BICUBIC,
11 | ) -> Image.Image:
12 | return image.resize(size, resample=resample)
13 |
14 |
15 | def adaptive_avg_pool2d(x, output_size):
16 | """Applies 2D adaptive average pooling over an input signal.
17 |
18 | Resizes input to a target size by averaging values in local neighborhoods.
19 | The neighborhoods are computed to evenly cover the input image while
20 | maintaining approximately equal size. Similar to PyTorch's
21 | adaptive_avg_pool2d but expects input shape (H,W,C) rather than (N,C,H,W).
22 |
23 | Args:
24 | x: Input tensor of shape (height, width, channels)
25 | output_size: Target output size. Can be:
26 | - Single integer for square output (size, size)
27 | - Tuple of two ints (out_height, out_width)
28 |
29 | Returns:
30 | Tensor of shape (out_height, out_width, channels)
31 |
32 | Example:
33 | >>> img = np.random.randn(32, 32, 3) # 32x32 RGB image
34 | >>> pooled = adaptive_avg_pool2d(img, (7, 7)) # Resize to 7x7
35 | >>> pooled.shape
36 | (7, 7, 3)
37 | """
38 | height, width, channels = x.shape
39 |
40 | if isinstance(output_size, int):
41 | output_size = (output_size, output_size)
42 | out_height, out_width = output_size
43 |
44 | stride_h = height // out_height
45 | stride_w = width // out_width
46 | kernel_h = height - (out_height - 1) * stride_h
47 | kernel_w = width - (out_width - 1) * stride_w
48 |
49 | output = np.zeros((out_height, out_width, channels), dtype=x.dtype)
50 |
51 | for i in range(out_height):
52 | for j in range(out_width):
53 | h_start = i * stride_h
54 | h_end = h_start + kernel_h
55 | w_start = j * stride_w
56 | w_end = w_start + kernel_w
57 | output[i, j, :] = x[h_start:h_end, w_start:w_end, :].mean(axis=(0, 1))
58 |
59 | return output
60 |
61 |
62 | def normalize(
63 | image: np.ndarray,
64 | mean: List[float] = [0.5, 0.5, 0.5],
65 | std: List[float] = [0.5, 0.5, 0.5],
66 | ) -> np.ndarray:
67 | """
68 | Normalize an image array.
69 | """
70 | return (image - np.array(mean)) / np.array(std)
71 |
72 |
73 | def create_patches(
74 | image: Image.Image, image_patch_size=378
75 | ) -> Tuple[np.ndarray, Tuple[int, int]]:
76 | """
77 | Split the given image into a variable number of patches depending upon its
78 | resolution. Returns the patches as a numpy array, and the selected patching
79 | template as a tuple of (rows, cols).
80 | """
81 | image = image.convert("RGB")
82 |
83 | # Start off with the global patch.
84 | patches = [im_resize(image, (image_patch_size, image_patch_size))]
85 |
86 | # Find the closest resolution template.
87 | #
88 | # (1, 2) (2, 1) (2, 2)
89 | # +-------+-------+ +-----------+ +-------+-------+
90 | # | 1 | 2 | | 1 | | 1 | 2 |
91 | # +-------+-------+ +-----------+ +-------+-------+
92 | # | 2 | | 3 | 4 |
93 | # +-----------+ +-------+-------+
94 | res_templates = [(1, 2), (2, 1), (2, 2)]
95 |
96 | im_width, im_height = image.size
97 | max_dim = max(im_width, im_height)
98 | if max_dim < image_patch_size * 1.4:
99 | # If the image is already small, we avoid adding an extra patch
100 | # here to avoid redundant computation in the vision encoder, and
101 | # instead copy the global patch features after running the vision
102 | # encoder, before passing it through the vision projection.
103 | selected_template = (1, 1)
104 | else:
105 | aspect_ratio = im_width / im_height
106 | selected_template = min(
107 | res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio)
108 | )
109 |
110 | patch_width = im_width // selected_template[1]
111 | patch_height = im_height // selected_template[0]
112 |
113 | for row in range(selected_template[0]):
114 | for col in range(selected_template[1]):
115 | x_min = col * patch_width
116 | y_min = row * patch_height
117 | x_max = x_min + patch_width
118 | y_max = y_min + patch_height
119 | patches.append(
120 | im_resize(
121 | image.crop((x_min, y_min, x_max, y_max)),
122 | (image_patch_size, image_patch_size),
123 | )
124 | )
125 |
126 | return (
127 | np.stack(
128 | [
129 | normalize(
130 | (np.array(patch_img) / 255.0),
131 | mean=[0.5, 0.5, 0.5],
132 | std=[0.5, 0.5, 0.5],
133 | ).transpose(2, 0, 1)
134 | for patch_img in patches
135 | ],
136 | dtype=np.float16,
137 | ),
138 | selected_template,
139 | )
140 |
--------------------------------------------------------------------------------
/clients/python/moondream/torch_vl.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional, Union
2 |
3 | import torch
4 | from PIL import Image
5 |
6 | from .torch.moondream import MoondreamConfig, MoondreamModel
7 | from .torch.weights import load_weights_into_model
8 | from .types import (
9 | VLM,
10 | Base64EncodedImage,
11 | CaptionOutput,
12 | DetectOutput,
13 | EncodedImage,
14 | PointOutput,
15 | QueryOutput,
16 | SamplingSettings,
17 | )
18 | from .version import __version__
19 |
20 |
21 | class TorchVL(VLM):
22 | def __init__(
23 | self,
24 | *,
25 | model: str,
26 | ):
27 | config = MoondreamConfig()
28 | self.model = MoondreamModel(config)
29 | load_weights_into_model(model, self.model)
30 | self.model.eval()
31 | # Move model to the appropriate device
32 | if torch.cuda.is_available():
33 | self.device = "cuda"
34 | elif torch.backends.mps.is_available():
35 | self.device = "mps"
36 | else:
37 | self.device = "cpu"
38 | self.model.to(self.device)
39 |
40 | def encode_image(
41 | self, image: Union[Image.Image, EncodedImage]
42 | ) -> Base64EncodedImage:
43 | if isinstance(image, EncodedImage):
44 | assert type(image) == Base64EncodedImage
45 | return image
46 |
47 | if not self.model:
48 | raise ValueError("No local model loaded")
49 |
50 | return self.model.encode_image(image)
51 |
52 | def caption(
53 | self,
54 | image: Union[Image.Image, EncodedImage],
55 | length: Literal["normal", "short"] = "normal",
56 | stream: bool = False,
57 | settings: Optional[SamplingSettings] = None,
58 | ) -> CaptionOutput:
59 | if not self.model:
60 | raise ValueError("No local model loaded")
61 |
62 | encoded_image = (
63 | self.model.encode_image(image) if isinstance(image, Image.Image) else image
64 | )
65 | return self.model.caption(
66 | encoded_image, length=length, stream=stream, settings=settings
67 | )
68 |
69 | def query(
70 | self,
71 | image: Union[Image.Image, EncodedImage],
72 | question: str,
73 | stream: bool = False,
74 | settings: Optional[SamplingSettings] = None,
75 | ) -> QueryOutput:
76 | if not self.model:
77 | raise ValueError("No local model loaded")
78 |
79 | encoded_image = (
80 | self.model.encode_image(image) if isinstance(image, Image.Image) else image
81 | )
82 | return self.model.query(
83 | encoded_image, question, stream=stream, settings=settings
84 | )
85 |
86 | def detect(
87 | self,
88 | image: Union[Image.Image, EncodedImage],
89 | object: str,
90 | ) -> DetectOutput:
91 | if not self.model:
92 | raise ValueError("No local model loaded")
93 |
94 | encoded_image = (
95 | self.model.encode_image(image) if isinstance(image, Image.Image) else image
96 | )
97 | return self.model.detect(encoded_image, object)
98 |
99 | def point(
100 | self,
101 | image: Union[Image.Image, EncodedImage],
102 | object: str,
103 | ) -> PointOutput:
104 | if not self.model:
105 | raise ValueError("No local model loaded")
106 |
107 | encoded_image = (
108 | self.model.encode_image(image) if isinstance(image, Image.Image) else image
109 | )
110 | return self.model.point(encoded_image, object)
111 |
--------------------------------------------------------------------------------
/clients/python/moondream/types.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from abc import ABC, abstractmethod
3 | from PIL import Image
4 | from dataclasses import dataclass
5 | from typing import Generator, List, TypedDict, Union, Optional, Literal
6 |
7 |
8 | @dataclass
9 | class EncodedImage(ABC):
10 | pass
11 |
12 |
13 | @dataclass
14 | class OnnxEncodedImage(EncodedImage):
15 | pos: int
16 | kv_cache: np.ndarray
17 |
18 |
19 | @dataclass
20 | class Base64EncodedImage(EncodedImage):
21 | image_url: str
22 |
23 |
24 | SamplingSettings = TypedDict(
25 | "SamplingSettings",
26 | {"max_tokens": int},
27 | total=False,
28 | )
29 |
30 | CaptionOutput = TypedDict(
31 | "CaptionOutput", {"caption": Union[str, Generator[str, None, None]]}
32 | )
33 |
34 | QueryOutput = TypedDict(
35 | "QueryOutput", {"answer": Union[str, Generator[str, None, None]]}
36 | )
37 |
38 | Region = TypedDict(
39 | "Region", {"x_min": float, "y_min": float, "x_max": int, "y_max": float}
40 | )
41 | DetectOutput = TypedDict("DetectOutput", {"objects": List[Region]})
42 |
43 | Point = TypedDict("Point", {"x": float, "y": float})
44 | PointOutput = TypedDict("PointOutput", {"points": List[Point]})
45 |
46 |
47 | class VLM(ABC):
48 | @abstractmethod
49 | def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
50 | """
51 | Preprocess the image by running it through the model. Only supported for local
52 | inference.
53 |
54 | This method is useful if the user wants to make multiple queries with the same image.
55 | The output is not guaranteed to be backward-compatible across version updates,
56 | and should not be persisted out of band.
57 |
58 | Args:
59 | image (Image.Image): The input image to be encoded.
60 |
61 | Returns:
62 | The encoded representation of the image.
63 | """
64 |
65 | @abstractmethod
66 | def caption(
67 | self,
68 | image: Union[Image.Image, EncodedImage],
69 | length: Literal["normal", "short"] = "normal",
70 | stream: bool = False,
71 | settings: Optional[SamplingSettings] = None,
72 | ) -> CaptionOutput:
73 | """
74 | Generate a caption for the input image.
75 |
76 | Args:
77 | image (Union[Image.Image, EncodedImage]): The input image to be captioned.
78 | length (str): Length of caption to generate. Can be "normal" or "short".
79 | Defaults to "normal".
80 | stream (bool): If True, returns a generator that streams the output tokens.
81 | Defaults to False.
82 | settings (Optional[SamplingSettings]): Optional settings for the caption
83 | generation. If not provided, default settings will be used.
84 |
85 | Returns:
86 | CaptionOutput: A dictionary containing the 'caption' field with either a string
87 | or generator that yields strings for the caption.
88 | """
89 |
90 | @abstractmethod
91 | def query(
92 | self,
93 | image: Union[Image.Image, EncodedImage],
94 | question: str,
95 | stream: bool = False,
96 | settings: Optional[SamplingSettings] = None,
97 | ) -> QueryOutput:
98 | """
99 | Generate an answer to the input question about the input image.
100 |
101 | Args:
102 | image (Union[Image.Image, EncodedImage]): The input image to be queried.
103 | question (str): The question to be answered.
104 | stream (bool): If True, returns a generator that streams the output tokens.
105 | (default: False)
106 | settings (Optional[SamplingSettings]): Optional settings for the query
107 | generation.
108 |
109 | Returns:
110 | QueryOutput: A dictionary containing the 'answer' field with either a string
111 | or generator that yields strings for the response.
112 | """
113 |
114 | @abstractmethod
115 | def detect(
116 | self,
117 | image: Union[Image.Image, EncodedImage],
118 | object: str,
119 | ) -> DetectOutput:
120 | """
121 | Detect and localize the specified object in the input image.
122 |
123 | Args:
124 | image (Union[Image.Image, EncodedImage]): The input image to be analyzed.
125 | object (str): The object to be detected in the image.
126 |
127 | Returns:
128 | DetectOutput: A dictionary containing:
129 | 'objects' (List[Region]): List of detected object regions, where each
130 | Region has:
131 | - x_min (float): Left boundary of detection box
132 | - y_min (float): Top boundary of detection box
133 | - x_max (float): Right boundary of detection box
134 | - y_max (float): Bottom boundary of detection box
135 | """
136 |
137 | @abstractmethod
138 | def point(
139 | self,
140 | image: Union[Image.Image, EncodedImage],
141 | object: str,
142 | ) -> PointOutput:
143 | """
144 | Points out all instances of the given object in the input image.
145 |
146 | Args:
147 | image (Union[Image.Image, EncodedImage]): The input image to be analyzed for
148 | pointing out objects.
149 | object (str): The object type to be pointed out in the image.
150 |
151 | Returns:
152 | PointOutput: A dictionary containing:
153 | 'points' (List[Point]): List of detected points, where each Point has:
154 | - x (float): X coordinate of the point marking the object
155 | - y (float): Y coordinate of the point marking the object
156 |
157 | This method identifies instances of the specified object in the image and returns
158 | a list of coordinates marking the location of each instance found. Each point
159 | indicates the approximate center or most relevant position for that object
160 | instance.
161 | """
162 |
--------------------------------------------------------------------------------
/clients/python/moondream/version.py:
--------------------------------------------------------------------------------
1 | from importlib import metadata
2 |
3 | try:
4 | __version__ = metadata.version("moondream")
5 | except Exception:
6 | __version__ = "unknown"
7 |
--------------------------------------------------------------------------------
/clients/python/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [ "poetry-core",]
3 | build-backend = "poetry.core.masonry.api"
4 |
5 | [tool.poetry]
6 | name = "moondream"
7 | version = "0.0.2"
8 | description = "Python client library for moondream"
9 | authors = [ "M87 Labs ",]
10 | readme = "README.md"
11 | [[tool.poetry.packages]]
12 | include = "moondream"
13 | from = "."
14 |
15 | [tool.pyright]
16 | venvPath = "."
17 | venv = ".venv"
18 | reportMissingParameterType = false
19 |
20 | [tool.poetry.dependencies]
21 | python = "^3.10"
22 | pillow = "^10.4.0"
23 | numpy = "^2.1.2"
24 | onnxruntime = { version = ">=1.19.2", optional = true }
25 | tokenizers = { version = ">=0.20.1", optional = true }
26 | torch = { version = ">=2.5.0", optional = true }
27 | safetensors = { version = ">=0.4.2", optional = true }
28 | einops = { version = ">=0.7.0", optional = true }
29 | pyvips-binary = { version = ">=8.16.0", optional = true }
30 | pyvips = { version = ">=2.2.1", optional = true }
31 |
32 | [tool.poetry.extras]
33 | cpu = [
34 | "onnxruntime",
35 | "tokenizers"
36 | ]
37 | gpu = [
38 | "torch",
39 | "safetensors",
40 | "einops",
41 | "pyvips-binary",
42 | "pyvips",
43 | "tokenizers"
44 | ]
45 |
46 | [tool.poetry.scripts]
47 | moondream = "moondream.cli:main"
48 |
--------------------------------------------------------------------------------
/clients/python/scripts/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import tracemalloc
4 |
5 | from PIL import Image, ImageDraw
6 |
7 | import moondream as md
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--model-path", type=str, required=True)
11 | args = parser.parse_args()
12 |
13 |
14 | class Colors:
15 | HEADER = "\033[95m" # Purple
16 | BLUE = "\033[94m"
17 | GREEN = "\033[92m"
18 | YELLOW = "\033[93m"
19 | RED = "\033[91m"
20 | ENDC = "\033[0m"
21 | BOLD = "\033[1m"
22 |
23 |
24 | def format_memory(memory_mb):
25 | """Format memory size with appropriate unit"""
26 | return f"{memory_mb:.2f} MiB"
27 |
28 |
29 | def print_section(title):
30 | """Print a section header with dynamic padding to center the text"""
31 | total_width = 65
32 | text_length = len(title) + 2 # Add 2 for spaces around title
33 | total_padding = total_width - text_length
34 | left_padding = total_padding // 2
35 | right_padding = total_padding - left_padding
36 | print(
37 | f"\n{Colors.HEADER}{Colors.BOLD}{'-'*left_padding} {title} {'-'*right_padding}{Colors.ENDC}"
38 | )
39 |
40 |
41 | def print_metric(label, value, color=Colors.BLUE):
42 | """Print a metric with consistent formatting"""
43 | print(f"| {color}{label}{Colors.ENDC}: {value}")
44 |
45 |
46 | def log_memory_and_time(operation_name, start_time, start_memory):
47 | """Log memory and time differences for an operation"""
48 | end_time = time.time()
49 | current_memory = get_memory_usage()
50 | time_diff = end_time - start_time
51 | memory_diff = current_memory - start_memory
52 |
53 | print("\nStats")
54 | print_metric("Time", f"{time_diff:.2f} seconds")
55 | print_metric("Memory usage", format_memory(current_memory))
56 |
57 | # Color-code memory increase based on significance
58 | color = (
59 | Colors.GREEN
60 | if memory_diff < 10
61 | else Colors.YELLOW if memory_diff < 100 else Colors.RED
62 | )
63 | print_metric("Memory increase", format_memory(memory_diff), color)
64 |
65 | return end_time, current_memory
66 |
67 |
68 | def get_memory_usage():
69 | """Get current memory usage in MiB"""
70 | current, peak = tracemalloc.get_traced_memory()
71 | return current / 1024 / 1024
72 |
73 |
74 | # Start tracking memory
75 | tracemalloc.start()
76 |
77 | # Initial memory measurement
78 | initial_memory = get_memory_usage()
79 | print_section("Initial State")
80 | print_metric("Initial memory usage", format_memory(initial_memory))
81 |
82 | # Load image
83 | print_section("Image Loading")
84 | start_time = time.time()
85 | start_memory = get_memory_usage()
86 | image = Image.open("../../assets/demo-1.jpg")
87 | log_memory_and_time("Image Loading", start_time, start_memory)
88 |
89 | # Initialize model
90 | print_section("Model Initialization")
91 | start_time = time.time()
92 | start_memory = get_memory_usage()
93 | model = md.vl(model=args.model_path)
94 | log_memory_and_time("Model Initialization", start_time, start_memory)
95 |
96 | # Encode image
97 | print_section("Image Encoding")
98 | start_time = time.time()
99 | start_memory = get_memory_usage()
100 | encoded_image = model.encode_image(image)
101 | log_memory_and_time("Image Encoding", start_time, start_memory)
102 |
103 | # Generate caption
104 | print_section("Caption Generation")
105 | print(f"{Colors.BOLD}Caption:{Colors.ENDC}", end="", flush=True)
106 | start_time = time.time()
107 | start_memory = get_memory_usage()
108 | tokens = 0
109 | for tok in model.caption(encoded_image, stream=True)["caption"]:
110 | print(tok, end="", flush=True)
111 | tokens += 1
112 | print()
113 | end_time, end_memory = log_memory_and_time("Caption Stats", start_time, start_memory)
114 | print_metric("Token generation speed", f"{tokens / (end_time - start_time):.2f} tok/s")
115 |
116 | # Generate answer to question
117 | question = "How many people are in this image? Answer briefly."
118 | print_section("Question Answering")
119 | print(f"{Colors.BOLD}Question:{Colors.ENDC} {question}")
120 | print(f"{Colors.BOLD}Answer:{Colors.ENDC}", end="", flush=True)
121 | start_time = time.time()
122 | start_memory = get_memory_usage()
123 | tokens = 0
124 | for tok in model.query(encoded_image, question, stream=True)["answer"]:
125 | print(tok, end="", flush=True)
126 | tokens += 1
127 | print()
128 | end_time, end_memory = log_memory_and_time(
129 | "Question Answering Stats", start_time, start_memory
130 | )
131 | print_metric("Token generation speed", f"{tokens / (end_time - start_time):.2f} tok/s")
132 |
133 | # Object detection
134 | object = "burger"
135 | print_section("Object Detection")
136 | print(f"{Colors.BOLD}Detect:{Colors.ENDC} {object}")
137 | start_time = time.time()
138 | start_memory = get_memory_usage()
139 | objects = model.detect(encoded_image, object)["objects"]
140 | print(len(objects), "detected")
141 |
142 | # Draw rectangles for each detected object
143 | width, height = image.size
144 | draw = ImageDraw.Draw(image)
145 | for obj in objects:
146 | x_min = int(obj["x_min"] * width)
147 | x_max = int(obj["x_max"] * width)
148 | y_min = int(obj["y_min"] * height)
149 | y_max = int(obj["y_max"] * height)
150 | draw.rectangle([(x_min, y_min), (x_max, y_max)], outline="green", width=2)
151 | image.save("detection_output.jpg")
152 |
153 | end_time, end_memory = log_memory_and_time(
154 | "Object Detection Stats", start_time, start_memory
155 | )
156 |
157 | # Final summary
158 | print_section("Final Summary")
159 | final_memory = get_memory_usage()
160 | current, peak = tracemalloc.get_traced_memory()
161 |
162 | print_metric("Final memory usage", format_memory(final_memory))
163 | print_metric("Total memory increase", format_memory(final_memory - initial_memory))
164 | print_metric("Peak memory usage", format_memory(peak / 1024 / 1024))
165 |
166 | # Stop tracking memory
167 | tracemalloc.stop()
168 |
--------------------------------------------------------------------------------
/clients/python/scripts/test_cloud_parity.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import moondream as md
4 |
5 | from PIL import Image
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--model-path", type=str, required=True)
9 | args = parser.parse_args()
10 |
11 | local = md.vl(model_path=args.model_path)
12 | cloud = md.vl(api_key=os.environ["MOONDREAM_API_KEY"])
13 |
14 | image_path = "../../assets/demo-1.jpg"
15 | image = Image.open(image_path)
16 |
17 | print("# Pointing")
18 | object = "person"
19 | print("Local:", local.point(image, object))
20 | print("Cloud:", cloud.point(image, object))
21 |
22 | print("# Captioning")
23 | print("Local:", local.caption(image))
24 | print("Cloud:", cloud.caption(image))
25 |
26 | print("# Querying")
27 | question = "What is the character eating?"
28 | print("Local:", local.query(image, question))
29 | print("Cloud:", cloud.query(image, question))
30 |
31 | print("# Detecting")
32 | object_to_detect = "burger"
33 | print("Local:", local.detect(image, object_to_detect))
34 | print("Cloud:", cloud.detect(image, object_to_detect))
35 |
36 | print("# Streaming Caption")
37 | print("Local:")
38 | for tok in local.caption(image, stream=True)["caption"]:
39 | print(tok, end="", flush=True)
40 | print()
41 | print("Cloud:")
42 | for tok in cloud.caption(image, stream=True)["caption"]:
43 | print(tok, end="", flush=True)
44 | print()
45 |
46 | print("# Streaming Query")
47 | print("Local:")
48 | for tok in local.query(image, question, stream=True)["answer"]:
49 | print(tok, end="", flush=True)
50 | print()
51 | print("Cloud:")
52 | for tok in cloud.query(image, question, stream=True)["answer"]:
53 | print(tok, end="", flush=True)
54 | print()
55 |
--------------------------------------------------------------------------------
/clients/python/scripts/test_local_server.py:
--------------------------------------------------------------------------------
1 | import moondream as md
2 | from PIL import Image
3 |
4 | base_url = "http://localhost:3475"
5 | local_server_client = md.vl(api_url=base_url)
6 |
7 | image_path = "../../assets/demo-1.jpg"
8 | image = Image.open(image_path)
9 |
10 | print("# Pointing")
11 | object = "person"
12 | print("Local Server:", local_server_client.point(image, object))
13 |
14 | print("# Captioning")
15 | print("Local Server:", local_server_client.caption(image))
16 |
17 | print("# Querying")
18 | question = "What is the character eating?"
19 | print("Local Server:", local_server_client.query(image, question))
20 |
21 | print("# Detecting")
22 | object_to_detect = "burger"
23 | print("Local Server:", local_server_client.detect(image, object_to_detect))
24 |
25 | print("# Captioning Stream")
26 | print("Local Server:")
27 | for tok in local_server_client.caption(image, stream=True)["caption"]:
28 | print(tok, end="", flush=True)
29 | print()
30 |
31 | print("# Querying Stream")
32 | print("Local Server:")
33 | for tok in local_server_client.query(image, question, stream=True)["answer"]:
34 | print(tok, end="", flush=True)
35 | print()
36 |
--------------------------------------------------------------------------------
/clients/python/tests/test_api_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | from PIL import Image
4 | import moondream as md
5 |
6 | TEST_IMAGE_PATH = os.path.join(
7 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
8 | "assets",
9 | "demo-1.jpg",
10 | )
11 |
12 |
13 | @pytest.fixture
14 | def model():
15 | api_key = os.getenv("MOONDREAM_API_KEY")
16 | if not api_key:
17 | pytest.skip("MOONDREAM_API_KEY environment variable not set")
18 | return md.vl(api_key=api_key)
19 |
20 |
21 | @pytest.fixture
22 | def test_image():
23 | return Image.open(TEST_IMAGE_PATH)
24 |
25 |
26 | def test_api_initialization(model):
27 | assert model is not None
28 | assert isinstance(model, md.cloud_vl.CloudVL)
29 |
30 |
31 | def test_image_captioning(model, test_image):
32 | # Test normal length caption
33 | result = model.caption(test_image, length="normal")
34 | assert "caption" in result
35 | assert isinstance(result["caption"], str)
36 | assert len(result["caption"]) > 0
37 |
38 | # Test short length caption
39 | result = model.caption(test_image, length="short")
40 | assert "caption" in result
41 | assert isinstance(result["caption"], str)
42 | assert len(result["caption"]) > 0
43 |
44 |
45 | def test_streaming_caption(model, test_image):
46 | result = model.caption(test_image, stream=True)
47 | assert "caption" in result
48 |
49 | # Test that we can iterate over the stream
50 | caption = ""
51 | for chunk in result["caption"]:
52 | assert isinstance(chunk, str)
53 | caption += chunk
54 |
55 | assert len(caption) > 0
56 |
57 |
58 | def test_query_answering(model, test_image):
59 | # Test basic question answering
60 | result = model.query(test_image, "What is in this image?")
61 | assert "answer" in result
62 | assert isinstance(result["answer"], str)
63 | assert len(result["answer"]) > 0
64 |
65 |
66 | def test_streaming_query(model, test_image):
67 | result = model.query(test_image, "What is in this image?", stream=True)
68 | assert "answer" in result
69 |
70 | # Test that we can iterate over the stream
71 | answer = ""
72 | for chunk in result["answer"]:
73 | assert isinstance(chunk, str)
74 | answer += chunk
75 |
76 | assert len(answer) > 0
77 |
78 |
79 | @pytest.mark.skip(
80 | reason="API handles invalid caption lengths differently than local model"
81 | )
82 | def test_invalid_caption_length(model, test_image):
83 | with pytest.raises(ValueError, match="Model does not support caption length"):
84 | model.caption(test_image, length="invalid")
85 |
86 |
87 | def test_missing_api_key():
88 | with pytest.raises(ValueError, match="An api_key is required for cloud inference"):
89 | md.vl(api_url="https://api.moondream.ai/v1")
90 |
--------------------------------------------------------------------------------
/clients/python/tests/test_local_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | from PIL import Image
4 | import moondream as md
5 |
6 | MODEL_URL = "https://huggingface.co/vikhyatk/moondream2/resolve/9dddae84d54db4ac56fe37817aeaeb502ed083e2/moondream-0_5b-int8.mf.gz"
7 | MODEL_PATH = os.path.join(
8 | os.path.dirname(__file__), "test_data", "moondream-0_5b-int8.mf"
9 | )
10 | TEST_IMAGE_PATH = os.path.join(
11 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
12 | "assets",
13 | "demo-1.jpg",
14 | )
15 |
16 |
17 | @pytest.fixture(scope="session", autouse=True)
18 | def download_model():
19 | if not os.path.exists(os.path.dirname(MODEL_PATH)):
20 | os.makedirs(os.path.dirname(MODEL_PATH))
21 |
22 | if not os.path.exists(MODEL_PATH):
23 | import requests
24 | import gzip
25 | import io
26 |
27 | # Download the model file
28 | print("Downloading model file...")
29 | response = requests.get(MODEL_URL, stream=True)
30 | response.raise_for_status()
31 |
32 | # Read the gzipped content into memory
33 | content = response.content
34 |
35 | # Decompress and save
36 | print("Decompressing model file...")
37 | with gzip.open(io.BytesIO(content), "rb") as f_in:
38 | with open(MODEL_PATH, "wb") as f_out:
39 | f_out.write(f_in.read())
40 | print("Model file ready.")
41 |
42 |
43 | @pytest.fixture
44 | def model():
45 | return md.vl(model=MODEL_PATH)
46 |
47 |
48 | @pytest.fixture
49 | def test_image():
50 | return Image.open(TEST_IMAGE_PATH)
51 |
52 |
53 | def test_model_initialization(model):
54 | assert model is not None
55 | assert hasattr(model, "vision_encoder")
56 | assert hasattr(model, "text_decoder")
57 | assert hasattr(model, "tokenizer")
58 |
59 |
60 | def test_image_encoding(model, test_image):
61 | encoded_image = model.encode_image(test_image)
62 | assert encoded_image is not None
63 | assert hasattr(encoded_image, "pos")
64 | assert hasattr(encoded_image, "kv_cache")
65 |
66 |
67 | def test_image_captioning(model, test_image):
68 | # Test normal length caption
69 | result = model.caption(test_image, length="normal")
70 | assert "caption" in result
71 | assert isinstance(result["caption"], str)
72 | assert len(result["caption"]) > 0
73 |
74 | # Test short length caption
75 | result = model.caption(test_image, length="short")
76 | assert "caption" in result
77 | assert isinstance(result["caption"], str)
78 | assert len(result["caption"]) > 0
79 |
80 |
81 | def test_streaming_caption(model, test_image):
82 | result = model.caption(test_image, stream=True)
83 | assert "caption" in result
84 |
85 | # Test that we can iterate over the stream
86 | caption = ""
87 | for chunk in result["caption"]:
88 | assert isinstance(chunk, str)
89 | caption += chunk
90 |
91 | assert len(caption) > 0
92 |
93 |
94 | def test_reuse_encoded_image(model, test_image):
95 | # Test that we can reuse an encoded image for multiple operations
96 | encoded_image = model.encode_image(test_image)
97 |
98 | # Generate two captions using the same encoded image
99 | result1 = model.caption(encoded_image)
100 | result2 = model.caption(encoded_image)
101 |
102 | assert result1["caption"] == result2["caption"]
103 |
104 |
105 | def test_invalid_caption_length(model, test_image):
106 | with pytest.raises(ValueError, match="Model does not support caption length"):
107 | model.caption(test_image, length="invalid")
108 |
109 |
110 | def test_invalid_model_path():
111 | with pytest.raises(
112 | ValueError,
113 | match="Unsupported model filetype. Please use a .safetensors for GPU use or .mf for CPU use.",
114 | ):
115 | md.vl(model="invalid/path/to/model.bin")
116 |
--------------------------------------------------------------------------------
/clients/python/tests/test_local_torch_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | from PIL import Image
4 | import moondream as md
5 |
6 | MODEL_PATH = "/Users/caleb/Projects/moondream/moondream-playground-inf/src/ai-models/05/moondream-01-08-2025.safetensors"
7 | TEST_IMAGE_PATH = os.path.join(
8 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
9 | "assets",
10 | "demo-1.jpg",
11 | )
12 |
13 |
14 | @pytest.fixture
15 | def model():
16 | return md.vl(model=MODEL_PATH)
17 |
18 |
19 | @pytest.fixture
20 | def test_image():
21 | return Image.open(TEST_IMAGE_PATH)
22 |
23 |
24 | def test_image_captioning(model, test_image):
25 | # Test normal length caption
26 | result = model.caption(test_image, length="normal")
27 | assert "caption" in result
28 | assert isinstance(result["caption"], str)
29 | assert len(result["caption"]) > 0
30 |
31 | # Test short length caption
32 | result = model.caption(test_image, length="short")
33 | assert "caption" in result
34 | assert isinstance(result["caption"], str)
35 | assert len(result["caption"]) > 0
36 |
37 | # Test streaming caption
38 | result = model.caption(test_image, stream=True)
39 | assert "caption" in result
40 |
41 | # Test that we can iterate over the stream
42 | num_chunks = 0
43 | caption = ""
44 | for chunk in result["caption"]:
45 | assert isinstance(chunk, str)
46 | caption += chunk
47 | num_chunks += 1
48 |
49 | assert len(caption) > 0
50 | assert num_chunks > 1
51 |
52 |
53 | def test_query(model, test_image):
54 | result = model.query(test_image, "What is in this image?")
55 | assert "answer" in result
56 | assert isinstance(result["answer"], str)
57 | assert len(result["answer"]) > 0
58 |
59 | # Test streaming query
60 | result = model.query(test_image, "What is in this image?", stream=True)
61 | assert "answer" in result
62 |
63 | # Test that we can iterate over the stream
64 | num_chunks = 0
65 | answer = ""
66 | for chunk in result["answer"]:
67 | assert isinstance(chunk, str)
68 | answer += chunk
69 | num_chunks += 1
70 |
71 | assert num_chunks > 1
72 | assert len(answer) > 0
73 |
74 |
75 | def test_detect(model, test_image):
76 | result = model.detect(test_image, "person")
77 | assert "objects" in result
78 | assert isinstance(result["objects"], list)
79 |
80 |
81 | def test_point(model, test_image):
82 | result = model.point(test_image, "face")
83 | assert "points" in result
84 | assert isinstance(result["points"], list)
85 |
--------------------------------------------------------------------------------
/gradio_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 | from threading import Thread
4 |
5 | import gradio as gr
6 | import torch
7 | from PIL import ImageDraw
8 | from torchvision.transforms.v2 import Resize
9 | from transformers import AutoTokenizer, TextIteratorStreamer
10 |
11 | from moondream.hf import LATEST_REVISION, Moondream, detect_device
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--cpu", action="store_true")
15 | args = parser.parse_args()
16 |
17 | if args.cpu:
18 | device = torch.device("cpu")
19 | dtype = torch.float32
20 | else:
21 | device, dtype = detect_device()
22 | if device != torch.device("cpu"):
23 | print("Using device:", device)
24 | print("If you run into issues, pass the `--cpu` flag to this script.")
25 | print()
26 |
27 | model_id = "vikhyatk/moondream2"
28 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
29 | moondream = Moondream.from_pretrained(
30 | model_id, revision=LATEST_REVISION, torch_dtype=dtype
31 | ).to(device=device)
32 | moondream.eval()
33 |
34 |
35 | def answer_question(img, prompt):
36 | image_embeds = moondream.encode_image(img)
37 | streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
38 | thread = Thread(
39 | target=moondream.answer_question,
40 | kwargs={
41 | "image_embeds": image_embeds,
42 | "question": prompt,
43 | "tokenizer": tokenizer,
44 | "streamer": streamer,
45 | },
46 | )
47 | thread.start()
48 |
49 | buffer = ""
50 | for new_text in streamer:
51 | buffer += new_text
52 | yield buffer
53 |
54 |
55 | def extract_floats(text):
56 | # Regular expression to match an array of four floating point numbers
57 | pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
58 | match = re.search(pattern, text)
59 | if match:
60 | # Extract the numbers and convert them to floats
61 | return [float(num) for num in match.groups()]
62 | return None # Return None if no match is found
63 |
64 |
65 | def extract_bbox(text):
66 | bbox = None
67 | if extract_floats(text) is not None:
68 | x1, y1, x2, y2 = extract_floats(text)
69 | bbox = (x1, y1, x2, y2)
70 | return bbox
71 |
72 |
73 | def process_answer(img, answer):
74 | if extract_bbox(answer) is not None:
75 | x1, y1, x2, y2 = extract_bbox(answer)
76 | draw_image = Resize(768)(img)
77 | width, height = draw_image.size
78 | x1, x2 = int(x1 * width), int(x2 * width)
79 | y1, y2 = int(y1 * height), int(y2 * height)
80 | bbox = (x1, y1, x2, y2)
81 | ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3)
82 | return gr.update(visible=True, value=draw_image)
83 |
84 | return gr.update(visible=False, value=None)
85 |
86 |
87 | with gr.Blocks() as demo:
88 | gr.Markdown(
89 | """
90 | # 🌔 moondream
91 | """
92 | )
93 | with gr.Row():
94 | prompt = gr.Textbox(label="Input Prompt", value="Describe this image.", scale=4)
95 | submit = gr.Button("Submit")
96 | with gr.Row():
97 | img = gr.Image(type="pil", label="Upload an Image")
98 | with gr.Column():
99 | output = gr.Markdown(label="Response")
100 | ann = gr.Image(visible=False, label="Annotated Image")
101 |
102 | submit.click(answer_question, [img, prompt], output)
103 | prompt.submit(answer_question, [img, prompt], output)
104 | output.change(process_answer, [img, output], ann, show_progress=False)
105 |
106 | demo.queue().launch(debug=True)
107 |
--------------------------------------------------------------------------------
/moondream/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/moondream/__init__.py
--------------------------------------------------------------------------------
/moondream/config/config_md05.json:
--------------------------------------------------------------------------------
1 | {
2 | "text": {
3 | "dim": 1024,
4 | "ff_dim": 4096,
5 | "n_layers": 24,
6 | "vocab_size": 51200,
7 | "max_context": 2048,
8 | "n_heads": 16,
9 | "prefix_attn": 730
10 | },
11 | "vision": {
12 | "enc_dim": 720,
13 | "enc_patch_size": 14,
14 | "enc_n_layers": 27,
15 | "enc_ff_dim": 2690,
16 | "enc_n_heads": 10,
17 | "proj_out_dim": 1024,
18 | "crop_size": 378,
19 | "in_channels": 3,
20 | "max_crops": 12,
21 | "overlap_margin": 4,
22 | "proj_inner_dim": 8192
23 | },
24 | "region": {
25 | "dim": 1024,
26 | "coord_feat_dim": 256,
27 | "coord_out_dim": 1024,
28 | "size_feat_dim": 512,
29 | "size_out_dim": 2048,
30 | "inner_dim": 8192
31 | },
32 | "tokenizer": {
33 | "bos_id": 50256,
34 | "eos_id": 50256,
35 | "templates": {
36 | "caption": {
37 | "short": [
38 | 198,
39 | 198,
40 | 16438,
41 | 8305,
42 | 25
43 | ],
44 | "normal": [
45 | 198,
46 | 198,
47 | 24334,
48 | 1159,
49 | 25
50 | ]
51 | },
52 | "query": {
53 | "prefix": [
54 | 198,
55 | 198,
56 | 24361,
57 | 25
58 | ],
59 | "suffix": [
60 | 198,
61 | 198,
62 | 33706,
63 | 25
64 | ]
65 | },
66 | "detect": {
67 | "prefix": [
68 | 198,
69 | 198,
70 | 47504,
71 | 25
72 | ],
73 | "suffix": [
74 | 628
75 | ]
76 | },
77 | "point": {
78 | "prefix": [
79 | 198,
80 | 198,
81 | 12727,
82 | 25
83 | ],
84 | "suffix": [
85 | 628
86 | ]
87 | }
88 | }
89 | }
90 | }
--------------------------------------------------------------------------------
/moondream/config/config_md2.json:
--------------------------------------------------------------------------------
1 | {
2 | "text": {
3 | "dim": 2048,
4 | "ff_dim": 8192,
5 | "n_layers": 24,
6 | "vocab_size": 51200,
7 | "max_context": 2048,
8 | "n_heads": 32,
9 | "prefix_attn": 730
10 | },
11 | "vision": {
12 | "enc_dim": 1152,
13 | "enc_patch_size": 14,
14 | "enc_n_layers": 27,
15 | "enc_ff_dim": 4304,
16 | "enc_n_heads": 16,
17 | "proj_out_dim": 2048,
18 | "crop_size": 378,
19 | "in_channels": 3,
20 | "max_crops": 12,
21 | "overlap_margin": 4,
22 | "proj_inner_dim": 8192
23 | },
24 | "region": {
25 | "dim": 2048,
26 | "coord_feat_dim": 256,
27 | "coord_out_dim": 1024,
28 | "size_feat_dim": 512,
29 | "size_out_dim": 2048,
30 | "inner_dim": 8192
31 | },
32 | "tokenizer": {
33 | "bos_id": 50256,
34 | "eos_id": 50256,
35 | "templates": {
36 | "caption": {
37 | "short": [
38 | 198,
39 | 198,
40 | 16438,
41 | 8305,
42 | 25
43 | ],
44 | "normal": [
45 | 198,
46 | 198,
47 | 24334,
48 | 1159,
49 | 25
50 | ]
51 | },
52 | "query": {
53 | "prefix": [
54 | 198,
55 | 198,
56 | 24361,
57 | 25
58 | ],
59 | "suffix": [
60 | 198,
61 | 198,
62 | 33706,
63 | 25
64 | ]
65 | },
66 | "detect": {
67 | "prefix": [
68 | 198,
69 | 198,
70 | 47504,
71 | 25
72 | ],
73 | "suffix": [
74 | 628
75 | ]
76 | },
77 | "point": {
78 | "prefix": [
79 | 198,
80 | 198,
81 | 12727,
82 | 25
83 | ],
84 | "suffix": [
85 | 628
86 | ]
87 | }
88 | }
89 | }
90 | }
--------------------------------------------------------------------------------
/moondream/eval/chartqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datasets
3 | import torch
4 |
5 | from tqdm import tqdm
6 | import json
7 |
8 | from ..torch.config import MoondreamConfig
9 | from ..torch.moondream import MoondreamModel
10 | from ..torch.weights import load_weights_into_model
11 |
12 | PREFIX = "Analyze the chart carefully, consider both visual features and data values, and provide a precise answer without any additional explanation or formatting. "
13 |
14 |
15 | # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
16 | def relaxed_correctness(
17 | target: str, prediction: str, max_relative_change: float = 0.05
18 | ) -> bool:
19 | """Calculates relaxed correctness.
20 |
21 | The correctness tolerates certain error ratio defined by max_relative_change.
22 | See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
23 | “Following Methani et al. (2020), we use a relaxed accuracy measure for the
24 | numeric answers to allow a minor inaccuracy that may result from the automatic
25 | data extraction process. We consider an answer to be correct if it is within
26 | 5% of the gold answer. For non-numeric answers, we still need an exact match
27 | to consider an answer to be correct.”
28 |
29 | Args:
30 | target: Target string.
31 | prediction: Predicted string.
32 | max_relative_change: Maximum relative change.
33 |
34 | Returns:
35 | Whether the prediction was correct given the specified tolerance.
36 | """
37 |
38 | def _to_float(text):
39 | try:
40 | if text.endswith("%"):
41 | # Convert percentages to floats.
42 | return float(text.rstrip("%")) / 100.0
43 | else:
44 | return float(text)
45 | except ValueError:
46 | return None
47 |
48 | prediction = str(prediction)
49 | target = str(target)
50 | prediction_float = _to_float(prediction)
51 | target_float = _to_float(target)
52 | if prediction_float is not None and target_float:
53 | relative_change = abs(prediction_float - target_float) / abs(target_float)
54 | return relative_change <= max_relative_change
55 | else:
56 | return prediction == target
57 |
58 |
59 | def eval_chartqa(model, debug=False):
60 | dataset = datasets.load_dataset("vikhyatk/chartqa", split="test")
61 |
62 | correct = 0
63 | total = 0
64 | human_correct = 0
65 | human_total = 0
66 | results = []
67 |
68 | for row in tqdm(dataset, disable=debug, desc="ChartQA"):
69 | image = row["image"]
70 | encoded_image = model.encode_image(image)
71 |
72 | result = []
73 | for qa in row["qa"]:
74 | question = PREFIX + qa["question"]
75 | answer = qa["answer"]
76 | model_answer = model.query(encoded_image, question)["answer"]
77 |
78 | # Attempt to parse both answers into lists, otherwise
79 | try:
80 | answer_list = json.loads(answer)
81 | model_answer_list = json.loads(model_answer)
82 | if not (
83 | isinstance(answer_list, list)
84 | and isinstance(model_answer_list, list)
85 | and len(answer_list) == len(model_answer_list)
86 | ):
87 | raise ValueError
88 | except:
89 | # If parsing fails or lengths are not equal, compare the strings directly instead
90 | answer_list = [answer]
91 | model_answer_list = [model_answer]
92 |
93 | total += 1
94 | if qa["source"] == "human":
95 | human_total += 1
96 |
97 | is_correct = False
98 | if all(
99 | relaxed_correctness(
100 | str(cur_answer).strip().lower(),
101 | str(cur_model_answer).strip().lower(),
102 | )
103 | for cur_answer, cur_model_answer in zip(answer_list, model_answer_list)
104 | ):
105 | correct += 1
106 | if qa["source"] == "human":
107 | human_correct += 1
108 | is_correct = True
109 | if debug:
110 | print(
111 | f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}"
112 | )
113 | print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}")
114 | print(f"Total Accuracy: {correct * 100 / total:.2f}")
115 | print("---------")
116 | result.append(
117 | {
118 | "question": question,
119 | "ground_truth": answer_list,
120 | "model_answer": model_answer_list,
121 | "is_correct": is_correct,
122 | "source": qa["source"],
123 | }
124 | )
125 | results.append(result)
126 |
127 | return {
128 | "human_acc": human_correct * 100 / human_total,
129 | "total_acc": correct * 100 / total,
130 | "results": results,
131 | }
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser()
136 | parser.add_argument("--model", type=str, required=True)
137 | parser.add_argument("--debug", action="store_true")
138 | args = parser.parse_args()
139 |
140 | if torch.cuda.is_available():
141 | torch.set_default_device("cuda")
142 | elif torch.backends.mps.is_available():
143 | torch.set_default_device("mps")
144 |
145 | config = MoondreamConfig()
146 | model = MoondreamModel(config)
147 | load_weights_into_model(args.model, model)
148 | model.compile()
149 |
150 | results = eval_chartqa(model, args.debug)
151 | print(f"Human Accuracy: {results['human_acc']:.2f}")
152 | print(f"Total Accuracy: {results['total_acc']:.2f}")
153 |
--------------------------------------------------------------------------------
/moondream/eval/countbenchqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datasets
3 | import torch
4 |
5 | from tqdm import tqdm
6 |
7 | from ..torch.config import MoondreamConfig
8 | from ..torch.moondream import MoondreamModel
9 | from ..torch.weights import load_weights_into_model
10 |
11 | PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. "
12 |
13 |
14 | def eval_countbenchqa(model, debug=False):
15 | dataset = datasets.load_dataset("vikhyatk/CountBenchQA", split="test")
16 |
17 | correct = 0
18 | total = 0
19 | results = []
20 |
21 | for row in tqdm(dataset, disable=debug, desc="CountBenchQA"):
22 | image = row["image"]
23 | encoded_image = model.encode_image(image)
24 |
25 | question = PREFIX + row["question"]
26 | answer = str(row["number"])
27 | model_answer = model.query(encoded_image, question)["answer"]
28 | is_correct = model_answer.strip().lower() == answer.strip().lower()
29 |
30 | results.append(
31 | {
32 | "question": question,
33 | "ground_truth": answer,
34 | "model_answer": model_answer,
35 | "is_correct": is_correct,
36 | }
37 | )
38 |
39 | total += 1
40 | if is_correct:
41 | correct += 1
42 | elif debug:
43 | print(f"Question: {row['question']}")
44 | print(f"Answer: {answer}")
45 | print(f"Model Answer: {model_answer}")
46 | if debug:
47 | print(f"Correct: {correct}, Total: {total}")
48 | print(f"Accuracy: {correct * 100 / total:.2f}")
49 | print("---------")
50 |
51 | return {
52 | "acc": correct * 100 / total,
53 | "correct_count": correct,
54 | "total_count": total,
55 | "results": results,
56 | }
57 |
58 |
59 | if __name__ == "__main__":
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--model", type=str, required=True)
62 | parser.add_argument("--debug", action="store_true")
63 | args = parser.parse_args()
64 |
65 | if torch.cuda.is_available():
66 | torch.set_default_device("cuda")
67 | elif torch.backends.mps.is_available():
68 | torch.set_default_device("mps")
69 |
70 | config = MoondreamConfig()
71 | model = MoondreamModel(config)
72 | load_weights_into_model(args.model, model)
73 |
74 | result = eval_countbenchqa(model, args.debug)
75 |
76 | print(f"Accuracy: {result['acc']:.2f}")
77 | print(f"Correct: {result['correct_count']}, Total: {result['total_count']}")
78 |
--------------------------------------------------------------------------------
/moondream/eval/docvqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import editdistance
3 | from datasets import load_dataset
4 | from tqdm import tqdm
5 | import torch
6 |
7 | from ..torch.config import MoondreamConfig
8 | from ..torch.moondream import MoondreamModel
9 | from ..torch.weights import load_weights_into_model
10 |
11 | SUFFIX = " The answer should be a short text span taken verbatim from the document."
12 |
13 |
14 | def get_anls(s1, s2):
15 | s1 = s1.lower().strip()
16 | s2 = s2.lower().strip()
17 | iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
18 | anls = iou if iou >= 0.5 else 0.0
19 | return anls
20 |
21 |
22 | def eval_docvqa(model, debug=False):
23 | docvqa_val = load_dataset("vikhyatk/docvqa-val", split="validation")
24 |
25 | scores = []
26 | results = []
27 |
28 | for row in tqdm(docvqa_val, disable=debug, desc="DocVQA"):
29 | image = row["image"]
30 | encoded_image = model.encode_image(image)
31 |
32 | result = []
33 | for qa in row["qa"]:
34 | question = qa["question"]
35 | answers = qa["answers"]
36 | prompt = question + SUFFIX
37 |
38 | model_answer = model.query(encoded_image, prompt)["answer"]
39 | anls = max(get_anls(model_answer, gt) for gt in answers)
40 | scores.append(anls)
41 | result.append(
42 | {
43 | "question": question,
44 | "ground_truth": answers,
45 | "model_answer": model_answer,
46 | "anls": anls,
47 | }
48 | )
49 |
50 | if debug:
51 | print(f"Question: {question}")
52 | print(f"Ground Truth: {answers}")
53 | print(f"Model Answer: {model_answer}")
54 | print(f"ANLS: {anls}")
55 | print(f"Current Average ANLS: {sum(scores) / len(scores):.4f}")
56 | print("---------")
57 | results.append(result)
58 |
59 | return {
60 | "anls": sum(scores) / len(scores),
61 | "results": results,
62 | }
63 |
64 |
65 | if __name__ == "__main__":
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument("--model", type=str, required=True)
68 | parser.add_argument("--debug", action="store_true")
69 | args = parser.parse_args()
70 |
71 | if torch.cuda.is_available():
72 | torch.set_default_device("cuda")
73 | elif torch.backends.mps.is_available():
74 | torch.set_default_device("mps")
75 |
76 | config = MoondreamConfig()
77 | model = MoondreamModel(config)
78 | load_weights_into_model(args.model, model)
79 | model.compile()
80 |
81 | result = eval_docvqa(model, args.debug)
82 |
83 | print(f"ANLS: {result['anls']:.4f}")
84 |
--------------------------------------------------------------------------------
/moondream/eval/eval_all.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from pprint import pprint
5 |
6 | from ..torch.config import MoondreamConfig
7 | from ..torch.moondream import MoondreamModel
8 | from ..torch.weights import load_weights_into_model
9 |
10 | from .countbenchqa import eval_countbenchqa
11 | from .pope import evaluate_pope
12 | from .realworldqa import eval_realworldqa
13 | from .chartqa import eval_chartqa
14 | from .textvqa import eval_textvqa
15 | from .docvqa import eval_docvqa
16 | from .mmstar import eval_mmstar
17 | from .coco_map import eval_coco_map
18 | from .naturalbench import eval_naturalbench
19 | from .tallyqa import eval_tallyqa
20 |
21 |
22 | def create_model(ckpt_path):
23 | config = MoondreamConfig()
24 | model = MoondreamModel(config)
25 | load_weights_into_model(ckpt_path, model)
26 | model.compile()
27 | return model
28 |
29 |
30 | def eval_all(model, skip=[]):
31 | evals = {
32 | "countbenchqa": eval_countbenchqa,
33 | "pope": evaluate_pope,
34 | "realworldqa": eval_realworldqa,
35 | "chartqa": eval_chartqa,
36 | "mmstar": eval_mmstar,
37 | "docvqa": eval_docvqa,
38 | "coco_map": eval_coco_map,
39 | "textvqa": eval_textvqa,
40 | "naturalbench": eval_naturalbench,
41 | "tallyqa": eval_tallyqa,
42 | }
43 |
44 | for b in skip:
45 | del evals[b]
46 |
47 | results = {}
48 | for name, eval_fn in evals.items():
49 | results[name] = eval_fn(model)
50 | pprint({k: v for k, v in results[name].items() if k != "results"})
51 |
52 | return results
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--model", type=str, required=True)
58 | args = parser.parse_args()
59 |
60 | if torch.cuda.is_available():
61 | torch.set_default_device("cuda")
62 | elif torch.backends.mps.is_available():
63 | torch.set_default_device("mps")
64 |
65 | model = create_model(args.model)
66 | eval_all(model)
67 |
--------------------------------------------------------------------------------
/moondream/eval/gazefollow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datasets
3 | import math
4 |
5 | from tqdm import tqdm
6 |
7 | from ..torch.config import MoondreamConfig
8 | from ..torch.moondream import MoondreamModel
9 | from ..torch.weights import load_weights_into_model
10 |
11 |
12 | def eval_gazefollow(model, debug=False):
13 | dataset = datasets.load_dataset("vikhyatk/gazefollow", split="test")
14 |
15 | mean_l2_error = []
16 | min_l2_error = []
17 | total = 0
18 |
19 | for i, row in tqdm(enumerate(dataset), total=len(dataset)):
20 | heads = []
21 |
22 | for gaze in row["gazes"]:
23 | head_bbox = gaze["head_bbox"] # xmin, ymin, xmax, ymax
24 | eye_coord = (gaze["eye"]["x"], gaze["eye"]["y"])
25 | mean_target_gaze = (gaze["gaze"]["x"], gaze["gaze"]["y"])
26 |
27 | # Check if a head already exists with the same approximate bbox.
28 | # If so, use that head instead of creating a new one.
29 | for head in heads:
30 | if (
31 | abs(head["head_bbox"]["xmin"] - head_bbox["xmin"]) < 0.001
32 | and abs(head["head_bbox"]["xmax"] - head_bbox["xmax"]) < 0.001
33 | and abs(head["head_bbox"]["ymin"] - head_bbox["ymin"]) < 0.001
34 | and abs(head["head_bbox"]["ymax"] - head_bbox["ymax"]) < 0.001
35 | ):
36 | head["gazes"].append(mean_target_gaze)
37 | break
38 | else:
39 | heads.append(
40 | {
41 | "head_bbox": head_bbox,
42 | "eye_coord": eye_coord,
43 | "gazes": [mean_target_gaze],
44 | }
45 | )
46 |
47 | for head in heads:
48 | pred_gaze = model.detect_gaze(
49 | row["image"],
50 | eye=head["eye_coord"],
51 | face={
52 | "x_min": head["head_bbox"]["xmin"],
53 | "y_min": head["head_bbox"]["ymin"],
54 | "x_max": head["head_bbox"]["xmax"],
55 | "y_max": head["head_bbox"]["ymax"],
56 | },
57 | unstable_settings={"force_detect": True},
58 | )["gaze"]
59 |
60 | mean_target_gaze = (
61 | sum(gaze[0] for gaze in head["gazes"]) / len(head["gazes"]),
62 | sum(gaze[1] for gaze in head["gazes"]) / len(head["gazes"]),
63 | )
64 | mean_l2 = math.sqrt(
65 | (mean_target_gaze[0] - pred_gaze["x"]) ** 2
66 | + (mean_target_gaze[1] - pred_gaze["y"]) ** 2
67 | )
68 | min_l2 = min(
69 | math.sqrt(
70 | (target_gaze[0] - pred_gaze["x"]) ** 2
71 | + (target_gaze[1] - pred_gaze["y"]) ** 2
72 | )
73 | for target_gaze in head["gazes"]
74 | )
75 |
76 | mean_l2_error.append(mean_l2)
77 | min_l2_error.append(min_l2)
78 | total += 1
79 |
80 | if i % 100 == 0 and debug:
81 | print("Mean L2 error:", sum(mean_l2_error) / total)
82 | print("Min L2 error:", sum(min_l2_error) / total)
83 |
84 | return {
85 | "mean_l2": sum(mean_l2_error) / total,
86 | "min_l2": sum(min_l2_error) / total,
87 | }
88 |
89 |
90 | if __name__ == "__main__":
91 | import argparse
92 |
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--model", type=str, required=True)
95 |
96 | parser.add_argument("--debug", action="store_true")
97 | args = parser.parse_args()
98 |
99 | if torch.cuda.is_available():
100 | torch.set_default_device("cuda")
101 | elif torch.backends.mps.is_available():
102 | torch.set_default_device("mps")
103 |
104 | config = MoondreamConfig()
105 | model = MoondreamModel(config)
106 | load_weights_into_model(args.model, model)
107 |
108 | results = eval_gazefollow(model, debug=args.debug)
109 |
110 | print(f"Mean L2 error: {results['mean_l2']:.4f}")
111 | print(f"Min L2 error: {results['min_l2']:.4f}")
112 |
--------------------------------------------------------------------------------
/moondream/eval/mmstar.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import torch
3 |
4 | from tqdm import tqdm
5 |
6 | from ..torch.config import MoondreamConfig
7 | from ..torch.moondream import MoondreamModel
8 | from ..torch.weights import load_weights_into_model
9 |
10 | SUFFIX = " Please answer directly with only the letter of the correct option and nothing else."
11 |
12 |
13 | def eval_mmstar(model, debug=False):
14 | dataset = datasets.load_dataset("Lin-Chen/MMStar", split="val")
15 |
16 | correct = 0
17 | total = 0
18 | category_stats = {}
19 | results = []
20 |
21 | for row in tqdm(dataset, disable=debug, desc="MMStar"):
22 | image = row["image"]
23 | question = row["question"] + SUFFIX
24 | answer = row["answer"]
25 | model_answer = model.query(image, question)["answer"]
26 | is_correct = model_answer.strip().lower() == answer.strip().lower()
27 |
28 | category = f"{row['category']} / {row['l2_category']}"
29 | if category not in category_stats:
30 | category_stats[category] = {"correct": 0, "total": 0}
31 |
32 | total += 1
33 | category_stats[category]["total"] += 1
34 |
35 | results.append(
36 | {
37 | "question": question,
38 | "ground_truth": answer,
39 | "model_answer": model_answer,
40 | "is_correct": is_correct,
41 | "category": category,
42 | }
43 | )
44 |
45 | if is_correct:
46 | correct += 1
47 | category_stats[category]["correct"] += 1
48 | elif debug:
49 | print(f"Index: {row['index']}")
50 | print(f"Question: {row['question']}")
51 | print(f"Answer: {answer}")
52 | print(f"Model Answer: {model_answer}")
53 | if debug:
54 | print(f"Correct: {correct}, Total: {total}")
55 | print(f"Accuracy: {correct * 100 / total:.2f}")
56 | print("Results by category:")
57 | for category, stats in category_stats.items():
58 | acc = stats["correct"] * 100 / stats["total"]
59 | print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%")
60 | print("---------")
61 |
62 | return {
63 | "acc": correct * 100 / total,
64 | "correct_count": correct,
65 | "total_count": total,
66 | "category_stats": category_stats,
67 | "results": results,
68 | }
69 |
70 |
71 | if __name__ == "__main__":
72 | import argparse
73 |
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument("--model", type=str, required=True)
76 | parser.add_argument("--debug", action="store_true")
77 | args = parser.parse_args()
78 |
79 | if torch.cuda.is_available():
80 | torch.set_default_device("cuda")
81 | elif torch.backends.mps.is_available():
82 | torch.set_default_device("mps")
83 |
84 | config = MoondreamConfig()
85 | model = MoondreamModel(config)
86 | load_weights_into_model(args.model, model)
87 | model.compile()
88 |
89 | result = eval_mmstar(model, args.debug)
90 |
91 | print(f"Correct: {result['correct_count']}, Total: {result['total_count']}")
92 | print(f"Accuracy: {result['acc']:.2f}")
93 |
94 | print("\nResults by category:")
95 | for category, stats in result["category_stats"].items():
96 | acc = stats["correct"] * 100 / stats["total"]
97 | print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%")
98 |
--------------------------------------------------------------------------------
/moondream/eval/naturalbench.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from tqdm import tqdm
3 | import torch
4 |
5 | from ..torch.config import MoondreamConfig
6 | from ..torch.moondream import MoondreamModel
7 | from ..torch.weights import load_weights_into_model
8 |
9 |
10 | def eval_naturalbench(model, debug=False):
11 | # Yes, the benchmark test set is stored in the 'train' split...
12 | dataset = load_dataset("BaiqiL/NaturalBench", split="train")
13 |
14 | acc = []
15 | q_acc = []
16 | i_acc = []
17 | g_acc = []
18 |
19 | for row in tqdm(dataset, disable=debug, desc="NaturalBench"):
20 | if row["Question_Type"] == "yes_no":
21 | suffix = " Answer yes or no."
22 | else:
23 | suffix = ""
24 |
25 | images = [row["Image_0"], row["Image_1"], row["Image_0"], row["Image_1"]]
26 | prompts = [
27 | row["Question_0"] + suffix,
28 | row["Question_0"] + suffix,
29 | row["Question_1"] + suffix,
30 | row["Question_1"] + suffix,
31 | ]
32 | expected = [
33 | row["Image_0_Question_0"].strip().lower(),
34 | row["Image_1_Question_0"].strip().lower(),
35 | row["Image_0_Question_1"].strip().lower(),
36 | row["Image_0_Question_1"].strip().lower(),
37 | ]
38 |
39 | answers = []
40 | for img, prompt in zip(images, prompts):
41 | encoded_image = model.encode_image(img)
42 | answer = model.query(encoded_image, prompt)["answer"]
43 | answers.append(answer.strip().lower())
44 |
45 | if debug:
46 | for i, (q, a, e) in enumerate(zip(prompts, answers, expected)):
47 | print(f"Q{i}: {q}")
48 | print(f"Model: {a}")
49 | print(f"Expected: {e}")
50 | print(f"Correct: {a == e}")
51 | print("---")
52 |
53 | acc.append(answers[0] == expected[0])
54 | acc.append(answers[1] == expected[1])
55 | acc.append(answers[2] == expected[2])
56 | acc.append(answers[3] == expected[3])
57 |
58 | i_acc.append(answers[0] == expected[0] and answers[2] == expected[2])
59 | i_acc.append(answers[1] == expected[1] and answers[3] == expected[3])
60 |
61 | q_acc.append(answers[0] == expected[0] and answers[1] == expected[1])
62 | q_acc.append(answers[2] == expected[2] and answers[3] == expected[3])
63 |
64 | g_acc.append(
65 | answers[0] == expected[0]
66 | and answers[1] == expected[1]
67 | and answers[2] == expected[2]
68 | and answers[3] == expected[3]
69 | )
70 |
71 | if debug:
72 | print(f"Current Overall Accuracy: {sum(acc) / len(acc):.4f}")
73 | print(f"Current Image Accuracy: {sum(i_acc) / len(i_acc):.4f}")
74 | print(f"Current Question Accuracy: {sum(q_acc) / len(q_acc):.4f}")
75 | print(f"Current Group Accuracy: {sum(g_acc) / len(g_acc):.4f}")
76 | print("=========")
77 |
78 | return {
79 | "overall_acc": sum(acc) / len(acc),
80 | "image_acc": sum(i_acc) / len(i_acc),
81 | "question_acc": sum(q_acc) / len(q_acc),
82 | "group_acc": sum(g_acc) / len(g_acc),
83 | }
84 |
85 |
86 | if __name__ == "__main__":
87 | import argparse
88 |
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument("--model", type=str, required=True)
91 | parser.add_argument("--debug", action="store_true")
92 | args = parser.parse_args()
93 |
94 | if torch.cuda.is_available():
95 | torch.set_default_device("cuda")
96 | elif torch.backends.mps.is_available():
97 | torch.set_default_device("mps")
98 |
99 | config = MoondreamConfig()
100 | model = MoondreamModel(config)
101 | load_weights_into_model(args.model, model)
102 | model.compile()
103 |
104 | results = eval_naturalbench(model, debug=args.debug)
105 |
106 | print(f"Overall Accuracy: {results['overall_acc']:.4f}")
107 | print(f"Image Accuracy: {results['image_acc']:.4f}")
108 | print(f"Question Accuracy: {results['question_acc']:.4f}")
109 | print(f"Group Accuracy: {results['group_acc']:.4f}")
110 |
--------------------------------------------------------------------------------
/moondream/eval/pope.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import load_dataset
3 | from tqdm import tqdm
4 | import torch
5 |
6 | from ..torch.config import MoondreamConfig
7 | from ..torch.moondream import MoondreamModel
8 | from ..torch.weights import load_weights_into_model
9 |
10 |
11 | def evaluate_pope(model, debug=False):
12 | pope_dataset = load_dataset("vikhyatk/POPE", split="test")
13 |
14 | stats = {
15 | "random": (0, 0),
16 | "popular": (0, 0),
17 | "adversarial": (0, 0),
18 | }
19 |
20 | for row in tqdm(pope_dataset, disable=debug, desc="POPE"):
21 | image = row["image"]
22 | encoded_image = model.encode_image(image)
23 | for split in ["adversarial", "popular", "random"]:
24 | for qa in row[split]:
25 | question = qa["question"]
26 | answer = qa["answer"]
27 | prompt = f"{question}\nAnswer yes or no."
28 | model_answer = model.query(encoded_image, prompt)["answer"].strip()
29 |
30 | if debug:
31 | print(f"Split: {split}")
32 | print(f"Question: {question}")
33 | print(f"Model: {model_answer}")
34 | print(f"Expected: {answer}")
35 | print(f"Correct: {model_answer.lower() == answer.lower()}")
36 | print("---")
37 |
38 | if model_answer.lower() == answer.lower():
39 | stats[split] = (stats[split][0] + 1, stats[split][1] + 1)
40 | else:
41 | stats[split] = (stats[split][0], stats[split][1] + 1)
42 |
43 | if debug:
44 | for s in stats:
45 | if stats[s][1] > 0:
46 | print(
47 | f"{s.capitalize()}: {stats[s][0]}/{stats[s][1]} = {stats[s][0] * 100.0 / stats[s][1]:.2f}%"
48 | )
49 | print("=========")
50 |
51 | return {
52 | "random": stats["random"][0] * 100.0 / stats["random"][1],
53 | "popular": stats["popular"][0] * 100.0 / stats["popular"][1],
54 | "adversarial": stats["adversarial"][0] * 100.0 / stats["adversarial"][1],
55 | }
56 |
57 |
58 | if __name__ == "__main__":
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument("--model", type=str, required=True)
61 | parser.add_argument("--debug", action="store_true")
62 | args = parser.parse_args()
63 |
64 | if torch.cuda.is_available():
65 | torch.set_default_device("cuda")
66 | elif torch.backends.mps.is_available():
67 | torch.set_default_device("mps")
68 |
69 | config = MoondreamConfig()
70 | model = MoondreamModel(config)
71 | load_weights_into_model(args.model, model)
72 |
73 | result = evaluate_pope(model, args.debug)
74 |
75 | print(f"Random Accuracy: {result['random']:.2f}")
76 | print(f"Popular Accuracy: {result['popular']:.2f}")
77 | print(f"Adversarial Accuracy: {result['adversarial']:.2f}")
78 |
--------------------------------------------------------------------------------
/moondream/eval/realworldqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datasets
3 | import torch
4 |
5 | from tqdm import tqdm
6 |
7 | from ..torch.config import MoondreamConfig
8 | from ..torch.moondream import MoondreamModel
9 | from ..torch.weights import load_weights_into_model
10 |
11 |
12 | def eval_realworldqa(model, debug=False):
13 | dataset = datasets.load_dataset("lmms-lab/RealWorldQA", split="test")
14 |
15 | correct = 0
16 | total = 0
17 | results = []
18 |
19 | for row in tqdm(dataset, disable=debug, desc="RealWorldQA"):
20 | image = row["image"]
21 | question = row["question"]
22 | answer = row["answer"]
23 | model_answer = model.query(image, question)["answer"]
24 | is_correct = model_answer.strip().lower() == answer.strip().lower()
25 |
26 | results.append(
27 | {
28 | "question": question,
29 | "ground_truth": answer,
30 | "model_answer": model_answer,
31 | "is_correct": is_correct,
32 | }
33 | )
34 |
35 | total += 1
36 | if is_correct:
37 | correct += 1
38 | elif debug:
39 | print(f"Image: {row['image_path']}")
40 | print(f"Question: {question}")
41 | print(f"Answer: {answer}")
42 | print(f"Model Answer: {model_answer}")
43 | if debug:
44 | print(f"Correct: {correct}, Total: {total}")
45 | print(f"Accuracy: {correct * 100 / total:.2f}")
46 | print("---------")
47 |
48 | return {
49 | "acc": correct * 100 / total,
50 | "correct_count": correct,
51 | "total_count": total,
52 | "results": results,
53 | }
54 |
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument("--model", type=str, required=True)
59 | parser.add_argument("--debug", action="store_true")
60 | args = parser.parse_args()
61 |
62 | if torch.cuda.is_available():
63 | torch.set_default_device("cuda")
64 | elif torch.backends.mps.is_available():
65 | torch.set_default_device("mps")
66 |
67 | config = MoondreamConfig()
68 | model = MoondreamModel(config)
69 | load_weights_into_model(args.model, model)
70 | model.compile()
71 |
72 | result = eval_realworldqa(model, args.debug)
73 |
74 | print(f"Accuracy: {result['acc']:.2f}")
75 | print(f"Correct: {result['correct_count']} / {result['total_count']}")
76 |
--------------------------------------------------------------------------------
/moondream/eval/tallyqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datasets
3 | import torch
4 |
5 | from tqdm import tqdm
6 |
7 | from ..torch.config import MoondreamConfig
8 | from ..torch.moondream import MoondreamModel
9 | from ..torch.weights import load_weights_into_model
10 |
11 | PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. "
12 |
13 |
14 | def eval_tallyqa(model, debug=False):
15 | dataset = datasets.load_dataset(
16 | "vikhyatk/tallyqa-test",
17 | split="test",
18 | download_config=datasets.DownloadConfig(num_proc=16),
19 | )
20 |
21 | total = 0
22 | total_simple = 0
23 | correct = 0
24 | correct_simple = 0
25 |
26 | for row in tqdm(dataset, disable=args.debug):
27 | image = row["image"]
28 | encoded_image = model.encode_image(image)
29 |
30 | for qa in row["qa"]:
31 | question = PREFIX + qa["question"]
32 | answer = str(qa["answer"])
33 | is_simple = qa["is_simple"]
34 |
35 | model_answer = model.query(encoded_image, question)["answer"]
36 |
37 | total += 1
38 | if model_answer.strip().lower() == answer.strip().lower():
39 | correct += 1
40 | elif args.debug:
41 | print(f"Question: {qa['question']}")
42 | print(f"Answer: {answer}")
43 | print(f"Model Answer: {model_answer}")
44 |
45 | if is_simple:
46 | total_simple += 1
47 | if model_answer.strip().lower() == answer.strip().lower():
48 | correct_simple += 1
49 |
50 | if args.debug:
51 | print(f"Simple - Correct: {correct_simple}, Total: {total_simple}")
52 | print(f"Simple Accuracy: {correct_simple * 100 / total_simple:.2f}")
53 | print(f"All - Correct: {correct}, Total: {total}")
54 | print(f"All Accuracy: {correct * 100 / total:.2f}")
55 | print("---------")
56 |
57 | return {
58 | "simple_acc": correct_simple * 100 / total_simple,
59 | "full_acc": correct * 100 / total,
60 | }
61 |
62 |
63 | if __name__ == "__main__":
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument("--model", type=str, required=True)
66 | parser.add_argument("--debug", action="store_true")
67 | args = parser.parse_args()
68 |
69 | if torch.cuda.is_available():
70 | torch.set_default_device("cuda")
71 | elif torch.backends.mps.is_available():
72 | torch.set_default_device("mps")
73 |
74 | config = MoondreamConfig()
75 | model = MoondreamModel(config)
76 | load_weights_into_model(args.model, model)
77 | model.compile()
78 |
79 | result = eval_tallyqa(model, args.debug)
80 |
81 | print(f"Simple acc: {result['simple_acc']:.2f}")
82 | print(f"Full acc: {result['full_acc']:.2f}")
83 |
--------------------------------------------------------------------------------
/moondream/eval/textvqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datasets
3 | import torch
4 |
5 | from tqdm import tqdm
6 |
7 | from ..torch.config import MoondreamConfig
8 | from ..torch.moondream import MoondreamModel
9 | from ..torch.weights import load_weights_into_model
10 | from .utils import VQAScorer
11 |
12 | PREFIX_TEXTVQA = "Read the text in the image and provide a brief lowercase answer. Respond 'unanswerable' only if there is no plausible answer. "
13 |
14 |
15 | def eval_textvqa(model, debug=False):
16 | dataset = datasets.load_dataset("vikhyatk/textvqa_val", split="validation")
17 |
18 | scorer = VQAScorer()
19 |
20 | total_score = 0
21 | total_samples = 0
22 | results = []
23 |
24 | for row in tqdm(dataset, disable=debug, desc="TextVQA"):
25 | image = row["image"]
26 | encoded_image = model.encode_image(image)
27 | question = PREFIX_TEXTVQA + row["question"]
28 | model_answer = model.query(encoded_image, question)["answer"]
29 |
30 | score = scorer.compute_score(model_answer, row["answers"])
31 | total_score += score
32 | total_samples += 1
33 |
34 | results.append(
35 | {
36 | "question": question,
37 | "ground_truth": row["answers"],
38 | "model_answer": model_answer,
39 | "score": score,
40 | }
41 | )
42 |
43 | if debug:
44 | print(f"Question: {row['question']}")
45 | print(f"Ground Truth Answers: {row['answers']}")
46 | print(f"Model Answer: {model_answer}")
47 | print(f"Score: {score}")
48 | print(f"Running Average Score: {total_score * 100 / total_samples:.2f}")
49 | print("---------")
50 |
51 | return {"score": total_score * 100 / total_samples, "results": results}
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("--model", type=str, required=True)
57 | parser.add_argument("--debug", action="store_true")
58 | args = parser.parse_args()
59 |
60 | if torch.cuda.is_available():
61 | torch.set_default_device("cuda")
62 | elif torch.backends.mps.is_available():
63 | torch.set_default_device("mps")
64 |
65 | config = MoondreamConfig()
66 | model = MoondreamModel(config)
67 | load_weights_into_model(args.model, model)
68 | model.compile()
69 |
70 | result = eval_textvqa(model, args.debug)
71 |
72 | print(f"Score: {result['score']}")
73 |
--------------------------------------------------------------------------------
/moondream/eval/waste_detection.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | from typing import Dict, List, Tuple
4 |
5 | import torch
6 | from PIL import Image
7 | from tqdm import tqdm
8 | from datasets import load_dataset
9 |
10 | from ..torch.config import MoondreamConfig
11 | from ..torch.moondream import MoondreamModel
12 | from ..torch.weights import load_weights_into_model
13 |
14 |
15 | Box = Tuple[float, float, float, float] # (x1, y1, x2, y2) – in proportion form
16 |
17 |
18 | def iou(a: Box, b: Box) -> float:
19 | """Corner-format IoU. Returns 0 when either box has zero area."""
20 | x1, y1 = max(a[0], b[0]), max(a[1], b[1])
21 | x2, y2 = min(a[2], b[2]), min(a[3], b[3])
22 | inter = max(0.0, x2 - x1) * max(0.0, y2 - y1)
23 |
24 | union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
25 | return inter / union if union else 0.0
26 |
27 |
28 | def match(gt: List[Box], pr: List[Box], iou_thr: float) -> Tuple[int, int, int]:
29 | """
30 | Greedy one-to-one matching with no confidences.
31 | Predictions are taken in the order produced by the model.
32 | """
33 | tp = fp = 0
34 | seen = [False] * len(gt)
35 |
36 | for p in pr:
37 | best, best_i = 0.0, -1
38 | for i, g in enumerate(gt):
39 | if seen[i]:
40 | continue
41 | iou_ = iou(p, g)
42 | if iou_ > best:
43 | best, best_i = iou_, i
44 | if best >= iou_thr:
45 | tp += 1
46 | seen[best_i] = True
47 | else:
48 | fp += 1
49 |
50 | fn = len(gt) - tp
51 | return tp, fp, fn
52 |
53 |
54 | class WasteDetection(torch.utils.data.Dataset):
55 | def __init__(self, name: str = "moondream/waste_detection", split: str = "test"):
56 | self.ds = load_dataset(name, split=split)
57 |
58 | def __len__(self):
59 | return len(self.ds)
60 |
61 | def __getitem__(self, idx: int) -> Dict:
62 | s = self.ds[idx]
63 | img = (
64 | s["image"]
65 | if isinstance(s["image"], Image.Image)
66 | else Image.fromarray(s["image"])
67 | )
68 | W, H = float(s.get("width", img.width)), float(s.get("height", img.height))
69 |
70 | lbl_to_boxes = defaultdict(list)
71 | for (xc, yc, bw, bh), lbl in zip(s["boxes"], s["labels"]):
72 | x1 = xc - bw / 2
73 | y1 = yc - bh / 2
74 | x2 = xc + bw / 2
75 | y2 = yc + bh / 2
76 | lbl_to_boxes[lbl].append((x1, y1, x2, y2))
77 |
78 | return {"image": img, "gt": lbl_to_boxes, "W": W, "H": H}
79 |
80 |
81 | def evaluate(
82 | model: MoondreamModel,
83 | iou_thr: float,
84 | debug: bool,
85 | ):
86 | ds = WasteDetection(split="test")
87 | TP = FP = FN = 0
88 |
89 | for s in tqdm(ds, disable=debug, desc="Waste"):
90 | img, gts = s["image"], s["gt"]
91 | enc = model.encode_image(img)
92 |
93 | for lbl, gt_boxes in gts.items():
94 | preds: List[Box] = [
95 | (
96 | o["x_min"],
97 | o["y_min"],
98 | o["x_max"],
99 | o["y_max"],
100 | )
101 | for o in model.detect(enc, lbl)["objects"]
102 | ]
103 | tp, fp, fn = match(gt_boxes, preds, iou_thr)
104 | TP += tp
105 | FP += fp
106 | FN += fn
107 |
108 | prec = TP / (TP + FP) if TP + FP else 0.0
109 | rec = TP / (TP + FN) if TP + FN else 0.0
110 | f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0.0
111 | return dict(precision=prec, recall=rec, f1=f1, tp=TP, fp=FP, fn=FN)
112 |
113 |
114 | def load_model(path: str, device: torch.device) -> MoondreamModel:
115 | cfg = MoondreamConfig()
116 | model = MoondreamModel(cfg)
117 | load_weights_into_model(path, model)
118 | model.compile()
119 | model.to(device)
120 | return model
121 |
122 |
123 | def main():
124 | p = argparse.ArgumentParser()
125 | p.add_argument("--model", required=True)
126 | p.add_argument("--iou_thr", type=float, default=0.5)
127 | p.add_argument("--gpu", type=int, default=0)
128 | p.add_argument("--debug", action="store_true")
129 | args = p.parse_args()
130 |
131 | if torch.cuda.is_available():
132 | torch.cuda.set_device(args.gpu)
133 | device = torch.device(f"cuda:{args.gpu}")
134 | elif torch.backends.mps.is_available():
135 | device = torch.device("mps")
136 | else:
137 | device = torch.device("cpu")
138 |
139 | model = load_model(args.model, device)
140 | res = evaluate(model, args.iou_thr, args.debug)
141 |
142 | print(f"Precision: {res['precision']*100:.2f}%")
143 | print(f"Recall: {res['recall']*100:.2f}%")
144 | print(f"F1 Score: {res['f1']*100:.2f}%")
145 | print(f"TP: {res['tp']} FP: {res['fp']} FN: {res['fn']}")
146 |
147 |
148 | if __name__ == "__main__":
149 | """
150 | Eval to accompany finetune_region.py.
151 | """
152 | main()
153 |
--------------------------------------------------------------------------------
/moondream/finetune/README.md:
--------------------------------------------------------------------------------
1 | # Finetuning Moondream 2B
2 |
3 | This readme will walk you through the process of finetuning the text and region encoders of the Moondream 2B model.
4 |
5 | > Make sure to run all commands from the root directory of the project.
6 |
7 | ## Initial Setup
8 |
9 | ### Clone and Setup Environment
10 | ```bash
11 | git clone https://github.com/vikhyat/moondream
12 | cd moondream
13 | python -m venv .venv
14 | source .venv/bin/activate
15 | ```
16 |
17 | ### Install Dependencies
18 | ```bash
19 | # Install base requirements
20 | pip install -r requirements.txt
21 |
22 | # Install finetuning specific dependencies
23 | pip install safetensors datasets bitsandbytes tqdm wandb einops
24 | ```
25 |
26 | ## Downloading the Base Model
27 |
28 | Download `model.safetensors` from the [Hugging Face repository](https://huggingface.co/vikhyatk/moondream2/tree/main) and place it in the `models` directory as `moondream_base.safetensors`.
29 |
30 | ```bash
31 | # Create models directory
32 | mkdir -p models
33 |
34 | # Download it using curl (run from root moondream directory)
35 | wget https://huggingface.co/vikhyatk/moondream2/resolve/main/model.safetensors
36 | ```
37 |
38 | ## Weights & Biases
39 |
40 | We use Weights & Biases (wandb) to track finetuning progress.
41 |
42 | To set it up to track your runs, use `wandb login`.
43 |
44 | This will take you through creating an account if you don't have one setup already. Enter your API key and you're ready to go.
45 |
46 | ## Finetuning the Text Encoder
47 |
48 | For this example, we will be teaching Moondream to describe images.
49 |
50 | Given the prompt:
51 | `\n\nQuestion: Describe this image.\n\nAnswer:`
52 |
53 | We return a more detailed caption of the image then you would get from the base model.
54 |
55 | 1. Double check that you've updated MODEL_PATH to point to the base moondream model in `moondream/finetune/finetune_text.py`
56 | 2. Double check that the save path ends in `.safetensors`, otherwise the run will fail.
57 | > Navigate to line 150 in `moondream/finetune/finetune_text.py`,
58 | ``` # Add save path
59 | save_file(
60 | model.state_dict(),
61 | "moondream_finetune.safetensors", // update this line ex: "models/moondream_text_finetuned.safetensors"
62 | )
63 | ```
64 |
65 | ### Start Text Finetuning
66 | ```bash
67 | python -m moondream.finetune.finetune_text
68 | ```
69 |
70 | The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_text_finetuned.safetensors`.
71 |
72 | ### Test the Finetuned Text Encoder
73 |
74 | You can test the finetuned models performance with the following command (run from root moondream directory).
75 |
76 | This will return the caption of the image.
77 |
78 | ```bash
79 | # Remember to update the paths
80 | python -m moondream.torch.sample --model [FINETUNED_MODEL_PATH] --image "[DATASET_DIRECTORY]/test/[IMAGE_NAME]" --prompt "\n\nQuestion: Describe this image.\n\nAnswer:"
81 | ```
82 |
83 | ## Finetuning the Region Encoder
84 |
85 | For this example, we will be teaching Moondream to detect railroad cracks in images of a railway.
86 |
87 | Our dataset trains our model such that,
88 |
89 | Given the prompt:
90 | `\n\nDetect: \n\n`
91 |
92 | We are returned the coordinates of a detected crack in the following format:
93 | ```{'objects': [{'x_min': [X_MIN], 'y_min': [Y_MIN], 'x_max': [X_MAX], 'y_max': [Y_MAX]}]}```
94 |
95 | ### Setup Dataset Dependencies
96 |
97 | 1. Update MODEL_PATH to point to the base moondream model.
98 | 5. Double check that the save path ends in `.safetensors`, otherwise the run will fail.
99 | > Navigate to line 244 in `moondream/finetune/finetune_region.py`.
100 | ``` # Add save path
101 | save_file(
102 | model.state_dict(),
103 | "moondream_finetune.safetensors", // update this line
104 | )
105 | ```
106 |
107 | ### Start Region Finetuning
108 | ```bash
109 | python -m moondream.finetune.finetune_region
110 | ```
111 |
112 | The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_region_finetuned.safetensors`.
--------------------------------------------------------------------------------
/moondream/finetune/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/moondream/finetune/__init__.py
--------------------------------------------------------------------------------
/moondream/finetune/finetune_text.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import Dataset
4 | import math
5 | from safetensors.torch import save_file
6 |
7 | from tqdm import tqdm
8 | from datasets import load_dataset
9 | from bitsandbytes.optim import AdamW8bit
10 | import wandb
11 |
12 | from ..torch.weights import load_weights_into_model
13 | from ..torch.moondream import MoondreamModel, MoondreamConfig, text_encoder
14 | from ..torch.text import _produce_hidden, _lm_head, TextConfig
15 |
16 | # This is a intended to be a basic starting point for fine-tuning the text encoder.
17 | # Your optimal hyperparams and data may be different.
18 | MODEL_PATH = ""
19 | # Your data should end with the eos token. Here is the textual representation.
20 | ANSWER_EOS = "<|endoftext|>"
21 | LR = 3e-6
22 | EPOCHS = 3
23 | GRAD_ACCUM_STEPS = 128
24 |
25 |
26 | def lr_schedule(step, max_steps):
27 | x = step / max_steps
28 | if x < 0.1:
29 | return 0.1 * LR + 0.9 * LR * x / 0.1
30 | else:
31 | return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2
32 |
33 |
34 | def text_loss(
35 | inputs_embeds: torch.Tensor, w: nn.Module, labels: torch.Tensor, config: TextConfig
36 | ):
37 | _, q_len, _ = inputs_embeds.shape
38 | hidden_BTC = _produce_hidden(inputs_embeds, w, config)
39 | lm_logits = _lm_head(hidden_BTC, w)
40 |
41 | loss = None
42 | if labels is not None:
43 | _, _, l_len = labels.shape
44 | shift_index = (q_len - l_len) - 1
45 | shifted_logits = lm_logits[..., shift_index:-1, :].contiguous()
46 | shifted_labels = labels.contiguous()
47 | loss = nn.CrossEntropyLoss()(
48 | shifted_logits.view(-1, shifted_logits.size(-1)),
49 | shifted_labels.view(-1),
50 | )
51 | return loss
52 |
53 |
54 | class DocciDataset(Dataset):
55 | def __init__(self, split="train"):
56 | self.data = load_dataset("google/docci", trust_remote_code=True)[split]
57 |
58 | def __len__(self):
59 | return len(self.data)
60 |
61 | def __getitem__(self, idx):
62 | sample = self.data[idx]
63 | description = sample["description"]
64 | return {
65 | "image": sample["image"],
66 | "qa": {
67 | "question": "\n\nQuestion: Describe this image.\n\nAnswer:",
68 | "answer": f"{description}{ANSWER_EOS}",
69 | },
70 | }
71 |
72 |
73 | def main():
74 | if torch.cuda.is_available():
75 | torch.set_default_device("cuda")
76 | elif torch.backends.mps.is_available():
77 | torch.set_default_device("mps")
78 |
79 | wandb.init(
80 | project="moondream-ft",
81 | config={
82 | "EPOCHS": EPOCHS,
83 | "GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS,
84 | "LR": LR,
85 | },
86 | )
87 |
88 | config = MoondreamConfig()
89 | model = MoondreamModel(config)
90 | load_weights_into_model(MODEL_PATH, model)
91 |
92 | optimizer = AdamW8bit(
93 | [
94 | {"params": model.text.parameters()},
95 | ],
96 | lr=LR,
97 | betas=(0.9, 0.95),
98 | eps=1e-6,
99 | )
100 |
101 | dataset = DocciDataset("train")
102 |
103 | total_steps = EPOCHS * len(dataset) // GRAD_ACCUM_STEPS
104 | pbar = tqdm(total=total_steps)
105 |
106 | i = 0
107 | for epoch in range(EPOCHS):
108 | for sample in dataset:
109 | i += 1
110 | with torch.no_grad():
111 | img_emb = model._run_vision_encoder(sample["image"])
112 | bos_emb = text_encoder(
113 | torch.tensor([[model.config.tokenizer.bos_id]], device=model.device),
114 | model.text,
115 | )
116 | question_tokens = model.tokenizer.encode(sample["qa"]["question"]).ids
117 | question_emb = text_encoder(
118 | torch.tensor([[question_tokens]], device=model.device),
119 | model.text,
120 | ).squeeze(0)
121 | answer_tokens = model.tokenizer.encode(sample["qa"]["answer"]).ids
122 | answer_emb = text_encoder(
123 | torch.tensor([[answer_tokens]], device=model.device),
124 | model.text,
125 | ).squeeze(0)
126 | inputs_embeds = torch.cat(
127 | [bos_emb, img_emb[None], question_emb, answer_emb], dim=1
128 | )
129 | loss = text_loss(
130 | inputs_embeds=inputs_embeds,
131 | w=model.text,
132 | labels=torch.tensor([[answer_tokens]], device=model.device),
133 | config=config.text,
134 | )
135 |
136 | loss.backward()
137 |
138 | if i % GRAD_ACCUM_STEPS == 0:
139 | optimizer.step()
140 | optimizer.zero_grad()
141 |
142 | lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
143 | for param_group in optimizer.param_groups:
144 | param_group["lr"] = lr
145 | pbar.set_postfix({"step": i // GRAD_ACCUM_STEPS, "loss": loss.item()})
146 | pbar.update(1)
147 | wandb.log(
148 | {"loss/train": loss.item(), "lr": optimizer.param_groups[0]["lr"]}
149 | )
150 | wandb.finish()
151 | # Add save path: ex. home/model.safetensors
152 | save_file(
153 | model.state_dict(),
154 | "moondream_finetune.safetensors",
155 | )
156 |
157 |
158 | if __name__ == "__main__":
159 | """
160 | Replace paths with your appropriate paths.
161 | To run: python -m moondream.finetune.finetune_text
162 | """
163 | main()
164 |
--------------------------------------------------------------------------------
/moondream/torch/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Dict, List, Optional
3 |
4 |
5 | @dataclass(frozen=True)
6 | class TextConfig:
7 | dim: int = 2048
8 | ff_dim: int = 8192
9 | n_layers: int = 24
10 | vocab_size: int = 51200
11 | max_context: int = 2048
12 | n_heads: int = 32
13 | n_kv_heads: int = 32
14 | prefix_attn: int = 730
15 | group_size: Optional[int] = None
16 |
17 |
18 | @dataclass(frozen=True)
19 | class VisionConfig:
20 | enc_dim: int = 1152
21 | enc_patch_size: int = 14
22 | enc_n_layers: int = 27
23 | enc_ff_dim: int = 4304
24 | enc_n_heads: int = 16
25 | proj_out_dim: int = 2048
26 | crop_size: int = 378
27 | in_channels: int = 3
28 | max_crops: int = 12
29 | overlap_margin: int = 4
30 | proj_inner_dim: int = 8192
31 |
32 |
33 | @dataclass(frozen=True)
34 | class RegionConfig:
35 | dim: int = 2048
36 | coord_feat_dim: int = 256
37 | coord_out_dim: int = 1024
38 | size_feat_dim: int = 512
39 | size_out_dim: int = 2048
40 | inner_dim: int = 8192
41 | group_size: Optional[int] = None
42 |
43 |
44 | @dataclass(frozen=True)
45 | class TokenizerConfig:
46 | bos_id: int = 0
47 | eos_id: int = 0
48 | coord_id: int = 5
49 | size_id: int = 6
50 | templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
51 | default_factory=lambda: {
52 | "caption": {
53 | "short": [1, 32708, 2, 12492, 3],
54 | "normal": [1, 32708, 2, 6382, 3],
55 | "long": [1, 32708, 2, 4059, 3],
56 | },
57 | "query": {"prefix": [1, 15381, 2], "suffix": [3]},
58 | "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
59 | "point": {"prefix": [1, 2581, 2], "suffix": [3]},
60 | }
61 | )
62 |
63 |
64 | @dataclass(frozen=True)
65 | class MoondreamConfig:
66 | text: TextConfig = TextConfig()
67 | vision: VisionConfig = VisionConfig()
68 | region: RegionConfig = RegionConfig()
69 | tokenizer: TokenizerConfig = TokenizerConfig()
70 |
71 | @classmethod
72 | def from_dict(cls, config_dict: dict):
73 | text_config = TextConfig(**config_dict.get("text", {}))
74 | vision_config = VisionConfig(**config_dict.get("vision", {}))
75 | region_config = RegionConfig(**config_dict.get("region", {}))
76 | tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
77 | return cls(
78 | text=text_config,
79 | vision=vision_config,
80 | region=region_config,
81 | tokenizer=tokenizer_config,
82 | )
83 |
84 | def to_dict(self):
85 | return {
86 | "text": self.text.__dict__,
87 | "vision": self.vision.__dict__,
88 | "region": self.region.__dict__,
89 | "tokenizer": self.tokenizer.__dict__,
90 | }
91 |
--------------------------------------------------------------------------------
/moondream/torch/hf_moondream.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedModel, PretrainedConfig
2 |
3 | from .config import MoondreamConfig
4 | from .moondream import MoondreamModel
5 |
6 | # Files sometimes don't get loaded without these...
7 | from .image_crops import *
8 | from .vision import *
9 | from .text import *
10 | from .region import *
11 | from .utils import *
12 |
13 |
14 | def extract_question(text):
15 | prefix = "\n\nQuestion: "
16 | suffix = "\n\nAnswer:"
17 |
18 | if text.startswith(prefix) and text.endswith(suffix):
19 | return text[len(prefix) : -len(suffix)]
20 | else:
21 | return None
22 |
23 |
24 | class HfConfig(PretrainedConfig):
25 | _auto_class = "AutoConfig"
26 | model_type = "moondream1"
27 |
28 | def __init__(self, **kwargs):
29 | super().__init__(**kwargs)
30 | self.config = {}
31 |
32 |
33 | class HfMoondream(PreTrainedModel):
34 | _auto_class = "AutoModelForCausalLM"
35 | config_class = HfConfig
36 |
37 | def __init__(self, config):
38 | super().__init__(config)
39 | self.model = MoondreamModel(
40 | MoondreamConfig.from_dict(config.config), setup_caches=False
41 | )
42 | self._is_kv_cache_setup = False
43 |
44 | def _setup_caches(self):
45 | if not self._is_kv_cache_setup:
46 | self.model._setup_caches()
47 | self._is_kv_cache_setup = True
48 |
49 | @property
50 | def encode_image(self):
51 | self._setup_caches()
52 | return self.model.encode_image
53 |
54 | @property
55 | def query(self):
56 | self._setup_caches()
57 | return self.model.query
58 |
59 | @property
60 | def caption(self):
61 | self._setup_caches()
62 | return self.model.caption
63 |
64 | @property
65 | def detect(self):
66 | self._setup_caches()
67 | return self.model.detect
68 |
69 | @property
70 | def point(self):
71 | self._setup_caches()
72 | return self.model.point
73 |
74 | @property
75 | def detect_gaze(self):
76 | self._setup_caches()
77 | return self.model.detect_gaze
78 |
79 | def answer_question(
80 | self,
81 | image_embeds,
82 | question,
83 | tokenizer=None,
84 | chat_history="",
85 | result_queue=None,
86 | max_new_tokens=256,
87 | **kwargs
88 | ):
89 | answer = self.query(image_embeds, question)["answer"].strip()
90 |
91 | if result_queue is not None:
92 | result_queue.put(answer)
93 | return answer
94 |
95 | def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
96 | answers = []
97 | for image, prompt in zip(images, prompts):
98 | answers.append(self.query(image, prompt)["answer"].strip())
99 | return answers
100 |
101 | def _unsupported_exception(self):
102 | raise NotImplementedError(
103 | "This method is not supported in the latest version of moondream. "
104 | "Consider upgrading to the updated API spec, or alternately pin "
105 | "to 'revision=2024-08-26'."
106 | )
107 |
108 | def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
109 | """
110 | Function definition remains unchanged for backwards compatibility.
111 | Be aware that tokenizer, max_new_takens, and kwargs are ignored.
112 | """
113 | prompt_extracted = extract_question(prompt)
114 | if prompt_extracted is not None:
115 | answer = self.model.query(
116 | image=image_embeds, question=prompt_extracted, stream=False
117 | )["answer"]
118 | else:
119 | image_embeds = self.encode_image(image_embeds)
120 | prompt_tokens = torch.tensor(
121 | [self.model.tokenizer.encode(prompt).ids],
122 | device=self.device,
123 | )
124 |
125 | def generator():
126 | for token in self.model._generate_text(
127 | prompt_tokens,
128 | image_embeds.kv_cache,
129 | image_embeds.pos,
130 | max_new_tokens,
131 | ):
132 | yield token
133 |
134 | answer = "".join(list(generator()))
135 |
136 | return [answer]
137 |
138 | def get_input_embeddings(self) -> nn.Embedding:
139 | """
140 | Lazily wrap the raw parameter `self.model.text.wte` in a real
141 | `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
142 | **shares** the weight tensor—no copy is made.
143 | """
144 | if not hasattr(self, "_input_embeddings"):
145 | self._input_embeddings = nn.Embedding.from_pretrained(
146 | self.model.text.wte, # tensor created in text.py
147 | freeze=True, # set to False if you need it trainable
148 | )
149 | return self._input_embeddings
150 |
151 | def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
152 | """
153 | Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
154 | embeddings and keeps everything tied to `self.model.text.wte`.
155 | """
156 | # 1. point the low-level parameter to the new weight matrix
157 | self.model.text.wte = value.weight
158 | # 2. keep a reference for get_input_embeddings()
159 | self._input_embeddings = value
160 |
161 | def input_embeds(
162 | self,
163 | input_ids: Union[torch.LongTensor, list, tuple],
164 | *,
165 | device: torch.device | None = None
166 | ) -> torch.FloatTensor:
167 | """
168 | Back-compat wrapper that turns token IDs into embeddings.
169 |
170 | Example:
171 | ids = torch.tensor([[1, 2, 3]])
172 | embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
173 | """
174 | if not torch.is_tensor(input_ids):
175 | input_ids = torch.as_tensor(input_ids)
176 | if device is not None:
177 | input_ids = input_ids.to(device)
178 |
179 | return self.get_input_embeddings()(input_ids)
180 |
--------------------------------------------------------------------------------
/moondream/torch/hf_release.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 |
4 | from .weights import load_weights_into_model
5 | from .hf_moondream import HfConfig, HfMoondream
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--model-name", type=str, default="vikhyatk/moondream-next")
10 | parser.add_argument("--ckpt", type=str, required=True)
11 | args = parser.parse_args()
12 |
13 | config = HfConfig()
14 | model = HfMoondream(config)
15 | load_weights_into_model(args.ckpt, model.model)
16 |
17 | model.push_to_hub(args.model_name, config=config)
18 |
--------------------------------------------------------------------------------
/moondream/torch/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from dataclasses import dataclass
6 | from typing import Literal
7 |
8 | try:
9 | from torchao import quantize_
10 | from torchao.quantization import int4_weight_only
11 | except ImportError:
12 |
13 | def quantize_(model, quant_mode):
14 | raise ImportError(
15 | "torchao is not installed. Please install it with `pip install torchao`."
16 | )
17 |
18 | def int4_weight_only(group_size):
19 | raise ImportError(
20 | "torchao is not installed. Please install it with `pip install torchao`."
21 | )
22 |
23 |
24 | def gelu_approx(x):
25 | return F.gelu(x, approximate="tanh")
26 |
27 |
28 | @dataclass
29 | class LinearWeights:
30 | weight: torch.Tensor
31 | bias: torch.Tensor
32 |
33 |
34 | def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
35 | return F.linear(x, w.weight, w.bias)
36 |
37 |
38 | def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39 | _step = W_q.shape[0]
40 | W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
41 | W_r[:_step] = (W_q & 0b11110000) >> 4
42 | W_r[_step:] = W_q & 0b00001111
43 | W_r.sub_(zero).mul_(scale)
44 | return W_r.reshape(orig_shape)
45 |
46 |
47 | class QuantizedLinear(nn.Module):
48 | def __init__(
49 | self,
50 | in_features: int,
51 | out_features: int,
52 | dtype: torch.dtype,
53 | ):
54 | # TODO: Take group_size as an input instead of hardcoding it here.
55 | super().__init__()
56 | self.in_features = in_features
57 | self.out_features = out_features
58 | self.weight = nn.ParameterDict(
59 | {
60 | "packed": nn.Parameter(
61 | torch.empty(
62 | out_features * in_features // (128 * 2), 128, dtype=torch.uint8
63 | ),
64 | requires_grad=False,
65 | ),
66 | "scale": nn.Parameter(
67 | torch.empty(out_features * in_features // 128, 1),
68 | requires_grad=False,
69 | ),
70 | "zero_point": nn.Parameter(
71 | torch.empty(out_features * in_features // 128, 1),
72 | requires_grad=False,
73 | ),
74 | }
75 | )
76 | self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
77 | self.unpacked = False
78 |
79 | def unpack(self):
80 | if self.unpacked:
81 | return
82 |
83 | self.weight = nn.Parameter(
84 | dequantize_tensor(
85 | self.weight["packed"],
86 | self.weight["scale"],
87 | self.weight["zero_point"],
88 | (self.out_features, self.in_features),
89 | torch.bfloat16,
90 | )
91 | )
92 | with torch.device("meta"):
93 | self.linear = nn.Linear(
94 | self.in_features, self.out_features, dtype=torch.bfloat16
95 | )
96 | self.linear.weight = self.weight
97 | self.linear.bias = nn.Parameter(
98 | self.bias.to(torch.bfloat16), requires_grad=False
99 | )
100 |
101 | del self.weight, self.bias
102 | quantize_(self, int4_weight_only(group_size=128))
103 | self.unpacked = True
104 | torch.cuda.empty_cache()
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | if not self.unpacked:
108 | self.unpack()
109 | return self.linear(x)
110 |
111 |
112 | @dataclass
113 | class LayerNormWeights:
114 | weight: torch.Tensor
115 | bias: torch.Tensor
116 |
117 |
118 | def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
119 | return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
120 |
121 |
122 | @dataclass
123 | class MLPWeights:
124 | fc1: LinearWeights
125 | fc2: LinearWeights
126 | act: Literal["gelu_approx"] = "gelu_approx"
127 |
128 |
129 | def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
130 | x = w.fc1(x)
131 | x = gelu_approx(x)
132 | x = w.fc2(x)
133 | return x
134 |
135 |
136 | @dataclass
137 | class AttentionWeights:
138 | qkv: LinearWeights
139 | proj: LinearWeights
140 |
141 |
142 | def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
143 | bsz, q_len, d_model = x.shape
144 | head_dim = d_model // n_heads
145 |
146 | q, k, v = [
147 | t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
148 | for t in linear(x, w.qkv).chunk(3, dim=-1)
149 | ]
150 | out = F.scaled_dot_product_attention(q, k, v)
151 | out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
152 | out = linear(out, w.proj)
153 | return out
154 |
--------------------------------------------------------------------------------
/moondream/torch/region.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 | from typing import List, Tuple, Union
6 |
7 | from .layers import mlp
8 |
9 | SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
10 |
11 |
12 | def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
13 | """
14 | Applies Fourier feature mapping to input tensor x using frequency matrix w. This
15 | projects inputs through sinusoidal functions to create higher dimensional features
16 | that help mitigate spectral bias - the tendency of neural networks to learn
17 | low-frequency functions more easily than high-frequency ones. By explicitly
18 | mapping inputs to higher frequencies through sin/cos transformations, we enable
19 | better learning of fine details and higher frequency patterns.
20 |
21 | Args:
22 | x: Input tensor to transform
23 | w: Matrix of frequencies for the Fourier features transformation
24 |
25 | Returns:
26 | Concatenated cosine and sine transformed features as a tensor
27 | """
28 | f = 2 * math.pi * x @ w
29 | return torch.cat([f.cos(), f.sin()], dim=-1)
30 |
31 |
32 | def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
33 | """
34 | Takes as input a tensor containing a single float coordinate value (x or y)
35 | and encodes it into hidden states for input to the text model.
36 |
37 | Args:
38 | coord: Tensor with single float coordinate value
39 |
40 | Returns:
41 | Encoded hidden states tensor for input to text model
42 | """
43 | return w.coord_encoder(fourier_features(coord, w.coord_features))
44 |
45 |
46 | def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
47 | """
48 | Takes as input the last hidden state from the text model and outputs a single logit
49 | representing either an x or y coordinate prediction.
50 |
51 | Args:
52 | hidden_state: The final hidden state tensor from the text model.
53 |
54 | Returns:
55 | A single logit representing the predicted coordinate value (x or y)
56 | """
57 | return mlp(hidden_state, w.coord_decoder)
58 |
59 |
60 | def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
61 | """
62 | Takes a tensor containing width and height values and encodes them into
63 | hidden states for input to the text model.
64 |
65 | Args:
66 | size: Tensor with two floats for width and height
67 |
68 | Returns:
69 | Encoded hidden states tensor for input to text model
70 | """
71 | return w.size_encoder(fourier_features(size, w.size_features))
72 |
73 |
74 | def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75 | """
76 | Takes as input the last hidden state from the text model and outputs logits
77 | for 1024 bins representing width and height in log-scale.
78 |
79 | The bins are distributed according to the formula:
80 | bin = (log2(size) + 10.0) / 10.0 * 1023.0
81 | where size values are clamped to be at least 1/1024.
82 |
83 | To convert from bin back to size:
84 | size = 2^((bin / 1023.0) * 10.0 - 10.0)
85 |
86 | Args:
87 | hidden_state: The final hidden state tensor from the text model.
88 |
89 | Returns:
90 | A tensor containing logits for 1024 bins for width and height.
91 | Shape is (2, 1024) where the first dimension corresponds to width and height.
92 | """
93 | return mlp(hidden_state, w.size_decoder).view(2, -1)
94 |
95 |
96 | def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
97 | """
98 | Takes a list of spatial references (points or regions) and encodes them into
99 | hidden states for input to the text model.
100 |
101 | Args:
102 | spatial_refs: List of spatial references (points or boxes)
103 | - Points are represented as normalized (x, y) tuples
104 | - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
105 |
106 | Returns:
107 | {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
108 | """
109 | coords, sizes = [], []
110 | for ref in spatial_refs:
111 | if len(ref) == 2:
112 | coords.append(ref[0])
113 | coords.append(ref[1])
114 | else:
115 | x_c = (ref[0] + ref[2]) / 2
116 | y_c = (ref[1] + ref[3]) / 2
117 | width = ref[2] - ref[0]
118 | height = ref[3] - ref[1]
119 | coords.append(x_c)
120 | coords.append(y_c)
121 | sizes.append([width, height])
122 |
123 | coords = torch.tensor(
124 | coords, device=w.coord_features.device, dtype=w.coord_features.dtype
125 | ).view(-1, 1)
126 | coords = encode_coordinate(coords, w)
127 |
128 | if sizes:
129 | sizes = torch.tensor(
130 | sizes, device=w.size_features.device, dtype=w.size_features.dtype
131 | )
132 | sizes = encode_size(sizes, w)
133 | else:
134 | sizes = None
135 |
136 | return {"coords": coords, "sizes": sizes}
137 |
--------------------------------------------------------------------------------
/moondream/torch/rope.py:
--------------------------------------------------------------------------------
1 | # Ethically sourced from https://github.com/xjdr-alt/entropix
2 |
3 | import torch
4 |
5 |
6 | def precompute_freqs_cis(
7 | dim: int,
8 | end: int,
9 | theta: float = 10000.0,
10 | use_scaled: bool = False,
11 | dtype: torch.dtype = torch.float32,
12 | ) -> torch.Tensor:
13 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
14 | t = torch.arange(end, dtype=dtype).unsqueeze(1)
15 | freqs = t * freqs.unsqueeze(0)
16 | freqs = torch.exp(1j * freqs)
17 | return torch.stack([freqs.real, freqs.imag], dim=-1)
18 |
19 |
20 | def apply_rotary_emb(
21 | x: torch.Tensor,
22 | freqs_cis: torch.Tensor,
23 | position_ids: torch.Tensor,
24 | num_heads: int,
25 | rot_dim: int = 32,
26 | interleave: bool = False,
27 | ) -> torch.Tensor:
28 | assert rot_dim == freqs_cis.shape[-2] * 2
29 | assert num_heads == x.shape[1]
30 |
31 | x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
32 |
33 | if interleave:
34 | xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
35 | xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
36 | else:
37 | d_q = x_rot.shape[-1] // 2
38 | xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
39 |
40 | freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
41 | freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
42 |
43 | # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
44 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
45 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
46 | xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
47 |
48 | return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
49 |
--------------------------------------------------------------------------------
/moondream/torch/text.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch.nn import functional as F
5 |
6 | from .layers import layer_norm, mlp, QuantizedLinear
7 | from .rope import apply_rotary_emb, precompute_freqs_cis
8 | from .config import TextConfig
9 |
10 |
11 | def text_encoder(input_ids: torch.Tensor, w: nn.Module):
12 | return F.embedding(input_ids, w.wte)
13 |
14 |
15 | def attn(
16 | x: torch.Tensor,
17 | w: nn.Module,
18 | freqs_cis: torch.Tensor,
19 | kv_cache: nn.Module,
20 | attn_mask: torch.Tensor,
21 | n_heads: int,
22 | n_kv_heads: int,
23 | position_ids: torch.Tensor,
24 | ):
25 | bsz, q_len, d_model = x.shape
26 | head_dim = d_model // n_heads
27 |
28 | qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
29 | q_dim = n_heads * head_dim
30 | kv_dim = n_kv_heads * head_dim
31 | q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
32 | del qkv_out
33 |
34 | q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
35 | k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
36 | v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
37 |
38 | q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
39 | k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
40 |
41 | if kv_cache is not None:
42 | k, v = kv_cache.update(position_ids, k, v)
43 |
44 | out = F.scaled_dot_product_attention(
45 | q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
46 | )
47 | out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
48 | out = w.proj(out)
49 | return out
50 |
51 |
52 | def _attn(
53 | x: torch.Tensor,
54 | w: torch.Tensor,
55 | freqs_cis: torch.Tensor,
56 | attn_mask: torch.Tensor,
57 | n_heads: int,
58 | n_kv_heads: int,
59 | ):
60 | bsz, q_len, d_model = x.shape
61 | head_dim = d_model // n_heads
62 | pos = 0
63 |
64 | qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
65 | q_dim = n_heads * head_dim
66 | kv_dim = n_kv_heads * head_dim
67 |
68 | q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
69 | k = (
70 | qkv_out[..., q_dim : q_dim + kv_dim]
71 | .view(bsz, q_len, n_kv_heads, head_dim)
72 | .transpose(1, 2)
73 | )
74 | v = (
75 | qkv_out[..., q_dim + kv_dim :]
76 | .view(bsz, q_len, n_kv_heads, head_dim)
77 | .transpose(1, 2)
78 | )
79 |
80 | position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
81 | q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
82 | k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
83 | out = F.scaled_dot_product_attention(
84 | q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
85 | )
86 | out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
87 | out = w.proj(out)
88 | return out
89 |
90 |
91 | def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
92 | hidden_BTC = inputs_embeds
93 |
94 | bsz, q_len, d_model = inputs_embeds.shape
95 | attn_mask = torch.zeros(q_len, q_len)
96 | attn_mask[:730, :730] = 1
97 | for i in range(730, q_len):
98 | attn_mask[i, : i + 1] = 1
99 | attn_mask = attn_mask.to(dtype=torch.bool)
100 |
101 | for i, block in enumerate(w.blocks):
102 | l_in = layer_norm(hidden_BTC, block.ln)
103 | l_attn = _attn(
104 | x=l_in,
105 | w=block.attn,
106 | freqs_cis=w.freqs_cis,
107 | attn_mask=attn_mask,
108 | n_heads=config.n_heads,
109 | n_kv_heads=config.n_kv_heads,
110 | )
111 | l_mlp = mlp(l_in, block.mlp)
112 | hidden_BTC = hidden_BTC + l_attn + l_mlp
113 |
114 | return hidden_BTC
115 |
116 |
117 | def text_decoder(
118 | x: torch.Tensor,
119 | w: nn.Module,
120 | attn_mask: torch.Tensor,
121 | position_ids: torch.Tensor,
122 | config: TextConfig,
123 | ):
124 | for block in w.blocks:
125 | l_in = layer_norm(x, block.ln)
126 | l_attn = attn(
127 | l_in,
128 | block.attn,
129 | freqs_cis=w.freqs_cis,
130 | kv_cache=block.kv_cache,
131 | attn_mask=attn_mask,
132 | n_heads=config.n_heads,
133 | n_kv_heads=config.n_kv_heads,
134 | position_ids=position_ids,
135 | )
136 | l_mlp = mlp(l_in, block.mlp)
137 | x = x + l_attn + l_mlp
138 |
139 | return x
140 |
141 |
142 | def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
143 | hidden_BC = hidden_BTC[:, -1, :]
144 | hidden_BC = layer_norm(hidden_BC, w.post_ln)
145 | logits = w.lm_head(hidden_BC)
146 | return logits
147 |
148 |
149 | def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
150 | hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
151 | logits = w.lm_head(hidden_BTC)
152 | return logits
153 |
154 |
155 | def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
156 | qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
157 | linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
158 |
159 | text = nn.ModuleDict(
160 | {
161 | "blocks": nn.ModuleList(
162 | [
163 | nn.ModuleDict(
164 | {
165 | "ln": nn.LayerNorm(config.dim, dtype=dtype),
166 | "attn": nn.ModuleDict(
167 | {
168 | "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
169 | "proj": linear_cls(
170 | config.dim, config.dim, dtype=dtype
171 | ),
172 | }
173 | ),
174 | "mlp": nn.ModuleDict(
175 | {
176 | "fc1": linear_cls(
177 | config.dim, config.ff_dim, dtype=dtype
178 | ),
179 | "fc2": linear_cls(
180 | config.ff_dim, config.dim, dtype=dtype
181 | ),
182 | }
183 | ),
184 | }
185 | )
186 | for _ in range(config.n_layers)
187 | ]
188 | ),
189 | "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
190 | "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
191 | }
192 | )
193 | text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
194 | text.register_buffer(
195 | "freqs_cis",
196 | precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
197 | persistent=False,
198 | )
199 |
200 | return text
201 |
--------------------------------------------------------------------------------
/moondream/torch/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5 | """
6 | Robust outlier detection for list of (x,y) tuples.
7 | Only requires numpy.
8 |
9 | Args:
10 | points_tuples: list of (x,y) tuples
11 | k_nearest: number of neighbors to consider
12 | threshold: multiplier for median distance
13 |
14 | Returns:
15 | list: filtered list of (x,y) tuples with outliers removed
16 | list: list of booleans indicating which points were kept (True = kept)
17 | """
18 | points = np.array(points_tuples)
19 | n_points = len(points)
20 |
21 | # Calculate pairwise distances manually
22 | dist_matrix = np.zeros((n_points, n_points))
23 | for i in range(n_points):
24 | for j in range(i + 1, n_points):
25 | # Euclidean distance between points i and j
26 | dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27 | dist_matrix[i, j] = dist
28 | dist_matrix[j, i] = dist
29 |
30 | # Get k nearest neighbors' distances
31 | k = min(k_nearest, n_points - 1)
32 | neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33 | avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34 |
35 | # Calculate mask using median distance
36 | median_dist = np.median(avg_neighbor_dist)
37 | mask = avg_neighbor_dist <= threshold * median_dist
38 |
39 | # Return filtered tuples and mask
40 | filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41 | return filtered_tuples
42 |
--------------------------------------------------------------------------------
/moondream/torch/vision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from typing import Union, Tuple
7 | from PIL import Image
8 |
9 | from .layers import attn, layer_norm, mlp
10 | from .image_crops import overlap_crop_image
11 | from .config import VisionConfig
12 |
13 | if torch.backends.mps.is_available():
14 | # Non-divisible input sizes are not implemented on MPS device yet.
15 | # https://github.com/pytorch/pytorch/issues/96056
16 | def adaptive_avg_pool2d(input, output_size):
17 | return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18 |
19 | else:
20 | adaptive_avg_pool2d = F.adaptive_avg_pool2d
21 |
22 | DeviceLike = Union[str, torch.device, int]
23 |
24 |
25 | def prepare_crops(
26 | image: Image.Image, config: VisionConfig, device: DeviceLike
27 | ) -> Tuple[torch.Tensor, Tuple[int, int]]:
28 | np_image = np.array(image.convert("RGB"))
29 | overlap_crops = overlap_crop_image(
30 | np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
31 | )
32 | all_crops = overlap_crops["crops"]
33 | all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34 | all_crops = (
35 | torch.from_numpy(all_crops)
36 | .to(device=device, dtype=torch.bfloat16)
37 | .div_(255.0)
38 | .sub_(0.5)
39 | .div_(0.5)
40 | )
41 | return all_crops, overlap_crops["tiling"]
42 |
43 |
44 | def create_patches(x, patch_size):
45 | # Original shape: [B, C, H, W]
46 | B, C, H, W = x.shape
47 | P1 = P2 = patch_size
48 |
49 | # Step 1: Split H and W dimensions into patches
50 | # [B, C, H/P1, P1, W/P2, P2]
51 | x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52 |
53 | # Step 2: Rearrange dimensions to match target shape
54 | # [B, H/P1, W/P2, C, P1, P2]
55 | x = x.permute(0, 2, 4, 1, 3, 5)
56 |
57 | # Step 3: Combine dimensions to get final shape
58 | # [B, (H/P1)*(W/P2), C*P1*P2]
59 | x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60 |
61 | return x
62 |
63 |
64 | def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65 | x = create_patches(input_BCHW, config.enc_patch_size)
66 |
67 | x = w.patch_emb(x)
68 | x = x + w.pos_emb
69 | for block in w.blocks:
70 | x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
71 | x = x + mlp(layer_norm(x, block.ln2), block.mlp)
72 | x = layer_norm(x, w.post_ln)
73 |
74 | return x
75 |
76 |
77 | def vision_projection(
78 | global_features: torch.Tensor,
79 | reconstructed: torch.Tensor,
80 | w: nn.Module,
81 | config: VisionConfig,
82 | ):
83 | reconstructed = reconstructed.permute(2, 0, 1)
84 | reconstructed = adaptive_avg_pool2d(
85 | reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
86 | )
87 | reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
88 | final_features = torch.cat([global_features, reconstructed], dim=-1)
89 | return mlp(final_features, w.proj_mlp)
90 |
91 |
92 | def build_vision_model(config: VisionConfig, dtype: torch.dtype):
93 | patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
94 | grid_size = config.crop_size // config.enc_patch_size
95 | num_patches = grid_size * grid_size
96 |
97 | vision = nn.ModuleDict(
98 | {
99 | "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
100 | "blocks": nn.ModuleList(
101 | [
102 | nn.ModuleDict(
103 | {
104 | "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
105 | "attn": nn.ModuleDict(
106 | {
107 | "qkv": nn.Linear(
108 | config.enc_dim, 3 * config.enc_dim, dtype=dtype
109 | ),
110 | "proj": nn.Linear(
111 | config.enc_dim, config.enc_dim, dtype=dtype
112 | ),
113 | }
114 | ),
115 | "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
116 | "mlp": nn.ModuleDict(
117 | {
118 | "fc1": nn.Linear(
119 | config.enc_dim, config.enc_ff_dim, dtype=dtype
120 | ),
121 | "fc2": nn.Linear(
122 | config.enc_ff_dim, config.enc_dim, dtype=dtype
123 | ),
124 | }
125 | ),
126 | }
127 | )
128 | for _ in range(config.enc_n_layers)
129 | ]
130 | ),
131 | "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
132 | "proj_mlp": nn.ModuleDict(
133 | {
134 | "fc1": nn.Linear(
135 | config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
136 | ),
137 | "fc2": nn.Linear(
138 | config.proj_inner_dim, config.proj_out_dim, dtype=dtype
139 | ),
140 | }
141 | ),
142 | }
143 | )
144 | vision.pos_emb = nn.Parameter(
145 | torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
146 | )
147 | return vision
148 |
--------------------------------------------------------------------------------
/recipes/gaze-detection-video/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | env/
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | *.egg-info/
21 | .installed.cfg
22 | *.egg
23 |
24 | # Virtual Environment
25 | venv/
26 | ENV/
27 |
28 | # IDE
29 | .idea/
30 | .vscode/
31 | *.swp
32 | *.swo
33 |
34 | # Project specific
35 | # input/*
36 | # !input/.gitkeep
37 | # output/*
38 | # !output/.gitkeep
39 | # temp/*
40 | # !temp/.gitkeep
41 |
42 | # Model files
43 | *.pt
44 | *.pth
45 | *.ckpt
46 |
47 | # Logs
48 | *.log
49 |
50 | # OS specific
51 | .DS_Store
52 | Thumbs.db
53 |
--------------------------------------------------------------------------------
/recipes/gaze-detection-video/README.md:
--------------------------------------------------------------------------------
1 | # Gaze Detection Video Processor
2 |
3 | > **⚠️ IMPORTANT:** This project currently uses Moondream 2B (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client
4 | > libraries once they become available for this version.
5 |
6 | ## Table of Contents
7 |
8 | - [Overview](#overview)
9 | - [Sample Output](#sample-output)
10 | - [Features](#features)
11 | - [Prerequisites](#prerequisites)
12 | - [Installation](#installation)
13 | - [Linux/macOS Installation](#linuxmacos-installation)
14 | - [Windows Installation](#windows-installation)
15 | - [Usage](#usage)
16 | - [Output](#output)
17 | - [Troubleshooting](#troubleshooting)
18 | - [Performance Notes](#performance-notes)
19 | - [Dependencies](#dependencies)
20 | - [Model Details](#model-details)
21 | - [License](#license)
22 |
23 | ## Overview
24 |
25 | This project uses the Moondream 2B model to detect faces and their gaze directions in videos. It processes videos frame by frame, visualizing face detections and gaze directions.
26 |
27 | ## Sample Output
28 |
29 | | Input Video | Processed Output |
30 | | :-----------------------------------: | :-----------------------------------------: |
31 | |  |  |
32 |
33 | ## Features
34 |
35 | - Face detection in video frames
36 | - Gaze direction tracking
37 | - Real-time visualization with:
38 | - Colored bounding boxes for faces
39 | - Gradient lines showing gaze direction
40 | - Gaze target points
41 | - Supports multiple faces per frame
42 | - Processes all common video formats (.mp4, .avi, .mov, .mkv)
43 | - Uses Moondream 2 (2025-01-09 release) via Hugging Face Transformers
44 | - Note: Will be migrated to official client libraries in future updates
45 | - No authentication required
46 |
47 | ## Prerequisites
48 |
49 | 1. Python 3.8 or later
50 | 2. CUDA-capable GPU recommended (but CPU mode works too)
51 | 3. FFmpeg installed on your system
52 |
53 | ## Installation
54 |
55 | ### Linux/macOS Installation
56 |
57 | 1. Install system dependencies:
58 |
59 | ```bash
60 | # Ubuntu/Debian
61 | sudo apt-get update && sudo apt-get install -y libvips42 libvips-dev ffmpeg
62 |
63 | # CentOS/RHEL
64 | sudo yum install vips vips-devel ffmpeg
65 |
66 | # macOS
67 | brew install vips ffmpeg
68 | ```
69 |
70 | 2. Clone and setup the project:
71 | ```bash
72 | git clone https://github.com/vikhyat/moondream.git
73 | cd moondream/recipes/gaze-detection-video
74 | python3 -m venv venv
75 | source venv/bin/activate
76 | pip install -r requirements.txt
77 | ```
78 |
79 | ### Windows Installation
80 |
81 | Windows setup requires a few additional steps for proper GPU support and libvips installation.
82 |
83 | 1. Clone the repository:
84 |
85 | ```bash
86 | git clone [repository-url]
87 | cd moondream/recipes/gaze-detection-video
88 | ```
89 |
90 | 2. Create and activate virtual environment:
91 |
92 | ```bash
93 | python -m venv venv
94 | .\venv\Scripts\activate
95 | ```
96 |
97 | 3. Install PyTorch with CUDA support:
98 |
99 | ```bash
100 | # For NVIDIA GPUs
101 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
102 | ```
103 |
104 | 4. Install libvips: Download the appropriate version based on your system architecture:
105 |
106 | | Architecture | VIPS Version to Download |
107 | | ------------ | ------------------------ |
108 | | 32-bit x86 | vips-dev-w32-all-8.16.0.zip |
109 | | 64-bit x64 | vips-dev-w64-all-8.16.0.zip |
110 |
111 | - Extract the ZIP file
112 | - Copy all DLL files from `vips-dev-8.16\bin` to either:
113 | - Your project's root directory (easier) OR
114 | - `C:\Windows\System32` (requires admin privileges)
115 | - Add to PATH:
116 | 1. Open System Properties → Advanced → Environment Variables
117 | 2. Under System Variables, find PATH
118 | 3. Add the full path to the `vips-dev-8.16\bin` directory
119 |
120 | 5. Install FFmpeg:
121 |
122 | - Download from https://ffmpeg.org/download.html#build-windows
123 | - Extract and add the `bin` folder to your system PATH (similar to step 4) or to the project root directory
124 |
125 | 6. Install other dependencies:
126 | ```bash
127 | pip install -r requirements.txt
128 | ```
129 |
130 | ## Usage
131 |
132 | 1. Place your input videos in the `input` directory
133 |
134 | - Supported formats: .mp4, .avi, .mov, .mkv
135 | - The directory will be created automatically if it doesn't exist
136 |
137 | 2. Run the script:
138 |
139 | ```bash
140 | python gaze-detection-video.py
141 | ```
142 |
143 | 3. The script will:
144 | - Process all videos in the input directory
145 | - Show progress bars for each video
146 | - Save processed videos to the `output` directory with prefix 'processed\_'
147 |
148 | ## Output
149 |
150 | - Processed videos are saved as `output/processed_[original_name].[ext]`
151 | - Each frame in the output video shows:
152 | - Colored boxes around detected faces
153 | - Lines indicating gaze direction
154 | - Points showing where each person is looking
155 |
156 | ## Troubleshooting
157 |
158 | 1. CUDA/GPU Issues:
159 |
160 | - Ensure you have CUDA installed for GPU support
161 | - The script will automatically fall back to CPU if no GPU is available
162 |
163 | 2. Memory Issues:
164 |
165 | - If processing large videos, ensure you have enough RAM
166 | - Consider reducing video resolution if needed
167 |
168 | 3. libvips Errors:
169 |
170 | - Make sure libvips is properly installed for your OS
171 | - Check system PATH includes libvips
172 |
173 | 4. Video Format Issues:
174 | - Ensure FFmpeg is installed and in your system PATH
175 | - Try converting problematic videos to MP4 format
176 |
177 | ## Performance Notes
178 |
179 | - GPU processing is significantly faster than CPU
180 | - Processing time depends on:
181 | - Video resolution
182 | - Number of faces per frame
183 | - Frame rate
184 | - Available computing power
185 |
186 | ## Dependencies
187 |
188 | - transformers (for Moondream 2 model access)
189 | - torch
190 | - opencv-python
191 | - pillow
192 | - matplotlib
193 | - numpy
194 | - tqdm
195 | - pyvips
196 | - accelerate
197 | - einops
198 |
199 | ## Model Details
200 |
201 | > **⚠️ IMPORTANT:** This project currently uses Moondream 2 (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client
202 | > libraries once they become available for this version.
203 |
204 | The model is loaded using:
205 |
--------------------------------------------------------------------------------
/recipes/gaze-detection-video/input/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/recipes/gaze-detection-video/input/.gitkeep
--------------------------------------------------------------------------------
/recipes/gaze-detection-video/output/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/recipes/gaze-detection-video/output/.gitkeep
--------------------------------------------------------------------------------
/recipes/gaze-detection-video/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.0
2 | transformers>=4.36.0
3 | opencv-python>=4.8.0
4 | pillow>=10.0.0
5 | matplotlib>=3.7.0
6 | numpy>=1.24.0
7 | tqdm>=4.65.0
8 | pyvips
9 | accelerate>=0.26.0
10 | einops
--------------------------------------------------------------------------------
/recipes/gaze-detection-video/temp/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vikhyat/moondream/ce59395b386eb956e19bc1cba5f97deb9befcd05/recipes/gaze-detection-video/temp/.gitkeep
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 | *.dll
23 |
24 | # Virtual Environment
25 | venv/
26 | env/
27 | ENV/
28 | .venv/
29 |
30 | # IDE
31 | .idea/
32 | .vscode/
33 | *.swp
34 | *.swo
35 |
36 | # Project specific
37 | inputs/*
38 | outputs/*
39 | !inputs/.gitkeep
40 | !outputs/.gitkeep
41 | inputs/
42 | outputs/
43 |
44 | # Model files
45 | *.pth
46 | *.onnx
47 | *.pt
48 |
49 | # Logs
50 | *.log
51 |
52 | certificate.pem
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/README.md:
--------------------------------------------------------------------------------
1 | # Promptable Content Moderation with Moondream
2 |
3 | Welcome to the future of content moderation with Moondream 2B, a powerful and lightweight vision-language model that enables detection and moderation of video content using natural language prompts.
4 |
5 | [Try it now.](https://huggingface.co/spaces/moondream/content-moderation)
6 |
7 | ## Features
8 |
9 | - Content moderation through natural language prompts
10 | - Multiple visualization styles
11 | - Intelligent scene detection and tracking:
12 | - DeepSORT tracking with scene-aware reset
13 | - Persistent moderation across frames
14 | - Smart tracker reset at scene boundaries
15 | - Optional grid-based detection for improved accuracy on complex scenes
16 | - Frame-by-frame processing with IoU-based merging
17 | - Web-compatible output format
18 | - Test mode (process only first X seconds)
19 | - Advanced moderation analysis with multiple visualization plots
20 |
21 | ## Examples
22 |
23 | | Prompt | Output |
24 | |--------|-----------------|
25 | | "white cigarette" |  |
26 | | "gun" |  |
27 | | "confederate flag" |  |
28 |
29 | ## Requirements
30 |
31 | ### Python Dependencies
32 |
33 | For Windows users, before installing other requirements, first install PyTorch with CUDA support:
34 |
35 | ```bash
36 | pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121
37 | ```
38 |
39 | Then install the remaining dependencies:
40 |
41 | ```bash
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ### System Requirements
46 |
47 | - FFmpeg (required for video processing)
48 | - libvips (required for image processing)
49 |
50 | Installation by platform:
51 |
52 | - Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`
53 | - macOS: `brew install ffmpeg libvips`
54 | - Windows:
55 | - Download FFmpeg from [ffmpeg.org](https://ffmpeg.org/download.html)
56 | - Follow [libvips Windows installation guide](https://docs.moondream.ai/quick-start)
57 |
58 | ## Installation
59 |
60 | 1. Clone this repository and create a new virtual environment:
61 |
62 | ```bash
63 | git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction
64 | python -m venv .venv
65 | source .venv/bin/activate # On Windows: .venv\Scripts\activate
66 | ```
67 |
68 | 2. Install Python dependencies:
69 |
70 | ```bash
71 | pip install -r requirements.txt
72 | ```
73 |
74 | 3. Install ffmpeg and libvips:
75 | - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`
76 | - On macOS: `brew install ffmpeg`
77 | - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html)
78 |
79 | > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start)
80 |
81 | ## Usage
82 |
83 | The easiest way to use this tool is through its web interface, which provides a user-friendly experience for video content moderation.
84 |
85 | ### Web Interface
86 |
87 | 1. Start the web interface:
88 |
89 | ```bash
90 | python app.py
91 | ```
92 |
93 | 2. Open the provided URL in your browser (typically )
94 |
95 | 3. Use the interface to:
96 | - Upload your video file
97 | - Specify content to moderate (e.g., "face", "cigarette", "gun")
98 | - Choose redaction style (default: obfuscated-pixel)
99 | - OPTIONAL: Configure advanced settings
100 | - Processing speed/quality
101 | - Grid size for detection
102 | - Test mode for quick validation (default: on, 3 seconds)
103 | - Process the video and download results
104 | - Analyze detection patterns with visualization tools
105 |
106 | ## Output Files
107 |
108 | The tool generates two types of output files in the `outputs` directory:
109 |
110 | 1. Processed Videos:
111 | - Format: `[style]_[content_type]_[original_filename].mp4`
112 | - Example: `censor_inappropriate_video.mp4`
113 |
114 | 2. Detection Data:
115 | - Format: `[style]_[content_type]_[original_filename]_detections.json`
116 | - Contains frame-by-frame detection information
117 | - Used for visualization and analysis
118 |
119 | ## Technical Details
120 |
121 | ### Scene Detection and Tracking
122 |
123 | The tool uses advanced scene detection and object tracking:
124 |
125 | 1. Scene Detection:
126 | - Powered by PySceneDetect's ContentDetector
127 | - Automatically identifies scene changes in videos
128 | - Configurable detection threshold (default: 30.0)
129 | - Helps maintain tracking accuracy across scene boundaries
130 |
131 | 2. Object Tracking:
132 | - DeepSORT tracking for consistent object identification
133 | - Automatic tracker reset at scene changes
134 | - Maintains object identity within scenes
135 | - Prevents tracking errors across scene boundaries
136 |
137 | 3. Integration Benefits:
138 | - More accurate object tracking
139 | - Better handling of scene transitions
140 | - Reduced false positives in tracking
141 | - Improved tracking consistency
142 |
143 | ## Best Practices
144 |
145 | - Use test mode for initial configuration
146 | - Enable grid-based detection for complex scenes
147 | - Choose appropriate redaction style based on content type:
148 | - Censor: Complete content blocking
149 | - Blur styles: Less intrusive moderation
150 | - Bounding Box: Content review and analysis
151 | - Monitor system resources during processing
152 | - Use appropriate processing quality settings based on your needs
153 |
154 | ## Notes
155 |
156 | - Processing time depends on video length, resolution, GPU availability, and chosen settings
157 | - GPU is strongly recommended for faster processing
158 | - Grid-based detection increases accuracy but requires more processing time (each grid cell is processed independently)
159 | - Test mode processes only first X seconds (default: 3 seconds) for quick validation
160 |
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/deep_sort_integration.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from deep_sort_realtime.deepsort_tracker import DeepSort
4 | from datetime import datetime
5 |
6 |
7 | class DeepSORTTracker:
8 | def __init__(self, max_age=5):
9 | """Initialize DeepSORT tracker."""
10 | self.max_age = max_age
11 | self.tracker = self._create_tracker()
12 |
13 | def _create_tracker(self):
14 | """Create a new instance of DeepSort tracker."""
15 | return DeepSort(
16 | max_age=self.max_age,
17 | embedder="mobilenet", # Using default MobileNetV2 embedder
18 | today=datetime.now().date(), # For track naming and daily ID reset
19 | )
20 |
21 | def reset(self):
22 | """Reset the tracker state by creating a new instance."""
23 | print("Resetting DeepSORT tracker...")
24 | self.tracker = self._create_tracker()
25 |
26 | def update(self, frame, detections):
27 | """Update tracking with new detections.
28 |
29 | Args:
30 | frame: Current video frame (numpy array)
31 | detections: List of (box, keyword) tuples where box is [x1, y1, x2, y2] normalized
32 |
33 | Returns:
34 | List of (box, keyword, track_id) tuples
35 | """
36 | if not detections:
37 | return []
38 |
39 | height, width = frame.shape[:2]
40 |
41 | # Convert normalized coordinates to absolute and format detections
42 | detection_list = []
43 | for box, keyword in detections:
44 | x1 = int(box[0] * width)
45 | y1 = int(box[1] * height)
46 | x2 = int(box[2] * width)
47 | y2 = int(box[3] * height)
48 | w = x2 - x1
49 | h = y2 - y1
50 |
51 | # Format: ([left,top,w,h], confidence, detection_class)
52 | detection_list.append(([x1, y1, w, h], 1.0, keyword))
53 |
54 | # Update tracker
55 | tracks = self.tracker.update_tracks(detection_list, frame=frame)
56 |
57 | # Convert back to normalized coordinates with track IDs
58 | tracked_objects = []
59 | for track in tracks:
60 | if not track.is_confirmed():
61 | continue
62 |
63 | ltrb = track.to_ltrb() # Get [left,top,right,bottom] format
64 | x1, y1, x2, y2 = ltrb
65 |
66 | # Normalize coordinates
67 | x1 = max(0.0, min(1.0, x1 / width))
68 | y1 = max(0.0, min(1.0, y1 / height))
69 | x2 = max(0.0, min(1.0, x2 / width))
70 | y2 = max(0.0, min(1.0, y2 / height))
71 |
72 | tracked_objects.append(([x1, y1, x2, y2], track.det_class, track.track_id))
73 |
74 | return tracked_objects
75 |
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/packages.txt:
--------------------------------------------------------------------------------
1 | libvips
2 | ffmpeg
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/persistence.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 |
5 | def save_detection_data(data, output_file):
6 | """
7 | Saves the detection data to a JSON file.
8 |
9 | Args:
10 | data (dict): The complete detection data structure.
11 | output_file (str): Path to the output JSON file.
12 | """
13 | try:
14 | # Create directory if it doesn't exist
15 | os.makedirs(os.path.dirname(output_file), exist_ok=True)
16 |
17 | with open(output_file, "w") as f:
18 | json.dump(data, f, indent=4)
19 | print(f"Detection data saved to {output_file}")
20 | return True
21 | except Exception as e:
22 | print(f"Error saving data: {str(e)}")
23 | return False
24 |
25 |
26 | def load_detection_data(input_file):
27 | """
28 | Loads the detection data from a JSON file.
29 |
30 | Args:
31 | input_file (str): Path to the JSON file.
32 |
33 | Returns:
34 | dict: The loaded detection data, or None if there was an error.
35 | """
36 | try:
37 | with open(input_file, "r") as f:
38 | return json.load(f)
39 | except Exception as e:
40 | print(f"Error loading data: {str(e)}")
41 | return None
42 |
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/requirements.txt:
--------------------------------------------------------------------------------
1 | gradio>=4.0.0
2 | torch>=2.0.0
3 | # if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121
4 | transformers>=4.36.0
5 | opencv-python>=4.8.0
6 | pillow>=10.0.0
7 | numpy>=1.24.0
8 | tqdm>=4.66.0
9 | ffmpeg-python
10 | einops
11 | pyvips-binary
12 | pyvips
13 | accelerate
14 | # for spaces
15 | --extra-index-url https://download.pytorch.org/whl/cu113
16 | spaces
17 | # SAM dependencies
18 | torchvision>=0.20.1
19 | matplotlib>=3.7.0
20 | pandas>=2.0.0
21 | plotly
22 | # DeepSORT dependencies
23 | deep-sort-realtime>=1.3.2
24 | scikit-learn # Required for deep-sort-realtime
25 | # Scene detection dependencies (for intelligent scene-aware tracking)
26 | scenedetect[opencv]>=0.6.2 # Provides scene change detection capabilities
--------------------------------------------------------------------------------
/recipes/promptable-content-moderation/visualization.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import matplotlib.pyplot as plt
3 | from persistence import load_detection_data
4 | import argparse
5 |
6 |
7 | def visualize_detections(json_path):
8 | """
9 | Visualize detection data from a JSON file.
10 |
11 | Args:
12 | json_path (str): Path to the JSON file containing detection data.
13 | """
14 | # Load the persisted JSON data
15 | data = load_detection_data(json_path)
16 | if not data:
17 | return
18 |
19 | # Convert the frame detections to a DataFrame
20 | rows = []
21 | for frame_data in data["frame_detections"]:
22 | frame = frame_data["frame"]
23 | timestamp = frame_data["timestamp"]
24 | for obj in frame_data["objects"]:
25 | rows.append(
26 | {
27 | "frame": frame,
28 | "timestamp": timestamp,
29 | "keyword": obj["keyword"],
30 | "x1": obj["bbox"][0],
31 | "y1": obj["bbox"][1],
32 | "x2": obj["bbox"][2],
33 | "y2": obj["bbox"][3],
34 | "area": (obj["bbox"][2] - obj["bbox"][0])
35 | * (obj["bbox"][3] - obj["bbox"][1]),
36 | }
37 | )
38 |
39 | if not rows:
40 | print("No detections found in the data")
41 | return
42 |
43 | df = pd.DataFrame(rows)
44 |
45 | # Create a figure with multiple subplots
46 | fig = plt.figure(figsize=(15, 10))
47 |
48 | # Plot 1: Number of detections per frame
49 | plt.subplot(2, 2, 1)
50 | detections_per_frame = df.groupby("frame").size()
51 | plt.plot(detections_per_frame.index, detections_per_frame.values)
52 | plt.xlabel("Frame")
53 | plt.ylabel("Number of Detections")
54 | plt.title("Detections Per Frame")
55 |
56 | # Plot 2: Distribution of detection areas
57 | plt.subplot(2, 2, 2)
58 | df["area"].hist(bins=30)
59 | plt.xlabel("Detection Area (normalized)")
60 | plt.ylabel("Count")
61 | plt.title("Distribution of Detection Areas")
62 |
63 | # Plot 3: Average detection area over time
64 | plt.subplot(2, 2, 3)
65 | avg_area = df.groupby("frame")["area"].mean()
66 | plt.plot(avg_area.index, avg_area.values)
67 | plt.xlabel("Frame")
68 | plt.ylabel("Average Detection Area")
69 | plt.title("Average Detection Area Over Time")
70 |
71 | # Plot 4: Heatmap of detection centers
72 | plt.subplot(2, 2, 4)
73 | df["center_x"] = (df["x1"] + df["x2"]) / 2
74 | df["center_y"] = (df["y1"] + df["y2"]) / 2
75 | plt.hist2d(df["center_x"], df["center_y"], bins=30)
76 | plt.colorbar()
77 | plt.xlabel("X Position")
78 | plt.ylabel("Y Position")
79 | plt.title("Detection Center Heatmap")
80 |
81 | # Adjust layout and display
82 | plt.tight_layout()
83 | plt.show()
84 |
85 | # Print summary statistics
86 | print("\nSummary Statistics:")
87 | print(f"Total frames analyzed: {len(data['frame_detections'])}")
88 | print(f"Total detections: {len(df)}")
89 | print(
90 | f"Average detections per frame: {len(df) / len(data['frame_detections']):.2f}"
91 | )
92 | print(f"\nVideo metadata:")
93 | for key, value in data["video_metadata"].items():
94 | print(f"{key}: {value}")
95 |
96 |
97 | def main():
98 | parser = argparse.ArgumentParser(description="Visualize object detection data")
99 | parser.add_argument(
100 | "json_file", help="Path to the JSON file containing detection data"
101 | )
102 | args = parser.parse_args()
103 |
104 | visualize_detections(args.json_file)
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/recipes/promptable-video-redaction/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # Virtual Environment
24 | venv/
25 | env/
26 | ENV/
27 | .venv/
28 |
29 | # IDE
30 | .idea/
31 | .vscode/
32 | *.swp
33 | *.swo
34 |
35 | # Project specific
36 | inputs/*
37 | outputs/*
38 | !inputs/.gitkeep
39 | !outputs/.gitkeep
40 | inputs/
41 | outputs/
42 |
43 | # Model files
44 | *.pth
45 | *.onnx
46 | *.pt
47 |
48 | # Logs
49 | *.log
50 |
51 | certificate.pem
--------------------------------------------------------------------------------
/recipes/promptable-video-redaction/README.md:
--------------------------------------------------------------------------------
1 | # Promptable Video Redaction with Moondream
2 |
3 | This tool uses Moondream 2B, a powerful yet lightweight vision-language model, to detect and redact objects from videos. Moondream can recognize a wide variety of objects, people,
4 | text, and more with high accuracy while being much smaller than traditional models.
5 |
6 | [Try it now.](https://huggingface.co/spaces/moondream/promptable-video-redaction)
7 |
8 | ## About Moondream
9 |
10 | Moondream is a tiny yet powerful vision-language model that can analyze images and answer questions about them. It's designed to be lightweight and efficient while maintaining high
11 | accuracy. Some key features:
12 |
13 | - Only 2B parameters
14 | - Fast inference with minimal resource requirements
15 | - Supports CPU and GPU execution
16 | - Open source and free to use
17 | - Can detect almost anything you can describe in natural language
18 |
19 | Links:
20 |
21 | - [GitHub Repository](https://github.com/vikhyat/moondream)
22 | - [Hugging Face](https://huggingface.co/vikhyatk/moondream2)
23 | - [Build with Moondream](http://docs.moondream.ai/)
24 |
25 | ## Features
26 |
27 | - Real-time object detection in videos using Moondream
28 | - Multiple visualization styles:
29 | - Censor: Black boxes over detected objects
30 | - Bounding Box: Traditional bounding boxes with labels
31 | - Hitmarker: Call of Duty style crosshair markers
32 | - Optional grid-based detection for improved accuracy
33 | - Flexible object type detection using natural language
34 | - Frame-by-frame processing with IoU-based merging
35 | - Batch processing of multiple videos
36 | - Web-compatible output format
37 | - User-friendly web interface
38 | - Command-line interface for automation
39 |
40 | ## Requirements
41 |
42 | - Python 3.8+
43 | - OpenCV (cv2)
44 | - PyTorch
45 | - Transformers
46 | - Pillow (PIL)
47 | - tqdm
48 | - ffmpeg
49 | - numpy
50 | - gradio (for web interface)
51 |
52 | ## Installation
53 |
54 | 1. Clone this repository and create a new virtual environment
55 |
56 | ```bash
57 | git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction
58 | python -m venv .venv
59 | source .venv/bin/activate
60 | ```
61 |
62 | 2. Install the required packages:
63 |
64 | ```bash
65 | pip install -r requirements.txt
66 | ```
67 |
68 | 3. Install ffmpeg:
69 | - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`
70 | - On macOS: `brew install ffmpeg`
71 | - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html)
72 | > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start)
73 |
74 | ## Usage
75 |
76 | ### Web Interface
77 |
78 | 1. Start the web interface:
79 |
80 | ```bash
81 | python app.py
82 | ```
83 |
84 | 2. Open the provided URL in your browser
85 |
86 | 3. Use the interface to:
87 | - Upload your video
88 | - Specify what to censor (e.g., face, logo, text)
89 | - Adjust processing speed and quality
90 | - Configure grid size for detection
91 | - Process and download the censored video
92 |
93 | ### Command Line Interface
94 |
95 | 1. Create an `inputs` directory in the same folder as the script:
96 |
97 | ```bash
98 | mkdir inputs
99 | ```
100 |
101 | 2. Place your video files in the `inputs` directory. Supported formats:
102 |
103 | - .mp4
104 | - .avi
105 | - .mov
106 | - .mkv
107 | - .webm
108 |
109 | 3. Run the script:
110 |
111 | ```bash
112 | python main.py
113 | ```
114 |
115 | ### Optional Arguments:
116 |
117 | - `--test`: Process only first 3 seconds of each video (useful for testing detection settings)
118 |
119 | ```bash
120 | python main.py --test
121 | ```
122 |
123 | - `--preset`: Choose FFmpeg encoding preset (affects output quality vs. speed)
124 |
125 | ```bash
126 | python main.py --preset ultrafast # Fastest, lower quality
127 | python main.py --preset veryslow # Slowest, highest quality
128 | ```
129 |
130 | - `--detect`: Specify what object type to detect (using natural language)
131 |
132 | ```bash
133 | python main.py --detect person # Detect people
134 | python main.py --detect "red car" # Detect red cars
135 | python main.py --detect "person wearing a hat" # Detect people with hats
136 | ```
137 |
138 | - `--box-style`: Choose visualization style
139 |
140 | ```bash
141 | python main.py --box-style censor # Black boxes (default)
142 | python main.py --box-style bounding-box # Bounding box-style boxes with labels
143 | python main.py --box-style hitmarker # COD-style hitmarkers
144 | ```
145 |
146 | - `--rows` and `--cols`: Enable grid-based detection by splitting frames
147 |
148 | ```bash
149 | python main.py --rows 2 --cols 2 # Split each frame into 2x2 grid
150 | python main.py --rows 3 --cols 3 # Split each frame into 3x3 grid
151 | ```
152 |
153 | You can combine arguments:
154 |
155 | ```bash
156 | python main.py --detect "person wearing sunglasses" --box-style bounding-box --test --preset "fast" --rows 2 --cols 2
157 | ```
158 |
159 | ### Visualization Styles
160 |
161 | The tool supports three different visualization styles for detected objects:
162 |
163 | 1. **Censor** (default)
164 |
165 | - Places solid black rectangles over detected objects
166 | - Best for privacy and content moderation
167 | - Completely obscures the detected region
168 |
169 | 2. **Bounding Box**
170 |
171 | - Traditional object detection style
172 | - Red bounding box around detected objects
173 | - Label showing object type above the box
174 | - Good for analysis and debugging
175 |
176 | 3. **Hitmarker**
177 | - Call of Duty inspired visualization
178 | - White crosshair marker at center of detected objects
179 | - Small label above the marker
180 | - Stylistic choice for gaming-inspired visualization
181 |
182 | Choose the style that best fits your use case using the `--box-style` argument.
183 |
184 | ## Output
185 |
186 | Processed videos will be saved in the `outputs` directory with the format: `[style]_[object_type]_[original_filename].mp4`
187 |
188 | For example:
189 |
190 | - `censor_face_video.mp4`
191 | - `bounding-box_person_video.mp4`
192 | - `hitmarker_car_video.mp4`
193 |
194 | The output videos will include:
195 |
196 | - Original video content
197 | - Selected visualization style for detected objects
198 | - Web-compatible H.264 encoding
199 |
200 | ## Notes
201 |
202 | - Processing time depends on video length, grid size, and GPU availability
203 | - GPU is strongly recommended for faster processing
204 | - Requires sufficient disk space for temporary files
205 | - Detection quality varies based on video quality and Moondream's ability to recognize the specified object
206 | - Grid-based detection impacts performance significantly - use only when needed
207 | - Web interface shows progress updates and errors
208 | - Choose visualization style based on your use case
209 | - Moondream can detect almost anything you can describe in natural language
210 |
--------------------------------------------------------------------------------
/recipes/promptable-video-redaction/packages.txt:
--------------------------------------------------------------------------------
1 | libvips
2 | ffmpeg
--------------------------------------------------------------------------------
/recipes/promptable-video-redaction/requirements.txt:
--------------------------------------------------------------------------------
1 | gradio>=4.0.0
2 | torch
3 | transformers
4 | opencv-python
5 | pillow
6 | numpy
7 | tqdm
8 | ffmpeg-python
9 | einops
10 | pyvips
11 | accelerate
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.32.1
2 | huggingface-hub==0.24.0
3 | Pillow==10.4.0
4 | pyvips-binary==8.16.0
5 | pyvips==2.2.3
6 | torch==2.5.1
7 | transformers==4.44.0
8 | gradio==4.38.1
9 |
10 | # Needed for running evals
11 | datasets==3.2.0
12 | editdistance==0.8.1
13 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from queue import Queue
3 | from threading import Thread
4 |
5 | import torch
6 | from PIL import Image
7 | from transformers import AutoTokenizer, TextIteratorStreamer
8 |
9 | from moondream.hf import LATEST_REVISION, Moondream, detect_device
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--image", type=str, required=True)
14 | parser.add_argument("--prompt", type=str, required=False)
15 | parser.add_argument("--caption", action="store_true")
16 | parser.add_argument("--cpu", action="store_true")
17 | args = parser.parse_args()
18 |
19 | if args.cpu:
20 | device = torch.device("cpu")
21 | dtype = torch.float32
22 | else:
23 | device, dtype = detect_device()
24 | if device != torch.device("cpu"):
25 | print("Using device:", device)
26 | print("If you run into issues, pass the `--cpu` flag to this script.")
27 | print()
28 |
29 | image_path = args.image
30 | prompt = args.prompt
31 |
32 | model_id = "vikhyatk/moondream2"
33 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
34 | moondream = Moondream.from_pretrained(
35 | model_id,
36 | revision=LATEST_REVISION,
37 | torch_dtype=dtype,
38 | ).to(device=device)
39 | moondream.eval()
40 |
41 | image = Image.open(image_path)
42 |
43 | if args.caption:
44 | print(moondream.caption(images=[image], tokenizer=tokenizer)[0])
45 | else:
46 | image_embeds = moondream.encode_image(image)
47 |
48 | if prompt is None:
49 | chat_history = ""
50 |
51 | while True:
52 | question = input("> ")
53 |
54 | result_queue = Queue()
55 |
56 | streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
57 |
58 | # Separate direct arguments from keyword arguments
59 | thread_args = (image_embeds, question, tokenizer, chat_history)
60 | thread_kwargs = {"streamer": streamer, "result_queue": result_queue}
61 |
62 | thread = Thread(
63 | target=moondream.answer_question,
64 | args=thread_args,
65 | kwargs=thread_kwargs,
66 | )
67 | thread.start()
68 |
69 | buffer = ""
70 | for new_text in streamer:
71 | buffer += new_text
72 | if not new_text.endswith("<") and not new_text.endswith("END"):
73 | print(buffer, end="", flush=True)
74 | buffer = ""
75 | print(buffer)
76 |
77 | thread.join()
78 |
79 | answer = result_queue.get()
80 | chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n"
81 | else:
82 | print(">", prompt)
83 | answer = moondream.answer_question(image_embeds, prompt, tokenizer)
84 | print(answer)
85 |
--------------------------------------------------------------------------------
/tests/test_image_crops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from moondream.torch.image_crops import overlap_crop_image, reconstruct_from_crops
4 |
5 |
6 | def test_overlap_crop_basic():
7 | # Create a test image
8 | test_image = np.zeros((800, 600, 3), dtype=np.uint8)
9 | # Add a recognizable pattern - white rectangle in the middle
10 | test_image[300:500, 200:400] = 255
11 |
12 | result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)
13 |
14 | # Check basic properties
15 | assert result["crops"][0].shape == (378, 378, 3)
16 | assert len(result["crops"]) > 1
17 | assert all(crop.shape == (378, 378, 3) for crop in result["crops"])
18 | assert len(result["tiling"]) == 2
19 |
20 |
21 | def test_overlap_crop_small_image():
22 | # Test with image smaller than crop size
23 | test_image = np.zeros((300, 200, 3), dtype=np.uint8)
24 | result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)
25 |
26 | # Should still produce valid output
27 | assert result["crops"][0].shape == (378, 378, 3)
28 | assert len(result["crops"]) == 2
29 | assert result["tiling"] == (1, 1)
30 |
31 |
32 | def test_reconstruction():
33 | # Create a test image
34 | test_image = np.zeros((800, 600, 3), dtype=np.uint8)
35 | # Add a recognizable pattern
36 | test_image[300:500, 200:400] = 255
37 |
38 | # Crop and reconstruct
39 | result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)
40 | crops_tensor = [torch.from_numpy(crop) for crop in result["crops"][1:]]
41 | reconstructed = reconstruct_from_crops(
42 | crops_tensor, result["tiling"], overlap_margin=4
43 | )
44 |
45 | # Convert back to numpy for comparison
46 | reconstructed_np = reconstructed.numpy()
47 |
48 | # The reconstructed image should be similar to the input
49 | # We can't expect exact equality due to resizing operations
50 | # but the white rectangle should still be visible in the middle
51 | center_reconstructed = reconstructed_np[
52 | reconstructed_np.shape[0] // 2 - 100 : reconstructed_np.shape[0] // 2 + 100,
53 | reconstructed_np.shape[1] // 2 - 100 : reconstructed_np.shape[1] // 2 + 100,
54 | ].mean()
55 |
56 | # The center region should be significantly brighter than the edges
57 | assert center_reconstructed > reconstructed_np[:100, :100].mean() + 100
58 |
--------------------------------------------------------------------------------
/webcam_gradio_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from threading import Thread
4 |
5 | import gradio as gr
6 | import torch
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8 |
9 | from moondream.hf import LATEST_REVISION, detect_device
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--cpu", action="store_true")
13 | args = parser.parse_args()
14 |
15 | if args.cpu:
16 | device = torch.device("cpu")
17 | dtype = torch.float32
18 | else:
19 | device, dtype = detect_device()
20 | if device != torch.device("cpu"):
21 | print("Using device:", device)
22 | print("If you run into issues, pass the `--cpu` flag to this script.")
23 | print()
24 |
25 | model_id = "vikhyatk/moondream2"
26 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
27 | moondream = AutoModelForCausalLM.from_pretrained(
28 | model_id, trust_remote_code=True, revision=LATEST_REVISION
29 | ).to(device=device, dtype=dtype)
30 | moondream.eval()
31 |
32 |
33 | def answer_question(img, prompt):
34 | image_embeds = moondream.encode_image(img)
35 | streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
36 | thread = Thread(
37 | target=moondream.answer_question,
38 | kwargs={
39 | "image_embeds": image_embeds,
40 | "question": prompt,
41 | "tokenizer": tokenizer,
42 | "streamer": streamer,
43 | },
44 | )
45 | thread.start()
46 |
47 | buffer = ""
48 | for new_text in streamer:
49 | buffer += new_text
50 | yield buffer
51 |
52 |
53 | with gr.Blocks() as demo:
54 | gr.Markdown("# 🌔 moondream")
55 |
56 | gr.HTML(
57 | """
58 |
64 | """
65 | )
66 |
67 | with gr.Row():
68 | prompt = gr.Textbox(
69 | label="Prompt",
70 | value="What's going on? Respond with a single sentence.",
71 | interactive=True,
72 | )
73 | with gr.Row():
74 | img = gr.Image(type="pil", label="Upload an Image", streaming=True)
75 | output = gr.Markdown(elem_classes=["md_output"])
76 |
77 | latest_img = None
78 | latest_prompt = prompt.value
79 |
80 | @img.change(inputs=[img])
81 | def img_change(img):
82 | global latest_img
83 | latest_img = img
84 |
85 | @prompt.change(inputs=[prompt])
86 | def prompt_change(prompt):
87 | global latest_prompt
88 | latest_prompt = prompt
89 |
90 | @demo.load(outputs=[output])
91 | def live_video():
92 | while True:
93 | if latest_img is None:
94 | time.sleep(0.1)
95 | else:
96 | for text in answer_question(latest_img, latest_prompt):
97 | if len(text) > 0:
98 | yield text
99 |
100 |
101 | demo.queue().launch(debug=True)
102 |
--------------------------------------------------------------------------------