├── .dockerignore ├── .gitignore ├── Dockerfile ├── README.md ├── download_ckpts.sh ├── enviroment.yaml ├── image_processing.py ├── only_video_process.json ├── requirements.txt ├── runpod_handler.py ├── sam2_configs ├── __init__.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml └── sam2_hiera_t.yaml ├── sam2_processor.py ├── test_input.json ├── test_input2.json ├── test_input3.json ├── test_runpod.py └── unittest_sam2.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore git directory 2 | .git 3 | 4 | # Ignore any .log files 5 | *.log 6 | 7 | # Ignore python cache files 8 | __pycache__ 9 | *.pyc 10 | *.pyo 11 | *.pyd 12 | 13 | # Ignore checkpoints directory (if you're downloading these during the build process) 14 | checkpoints/ 15 | 16 | # Ignore any local configuration files 17 | *.env 18 | 19 | # Ignore any temporary files 20 | *.tmp 21 | 22 | # Ignore Docker files themselves 23 | Dockerfile 24 | .dockerignore 25 | 26 | # Ignore any other files or directories that are not needed in the image 27 | # Add more patterns as needed 28 | 29 | .vscode/ 30 | .DS_Store 31 | __pycache__/ 32 | *-checkpoint.ipynb 33 | .venv 34 | *.egg* 35 | build/* 36 | _C.* 37 | outputs/* 38 | checkpoints/*.pt 39 | temp_frames_* 40 | static/* 41 | debug_frames* 42 | debug_single* 43 | 44 | test_input.json -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | __pycache__/ 4 | *-checkpoint.ipynb 5 | .venv 6 | *.egg* 7 | build/* 8 | _C.* 9 | outputs/* 10 | checkpoints/*.pt 11 | temp_frames_* 12 | static/* 13 | debug_frames* 14 | debug_single* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.10-slim 3 | 4 | # Set the working directory in the container 5 | WORKDIR /app 6 | 7 | # Copy the current directory contents into the container at /app 8 | COPY . /app 9 | 10 | # Install system dependencies including git, OpenCV dependencies, and additional requested packages 11 | RUN apt-get update && apt-get install -y \ 12 | wget \ 13 | git \ 14 | libgl1-mesa-glx \ 15 | libglib2.0-0 \ 16 | libsm6 \ 17 | libxext6 \ 18 | libxrender-dev \ 19 | build-essential cmake python3-dev python3-numpy \ 20 | libavcodec-dev libavformat-dev libswscale-dev \ 21 | libgstreamer-plugins-base1.0-dev \ 22 | libgstreamer1.0-dev libgtk-3-dev \ 23 | libpng-dev libjpeg-dev libopenexr-dev libtiff-dev libwebp-dev \ 24 | libopencv-dev x264 libx264-dev libssl-dev ffmpeg \ 25 | && rm -rf /var/lib/apt/lists/* 26 | 27 | # Install Python dependencies 28 | RUN pip install --no-cache-dir -r requirements.txt 29 | 30 | # Install OpenCV from source 31 | RUN python -m pip install --no-binary opencv-python opencv-python 32 | 33 | # Create the checkpoints directory 34 | RUN mkdir -p checkpoints 35 | 36 | # Make the download script executable 37 | RUN chmod +x download_ckpts.sh 38 | 39 | # Run the checkpoint download script 40 | RUN ./download_ckpts.sh 41 | 42 | # Move the downloaded checkpoints to the checkpoints directory 43 | RUN mv *.pt checkpoints/ 44 | 45 | # Set the entry point to your Python script 46 | CMD ["python", "runpod_handler.py"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RunPod Video and Image Processor 2 | 3 | This project is a serverless application using RunPod to process videos and single images using the SAM2 (Segment Anything Model 2) processor. 4 | 5 | ## Features 6 | 7 | - Process videos using SAM2 8 | - Process single images using SAM2 9 | - Serverless architecture using RunPod 10 | 11 | ## Dependencies 12 | - SAM2 13 | - File uploader [bytescale](https://bytescale.com/) (API key BYTESCALE_API_KEY required) 14 | 15 | ## Requirements 16 | 17 | - Python 3.x 18 | - RunPod SDK 19 | 20 | ## Installation 21 | 22 | 1. Clone this repository 23 | 2. Install the required dependencies: 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 3. Download the weights for SAM2 from the [official repository](https://github.com/IDEA-Research/SAM2) 28 | ``` 29 | ./download_weights.sh 30 | mv *.pt checkpoints/ 31 | ``` 32 | 4. Set the BYTESCALE_API_KEY environment variable to your bytescale API key 33 | 34 | ## Usage 35 | 36 | The main handler function in `runpod_handler.py` processes incoming events and routes them to the appropriate function based on the `action` parameter. 37 | 38 | To run the handler locally for testing, use the following command: 39 | 40 | ``` 41 | python runpod_handler.py 42 | ``` 43 | 44 | It will load the file test_input.json and process the event. 45 | 46 | ### Processing a Video 47 | 48 | Send an event with the following structure: 49 | 50 | ```json 51 | { 52 | "input": { 53 | "action": "process_video", 54 | "video_url": "https://example.com/video.mp4", 55 | "output_bucket": "my-output-bucket", 56 | "output_key": "processed_video.mp4" 57 | } 58 | } 59 | ``` 60 | 61 | ### Processing a Single Image 62 | 63 | Send an event with the following structure: 64 | 65 | ```json 66 | { 67 | "input": { 68 | "action": "process_single_image", 69 | // Add other required parameters for image processing 70 | } 71 | } 72 | ``` 73 | 74 | ## File Structure 75 | 76 | - `runpod_handler.py`: Main handler for RunPod serverless functions 77 | - `sam2_processor.py`: Contains the `process_video` and `process_single_image` functions 78 | 79 | ## Contributing 80 | 81 | Contributions are welcome! Please feel free to submit a Pull Request. 82 | 83 | ## License 84 | 85 | MIT License -------------------------------------------------------------------------------- /download_ckpts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | 10 | # Define the URLs for the checkpoints 11 | BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824/" 12 | sam2_hiera_t_url="${BASE_URL}sam2_hiera_tiny.pt" 13 | sam2_hiera_s_url="${BASE_URL}sam2_hiera_small.pt" 14 | sam2_hiera_b_plus_url="${BASE_URL}sam2_hiera_base_plus.pt" 15 | sam2_hiera_l_url="${BASE_URL}sam2_hiera_large.pt" 16 | 17 | 18 | # Download each of the four checkpoints using wget 19 | echo "Downloading sam2_hiera_tiny.pt checkpoint..." 20 | wget $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; } 21 | 22 | echo "Downloading sam2_hiera_small.pt checkpoint..." 23 | wget $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; } 24 | 25 | echo "Downloading sam2_hiera_base_plus.pt checkpoint..." 26 | wget $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; } 27 | 28 | echo "Downloading sam2_hiera_large.pt checkpoint..." 29 | wget $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; } 30 | 31 | echo "All checkpoints are downloaded successfully." 32 | -------------------------------------------------------------------------------- /enviroment.yaml: -------------------------------------------------------------------------------- 1 | name: sam2_env 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.10 # Assuming Python 3.10, adjust if necessary 9 | - pip 10 | - pytorch=2.4.0 11 | - torchvision=0.19.0 12 | - cudatoolkit=12.1 13 | - numpy=2.0.1 14 | - pillow=10.4.0 15 | - tqdm=4.66.5 16 | - pip: 17 | - aiodns==3.2.0 18 | - aiohappyeyeballs==2.3.5 19 | - aiohttp==3.10.1 20 | - aiohttp-retry==2.8.3 21 | - aiosignal==1.3.1 22 | - annotated-types==0.7.0 23 | - antlr4-python3-runtime==4.9.3 24 | - anyio==4.4.0 25 | - async-timeout==4.0.3 26 | - attrs==24.2.0 27 | - backoff==2.2.1 28 | - bcrypt==4.2.0 29 | - boto3==1.34.156 30 | - botocore==1.34.156 31 | - Brotli==1.1.0 32 | - certifi==2024.7.4 33 | - cffi==1.17.0 34 | - charset-normalizer==3.3.2 35 | - click==8.1.7 36 | - colorama==0.4.6 37 | - cryptography==42.0.8 38 | - dnspython==2.6.1 39 | - email_validator==2.2.0 40 | - exceptiongroup==1.2.2 41 | - fastapi==0.112.0 42 | - fastapi-cli==0.0.5 43 | - filelock==3.15.4 44 | - frozenlist==1.4.1 45 | - fsspec==2024.6.1 46 | - h11==0.14.0 47 | - httpcore==1.0.5 48 | - httptools==0.6.1 49 | - httpx==0.27.0 50 | - hydra-core==1.3.2 51 | - idna==3.7 52 | - inquirerpy==0.3.4 53 | - iopath==0.1.10 54 | - itsdangerous==2.2.0 55 | - Jinja2==3.1.4 56 | - jmespath==1.0.1 57 | - markdown-it-py==3.0.0 58 | - MarkupSafe==2.1.5 59 | - mdurl==0.1.2 60 | - mpmath==1.3.0 61 | - multidict==6.0.5 62 | - networkx==3.3 63 | - omegaconf==2.3.0 64 | - opencv-python==4.10.0.84 65 | - orjson==3.10.6 66 | - packaging==24.1 67 | - paramiko==3.4.0 68 | - pfzy==0.3.4 69 | - portalocker==2.10.1 70 | - prettytable==3.10.2 71 | - prompt_toolkit==3.0.47 72 | - py-cpuinfo==9.0.0 73 | - pycares==4.4.0 74 | - pycparser==2.22 75 | - pydantic==2.8.2 76 | - pydantic-extra-types==2.9.0 77 | - pydantic-settings==2.4.0 78 | - pydantic_core==2.20.1 79 | - Pygments==2.18.0 80 | - PyNaCl==1.5.0 81 | - python-dateutil==2.9.0.post0 82 | - python-dotenv==1.0.1 83 | - python-multipart==0.0.9 84 | - PyYAML==6.0.2 85 | - requests==2.32.3 86 | - rich==13.7.1 87 | - runpod==1.7.0 88 | - s3transfer==0.10.2 89 | - git+https://github.com/facebookresearch/segment-anything-2.git@6186d1529a9c26f7b6e658f3e704d4bee386d9ba 90 | - shellingham==1.5.4 91 | - six==1.16.0 92 | - sniffio==1.3.1 93 | - starlette==0.37.2 94 | - sympy==1.13.1 95 | - tomli==2.0.1 96 | - tomlkit==0.13.0 97 | - tqdm-loggable==0.2 98 | - triton==3.0.0 99 | - typer==0.12.3 100 | - typing_extensions==4.12.2 101 | - ujson==5.10.0 102 | - urllib3==2.2.2 103 | - uvicorn==0.30.5 104 | - uvloop==0.19.0 105 | - watchdog==4.0.1 106 | - watchfiles==0.23.0 107 | - wcwidth==0.2.13 108 | - websockets==12.0 109 | - yarl==1.9.4 -------------------------------------------------------------------------------- /image_processing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import base64 5 | from PIL import Image 6 | from io import BytesIO 7 | import requests 8 | import cv2 9 | from tqdm import tqdm 10 | import uuid 11 | import urllib.request 12 | import json 13 | import av 14 | import runpod 15 | 16 | def show_mask(mask, image, obj_id=None, random_color=False): 17 | if random_color: 18 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 19 | else: 20 | colors = [ 21 | (255, 0, 0), (0, 255, 0), (0, 0, 255), 22 | (255, 255, 0), (255, 0, 255), (0, 255, 255), 23 | (128, 0, 0), (0, 128, 0), (0, 0, 128), 24 | (128, 128, 0) 25 | ] 26 | color_idx = 0 if obj_id is None else obj_id % len(colors) 27 | color = colors[color_idx] + (153,) # 153 is roughly 0.6 * 255 for alpha 28 | 29 | h, w = mask.shape[-2:] 30 | mask_image = mask.reshape(h, w, 1) * np.array(color).reshape(1, 1, 4) / 255.0 31 | mask_image = (mask_image * 255).astype(np.uint8) 32 | 33 | # Convert mask_image to BGR for blending 34 | mask_image_bgr = cv2.cvtColor(mask_image, cv2.COLOR_RGBA2BGR) 35 | 36 | # Blend the mask with the original image 37 | cv2.addWeighted(mask_image_bgr, 0.6, image, 1, 0, image) 38 | 39 | return image 40 | 41 | def draw_single_image(mask, image, random_color=False, borders=True): 42 | if random_color: 43 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 44 | else: 45 | color = np.array([30/255, 144/255, 255/255, 0.6]) 46 | h, w = mask.shape[-2:] 47 | mask = mask.astype(np.uint8) 48 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 49 | 50 | if borders: 51 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 52 | contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] 53 | mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 54 | 55 | # Convert mask_image to uint8 and ensure it has the same number of channels as the original image 56 | mask_image = (mask_image * 255).astype(np.uint8) 57 | if mask_image.shape[-1] == 4 and image.shape[-1] == 3: 58 | mask_image = cv2.cvtColor(mask_image, cv2.COLOR_RGBA2RGB) 59 | elif mask_image.shape[-1] == 3 and image.shape[-1] == 4: 60 | mask_image = cv2.cvtColor(mask_image, cv2.COLOR_RGB2RGBA) 61 | 62 | # Ensure mask_image has the same shape as the original image 63 | mask_image = cv2.resize(mask_image, (image.shape[1], image.shape[0])) 64 | 65 | return cv2.addWeighted(image, 1, mask_image, 0.5, 0) 66 | 67 | def apply_mask(mask, image, obj_id=0, random_color=False, borders=True): 68 | # Determine color 69 | if random_color: 70 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 71 | else: 72 | colors = [ 73 | (70, 130, 180), # Steel Blue 74 | (255, 160, 122), # Light Salmon 75 | (152, 251, 152), # Pale Green 76 | (221, 160, 221), # Plum 77 | (176, 196, 222), # Light Steel Blue 78 | (255, 182, 193), # Light Pink 79 | (240, 230, 140), # Khaki 80 | (216, 191, 216), # Thistle 81 | (173, 216, 230), # Light Blue 82 | (255, 228, 196) # Bisque 83 | ] 84 | color_idx = obj_id % len(colors) 85 | color = np.array(colors[color_idx] + (153,)) / 255.0 # 153 is roughly 0.6 * 255 for alpha 86 | 87 | # Create mask image 88 | if mask.ndim == 2: 89 | h, w = mask.shape 90 | elif mask.ndim == 3: 91 | h, w = mask.shape[1:] 92 | mask = mask.squeeze(0) # Remove the first dimension if it's (1, h, w) 93 | else: 94 | raise ValueError("Unexpected mask shape") 95 | 96 | mask = mask.astype(np.uint8) 97 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 98 | 99 | 100 | # Draw borders if requested 101 | if borders: 102 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 103 | contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] 104 | mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 105 | 106 | # Convert mask_image to uint8 and ensure it has the same number of channels as the original image 107 | mask_image = (mask_image * 255).astype(np.uint8) 108 | if mask_image.shape[-1] == 4 and image.shape[-1] == 3: 109 | mask_image = cv2.cvtColor(mask_image, cv2.COLOR_RGBA2RGB) 110 | elif mask_image.shape[-1] == 3 and image.shape[-1] == 4: 111 | mask_image = cv2.cvtColor(mask_image, cv2.COLOR_RGB2RGBA) 112 | 113 | # Ensure mask_image has the same shape as the original image 114 | mask_image = cv2.resize(mask_image, (image.shape[1], image.shape[0])) 115 | 116 | # Blend the mask with the original image 117 | return cv2.addWeighted(image, 1, mask_image, 0.5, 0) 118 | 119 | 120 | def show_points(coords, labels, image, marker_size=20): 121 | pos_points = coords[labels==1] 122 | neg_points = coords[labels==0] 123 | 124 | for point in pos_points: 125 | cv2.circle(image, tuple(point.astype(int)), marker_size // 2, (0, 255, 0), -1) 126 | cv2.circle(image, tuple(point.astype(int)), marker_size // 2, (255, 255, 255), 2) 127 | for point in neg_points: 128 | cv2.circle(image, tuple(point.astype(int)), marker_size // 2, (0, 0, 255), -1) 129 | cv2.circle(image, tuple(point.astype(int)), marker_size // 2, (255, 255, 255), 2) 130 | 131 | def process_mask(mask, img_shape, color): 132 | # Process mask 133 | if mask.ndim == 2: 134 | h, w = mask.shape 135 | elif mask.ndim == 3: 136 | h, w = mask.shape[1:] 137 | mask = mask.squeeze(0) # Remove the first dimension if it's (1, h, w) 138 | else: 139 | raise ValueError("Unexpected mask shape") 140 | 141 | # Convert boolean mask to uint8 before resizing 142 | mask = mask.astype(np.uint8) * 255 143 | 144 | # Ensure mask has the same shape as the image 145 | mask = cv2.resize(mask, (img_shape[1], img_shape[0])) 146 | 147 | # Normalize the mask back to 0-1 range 148 | mask = mask / 255.0 149 | 150 | # Create colored mask 151 | colored_mask = (mask[:, :, np.newaxis] * color).astype(np.uint8) 152 | 153 | # Create alpha channel 154 | alpha = (mask * 255).astype(np.uint8) 155 | 156 | return colored_mask, alpha 157 | 158 | def annotate_frame(frame_idx, frame_names, video_dir, mode, masks=None, points=None, labels=None): 159 | # Load the image 160 | img_path = os.path.join(video_dir, frame_names[frame_idx]) 161 | img = cv2.imread(img_path) 162 | 163 | # Display points if provided 164 | if points is not None and labels is not None: 165 | show_points(points, labels, img) 166 | 167 | if masks is not None: 168 | if mode == "overlayer": 169 | # Current logic: mask applied to image 170 | for obj_id, mask in masks.items(): 171 | img = apply_mask(mask, img, obj_id=obj_id) 172 | elif mode == "masked_image": 173 | # Only show the masked region, other regions are transparent 174 | result = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) 175 | for obj_id, mask in masks.items(): 176 | color = np.array([255, 255, 255]) # Use white color to preserve original image colors 177 | colored_mask, alpha = process_mask(mask, img.shape, color) 178 | # Apply the mask to the original image 179 | masked_region = cv2.bitwise_and(img, colored_mask) 180 | result[:, :, :3] += masked_region 181 | result[:, :, 3] += alpha 182 | img = result 183 | elif mode == "mask_only": 184 | # Only show the mask 185 | result = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) 186 | for obj_id, mask in masks.items(): 187 | color = np.array(get_color(obj_id)) 188 | colored_mask, alpha = process_mask(mask, img.shape, color) 189 | result[:, :, :3] += colored_mask 190 | result[:, :, 3] += alpha 191 | img = result 192 | else: 193 | raise ValueError(f"Unknown mode: {mode}") 194 | 195 | # Save the image locally for debugging 196 | debug_output_dir = "./debug_frames" 197 | os.makedirs(debug_output_dir, exist_ok=True) 198 | debug_frame_path = os.path.join(debug_output_dir, f"frame_{frame_idx}_{mode}.png") 199 | cv2.imwrite(debug_frame_path, img) 200 | 201 | return img 202 | 203 | def get_color(obj_id): 204 | colors = [ 205 | (70, 130, 180), # Steel Blue 206 | (255, 160, 122), # Light Salmon 207 | (152, 251, 152), # Pale Green 208 | (221, 160, 221), # Plum 209 | (176, 196, 222), # Light Steel Blue 210 | (255, 182, 193), # Light Pink 211 | (240, 230, 140), # Khaki 212 | (216, 191, 216), # Thistle 213 | (173, 216, 230), # Light Blue 214 | (255, 228, 196) # Bisque 215 | ] 216 | return colors[obj_id % len(colors)] 217 | 218 | def load_image_from_url(image_url): 219 | response = requests.get(image_url) 220 | response.raise_for_status() 221 | return Image.open(BytesIO(response.content)) 222 | 223 | def extract_frame_from_video(video_url, frame_index): 224 | # Download the video to a temporary file 225 | temp_video_path = f"temp_video_{uuid.uuid4()}.mp4" 226 | urllib.request.urlretrieve(video_url, temp_video_path) 227 | 228 | # Open the video file 229 | cap = cv2.VideoCapture(temp_video_path) 230 | 231 | # Set the frame position 232 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) 233 | 234 | # Read the frame 235 | ret, frame = cap.read() 236 | 237 | # Release the video capture object and delete the temporary file 238 | cap.release() 239 | os.remove(temp_video_path) 240 | 241 | if not ret: 242 | raise Exception(f"Failed to extract frame {frame_index} from video") 243 | 244 | # Convert BGR to RGB 245 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 246 | 247 | return frame_rgb 248 | 249 | def encode_image(image): 250 | _, buffer = cv2.imencode('.png', cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 251 | return buffer 252 | 253 | def upload_to_bytescale(image_buffer): 254 | upload_url = "https://api.bytescale.com/v2/accounts/FW25b7k/uploads/binary" 255 | headers = { 256 | "Authorization": f"Bearer {os.environ.get('BYTESCALE_API_KEY', 'public_FW25b7k33rVdd9MShz7yH28Z1HWr')}", 257 | "Content-Type": "image/png" 258 | } 259 | response = requests.post(upload_url, headers=headers, data=image_buffer.tobytes()) 260 | response.raise_for_status() 261 | return response.json().get('fileUrl') 262 | 263 | def create_output_video(job, session_id, frame_names, video_dir, video_segments, mode): 264 | 265 | # Read video information from the JSON file 266 | video_info_path = os.path.join(video_dir, "video_settings.json") 267 | with open(video_info_path, 'r') as f: 268 | video_info = json.load(f) 269 | 270 | # Extract video dimensions and fps 271 | width = video_info['width'] 272 | height = video_info['height'] 273 | fps = video_info['fps'] 274 | 275 | # Ensure the dimensions are even (required by some codecs) 276 | width = width if width % 2 == 0 else width + 1 277 | height = height if height % 2 == 0 else height + 1 278 | 279 | # Create output container 280 | output_video_path = f"static/segmented_video_{session_id}.mp4" 281 | output = av.open(output_video_path, mode='w') 282 | 283 | # Use H.264 for other modes 284 | stream = output.add_stream('h264', rate='{0:.4f}'.format(fps)) 285 | stream.pix_fmt = 'yuv420p' 286 | stream.options = { 287 | 'crf': '23', # Default CRF value, good balance between quality and file size 288 | 'preset': 'medium' # Default preset, balances encoding speed and compression efficiency 289 | } 290 | 291 | stream.width = width 292 | stream.height = height 293 | 294 | vis_frame_stride = 1 295 | total_frames = len(range(0, len(frame_names), vis_frame_stride)) 296 | 297 | with tqdm(total=total_frames, desc="Writing video", unit="frame") as pbar: 298 | for out_frame_idx in range(0, len(frame_names), vis_frame_stride): 299 | if out_frame_idx in video_segments: 300 | annotated_frame = annotate_frame(out_frame_idx, frame_names, video_dir, mode, masks=video_segments[out_frame_idx]) 301 | else: 302 | annotated_frame = annotate_frame(out_frame_idx, frame_names, video_dir, mode) 303 | 304 | # Convert BGR to RGBA for masked_image and mask_only modes 305 | frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) 306 | 307 | # Create PyAV video frame 308 | frame = av.VideoFrame.from_ndarray(frame_rgb, format='rgb24') 309 | 310 | # Encode and write the frame 311 | for packet in stream.encode(frame): 312 | output.mux(packet) 313 | 314 | pbar.update(1) 315 | if os.environ.get('RUN_ENV') == 'production': 316 | runpod.serverless.progress_update(job, f"Writing video update {out_frame_idx}/{len(frame_names)} (3/3)") 317 | 318 | # Flush the stream 319 | for packet in stream.encode(): 320 | output.mux(packet) 321 | 322 | # Close the output container 323 | output.close() 324 | 325 | return output_video_path 326 | 327 | def upload_video_to_bytescale(video_path): 328 | upload_url = "https://api.bytescale.com/v2/accounts/FW25b7k/uploads/binary" 329 | headers = { 330 | "Authorization": f"Bearer {os.environ.get('BYTESCALE_API_KEY', 'public_FW25b7k33rVdd9MShz7yH28Z1HWr')}", 331 | "Content-Type": "video/mp4" if video_path.lower().endswith('.mp4') else "video/quicktime" if video_path.lower().endswith('.mov') else f"video/{os.path.splitext(video_path)[1][1:]}" 332 | } 333 | 334 | with open(video_path, 'rb') as video_file: 335 | response = requests.post(upload_url, headers=headers, data=video_file) 336 | 337 | if response.status_code != 200: 338 | raise Exception(f"Failed to upload video to Bytescale. Status code: {response.status_code}") 339 | 340 | bytescale_response = response.json() 341 | return bytescale_response.get('fileUrl') 342 | 343 | def upload_video(video_url): 344 | if not video_url: 345 | return {"error": "Missing video_url parameter"} 346 | 347 | # Generate a unique ID for this video processing session 348 | session_id = str(uuid.uuid4()) 349 | video_dir = f"./temp_frames_{session_id}" 350 | os.makedirs(video_dir, exist_ok=True) 351 | video_path = os.path.join(video_dir, "input_video.mp4") 352 | 353 | try: 354 | # Download video 355 | urllib.request.urlretrieve(video_url, video_path) 356 | 357 | # Extract frames 358 | vidcap = cv2.VideoCapture(video_path) 359 | success, image = vidcap.read() 360 | count = 0 361 | while success: 362 | cv2.imwrite(os.path.join(video_dir, f"{count}.jpg"), image) 363 | success, image = vidcap.read() 364 | count += 1 365 | 366 | # Extract video information 367 | fps = vidcap.get(cv2.CAP_PROP_FPS) 368 | width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) 369 | height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 370 | fourcc = int(vidcap.get(cv2.CAP_PROP_FOURCC)) 371 | codec = "".join([chr((fourcc >> 8 * i) & 0xFF) for i in range(4)]) 372 | 373 | # Create a dictionary with video settings 374 | video_settings = { 375 | "fps": fps, 376 | "width": width, 377 | "height": height, 378 | "codec": codec, 379 | "total_frames": count 380 | } 381 | 382 | # Write the video settings to a JSON file 383 | import json 384 | with open(os.path.join(video_dir, "video_settings.json"), "w") as f: 385 | json.dump(video_settings, f, indent=4) 386 | 387 | vidcap.release() 388 | except Exception as e: 389 | # Cleanup in case of an error 390 | if os.path.exists(video_dir): 391 | for file in os.listdir(video_dir): 392 | os.remove(os.path.join(video_dir, file)) 393 | os.rmdir(video_dir) 394 | return {"error": str(e)} 395 | 396 | return { 397 | "message": "Video uploaded, frames extracted, and inference state initialized successfully", 398 | "session_id": session_id, 399 | "video_settings": video_settings 400 | } -------------------------------------------------------------------------------- /only_video_process.json: -------------------------------------------------------------------------------- 1 | { 2 | "session_id": "56371269-1b54-4d13-9721-18768a9c21a9", 3 | "clicks": [ 4 | { 5 | "points": [[ 6 | 603, 7 | 866 8 | ]], 9 | "labels": [1], 10 | "ann_frame_idx": 0, 11 | "ann_obj_id": 1 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiodns==3.2.0 2 | aiohappyeyeballs==2.3.5 3 | aiohttp==3.10.1 4 | aiohttp-retry==2.8.3 5 | aiosignal==1.3.1 6 | annotated-types==0.7.0 7 | antlr4-python3-runtime==4.9.3 8 | anyio==4.4.0 9 | async-timeout==4.0.3 10 | attrs==24.2.0 11 | backoff==2.2.1 12 | bcrypt==4.2.0 13 | boto3==1.34.156 14 | botocore==1.34.156 15 | Brotli==1.1.0 16 | certifi==2024.7.4 17 | cffi==1.17.0 18 | charset-normalizer==3.3.2 19 | click==8.1.7 20 | colorama==0.4.6 21 | cryptography==42.0.8 22 | dnspython==2.6.1 23 | email_validator==2.2.0 24 | exceptiongroup==1.2.2 25 | fastapi==0.112.0 26 | fastapi-cli==0.0.5 27 | filelock==3.15.4 28 | frozenlist==1.4.1 29 | fsspec==2024.6.1 30 | h11==0.14.0 31 | httpcore==1.0.5 32 | httptools==0.6.1 33 | httpx==0.27.0 34 | hydra-core==1.3.2 35 | idna==3.7 36 | inquirerpy==0.3.4 37 | iopath==0.1.10 38 | itsdangerous==2.2.0 39 | Jinja2==3.1.4 40 | jmespath==1.0.1 41 | markdown-it-py==3.0.0 42 | MarkupSafe==2.1.5 43 | mdurl==0.1.2 44 | mpmath==1.3.0 45 | multidict==6.0.5 46 | networkx==3.3 47 | numpy==2.0.1 48 | nvidia-cublas-cu12==12.1.3.1 49 | nvidia-cuda-cupti-cu12==12.1.105 50 | nvidia-cuda-nvrtc-cu12==12.1.105 51 | nvidia-cuda-runtime-cu12==12.1.105 52 | nvidia-cudnn-cu12==9.1.0.70 53 | nvidia-cufft-cu12==11.0.2.54 54 | nvidia-curand-cu12==10.3.2.106 55 | nvidia-cusolver-cu12==11.4.5.107 56 | nvidia-cusparse-cu12==12.1.0.106 57 | nvidia-nccl-cu12==2.20.5 58 | nvidia-nvjitlink-cu12==12.6.20 59 | nvidia-nvtx-cu12==12.1.105 60 | omegaconf==2.3.0 61 | opencv-python==4.10.0.84 62 | orjson==3.10.6 63 | packaging==24.1 64 | paramiko==3.4.0 65 | pfzy==0.3.4 66 | pillow==10.4.0 67 | portalocker==2.10.1 68 | prettytable==3.10.2 69 | prompt_toolkit==3.0.47 70 | py-cpuinfo==9.0.0 71 | pycares==4.4.0 72 | pycparser==2.22 73 | pydantic==2.8.2 74 | pydantic-extra-types==2.9.0 75 | pydantic-settings==2.4.0 76 | pydantic_core==2.20.1 77 | Pygments==2.18.0 78 | PyNaCl==1.5.0 79 | python-dateutil==2.9.0.post0 80 | python-dotenv==1.0.1 81 | python-multipart==0.0.9 82 | PyYAML==6.0.2 83 | requests==2.32.3 84 | rich==13.7.1 85 | runpod==1.7.0 86 | s3transfer==0.10.2 87 | git+https://github.com/facebookresearch/segment-anything-2.git@6186d1529a9c26f7b6e658f3e704d4bee386d9ba 88 | shellingham==1.5.4 89 | six==1.16.0 90 | sniffio==1.3.1 91 | starlette==0.37.2 92 | sympy==1.13.1 93 | tomli==2.0.1 94 | tomlkit==0.13.0 95 | torch==2.4.0 96 | torchvision==0.19.0 97 | tqdm==4.66.5 98 | tqdm-loggable==0.2 99 | triton==3.0.0 100 | typer==0.12.3 101 | typing_extensions==4.12.2 102 | ujson==5.10.0 103 | urllib3==2.2.2 104 | uvicorn==0.30.5 105 | uvloop==0.19.0 106 | watchdog==4.0.1 107 | watchfiles==0.23.0 108 | wcwidth==0.2.13 109 | websockets==12.0 110 | yarl==1.9.4 111 | av==12.3.0 -------------------------------------------------------------------------------- /runpod_handler.py: -------------------------------------------------------------------------------- 1 | import runpod 2 | from sam2_processor import process_video, process_single_image 3 | 4 | def handler(event): 5 | if 'input' not in event: 6 | return {"error": "No input provided"} 7 | 8 | action = event['input'].get('action', 'process_video') 9 | 10 | if action == 'process_video': 11 | return process_video(event) 12 | elif action == 'process_single_image': 13 | return {"refresh_worker": True, "job_results": process_single_image(event)} 14 | else: 15 | return {"error": f"Unknown action: {action}"} 16 | 17 | 18 | runpod.serverless.start({"handler": handler}) -------------------------------------------------------------------------------- /sam2_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /sam2_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import urllib.request 5 | import uuid 6 | import cv2 7 | import logging 8 | import base64 9 | import requests 10 | from sam2.build_sam import build_sam2_video_predictor 11 | from sam2.sam2_image_predictor import SAM2ImagePredictor 12 | from sam2.build_sam import build_sam2 13 | 14 | from tqdm import tqdm 15 | import runpod 16 | 17 | 18 | # Import the image processing functions 19 | from image_processing import ( 20 | show_points, apply_mask, 21 | load_image_from_url, extract_frame_from_video, encode_image, upload_to_bytescale, 22 | create_output_video, upload_video_to_bytescale, upload_video # Add upload_video here 23 | ) 24 | 25 | # Set up logging 26 | logging.basicConfig(level=logging.DEBUG) 27 | logger = logging.getLogger(__name__) 28 | 29 | # Create the static directory if it doesn't exist 30 | os.makedirs('static', exist_ok=True) 31 | 32 | # Initialize the model 33 | logger.debug("Initializing the model...") 34 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 35 | 36 | if torch.cuda.get_device_properties(0).major >= 8: 37 | torch.backends.cuda.matmul.allow_tf32 = True 38 | torch.backends.cudnn.allow_tf32 = True 39 | 40 | sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" 41 | model_cfg = "sam2_hiera_l.yaml" 42 | 43 | predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) 44 | logger.debug("Model initialized successfully.") 45 | 46 | sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") 47 | image_predictor = SAM2ImagePredictor(sam2_model) 48 | 49 | def process_video(job): 50 | job_input = job["input"] 51 | session_id = job_input.get("session_id") 52 | points = np.array(job_input["points"], dtype=np.float32) 53 | labels = np.array(job_input["labels"], dtype=np.int32) 54 | ann_frame_idx = job_input["ann_frame_idx"] 55 | ann_obj_id = job_input["ann_obj_id"] 56 | input_video_url = job_input.get("input_video_url") 57 | mode = job_input.get("mode") 58 | 59 | # Validate that either session_id or input_video_url is provided 60 | if session_id is None and input_video_url is None: 61 | return {"error": "Either session_id or input_video_url must be provided"} 62 | 63 | # If both are provided, prioritize session_id 64 | if session_id is not None and input_video_url is not None: 65 | logger.warning("Both session_id and input_video_url provided. Using existing session.") 66 | input_video_url = None 67 | 68 | if input_video_url: 69 | upload_response = upload_video(input_video_url) 70 | if "error" in upload_response: 71 | return upload_response 72 | session_id = upload_response["session_id"] 73 | 74 | video_dir = f"./temp_frames_{session_id}" 75 | 76 | if not os.path.exists(video_dir): 77 | return {"error": "Invalid session ID"} 78 | 79 | try: 80 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 81 | if torch.cuda.get_device_properties(0).major >= 8: 82 | torch.backends.cuda.matmul.allow_tf32 = True 83 | torch.backends.cudnn.allow_tf32 = True 84 | # Load inference_state 85 | 86 | if os.environ.get('RUN_ENV') == 'production': 87 | runpod.serverless.progress_update(job, f"Initializing inference state (1/3)") 88 | inference_state = predictor.init_state(video_path=video_dir) 89 | except FileNotFoundError: 90 | return {"error": "Inference state not found. Please upload the video first."} 91 | except Exception as e: 92 | return {"error": str(e)} 93 | 94 | # scan all the JPEG frame names in this directory 95 | frame_names = [ 96 | p for p in os.listdir(video_dir) 97 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 98 | ] 99 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 100 | 101 | # Check if ann_frame_idx is valid 102 | if ann_frame_idx < 0 or ann_frame_idx >= len(frame_names): 103 | return {"error": f"Invalid ann_frame_idx. Must be between 0 and {len(frame_names) - 1}"} 104 | 105 | _, out_obj_ids, out_mask_logits = predictor.add_new_points( 106 | inference_state=inference_state, 107 | frame_idx=ann_frame_idx, 108 | obj_id=ann_obj_id, 109 | points=points, 110 | labels=labels, 111 | ) 112 | 113 | # run propagation throughout the video and collect the results in a dict 114 | video_segments = {} # video_segments contains the per-frame segmentation results 115 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): 116 | video_segments[out_frame_idx] = { 117 | out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() 118 | for i, out_obj_id in enumerate(out_obj_ids) 119 | } 120 | if os.environ.get('RUN_ENV') == 'production': 121 | runpod.serverless.progress_update(job, f"Segment update {out_frame_idx}/{len(frame_names)} (2/3)") 122 | 123 | # Create output videos 124 | output_video_path = create_output_video(job, session_id, frame_names, video_dir, video_segments, mode) 125 | 126 | # Upload video to Bytescale API 127 | try: 128 | bytescale_video_url = upload_video_to_bytescale(output_video_path) 129 | except Exception as e: 130 | return {"error": f"Failed to upload video to Bytescale: {str(e)}"} 131 | 132 | return { 133 | "output_video_url": bytescale_video_url, 134 | "session_id": session_id, 135 | } 136 | 137 | def process_single_image(job): 138 | job_input = job["input"] 139 | image_url = job_input.get("input_image_url") 140 | video_url = job_input.get("input_video_url") 141 | frame_index = job_input.get("ann_frame_idx") 142 | points = job_input.get("points") 143 | labels = job_input.get("labels") 144 | 145 | if not (image_url or (video_url and frame_index is not None)): 146 | return {"error": "Missing image_url or video_url with frame_index"} 147 | if points is None or labels is None: 148 | return {"error": "Missing points or labels parameter"} 149 | 150 | try: 151 | points = np.array(points, dtype=np.float32) 152 | labels = np.array(labels, dtype=np.int32) 153 | except ValueError: 154 | return {"error": "Invalid format for points or labels"} 155 | 156 | if video_url and frame_index is not None: 157 | try: 158 | image = extract_frame_from_video(video_url, frame_index) 159 | except Exception as e: 160 | return {"error": f"Failed to extract frame from video: {str(e)}"} 161 | else: 162 | try: 163 | image = load_image_from_url(image_url) 164 | except requests.RequestException as e: 165 | return {"error": f"Failed to download image: {str(e)}"} 166 | except IOError: 167 | return {"error": "Failed to open image"} 168 | 169 | if image is None: 170 | return {"error": "Failed to obtain image"} 171 | 172 | logger.debug("image predictor initialized successfully.") 173 | 174 | image_np = np.array(image) 175 | image_predictor.set_image(image_np) 176 | 177 | try: 178 | masks, scores, _ = image_predictor.predict( 179 | point_coords=points, 180 | point_labels=labels, 181 | multimask_output=True, 182 | ) 183 | except Exception as e: 184 | return {"error": f"Prediction failed: {str(e)}"} 185 | 186 | sorted_ind = np.argsort(scores)[::-1] 187 | masks = masks[sorted_ind] 188 | scores = scores[sorted_ind] 189 | 190 | annotated_image = image_np.copy() 191 | for mask in masks: 192 | annotated_image = apply_mask(mask, annotated_image.copy(), random_color=True) 193 | 194 | # Add points to the final annotated image 195 | show_points(points, labels, annotated_image) 196 | 197 | try: 198 | annotated_buffer = encode_image(annotated_image) 199 | combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255 200 | mask_buffer = encode_image(combined_mask) 201 | except Exception as e: 202 | return {"error": f"Failed to encode output images: {str(e)}"} 203 | 204 | try: 205 | bytescale_image_url = upload_to_bytescale(annotated_buffer) 206 | bytescale_mask_url = upload_to_bytescale(mask_buffer) 207 | except requests.RequestException as e: 208 | return {"error": f"Failed to upload images to Bytescale: {str(e)}"} 209 | 210 | return { 211 | "bytescale_image_url": bytescale_image_url, 212 | "bytescale_mask_url": bytescale_mask_url, 213 | "scores": scores.tolist() 214 | } -------------------------------------------------------------------------------- /test_input.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": { 3 | "action": "process_video", 4 | "input_video_url": "https://upcdn.io/FW25b7k/raw/uploads/test.mp4", 5 | "points": [[603, 866]], 6 | "labels": [1], 7 | "ann_frame_idx": 0, 8 | "ann_obj_id": 1, 9 | "mode": "mask_only" 10 | } 11 | } 12 | 13 | -------------------------------------------------------------------------------- /test_input2.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": { 3 | "action": "process_single_image", 4 | "input_image_url": "https://upcdn.io/FW25b7k/raw/uploads/test.jpg", 5 | "points": [[603, 866]], 6 | "labels": [1] 7 | } 8 | } -------------------------------------------------------------------------------- /test_input3.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": { 3 | "action": "process_single_image", 4 | "input_video_url": "https://upcdn.io/FW25b7k/raw/uploads/test.mp4", 5 | "points": [[603, 866]], 6 | "labels": [1], 7 | "ann_frame_idx": 0, 8 | "ann_obj_id": 1 9 | } 10 | } 11 | 12 | -------------------------------------------------------------------------------- /test_runpod.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import time 4 | import sys 5 | 6 | # Replace these with your actual values 7 | ENDPOINT_ID = "kswcaqyooco6t7" 8 | API_KEY = "WUN4F25QVDDGOKBDPGKT78XYUBT5WH3Z25N4OSIP" 9 | 10 | BASE_URL = f"https://api.runpod.ai/v2/{ENDPOINT_ID}" 11 | HEADERS = { 12 | 'Content-Type': 'application/json', 13 | 'Authorization': f'Bearer {API_KEY}' 14 | } 15 | 16 | def send_request(input_data): 17 | url = f"{BASE_URL}/run" 18 | response = requests.post(url, headers=HEADERS, json=input_data) 19 | print(response) 20 | return response.json() 21 | 22 | def check_status(job_id): 23 | url = f"{BASE_URL}/status/{job_id}" 24 | response = requests.get(url, headers=HEADERS) 25 | return response.json() 26 | 27 | def main(input_file): 28 | # Read input from JSON file 29 | with open(input_file, 'r') as f: 30 | input_data = json.load(f) 31 | 32 | print("Input data:", json.dumps(input_data, indent=2)) 33 | print("\nSending initial request...") 34 | response = send_request(input_data) 35 | print("Initial response:", json.dumps(response, indent=2)) 36 | 37 | if 'id' not in response: 38 | print("Failed to get a job ID. Exiting.") 39 | return 40 | 41 | job_id = response['id'] 42 | print(f"\nJob ID: {job_id}") 43 | 44 | while True: 45 | time.sleep(5) # Wait for 5 seconds before checking status 46 | status_response = check_status(job_id) 47 | print("\nStatus response:", json.dumps(status_response, indent=2)) 48 | 49 | if 'status' in status_response: 50 | if status_response['status'] == 'COMPLETED': 51 | print("\nJob completed successfully!") 52 | if 'output' in status_response: 53 | print("\nOutput:") 54 | print(json.dumps(status_response['output'], indent=2)) 55 | break 56 | elif status_response['status'] in ['FAILED', 'CANCELLED']: 57 | print(f"\nJob {status_response['status'].lower()}.") 58 | if 'error' in status_response: 59 | print("\nError:") 60 | print(json.dumps(status_response['error'], indent=2)) 61 | break 62 | else: 63 | print("Unexpected response format. Continuing to poll...") 64 | 65 | if __name__ == "__main__": 66 | if len(sys.argv) != 2: 67 | print("Usage: python script.py ") 68 | sys.exit(1) 69 | 70 | input_file = sys.argv[1] 71 | main(input_file) -------------------------------------------------------------------------------- /unittest_sam2.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | import numpy as np 4 | from sam2_processor import process_video 5 | 6 | class TestProcessVideo(unittest.TestCase): 7 | 8 | @patch('sam2_processor.upload_video') 9 | @patch('sam2_processor.predictor') 10 | @patch('sam2_processor.os.path.exists') 11 | @patch('sam2_processor.os.listdir') 12 | @patch('sam2_processor.annotate_frame') 13 | @patch('sam2_processor.cv2.VideoWriter') 14 | @patch('sam2_processor.requests.post') 15 | def test_process_video(self, mock_post, mock_video_writer, mock_annotate_frame, 16 | mock_listdir, mock_exists, mock_predictor, mock_upload_video): 17 | # Mock the necessary functions and objects 18 | mock_upload_video.return_value = {"session_id": "test_session"} 19 | mock_exists.return_value = True 20 | mock_listdir.return_value = ['0.jpg', '1.jpg', '2.jpg'] 21 | mock_predictor.init_state.return_value = MagicMock() 22 | mock_predictor.add_new_points.return_value = (None, [1], [np.array([[True, False], [False, True]])]) 23 | mock_predictor.propagate_in_video.return_value = [ 24 | (0, [1], [np.array([[True, False], [False, True]])]), 25 | (1, [1], [np.array([[False, True], [True, False]])]), 26 | (2, [1], [np.array([[True, True], [False, False]])]) 27 | ] 28 | mock_annotate_frame.return_value = np.zeros((100, 100, 3), dtype=np.uint8) 29 | mock_video_writer.return_value = MagicMock() 30 | mock_post.return_value.status_code = 200 31 | mock_post.return_value.json.return_value = {"fileUrl": "https://upcdn.io/FW25b7k/raw/uploads/processed_test.mp4"} 32 | 33 | # Use the provided test case 34 | job = { 35 | "input": { 36 | "action": "process_video", 37 | "input_video_url": "https://upcdn.io/FW25b7k/raw/uploads/test.mp4", 38 | "points": [[603, 866]], 39 | "labels": [1], 40 | "ann_frame_idx": 0, 41 | "ann_obj_id": 1 42 | } 43 | } 44 | 45 | # Call the function 46 | result = process_video(job) 47 | 48 | # Assert the expected results 49 | self.assertIn("output_video_url", result) 50 | # self.assertEqual(result["output_video_url"], "https://upcdn.io/FW25b7k/raw/uploads/processed_test.mp4") 51 | # self.assertEqual(result["session_id"], "test_session") 52 | 53 | # Verify that the mocked functions were called with correct arguments 54 | mock_upload_video.assert_called_once_with("https://upcdn.io/FW25b7k/raw/uploads/test.mp4") 55 | mock_predictor.init_state.assert_called_once() 56 | mock_predictor.add_new_points.assert_called_once_with( 57 | inference_state=mock_predictor.init_state.return_value, 58 | frame_idx=0, 59 | obj_id=1, 60 | points=np.array([[603, 866]], dtype=np.float32), 61 | labels=np.array([1], dtype=np.int32) 62 | ) 63 | mock_predictor.propagate_in_video.assert_called_once() 64 | mock_video_writer.assert_called_once() 65 | mock_post.assert_called_once() 66 | 67 | if __name__ == '__main__': 68 | unittest.main() --------------------------------------------------------------------------------