├── requirements.txt
├── requirements-test.txt
├── cache.db
├── requirements-base.txt
├── start.sh
├── .dockerignore
├── test_api.py
├── docker-compose.yml
├── Dockerfile
├── _wsgi.py
├── README.md
└── model.py
/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-cov
--------------------------------------------------------------------------------
/cache.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiuChoi/label-studio-sam2/master/cache.db
--------------------------------------------------------------------------------
/requirements-base.txt:
--------------------------------------------------------------------------------
1 | gunicorn==22.0.0
2 | label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Execute the gunicorn command
4 | exec gunicorn --bind :${PORT:-9090} --workers ${WORKERS:-1} --threads ${THREADS:-4} --timeout 0 --pythonpath '/app' _wsgi:app
5 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Exclude everything
2 | **
3 |
4 | # Include Dockerfile and docker-compose for reference (optional, decide based on your use case)
5 | !Dockerfile
6 | !docker-compose.yml
7 |
8 | # Include Python application files
9 | !*.py
10 |
11 | # Include requirements files
12 | !requirements*.txt
13 |
14 | # Include script
15 | !*.sh
16 |
17 | # Exclude specific requirements if necessary
18 | # requirements-test.txt (Uncomment if you decide to exclude this)
19 |
--------------------------------------------------------------------------------
/test_api.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains tests for the API of your model. You can run these tests by installing test requirements:
3 |
4 | ```bash
5 | pip install -r requirements-test.txt
6 | ```
7 | Then execute `pytest` in the directory of this file.
8 |
9 | - Change `NewModel` to the name of the class in your model.py file.
10 | - Change the `request` and `expected_response` variables to match the input and output of your model.
11 | """
12 |
13 | import pytest
14 | import json
15 | from model import NewModel
16 |
17 |
18 | @pytest.fixture
19 | def client():
20 | from _wsgi import init_app
21 | app = init_app(model_class=NewModel)
22 | app.config['TESTING'] = True
23 | with app.test_client() as client:
24 | yield client
25 |
26 |
27 | def test_predict(client):
28 | request = {
29 | 'tasks': [{
30 | 'data': {
31 | # Your input test data here
32 | }
33 | }],
34 | # Your labeling configuration here
35 | 'label_config': ''
36 | }
37 |
38 | expected_response = {
39 | 'results': [{
40 | # Your expected result here
41 | }]
42 | }
43 |
44 | response = client.post('/predict', data=json.dumps(request), content_type='application/json')
45 | assert response.status_code == 200
46 | response = json.loads(response.data)
47 | assert response == expected_response
48 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.8"
2 |
3 | services:
4 | ml-backend:
5 | container_name: ml-backend
6 | image: humansignal/ml-backend:v0
7 | build:
8 | context: .
9 | args:
10 | TEST_ENV: ${TEST_ENV}
11 |
12 | # deploy:
13 | # resources:
14 | # reservations:
15 | # devices:
16 | # - driver: nvidia
17 | # count: 1
18 | # capabilities: [ gpu ]
19 |
20 |
21 | environment:
22 | # specify these parameters if you want to use basic auth for the model server
23 | - BASIC_AUTH_USER=
24 | - BASIC_AUTH_PASS=
25 | # set the log level for the model server
26 | - LOG_LEVEL=DEBUG
27 | # any other parameters that you want to pass to the model server
28 | - ANY=PARAMETER
29 | # specify the number of workers and threads for the model server
30 | - WORKERS=1
31 | - THREADS=8
32 | # specify the model directory (likely you don't need to change this)
33 | - MODEL_DIR=/data/models
34 | # specify device
35 | - DEVICE=cuda # or 'cpu' (coming soon)
36 | # SAM2 model config
37 | - MODEL_CONFIG=configs/sam2.1/sam2.1_hiera_l.yaml
38 | # SAM2 checkpoint
39 | - MODEL_CHECKPOINT=sam2.1_hiera_large.pt
40 |
41 | # Specify the Label Studio URL and API key to access
42 | # uploaded, local storage and cloud storage files.
43 | # Do not use 'localhost' as it does not work within Docker containers.
44 | # Use prefix 'http://' or 'https://' for the URL always.
45 | # Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows).
46 | - LABEL_STUDIO_URL=
47 | - LABEL_STUDIO_API_KEY=
48 | ports:
49 | - "9090:9090"
50 | volumes:
51 | - "./data/server:/data"
52 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime
2 | ARG DEBIAN_FRONTEND=noninteractive
3 | ARG TEST_ENV
4 |
5 | WORKDIR /app
6 |
7 | RUN mamba update conda -y
8 |
9 | RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \
10 | --mount=type=cache,target="/var/lib/apt/lists",sharing=locked \
11 | apt-get -y update \
12 | && apt-get install -y git \
13 | && apt-get install -y wget \
14 | && apt-get install -y g++ freeglut3-dev build-essential libx11-dev \
15 | libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev libfreeimage-dev \
16 | && apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev python3-pip gcc
17 |
18 | ENV PYTHONUNBUFFERED=1 \
19 | PYTHONDONTWRITEBYTECODE=1 \
20 | PIP_CACHE_DIR=/.cache \
21 | PORT=9090 \
22 | WORKERS=2 \
23 | THREADS=4 \
24 | CUDA_HOME=/usr/local/cuda
25 |
26 | RUN mamba install nvidia/label/cuda-12.4.0::cuda -y
27 |
28 | ENV CUDA_HOME=/opt/conda \
29 | TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0"
30 |
31 | # install base requirements
32 | COPY requirements-base.txt .
33 | RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
34 | pip install -r requirements-base.txt
35 |
36 | COPY requirements.txt .
37 | RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
38 | pip3 install -r requirements.txt
39 |
40 | # install segment-anything-2
41 | RUN cd / && git clone --depth 1 --branch main --single-branch https://github.com/facebookresearch/sam2.git
42 | WORKDIR /sam2
43 | RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
44 | pip3 install -e .
45 | RUN cd checkpoints && ./download_ckpts.sh
46 |
47 | WORKDIR /app
48 |
49 | # install test requirements if needed
50 | COPY requirements-test.txt .
51 | # build only when TEST_ENV="true"
52 | RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
53 | if [ "$TEST_ENV" = "true" ]; then \
54 | pip3 install -r requirements-test.txt; \
55 | fi
56 |
57 | COPY . ./
58 |
59 | WORKDIR ../sam2
60 |
61 | CMD ["../app/start.sh"]
--------------------------------------------------------------------------------
/_wsgi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import logging
5 | import logging.config
6 |
7 | logging.config.dictConfig({
8 | "version": 1,
9 | "disable_existing_loggers": False,
10 | "formatters": {
11 | "standard": {
12 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
13 | }
14 | },
15 | "handlers": {
16 | "console": {
17 | "class": "logging.StreamHandler",
18 | "level": os.getenv('LOG_LEVEL'),
19 | "stream": "ext://sys.stdout",
20 | "formatter": "standard"
21 | }
22 | },
23 | "root": {
24 | "level": os.getenv('LOG_LEVEL'),
25 | "handlers": [
26 | "console"
27 | ],
28 | "propagate": True
29 | }
30 | })
31 |
32 | from label_studio_ml.api import init_app
33 | from model import NewModel
34 |
35 |
36 | _DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
37 |
38 |
39 | def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
40 | if not os.path.exists(config_path):
41 | return dict()
42 | with open(config_path) as f:
43 | config = json.load(f)
44 | assert isinstance(config, dict)
45 | return config
46 |
47 |
48 | if __name__ == "__main__":
49 | parser = argparse.ArgumentParser(description='Label studio')
50 | parser.add_argument(
51 | '-p', '--port', dest='port', type=int, default=9090,
52 | help='Server port')
53 | parser.add_argument(
54 | '--host', dest='host', type=str, default='0.0.0.0',
55 | help='Server host')
56 | parser.add_argument(
57 | '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
58 | help='Additional LabelStudioMLBase model initialization kwargs')
59 | parser.add_argument(
60 | '-d', '--debug', dest='debug', action='store_true',
61 | help='Switch debug mode')
62 | parser.add_argument(
63 | '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
64 | help='Logging level')
65 | parser.add_argument(
66 | '--model-dir', dest='model_dir', default=os.path.dirname(__file__),
67 | help='Directory where models are stored (relative to the project directory)')
68 | parser.add_argument(
69 | '--check', dest='check', action='store_true',
70 | help='Validate model instance before launching server')
71 | parser.add_argument('--basic-auth-user',
72 | default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
73 | help='Basic auth user')
74 |
75 | parser.add_argument('--basic-auth-pass',
76 | default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
77 | help='Basic auth pass')
78 |
79 | args = parser.parse_args()
80 |
81 | # setup logging level
82 | if args.log_level:
83 | logging.root.setLevel(args.log_level)
84 |
85 | def isfloat(value):
86 | try:
87 | float(value)
88 | return True
89 | except ValueError:
90 | return False
91 |
92 | def parse_kwargs():
93 | param = dict()
94 | for k, v in args.kwargs:
95 | if v.isdigit():
96 | param[k] = int(v)
97 | elif v == 'True' or v == 'true':
98 | param[k] = True
99 | elif v == 'False' or v == 'false':
100 | param[k] = False
101 | elif isfloat(v):
102 | param[k] = float(v)
103 | else:
104 | param[k] = v
105 | return param
106 |
107 | kwargs = get_kwargs_from_config()
108 |
109 | if args.kwargs:
110 | kwargs.update(parse_kwargs())
111 |
112 | if args.check:
113 | print('Check "' + NewModel.__name__ + '" instance creation..')
114 | model = NewModel(**kwargs)
115 |
116 | app = init_app(model_class=NewModel, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass)
117 |
118 | app.run(host=args.host, port=args.port, debug=args.debug)
119 |
120 | else:
121 | # for uWSGI use
122 | app = init_app(model_class=NewModel)
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
18 |
19 | # 使用SAM2自动化标注图片,用于目标检测
20 |
21 | Segment Anything 2, or SAM 2, is a model released by Meta in July 2024. An update to the original Segment Anything Model,
22 | SAM 2 provides even better object segmentation for both images and video. In this guide, we'll show you how to use
23 | SAM 2 for better image labeling with label studio.
24 |
25 | Click on the image below to watch our ML Evangelist Micaela Kaplan explain how to link SAM 2 to your Label Studio Project.
26 | You'll need to follow the instructions below to stand up an instance of SAM2 before you can link your model!
27 |
28 | [](https://www.youtube.com/watch?v=FTg8P8z4RgY)
29 |
30 | ## Before you begin
31 |
32 | Before you begin, you must install the [Label Studio ML backend](https://github.com/HumanSignal/label-studio-ml-backend?tab=readme-ov-file#quickstart).
33 |
34 | This tutorial uses the [`segment_anything_2_image` example](https://github.com/HumanSignal/label-studio-ml-backend/tree/master/label_studio_ml/examples/segment_anything_2_image).
35 |
36 | Note that as of 8/1/2024, SAM2 only runs on GPU.
37 |
38 | ## Labeling configuration
39 |
40 | The current implementation of the Label Studio SAM2 ML backend works using Interactive mode. The user-guided inputs are:
41 | - `KeypointLabels`
42 | - `RectangleLabels`
43 |
44 | And then SAM2 outputs `RectangleLabels` as a result.
45 |
46 | This means all three control tags should be represented in your labeling configuration:
47 |
48 | ```xml
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | ```
66 |
67 | ## Running from source
68 |
69 | 1. To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip:
70 |
71 | ```bash
72 | git clone https://github.com/HumanSignal/label-studio-ml-backend.git
73 | cd label-studio-ml-backend
74 | pip install -e .
75 | cd label_studio_ml/examples/segment_anything_2_image
76 | pip install -r requirements.txt
77 | ```
78 |
79 | 2. Download [`segment-anything-2` repo](https://github.com/facebookresearch/sam2) into the root directory. Install SegmentAnything model and download checkpoints using [the official Meta documentation](https://github.com/facebookresearch/sam2?tab=readme-ov-file#installation)
80 | You should now have the following folder structure:
81 |
82 |
83 | | root directory
84 | | label-studio-ml-backend
85 | | label-studio-ml
86 | | examples
87 | | segment_anything_2_image
88 | | sam2
89 | | sam2
90 | | checkpoints
91 |
92 |
93 | 3. Then you can start the ML backend on the default port `9090`:
94 |
95 | ```bash
96 | cd ~/sam2
97 | label-studio-ml start ../label-studio-ml-backend/label_studio_ml/examples/segment_anything_2_image
98 | ```
99 |
100 | Due to breaking changes from Meta [HERE](https://github.com/facebookresearch/sam2/blob/c2ec8e14a185632b0a5d8b161928ceb50197eddc/sam2/build_sam.py#L20), it is CRUCIAL that you run this command from the sam2 directory at your root directory.
101 |
102 | 4. Connect running ML backend server to Label Studio: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. Read more in the official [Label Studio documentation](https://labelstud.io/guide/ml#Connect-the-model-to-Label-Studio).
103 |
104 | ## Running with Docker
105 |
106 | 1. Start Machine Learning backend on `http://localhost:9090` with prebuilt image:
107 |
108 | ```bash
109 | docker-compose up
110 | ```
111 |
112 | 2. Validate that backend is running
113 |
114 | ```bash
115 | $ curl http://localhost:9090/
116 | {"status":"UP"}
117 | ```
118 |
119 | 3. Connect to the backend from Label Studio running on the same host: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL.
120 |
121 |
122 | ## Configuration
123 | Parameters can be set in `docker-compose.yml` before running the container.
124 |
125 |
126 | The following common parameters are available:
127 | - `DEVICE` - specify the device for the model server (currently only `cuda` is supported, `cpu` is coming soon)
128 | - `MODEL_CONFIG` - SAM2 model configuration file (`sam2_hiera_l.yaml` by default)
129 | - `MODEL_CHECKPOINT` - SAM2 model checkpoint file (`sam2_hiera_large.pt` by default)
130 | - `BASIC_AUTH_USER` - specify the basic auth user for the model server
131 | - `BASIC_AUTH_PASS` - specify the basic auth password for the model server
132 | - `LOG_LEVEL` - set the log level for the model server
133 | - `WORKERS` - specify the number of workers for the model server
134 | - `THREADS` - specify the number of threads for the model server
135 |
136 | ## Customization
137 |
138 | The ML backend can be customized by adding your own models and logic inside the `./segment_anything_2` directory.
139 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import sys
5 | import pathlib
6 | import cv2
7 | from typing import List, Dict, Optional
8 | from uuid import uuid4
9 | from label_studio_ml.model import LabelStudioMLBase
10 | from label_studio_ml.response import ModelResponse
11 | from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path
12 | from PIL import Image
13 |
14 | ROOT_DIR = os.getcwd()
15 | sys.path.insert(0, ROOT_DIR)
16 | from sam2.build_sam import build_sam2
17 | from sam2.sam2_image_predictor import SAM2ImagePredictor
18 |
19 | # 环境变量配置
20 | DEVICE = os.getenv('DEVICE', 'cuda')
21 | MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'configs/sam2.1/sam2.1_hiera_l.yaml')
22 | MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2.1_hiera_large.pt')
23 |
24 | # CUDA 设置
25 | if DEVICE == 'cuda':
26 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
27 | if torch.cuda.get_device_properties(0).major >= 8:
28 | torch.backends.cuda.matmul.allow_tf32 = True
29 | torch.backends.cudnn.allow_tf32 = True
30 |
31 | # 构建模型检查点路径
32 | sam2_checkpoint = str(os.path.join(ROOT_DIR, "checkpoints", MODEL_CHECKPOINT))
33 |
34 | # 全局初始化 SAM2 模型
35 | sam2_model = build_sam2(MODEL_CONFIG, sam2_checkpoint, device=DEVICE)
36 |
37 | class NewModel(LabelStudioMLBase):
38 | """Custom ML Backend model for generating rectanglelabels"""
39 |
40 | def __init__(self, **kwargs):
41 | super(NewModel, self).__init__(**kwargs)
42 | # 定义 value、from_name 和 to_name
43 | self.value = kwargs.get('image_key', 'image') # 图像 URL 的键,默认为 'image'
44 | self.from_name = kwargs.get('from_name', 'label') # 标签名称
45 | self.to_name = kwargs.get('to_name', 'image') # 目标对象名称
46 | # 初始化 SAM2 预测器
47 | self.predictor = SAM2ImagePredictor(sam2_model)
48 |
49 | def set_image(self, image_url, task_id):
50 | """加载并设置图像到 SAM2 预测器"""
51 | image_path = get_local_path(image_url, task_id=task_id)
52 | image = Image.open(image_path)
53 | image = np.array(image.convert("RGB"))
54 | self.predictor.set_image(image)
55 |
56 | def _sam_predict(self, img_url, point_coords=None, point_labels=None, input_box=None, task=None):
57 | """使用 SAM2 进行预测"""
58 | self.set_image(img_url, task.get('id'))
59 | point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
60 | point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
61 | input_box = np.array(input_box, dtype=np.float32) if input_box else None
62 |
63 | masks, scores, _ = self.predictor.predict(
64 | point_coords=point_coords,
65 | point_labels=point_labels,
66 | box=input_box,
67 | multimask_output=False # 仅返回一个掩码
68 | )
69 | return {
70 | 'masks': masks,
71 | 'scores': scores
72 | }
73 |
74 | def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
75 | """基于用户交互预测 rectanglelabels"""
76 | results = []
77 |
78 | # 获取 from_name, to_name 和 value
79 | from_name, to_name, value = self.get_first_tag_occurence('RectangleLabels', 'Image')
80 |
81 | # 如果没有上下文(用户未交互),返回空预测
82 | if not context or not context.get('result'):
83 | return ModelResponse(predictions=[])
84 |
85 | # 获取图像尺寸
86 | image_width = context['result'][0]['original_width']
87 | image_height = context['result'][0]['original_height']
88 |
89 | # 收集用户交互信息
90 | point_coords = []
91 | point_labels = []
92 | input_box = None
93 | selected_label = None
94 | for ctx in context['result']:
95 | x = ctx['value']['x'] * image_width / 100
96 | y = ctx['value']['y'] * image_height / 100
97 | ctx_type = ctx['type']
98 | selected_label = ctx['value'][ctx_type][0] if ctx_type in ctx['value'] else 'object'
99 | if ctx_type == 'keypointlabels':
100 | point_labels.append(int(ctx.get('is_positive', 0)))
101 | point_coords.append([int(x), int(y)])
102 | elif ctx_type == 'rectanglelabels':
103 | box_width = ctx['value']['width'] * image_width / 100
104 | box_height = ctx['value']['height'] * image_height / 100
105 | input_box = [int(x), int(y), int(box_width + x), int(box_height + y)]
106 |
107 | print(f'Point coords: {point_coords}, Point labels: {point_labels}, Input box: {input_box}')
108 |
109 | # 获取图像 URL
110 | img_url = tasks[0]['data'][value]
111 |
112 | # 使用 SAM2 预测
113 | predictor_results = self._sam_predict(
114 | img_url=img_url,
115 | point_coords=point_coords or None,
116 | point_labels=point_labels or None,
117 | input_box=input_box,
118 | task=tasks[0]
119 | )
120 |
121 | # 处理预测结果,生成 rectanglelabels
122 | masks = predictor_results['masks']
123 | scores = predictor_results['scores']
124 |
125 | for mask, score in zip(masks, scores):
126 | # 将掩码转换为二值图像
127 | mask = mask.astype(np.uint8) * 255
128 |
129 | # 找到掩码的轮廓
130 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
131 | if not contours:
132 | continue
133 |
134 | # 获取最大轮廓
135 | contour = max(contours, key=cv2.contourArea)
136 |
137 | # 计算边界框
138 | x, y, w, h = cv2.boundingRect(contour)
139 |
140 | # 转换为 Label Studio 的百分比格式
141 | x_percent = (x / image_width) * 100
142 | y_percent = (y / image_height) * 100
143 | width_percent = (w / image_width) * 100
144 | height_percent = (h / image_height) * 100
145 |
146 | # 构造 rectanglelabels 结果
147 | result = {
148 | 'id': str(uuid4())[:4],
149 | 'from_name': from_name,
150 | 'to_name': to_name,
151 | 'type': 'rectanglelabels',
152 | 'value': {
153 | 'x': x_percent,
154 | 'y': y_percent,
155 | 'width': width_percent,
156 | 'height': height_percent,
157 | 'rectanglelabels': [selected_label]
158 | },
159 | 'score': float(score),
160 | 'original_width': image_width,
161 | 'original_height': image_height,
162 | 'image_rotation': 0
163 | }
164 | results.append(result)
165 |
166 | return ModelResponse(predictions=[{
167 | 'result': results,
168 | 'model_version': '1.0', # 硬编码 model_version
169 | 'score': sum([r['score'] for r in results]) / max(len(results), 1)
170 | }])
--------------------------------------------------------------------------------