├── .gitignore
├── Inference.py
├── LICENSE
├── MORE_USAGES.md
├── README.md
├── app_gradio.py
├── assets
├── Overview.png
├── anomaly.png
├── building.png
├── dog_clip.png
├── eightpic.pdf
├── eightpic.png
├── head_fig.png
├── hf_everything_mode.png
├── hf_points_mode.png
├── logo.png
├── more_usages
│ ├── box_prompt.png
│ ├── draw_edge.png
│ ├── everything_mode.png
│ ├── everything_mode_without_retina.png
│ ├── more_points.png
│ └── text_prompt_cat.png
├── replicate-1.png
├── replicate-2.png
├── replicate-3.png
└── salient.png
├── cog.yaml
├── examples
├── dogs.jpg
├── sa_10039.jpg
├── sa_11025.jpg
├── sa_1309.jpg
├── sa_192.jpg
├── sa_414.jpg
├── sa_561.jpg
├── sa_862.jpg
└── sa_8776.jpg
├── fastsam
├── __init__.py
├── decoder.py
├── model.py
├── predict.py
├── prompt.py
└── utils.py
├── images
├── cat.jpg
└── dogs.jpg
├── output
├── cat.jpg
└── dogs.jpg
├── predict.py
├── requirements.txt
├── segpredict.py
├── setup.py
├── ultralytics
├── .pre-commit-config.yaml
├── __init__.py
├── assets
│ ├── bus.jpg
│ └── zidane.jpg
├── datasets
│ ├── Argoverse.yaml
│ ├── GlobalWheat2020.yaml
│ ├── ImageNet.yaml
│ ├── Objects365.yaml
│ ├── SKU-110K.yaml
│ ├── VOC.yaml
│ ├── VisDrone.yaml
│ ├── coco-pose.yaml
│ ├── coco.yaml
│ ├── coco128-seg.yaml
│ ├── coco128.yaml
│ ├── coco8-pose.yaml
│ ├── coco8-seg.yaml
│ ├── coco8.yaml
│ └── xView.yaml
├── hub
│ ├── __init__.py
│ ├── auth.py
│ ├── session.py
│ └── utils.py
├── models
│ ├── README.md
│ ├── rt-detr
│ │ ├── rtdetr-l.yaml
│ │ └── rtdetr-x.yaml
│ ├── v3
│ │ ├── yolov3-spp.yaml
│ │ ├── yolov3-tiny.yaml
│ │ └── yolov3.yaml
│ ├── v5
│ │ ├── yolov5-p6.yaml
│ │ └── yolov5.yaml
│ ├── v6
│ │ └── yolov6.yaml
│ └── v8
│ │ ├── yolov8-cls.yaml
│ │ ├── yolov8-p2.yaml
│ │ ├── yolov8-p6.yaml
│ │ ├── yolov8-pose-p6.yaml
│ │ ├── yolov8-pose.yaml
│ │ ├── yolov8-rtdetr.yaml
│ │ ├── yolov8-seg.yaml
│ │ └── yolov8.yaml
├── nn
│ ├── __init__.py
│ ├── autobackend.py
│ ├── autoshape.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── conv.py
│ │ ├── head.py
│ │ ├── transformer.py
│ │ └── utils.py
│ └── tasks.py
├── tracker
│ ├── README.md
│ ├── __init__.py
│ ├── cfg
│ │ ├── botsort.yaml
│ │ └── bytetrack.yaml
│ ├── track.py
│ ├── trackers
│ │ ├── __init__.py
│ │ ├── basetrack.py
│ │ ├── bot_sort.py
│ │ └── byte_tracker.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── gmc.py
│ │ ├── kalman_filter.py
│ │ └── matching.py
├── vit
│ ├── __init__.py
│ ├── rtdetr
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── predict.py
│ │ ├── train.py
│ │ └── val.py
│ ├── sam
│ │ ├── __init__.py
│ │ ├── amg.py
│ │ ├── autosize.py
│ │ ├── build.py
│ │ ├── model.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── decoders.py
│ │ │ ├── encoders.py
│ │ │ ├── mask_generator.py
│ │ │ ├── prompt_predictor.py
│ │ │ ├── sam.py
│ │ │ └── transformer.py
│ │ └── predict.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ └── ops.py
└── yolo
│ ├── __init__.py
│ ├── cfg
│ ├── __init__.py
│ └── default.yaml
│ ├── data
│ ├── __init__.py
│ ├── annotator.py
│ ├── augment.py
│ ├── base.py
│ ├── build.py
│ ├── converter.py
│ ├── dataloaders
│ │ ├── __init__.py
│ │ ├── stream_loaders.py
│ │ ├── v5augmentations.py
│ │ └── v5loader.py
│ ├── dataset.py
│ ├── dataset_wrappers.py
│ ├── scripts
│ │ ├── download_weights.sh
│ │ ├── get_coco.sh
│ │ ├── get_coco128.sh
│ │ └── get_imagenet.sh
│ └── utils.py
│ ├── engine
│ ├── __init__.py
│ ├── exporter.py
│ ├── model.py
│ ├── predictor.py
│ ├── results.py
│ ├── trainer.py
│ └── validator.py
│ ├── nas
│ ├── __init__.py
│ ├── model.py
│ ├── predict.py
│ └── val.py
│ ├── utils
│ ├── __init__.py
│ ├── autobatch.py
│ ├── benchmarks.py
│ ├── callbacks
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── clearml.py
│ │ ├── comet.py
│ │ ├── dvc.py
│ │ ├── hub.py
│ │ ├── mlflow.py
│ │ ├── neptune.py
│ │ ├── raytune.py
│ │ ├── tensorboard.py
│ │ └── wb.py
│ ├── checks.py
│ ├── dist.py
│ ├── downloads.py
│ ├── errors.py
│ ├── files.py
│ ├── instance.py
│ ├── loss.py
│ ├── metrics.py
│ ├── ops.py
│ ├── patches.py
│ ├── plotting.py
│ ├── tal.py
│ ├── torch_utils.py
│ └── tuner.py
│ └── v8
│ ├── __init__.py
│ ├── classify
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
│ ├── detect
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
│ ├── pose
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
│ └── segment
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
└── utils
├── __init__.py
├── tools.py
└── tools_gradio.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.pyo
3 | *.pyd
4 | .DS_Store
5 | .idea
6 | weights
7 | build/
8 | *.egg-info/
9 | gradio_cached_examples
--------------------------------------------------------------------------------
/Inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fastsam import FastSAM, FastSAMPrompt
3 | import ast
4 | import torch
5 | from PIL import Image
6 | from utils.tools import convert_box_xywh_to_xyxy
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument(
12 | "--model_path", type=str, default="./weights/FastSAM.pt", help="model"
13 | )
14 | parser.add_argument(
15 | "--img_path", type=str, default="./images/dogs.jpg", help="path to image file"
16 | )
17 | parser.add_argument("--imgsz", type=int, default=1024, help="image size")
18 | parser.add_argument(
19 | "--iou",
20 | type=float,
21 | default=0.9,
22 | help="iou threshold for filtering the annotations",
23 | )
24 | parser.add_argument(
25 | "--text_prompt", type=str, default=None, help='use text prompt eg: "a dog"'
26 | )
27 | parser.add_argument(
28 | "--conf", type=float, default=0.4, help="object confidence threshold"
29 | )
30 | parser.add_argument(
31 | "--output", type=str, default="./output/", help="image save path"
32 | )
33 | parser.add_argument(
34 | "--randomcolor", type=bool, default=True, help="mask random color"
35 | )
36 | parser.add_argument(
37 | "--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]"
38 | )
39 | parser.add_argument(
40 | "--point_label",
41 | type=str,
42 | default="[0]",
43 | help="[1,0] 0:background, 1:foreground",
44 | )
45 | parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes")
46 | parser.add_argument(
47 | "--better_quality",
48 | type=str,
49 | default=False,
50 | help="better quality using morphologyEx",
51 | )
52 | device = torch.device(
53 | "cuda"
54 | if torch.cuda.is_available()
55 | else "mps"
56 | if torch.backends.mps.is_available()
57 | else "cpu"
58 | )
59 | parser.add_argument(
60 | "--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
61 | )
62 | parser.add_argument(
63 | "--retina",
64 | type=bool,
65 | default=True,
66 | help="draw high-resolution segmentation masks",
67 | )
68 | parser.add_argument(
69 | "--withContours", type=bool, default=False, help="draw the edges of the masks"
70 | )
71 | return parser.parse_args()
72 |
73 |
74 | def main(args):
75 | # load model
76 | model = FastSAM(args.model_path)
77 | args.point_prompt = ast.literal_eval(args.point_prompt)
78 | args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt))
79 | args.point_label = ast.literal_eval(args.point_label)
80 | input = Image.open(args.img_path)
81 | input = input.convert("RGB")
82 | everything_results = model(
83 | input,
84 | device=args.device,
85 | retina_masks=args.retina,
86 | imgsz=args.imgsz,
87 | conf=args.conf,
88 | iou=args.iou
89 | )
90 | bboxes = None
91 | points = None
92 | point_label = None
93 | prompt_process = FastSAMPrompt(input, everything_results, device=args.device)
94 | if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0:
95 | ann = prompt_process.box_prompt(bboxes=args.box_prompt)
96 | bboxes = args.box_prompt
97 | elif args.text_prompt != None:
98 | ann = prompt_process.text_prompt(text=args.text_prompt)
99 | elif args.point_prompt[0] != [0, 0]:
100 | ann = prompt_process.point_prompt(
101 | points=args.point_prompt, pointlabel=args.point_label
102 | )
103 | points = args.point_prompt
104 | point_label = args.point_label
105 | else:
106 | ann = prompt_process.everything_prompt()
107 | prompt_process.plot(
108 | annotations=ann,
109 | output_path=args.output+args.img_path.split("/")[-1],
110 | bboxes = bboxes,
111 | points = points,
112 | point_label = point_label,
113 | withContours=args.withContours,
114 | better_quality=args.better_quality,
115 | )
116 |
117 |
118 |
119 |
120 | if __name__ == "__main__":
121 | args = parse_args()
122 | main(args)
123 |
--------------------------------------------------------------------------------
/MORE_USAGES.md:
--------------------------------------------------------------------------------
1 | # MORE_USAGES
2 |
3 |
4 |
5 | ### Everything mode
6 | Use --imgsz to change different input sizes.
7 |
8 | ```shell
9 | python Inference.py --model_path ./weights/FastSAM.pt \
10 | --img_path ./images/dogs.jpg \
11 | --imgsz 720 \
12 | ```
13 | 
14 |
15 |
16 |
17 | ### Use more points
18 | p
19 | ```shell
20 | python Inference.py --model_path ./weights/FastSAM.pt \
21 | --img_path ./images/dogs.jpg \
22 | --point_prompt "[[520,360],[620,300],[520,300],[620,360]]" \
23 | --point_label "[1,0,1,0]"
24 | ```
25 | 
26 | ### draw mask edge
27 | Use `--withContours True` to draw the edge of the mask.
28 |
29 | When `--better_quality True` is set, the edge will be more smooth.
30 |
31 | ```shell
32 | python Inference.py --model_path ./weights/FastSAM.pt \
33 | --img_path ./images/dogs.jpg \
34 | --point_prompt "[[620,360]]" \
35 | --point_label "[1]" \
36 | --withContours True \
37 | --better_quality True
38 | ```
39 |
40 | 
41 | ### use box prompt
42 | Use `--box_prompt [x,y,w,h]` to specify the bounding box of the foreground object
43 | ```shell
44 | python Inference.py --model_path ./weights/FastSAM.pt \
45 | --img_path ./images/dogs.jpg \
46 | --box_prompt "[[570,200,230,400]]"
47 | ```
48 | 
49 |
50 | ### use text prompt
51 | Use `--text_prompt "text"` to specify the text prompt
52 | ```shell
53 | python Inference.py --model_path ./weights/FastSAM.pt \
54 | --img_path ./images/cat.jpg \
55 | --text_prompt "cat" \
56 | --better_quality True \
57 | --withContours True
58 | ```
59 | 
60 |
--------------------------------------------------------------------------------
/assets/Overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/Overview.png
--------------------------------------------------------------------------------
/assets/anomaly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/anomaly.png
--------------------------------------------------------------------------------
/assets/building.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/building.png
--------------------------------------------------------------------------------
/assets/dog_clip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/dog_clip.png
--------------------------------------------------------------------------------
/assets/eightpic.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/eightpic.pdf
--------------------------------------------------------------------------------
/assets/eightpic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/eightpic.png
--------------------------------------------------------------------------------
/assets/head_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/head_fig.png
--------------------------------------------------------------------------------
/assets/hf_everything_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/hf_everything_mode.png
--------------------------------------------------------------------------------
/assets/hf_points_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/hf_points_mode.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/logo.png
--------------------------------------------------------------------------------
/assets/more_usages/box_prompt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/box_prompt.png
--------------------------------------------------------------------------------
/assets/more_usages/draw_edge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/draw_edge.png
--------------------------------------------------------------------------------
/assets/more_usages/everything_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/everything_mode.png
--------------------------------------------------------------------------------
/assets/more_usages/everything_mode_without_retina.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/everything_mode_without_retina.png
--------------------------------------------------------------------------------
/assets/more_usages/more_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/more_points.png
--------------------------------------------------------------------------------
/assets/more_usages/text_prompt_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/text_prompt_cat.png
--------------------------------------------------------------------------------
/assets/replicate-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/replicate-1.png
--------------------------------------------------------------------------------
/assets/replicate-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/replicate-2.png
--------------------------------------------------------------------------------
/assets/replicate-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/replicate-3.png
--------------------------------------------------------------------------------
/assets/salient.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/salient.png
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 | # Thanks for chenxwh.
4 |
5 | build:
6 | # set to true if your model requires a GPU
7 | gpu: true
8 | cuda: "11.7"
9 | system_packages:
10 | - "libgl1-mesa-glx"
11 | - "libglib2.0-0"
12 | python_version: "3.8"
13 | python_packages:
14 | - "matplotlib==3.7.1"
15 | - "opencv-python==4.7.0.72"
16 | - "Pillow==9.5.0"
17 | - "PyYAML==6.0"
18 | - "requests==2.31.0"
19 | - "scipy==1.10.1"
20 | - "torch==2.0.1"
21 | - "torchvision==0.15.2"
22 | - "tqdm==4.65.0"
23 | - "pandas==2.0.2"
24 | - "seaborn==0.12.0"
25 | - "ultralytics==8.0.121"
26 | - git+https://github.com/openai/CLIP.git
27 | predict: "predict.py:Predictor"
28 |
--------------------------------------------------------------------------------
/examples/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/dogs.jpg
--------------------------------------------------------------------------------
/examples/sa_10039.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_10039.jpg
--------------------------------------------------------------------------------
/examples/sa_11025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_11025.jpg
--------------------------------------------------------------------------------
/examples/sa_1309.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_1309.jpg
--------------------------------------------------------------------------------
/examples/sa_192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_192.jpg
--------------------------------------------------------------------------------
/examples/sa_414.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_414.jpg
--------------------------------------------------------------------------------
/examples/sa_561.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_561.jpg
--------------------------------------------------------------------------------
/examples/sa_862.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_862.jpg
--------------------------------------------------------------------------------
/examples/sa_8776.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_8776.jpg
--------------------------------------------------------------------------------
/fastsam/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import FastSAM
4 | from .predict import FastSAMPredictor
5 | from .prompt import FastSAMPrompt
6 | # from .val import FastSAMValidator
7 | from .decoder import FastSAMDecoder
8 |
9 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
10 |
--------------------------------------------------------------------------------
/fastsam/decoder.py:
--------------------------------------------------------------------------------
1 | from .model import FastSAM
2 | import numpy as np
3 | from PIL import Image
4 | from typing import Optional, List, Tuple, Union
5 |
6 |
7 | class FastSAMDecoder:
8 | def __init__(
9 | self,
10 | model: FastSAM,
11 | device: str='cpu',
12 | conf: float=0.4,
13 | iou: float=0.9,
14 | imgsz: int=1024,
15 | retina_masks: bool=True,
16 | ):
17 | self.model = model
18 | self.device = device
19 | self.retina_masks = retina_masks
20 | self.imgsz = imgsz
21 | self.conf = conf
22 | self.iou = iou
23 | self.image = None
24 | self.image_embedding = None
25 |
26 | def run_encoder(self, image):
27 | if isinstance(image,str):
28 | image = np.array(Image.open(image))
29 | self.image = image
30 | image_embedding = self.model(
31 | self.image,
32 | device=self.device,
33 | retina_masks=self.retina_masks,
34 | imgsz=self.imgsz,
35 | conf=self.conf,
36 | iou=self.iou
37 | )
38 | return image_embedding[0].numpy()
39 |
40 | def run_decoder(
41 | self,
42 | image_embedding,
43 | point_prompt: Optional[np.ndarray]=None,
44 | point_label: Optional[np.ndarray]=None,
45 | box_prompt: Optional[np.ndarray]=None,
46 | text_prompt: Optional[str]=None,
47 | )->np.ndarray:
48 | self.image_embedding = image_embedding
49 | if point_prompt is not None:
50 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
51 | return ann
52 | elif box_prompt is not None:
53 | ann = self.box_prompt(bbox=box_prompt)
54 | return ann
55 | elif text_prompt is not None:
56 | ann = self.text_prompt(text=text_prompt)
57 | return ann
58 | else:
59 | return None
60 |
61 | def box_prompt(self, bbox):
62 | assert (bbox[2] != 0 and bbox[3] != 0)
63 | masks = self.image_embedding.masks.data
64 | target_height = self.image.shape[0]
65 | target_width = self.image.shape[1]
66 | h = masks.shape[1]
67 | w = masks.shape[2]
68 | if h != target_height or w != target_width:
69 | bbox = [
70 | int(bbox[0] * w / target_width),
71 | int(bbox[1] * h / target_height),
72 | int(bbox[2] * w / target_width),
73 | int(bbox[3] * h / target_height), ]
74 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
75 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
76 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
77 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
78 |
79 | # IoUs = torch.zeros(len(masks), dtype=torch.float32)
80 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
81 |
82 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
83 | orig_masks_area = np.sum(masks, axis=(1, 2))
84 |
85 | union = bbox_area + orig_masks_area - masks_area
86 | IoUs = masks_area / union
87 | max_iou_index = np.argmax(IoUs)
88 |
89 | return np.array([masks[max_iou_index].cpu().numpy()])
90 |
91 | def point_prompt(self, points, pointlabel): # numpy
92 |
93 | masks = self._format_results(self.image_embedding[0], 0)
94 | target_height = self.image.shape[0]
95 | target_width = self.image.shape[1]
96 | h = masks[0]['segmentation'].shape[0]
97 | w = masks[0]['segmentation'].shape[1]
98 | if h != target_height or w != target_width:
99 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
100 | onemask = np.zeros((h, w))
101 | masks = sorted(masks, key=lambda x: x['area'], reverse=True)
102 | for i, annotation in enumerate(masks):
103 | if type(annotation) == dict:
104 | mask = annotation['segmentation']
105 | else:
106 | mask = annotation
107 | for i, point in enumerate(points):
108 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
109 | onemask[mask] = 1
110 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
111 | onemask[mask] = 0
112 | onemask = onemask >= 1
113 | return np.array([onemask])
114 |
115 | def _format_results(self, result, filter=0):
116 | annotations = []
117 | n = len(result.masks.data)
118 | for i in range(n):
119 | annotation = {}
120 | mask = result.masks.data[i] == 1.0
121 |
122 | if np.sum(mask) < filter:
123 | continue
124 | annotation['id'] = i
125 | annotation['segmentation'] = mask
126 | annotation['bbox'] = result.boxes.data[i]
127 | annotation['score'] = result.boxes.conf[i]
128 | annotation['area'] = annotation['segmentation'].sum()
129 | annotations.append(annotation)
130 | return annotations
131 |
--------------------------------------------------------------------------------
/fastsam/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | FastSAM model interface.
4 |
5 | Usage - Predict:
6 | from ultralytics import FastSAM
7 |
8 | model = FastSAM('last.pt')
9 | results = model.predict('ultralytics/assets/bus.jpg')
10 | """
11 |
12 | from ultralytics.yolo.cfg import get_cfg
13 | from ultralytics.yolo.engine.exporter import Exporter
14 | from ultralytics.yolo.engine.model import YOLO
15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
16 | from ultralytics.yolo.utils.checks import check_imgsz
17 |
18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
19 | from .predict import FastSAMPredictor
20 |
21 |
22 | class FastSAM(YOLO):
23 |
24 | @smart_inference_mode()
25 | def predict(self, source=None, stream=False, **kwargs):
26 | """
27 | Perform prediction using the YOLO model.
28 |
29 | Args:
30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
31 | Accepts all source types accepted by the YOLO model.
32 | stream (bool): Whether to stream the predictions or not. Defaults to False.
33 | **kwargs : Additional keyword arguments passed to the predictor.
34 | Check the 'configuration' section in the documentation for all available options.
35 |
36 | Returns:
37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results.
38 | """
39 | if source is None:
40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
42 | overrides = self.overrides.copy()
43 | overrides['conf'] = 0.25
44 | overrides.update(kwargs) # prefer kwargs
45 | overrides['mode'] = kwargs.get('mode', 'predict')
46 | assert overrides['mode'] in ['track', 'predict']
47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
48 | self.predictor = FastSAMPredictor(overrides=overrides)
49 | self.predictor.setup_model(model=self.model, verbose=False)
50 | try:
51 | return self.predictor(source, stream=stream)
52 | except Exception as e:
53 | return None
54 |
55 | def train(self, **kwargs):
56 | """Function trains models but raises an error as FastSAM models do not support training."""
57 | raise NotImplementedError("Currently, the training codes are on the way.")
58 |
59 | def val(self, **kwargs):
60 | """Run validation given dataset."""
61 | overrides = dict(task='segment', mode='val')
62 | overrides.update(kwargs) # prefer kwargs
63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1)
65 | validator = FastSAM(args=args)
66 | validator(model=self.model)
67 | self.metrics = validator.metrics
68 | return validator.metrics
69 |
70 | @smart_inference_mode()
71 | def export(self, **kwargs):
72 | """
73 | Export model.
74 |
75 | Args:
76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
77 | """
78 | overrides = dict(task='detect')
79 | overrides.update(kwargs)
80 | overrides['mode'] = 'export'
81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
82 | args.task = self.task
83 | if args.imgsz == DEFAULT_CFG.imgsz:
84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
85 | if args.batch == DEFAULT_CFG.batch:
86 | args.batch = 1 # default to 1 if not modified
87 | return Exporter(overrides=args)(model=self.model)
88 |
89 | def info(self, detailed=False, verbose=True):
90 | """
91 | Logs model info.
92 |
93 | Args:
94 | detailed (bool): Show detailed information about model.
95 | verbose (bool): Controls verbosity.
96 | """
97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98 |
99 | def __call__(self, source=None, stream=False, **kwargs):
100 | """Calls the 'predict' function with given arguments to perform object detection."""
101 | return self.predict(source, stream, **kwargs)
102 |
103 | def __getattr__(self, attr):
104 | """Raises error if object has no requested attribute."""
105 | name = self.__class__.__name__
106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
107 |
--------------------------------------------------------------------------------
/fastsam/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ultralytics.yolo.engine.results import Results
4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops
5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
6 | from .utils import bbox_iou
7 |
8 | class FastSAMPredictor(DetectionPredictor):
9 |
10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11 | super().__init__(cfg, overrides, _callbacks)
12 | self.args.task = 'segment'
13 |
14 | def postprocess(self, preds, img, orig_imgs):
15 | """TODO: filter by classes."""
16 | p = ops.non_max_suppression(preds[0],
17 | self.args.conf,
18 | self.args.iou,
19 | agnostic=self.args.agnostic_nms,
20 | max_det=self.args.max_det,
21 | nc=len(self.model.names),
22 | classes=self.args.classes)
23 |
24 | results = []
25 | if len(p) == 0 or len(p[0]) == 0:
26 | print("No object detected.")
27 | return results
28 |
29 | full_box = torch.zeros_like(p[0][0])
30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
31 | full_box = full_box.view(1, -1)
32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
33 | if critical_iou_index.numel() != 0:
34 | full_box[0][4] = p[0][critical_iou_index][:,4]
35 | full_box[0][6:] = p[0][critical_iou_index][:,6:]
36 | p[0][critical_iou_index] = full_box
37 |
38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
39 | for i, pred in enumerate(p):
40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
41 | path = self.batch[0]
42 | img_path = path[i] if isinstance(path, list) else path
43 | if not len(pred): # save empty boxes
44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
45 | continue
46 | if self.args.retina_masks:
47 | if not isinstance(orig_imgs, torch.Tensor):
48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
50 | else:
51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
52 | if not isinstance(orig_imgs, torch.Tensor):
53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54 | results.append(
55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
56 | return results
57 |
--------------------------------------------------------------------------------
/fastsam/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 |
5 |
6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold.
8 | Args:
9 | boxes: (n, 4)
10 | image_shape: (height, width)
11 | threshold: pixel threshold
12 | Returns:
13 | adjusted_boxes: adjusted bounding boxes
14 | '''
15 |
16 | # Image dimensions
17 | h, w = image_shape
18 |
19 | # Adjust boxes
20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
28 |
29 | return boxes
30 |
31 |
32 |
33 | def convert_box_xywh_to_xyxy(box):
34 | x1 = box[0]
35 | y1 = box[1]
36 | x2 = box[0] + box[2]
37 | y2 = box[1] + box[3]
38 | return [x1, y1, x2, y2]
39 |
40 |
41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
43 | Args:
44 | box1: (4, )
45 | boxes: (n, 4)
46 | Returns:
47 | high_iou_indices: Indices of boxes with IoU > thres
48 | '''
49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape)
50 | # obtain coordinates for intersections
51 | x1 = torch.max(box1[0], boxes[:, 0])
52 | y1 = torch.max(box1[1], boxes[:, 1])
53 | x2 = torch.min(box1[2], boxes[:, 2])
54 | y2 = torch.min(box1[3], boxes[:, 3])
55 |
56 | # compute the area of intersection
57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
58 |
59 | # compute the area of both individual boxes
60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
62 |
63 | # compute the area of union
64 | union = box1_area + box2_area - intersection
65 |
66 | # compute the IoU
67 | iou = intersection / union # Should be shape (n, )
68 | if raw_output:
69 | if iou.numel() == 0:
70 | return 0
71 | return iou
72 |
73 | # get indices of boxes with IoU > thres
74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
75 |
76 | return high_iou_indices
77 |
78 |
79 | def image_to_np_ndarray(image):
80 | if type(image) is str:
81 | return np.array(Image.open(image))
82 | elif issubclass(type(image), Image.Image):
83 | return np.array(image)
84 | elif type(image) is np.ndarray:
85 | return image
86 | return None
87 |
--------------------------------------------------------------------------------
/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/images/cat.jpg
--------------------------------------------------------------------------------
/images/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/images/dogs.jpg
--------------------------------------------------------------------------------
/output/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/output/cat.jpg
--------------------------------------------------------------------------------
/output/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/output/dogs.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Base-----------------------------------
2 | matplotlib>=3.2.2
3 | opencv-python>=4.6.0
4 | Pillow>=7.1.2
5 | PyYAML>=5.3.1
6 | requests>=2.23.0
7 | scipy>=1.4.1
8 | torch>=1.7.0
9 | torchvision>=0.8.1
10 | tqdm>=4.64.0
11 |
12 | pandas>=1.1.4
13 | seaborn>=0.11.0
14 |
15 | gradio==3.35.2
16 |
17 | # Ultralytics-----------------------------------
18 | # ultralytics == 8.0.120
19 |
20 |
--------------------------------------------------------------------------------
/segpredict.py:
--------------------------------------------------------------------------------
1 | from fastsam import FastSAM, FastSAMPrompt
2 | import torch
3 |
4 | model = FastSAM('FastSAM.pt')
5 | IMAGE_PATH = './images/dogs.jpg'
6 | DEVICE = torch.device(
7 | "cuda"
8 | if torch.cuda.is_available()
9 | else "mps"
10 | if torch.backends.mps.is_available()
11 | else "cpu"
12 | )
13 | everything_results = model(
14 | IMAGE_PATH,
15 | device=DEVICE,
16 | retina_masks=True,
17 | imgsz=1024,
18 | conf=0.4,
19 | iou=0.9,
20 | )
21 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
22 |
23 | # # everything prompt
24 | ann = prompt_process.everything_prompt()
25 |
26 | # # bbox prompt
27 | # # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
28 | # bboxes default shape [[0,0,0,0]] -> [[x1,y1,x2,y2]]
29 | # ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
30 | # ann = prompt_process.box_prompt(bboxes=[[200, 200, 300, 300], [500, 500, 600, 600]])
31 |
32 | # # text prompt
33 | # ann = prompt_process.text_prompt(text='a photo of a dog')
34 |
35 | # # point prompt
36 | # # points default [[0,0]] [[x1,y1],[x2,y2]]
37 | # # point_label default [0] [1,0] 0:background, 1:foreground
38 | # ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
39 |
40 | # point prompt
41 | # points default [[0,0]] [[x1,y1],[x2,y2]]
42 | # point_label default [0] [1,0] 0:background, 1:foreground
43 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
44 |
45 | prompt_process.plot(
46 | annotations=ann,
47 | output='./output/',
48 | mask_random_color=True,
49 | better_quality=True,
50 | retina=False,
51 | withContours=True,
52 | )
53 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | from setuptools import find_packages, setup
5 |
6 | REQUIREMENTS = [i.strip() for i in open("requirements.txt").readlines()]
7 | REQUIREMENTS += [
8 | "CLIP @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=CLIP"
9 | ]
10 |
11 | setup(
12 | name="fastsam",
13 | version="0.1.1",
14 | install_requires=REQUIREMENTS,
15 | packages=["fastsam", "fastsam_tools"],
16 | package_dir= {
17 | "fastsam": "fastsam",
18 | "fastsam_tools": "utils",
19 | },
20 | url="https://github.com/CASIA-IVA-Lab/FastSAM"
21 | )
22 |
--------------------------------------------------------------------------------
/ultralytics/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
3 |
4 | exclude: 'docs/'
5 | # Define bot property if installed via https://github.com/marketplace/pre-commit-ci
6 | ci:
7 | autofix_prs: true
8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
9 | autoupdate_schedule: monthly
10 | # submodules: true
11 |
12 | repos:
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v4.4.0
15 | hooks:
16 | - id: end-of-file-fixer
17 | - id: trailing-whitespace
18 | - id: check-case-conflict
19 | # - id: check-yaml
20 | - id: check-docstring-first
21 | - id: double-quote-string-fixer
22 | - id: detect-private-key
23 |
24 | - repo: https://github.com/asottile/pyupgrade
25 | rev: v3.4.0
26 | hooks:
27 | - id: pyupgrade
28 | name: Upgrade code
29 |
30 | - repo: https://github.com/PyCQA/isort
31 | rev: 5.12.0
32 | hooks:
33 | - id: isort
34 | name: Sort imports
35 |
36 | - repo: https://github.com/google/yapf
37 | rev: v0.33.0
38 | hooks:
39 | - id: yapf
40 | name: YAPF formatting
41 |
42 | - repo: https://github.com/executablebooks/mdformat
43 | rev: 0.7.16
44 | hooks:
45 | - id: mdformat
46 | name: MD formatting
47 | additional_dependencies:
48 | - mdformat-gfm
49 | - mdformat-black
50 | # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
51 |
52 | - repo: https://github.com/PyCQA/flake8
53 | rev: 6.0.0
54 | hooks:
55 | - id: flake8
56 | name: PEP8
57 |
58 | - repo: https://github.com/codespell-project/codespell
59 | rev: v2.2.4
60 | hooks:
61 | - id: codespell
62 | args:
63 | - --ignore-words-list=crate,nd,strack,dota
64 |
65 | # - repo: https://github.com/asottile/yesqa
66 | # rev: v1.4.0
67 | # hooks:
68 | # - id: yesqa
69 |
70 | # - repo: https://github.com/asottile/dead
71 | # rev: v1.5.0
72 | # hooks:
73 | # - id: dead
74 |
--------------------------------------------------------------------------------
/ultralytics/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | __version__ = '8.0.120'
4 |
5 | from ultralytics.hub import start
6 | from ultralytics.vit.rtdetr import RTDETR
7 | from ultralytics.vit.sam import SAM
8 | from ultralytics.yolo.engine.model import YOLO
9 | from ultralytics.yolo.nas import NAS
10 | from ultralytics.yolo.utils.checks import check_yolo as checks
11 |
12 | __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start' # allow simpler import
13 |
--------------------------------------------------------------------------------
/ultralytics/assets/bus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/assets/bus.jpg
--------------------------------------------------------------------------------
/ultralytics/assets/zidane.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/assets/zidane.jpg
--------------------------------------------------------------------------------
/ultralytics/datasets/Argoverse.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI
3 | # Example usage: yolo train data=Argoverse.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── Argoverse ← downloads here (31.3 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/Argoverse # dataset root dir
12 | train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images
13 | val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images
14 | test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview
15 |
16 | # Classes
17 | names:
18 | 0: person
19 | 1: bicycle
20 | 2: car
21 | 3: motorcycle
22 | 4: bus
23 | 5: truck
24 | 6: traffic_light
25 | 7: stop_sign
26 |
27 |
28 | # Download script/URL (optional) ---------------------------------------------------------------------------------------
29 | download: |
30 | import json
31 | from tqdm import tqdm
32 | from ultralytics.yolo.utils.downloads import download
33 | from pathlib import Path
34 |
35 | def argoverse2yolo(set):
36 | labels = {}
37 | a = json.load(open(set, "rb"))
38 | for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."):
39 | img_id = annot['image_id']
40 | img_name = a['images'][img_id]['name']
41 | img_label_name = f'{img_name[:-3]}txt'
42 |
43 | cls = annot['category_id'] # instance class id
44 | x_center, y_center, width, height = annot['bbox']
45 | x_center = (x_center + width / 2) / 1920.0 # offset and scale
46 | y_center = (y_center + height / 2) / 1200.0 # offset and scale
47 | width /= 1920.0 # scale
48 | height /= 1200.0 # scale
49 |
50 | img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']]
51 | if not img_dir.exists():
52 | img_dir.mkdir(parents=True, exist_ok=True)
53 |
54 | k = str(img_dir / img_label_name)
55 | if k not in labels:
56 | labels[k] = []
57 | labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n")
58 |
59 | for k in labels:
60 | with open(k, "w") as f:
61 | f.writelines(labels[k])
62 |
63 |
64 | # Download
65 | dir = Path(yaml['path']) # dataset root dir
66 | urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
67 | download(urls, dir=dir)
68 |
69 | # Convert
70 | annotations_dir = 'Argoverse-HD/annotations/'
71 | (dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images'
72 | for d in "train.json", "val.json":
73 | argoverse2yolo(dir / annotations_dir / d) # convert VisDrone annotations to YOLO labels
74 |
--------------------------------------------------------------------------------
/ultralytics/datasets/GlobalWheat2020.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan
3 | # Example usage: yolo train data=GlobalWheat2020.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── GlobalWheat2020 ← downloads here (7.0 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/GlobalWheat2020 # dataset root dir
12 | train: # train images (relative to 'path') 3422 images
13 | - images/arvalis_1
14 | - images/arvalis_2
15 | - images/arvalis_3
16 | - images/ethz_1
17 | - images/rres_1
18 | - images/inrae_1
19 | - images/usask_1
20 | val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1)
21 | - images/ethz_1
22 | test: # test images (optional) 1276 images
23 | - images/utokyo_1
24 | - images/utokyo_2
25 | - images/nau_1
26 | - images/uq_1
27 |
28 | # Classes
29 | names:
30 | 0: wheat_head
31 |
32 |
33 | # Download script/URL (optional) ---------------------------------------------------------------------------------------
34 | download: |
35 | from ultralytics.yolo.utils.downloads import download
36 | from pathlib import Path
37 |
38 | # Download
39 | dir = Path(yaml['path']) # dataset root dir
40 | urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip',
41 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip']
42 | download(urls, dir=dir)
43 |
44 | # Make Directories
45 | for p in 'annotations', 'images', 'labels':
46 | (dir / p).mkdir(parents=True, exist_ok=True)
47 |
48 | # Move
49 | for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \
50 | 'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1':
51 | (dir / 'global-wheat-codalab-official' / p).rename(dir / 'images' / p) # move to /images
52 | f = (dir / 'global-wheat-codalab-official' / p).with_suffix('.json') # json file
53 | if f.exists():
54 | f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations
55 |
--------------------------------------------------------------------------------
/ultralytics/datasets/SKU-110K.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail
3 | # Example usage: yolo train data=SKU-110K.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── SKU-110K ← downloads here (13.6 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/SKU-110K # dataset root dir
12 | train: train.txt # train images (relative to 'path') 8219 images
13 | val: val.txt # val images (relative to 'path') 588 images
14 | test: test.txt # test images (optional) 2936 images
15 |
16 | # Classes
17 | names:
18 | 0: object
19 |
20 |
21 | # Download script/URL (optional) ---------------------------------------------------------------------------------------
22 | download: |
23 | import shutil
24 | from pathlib import Path
25 |
26 | import numpy as np
27 | import pandas as pd
28 | from tqdm import tqdm
29 |
30 | from ultralytics.yolo.utils.downloads import download
31 | from ultralytics.yolo.utils.ops import xyxy2xywh
32 |
33 | # Download
34 | dir = Path(yaml['path']) # dataset root dir
35 | parent = Path(dir.parent) # download dir
36 | urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
37 | download(urls, dir=parent)
38 |
39 | # Rename directories
40 | if dir.exists():
41 | shutil.rmtree(dir)
42 | (parent / 'SKU110K_fixed').rename(dir) # rename dir
43 | (dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir
44 |
45 | # Convert labels
46 | names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names
47 | for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv':
48 | x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations
49 | images, unique_images = x[:, 0], np.unique(x[:, 0])
50 | with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f:
51 | f.writelines(f'./images/{s}\n' for s in unique_images)
52 | for im in tqdm(unique_images, desc=f'Converting {dir / d}'):
53 | cls = 0 # single-class dataset
54 | with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f:
55 | for r in x[images == im]:
56 | w, h = r[6], r[7] # image width, height
57 | xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
58 | f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
59 |
--------------------------------------------------------------------------------
/ultralytics/datasets/VOC.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford
3 | # Example usage: yolo train data=VOC.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── VOC ← downloads here (2.8 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/VOC
12 | train: # train images (relative to 'path') 16551 images
13 | - images/train2012
14 | - images/train2007
15 | - images/val2012
16 | - images/val2007
17 | val: # val images (relative to 'path') 4952 images
18 | - images/test2007
19 | test: # test images (optional)
20 | - images/test2007
21 |
22 | # Classes
23 | names:
24 | 0: aeroplane
25 | 1: bicycle
26 | 2: bird
27 | 3: boat
28 | 4: bottle
29 | 5: bus
30 | 6: car
31 | 7: cat
32 | 8: chair
33 | 9: cow
34 | 10: diningtable
35 | 11: dog
36 | 12: horse
37 | 13: motorbike
38 | 14: person
39 | 15: pottedplant
40 | 16: sheep
41 | 17: sofa
42 | 18: train
43 | 19: tvmonitor
44 |
45 |
46 | # Download script/URL (optional) ---------------------------------------------------------------------------------------
47 | download: |
48 | import xml.etree.ElementTree as ET
49 |
50 | from tqdm import tqdm
51 | from ultralytics.yolo.utils.downloads import download
52 | from pathlib import Path
53 |
54 | def convert_label(path, lb_path, year, image_id):
55 | def convert_box(size, box):
56 | dw, dh = 1. / size[0], 1. / size[1]
57 | x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
58 | return x * dw, y * dh, w * dw, h * dh
59 |
60 | in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml')
61 | out_file = open(lb_path, 'w')
62 | tree = ET.parse(in_file)
63 | root = tree.getroot()
64 | size = root.find('size')
65 | w = int(size.find('width').text)
66 | h = int(size.find('height').text)
67 |
68 | names = list(yaml['names'].values()) # names list
69 | for obj in root.iter('object'):
70 | cls = obj.find('name').text
71 | if cls in names and int(obj.find('difficult').text) != 1:
72 | xmlbox = obj.find('bndbox')
73 | bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
74 | cls_id = names.index(cls) # class id
75 | out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
76 |
77 |
78 | # Download
79 | dir = Path(yaml['path']) # dataset root dir
80 | url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
81 | urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
82 | f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
83 | f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
84 | download(urls, dir=dir / 'images', curl=True, threads=3)
85 |
86 | # Convert
87 | path = dir / 'images/VOCdevkit'
88 | for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
89 | imgs_path = dir / 'images' / f'{image_set}{year}'
90 | lbs_path = dir / 'labels' / f'{image_set}{year}'
91 | imgs_path.mkdir(exist_ok=True, parents=True)
92 | lbs_path.mkdir(exist_ok=True, parents=True)
93 |
94 | with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
95 | image_ids = f.read().strip().split()
96 | for id in tqdm(image_ids, desc=f'{image_set}{year}'):
97 | f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path
98 | lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path
99 | f.rename(imgs_path / f.name) # move image
100 | convert_label(path, lb_path, year, id) # convert labels to YOLO format
101 |
--------------------------------------------------------------------------------
/ultralytics/datasets/VisDrone.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University
3 | # Example usage: yolo train data=VisDrone.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── VisDrone ← downloads here (2.3 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/VisDrone # dataset root dir
12 | train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images
13 | val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images
14 | test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images
15 |
16 | # Classes
17 | names:
18 | 0: pedestrian
19 | 1: people
20 | 2: bicycle
21 | 3: car
22 | 4: van
23 | 5: truck
24 | 6: tricycle
25 | 7: awning-tricycle
26 | 8: bus
27 | 9: motor
28 |
29 |
30 | # Download script/URL (optional) ---------------------------------------------------------------------------------------
31 | download: |
32 | import os
33 | from pathlib import Path
34 |
35 | from ultralytics.yolo.utils.downloads import download
36 |
37 | def visdrone2yolo(dir):
38 | from PIL import Image
39 | from tqdm import tqdm
40 |
41 | def convert_box(size, box):
42 | # Convert VisDrone box to YOLO xywh box
43 | dw = 1. / size[0]
44 | dh = 1. / size[1]
45 | return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
46 |
47 | (dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory
48 | pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
49 | for f in pbar:
50 | img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
51 | lines = []
52 | with open(f, 'r') as file: # read annotation.txt
53 | for row in [x.split(',') for x in file.read().strip().splitlines()]:
54 | if row[4] == '0': # VisDrone 'ignored regions' class 0
55 | continue
56 | cls = int(row[5]) - 1
57 | box = convert_box(img_size, tuple(map(int, row[:4])))
58 | lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
59 | with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl:
60 | fl.writelines(lines) # write label.txt
61 |
62 |
63 | # Download
64 | dir = Path(yaml['path']) # dataset root dir
65 | urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip',
66 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip',
67 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip',
68 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip']
69 | download(urls, dir=dir, curl=True, threads=4)
70 |
71 | # Convert
72 | for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
73 | visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels
74 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco-pose.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO 2017 dataset http://cocodataset.org by Microsoft
3 | # Example usage: yolo train data=coco-pose.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco-pose ← downloads here (20.1 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco-pose # dataset root dir
12 | train: train2017.txt # train images (relative to 'path') 118287 images
13 | val: val2017.txt # val images (relative to 'path') 5000 images
14 | test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
15 |
16 | # Keypoints
17 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
18 | flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
19 |
20 | # Classes
21 | names:
22 | 0: person
23 |
24 | # Download script/URL (optional)
25 | download: |
26 | from ultralytics.yolo.utils.downloads import download
27 | from pathlib import Path
28 |
29 | # Download labels
30 | dir = Path(yaml['path']) # dataset root dir
31 | url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
32 | urls = [url + 'coco2017labels-pose.zip'] # labels
33 | download(urls, dir=dir.parent)
34 | # Download data
35 | urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
36 | 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
37 | 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
38 | download(urls, dir=dir / 'images', threads=3)
39 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO 2017 dataset http://cocodataset.org by Microsoft
3 | # Example usage: yolo train data=coco.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco ← downloads here (20.1 GB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco # dataset root dir
12 | train: train2017.txt # train images (relative to 'path') 118287 images
13 | val: val2017.txt # val images (relative to 'path') 5000 images
14 | test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
15 |
16 | # Classes
17 | names:
18 | 0: person
19 | 1: bicycle
20 | 2: car
21 | 3: motorcycle
22 | 4: airplane
23 | 5: bus
24 | 6: train
25 | 7: truck
26 | 8: boat
27 | 9: traffic light
28 | 10: fire hydrant
29 | 11: stop sign
30 | 12: parking meter
31 | 13: bench
32 | 14: bird
33 | 15: cat
34 | 16: dog
35 | 17: horse
36 | 18: sheep
37 | 19: cow
38 | 20: elephant
39 | 21: bear
40 | 22: zebra
41 | 23: giraffe
42 | 24: backpack
43 | 25: umbrella
44 | 26: handbag
45 | 27: tie
46 | 28: suitcase
47 | 29: frisbee
48 | 30: skis
49 | 31: snowboard
50 | 32: sports ball
51 | 33: kite
52 | 34: baseball bat
53 | 35: baseball glove
54 | 36: skateboard
55 | 37: surfboard
56 | 38: tennis racket
57 | 39: bottle
58 | 40: wine glass
59 | 41: cup
60 | 42: fork
61 | 43: knife
62 | 44: spoon
63 | 45: bowl
64 | 46: banana
65 | 47: apple
66 | 48: sandwich
67 | 49: orange
68 | 50: broccoli
69 | 51: carrot
70 | 52: hot dog
71 | 53: pizza
72 | 54: donut
73 | 55: cake
74 | 56: chair
75 | 57: couch
76 | 58: potted plant
77 | 59: bed
78 | 60: dining table
79 | 61: toilet
80 | 62: tv
81 | 63: laptop
82 | 64: mouse
83 | 65: remote
84 | 66: keyboard
85 | 67: cell phone
86 | 68: microwave
87 | 69: oven
88 | 70: toaster
89 | 71: sink
90 | 72: refrigerator
91 | 73: book
92 | 74: clock
93 | 75: vase
94 | 76: scissors
95 | 77: teddy bear
96 | 78: hair drier
97 | 79: toothbrush
98 |
99 |
100 | # Download script/URL (optional)
101 | download: |
102 | from ultralytics.yolo.utils.downloads import download
103 | from pathlib import Path
104 |
105 | # Download labels
106 | segments = True # segment or box labels
107 | dir = Path(yaml['path']) # dataset root dir
108 | url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
109 | urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
110 | download(urls, dir=dir.parent)
111 | # Download data
112 | urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
113 | 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
114 | 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
115 | download(urls, dir=dir / 'images', threads=3)
116 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco128-seg.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
3 | # Example usage: yolo train data=coco128.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco128-seg ← downloads here (7 MB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco128-seg # dataset root dir
12 | train: images/train2017 # train images (relative to 'path') 128 images
13 | val: images/train2017 # val images (relative to 'path') 128 images
14 | test: # test images (optional)
15 |
16 | # Classes
17 | names:
18 | 0: person
19 | 1: bicycle
20 | 2: car
21 | 3: motorcycle
22 | 4: airplane
23 | 5: bus
24 | 6: train
25 | 7: truck
26 | 8: boat
27 | 9: traffic light
28 | 10: fire hydrant
29 | 11: stop sign
30 | 12: parking meter
31 | 13: bench
32 | 14: bird
33 | 15: cat
34 | 16: dog
35 | 17: horse
36 | 18: sheep
37 | 19: cow
38 | 20: elephant
39 | 21: bear
40 | 22: zebra
41 | 23: giraffe
42 | 24: backpack
43 | 25: umbrella
44 | 26: handbag
45 | 27: tie
46 | 28: suitcase
47 | 29: frisbee
48 | 30: skis
49 | 31: snowboard
50 | 32: sports ball
51 | 33: kite
52 | 34: baseball bat
53 | 35: baseball glove
54 | 36: skateboard
55 | 37: surfboard
56 | 38: tennis racket
57 | 39: bottle
58 | 40: wine glass
59 | 41: cup
60 | 42: fork
61 | 43: knife
62 | 44: spoon
63 | 45: bowl
64 | 46: banana
65 | 47: apple
66 | 48: sandwich
67 | 49: orange
68 | 50: broccoli
69 | 51: carrot
70 | 52: hot dog
71 | 53: pizza
72 | 54: donut
73 | 55: cake
74 | 56: chair
75 | 57: couch
76 | 58: potted plant
77 | 59: bed
78 | 60: dining table
79 | 61: toilet
80 | 62: tv
81 | 63: laptop
82 | 64: mouse
83 | 65: remote
84 | 66: keyboard
85 | 67: cell phone
86 | 68: microwave
87 | 69: oven
88 | 70: toaster
89 | 71: sink
90 | 72: refrigerator
91 | 73: book
92 | 74: clock
93 | 75: vase
94 | 76: scissors
95 | 77: teddy bear
96 | 78: hair drier
97 | 79: toothbrush
98 |
99 |
100 | # Download script/URL (optional)
101 | download: https://ultralytics.com/assets/coco128-seg.zip
102 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco128.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
3 | # Example usage: yolo train data=coco128.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco128 ← downloads here (7 MB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco128 # dataset root dir
12 | train: images/train2017 # train images (relative to 'path') 128 images
13 | val: images/train2017 # val images (relative to 'path') 128 images
14 | test: # test images (optional)
15 |
16 | # Classes
17 | names:
18 | 0: person
19 | 1: bicycle
20 | 2: car
21 | 3: motorcycle
22 | 4: airplane
23 | 5: bus
24 | 6: train
25 | 7: truck
26 | 8: boat
27 | 9: traffic light
28 | 10: fire hydrant
29 | 11: stop sign
30 | 12: parking meter
31 | 13: bench
32 | 14: bird
33 | 15: cat
34 | 16: dog
35 | 17: horse
36 | 18: sheep
37 | 19: cow
38 | 20: elephant
39 | 21: bear
40 | 22: zebra
41 | 23: giraffe
42 | 24: backpack
43 | 25: umbrella
44 | 26: handbag
45 | 27: tie
46 | 28: suitcase
47 | 29: frisbee
48 | 30: skis
49 | 31: snowboard
50 | 32: sports ball
51 | 33: kite
52 | 34: baseball bat
53 | 35: baseball glove
54 | 36: skateboard
55 | 37: surfboard
56 | 38: tennis racket
57 | 39: bottle
58 | 40: wine glass
59 | 41: cup
60 | 42: fork
61 | 43: knife
62 | 44: spoon
63 | 45: bowl
64 | 46: banana
65 | 47: apple
66 | 48: sandwich
67 | 49: orange
68 | 50: broccoli
69 | 51: carrot
70 | 52: hot dog
71 | 53: pizza
72 | 54: donut
73 | 55: cake
74 | 56: chair
75 | 57: couch
76 | 58: potted plant
77 | 59: bed
78 | 60: dining table
79 | 61: toilet
80 | 62: tv
81 | 63: laptop
82 | 64: mouse
83 | 65: remote
84 | 66: keyboard
85 | 67: cell phone
86 | 68: microwave
87 | 69: oven
88 | 70: toaster
89 | 71: sink
90 | 72: refrigerator
91 | 73: book
92 | 74: clock
93 | 75: vase
94 | 76: scissors
95 | 77: teddy bear
96 | 78: hair drier
97 | 79: toothbrush
98 |
99 |
100 | # Download script/URL (optional)
101 | download: https://ultralytics.com/assets/coco128.zip
102 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco8-pose.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics
3 | # Example usage: yolo train data=coco8-pose.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco8-pose ← downloads here (1 MB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco8-pose # dataset root dir
12 | train: images/train # train images (relative to 'path') 4 images
13 | val: images/val # val images (relative to 'path') 4 images
14 | test: # test images (optional)
15 |
16 | # Keypoints
17 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
18 | flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
19 |
20 | # Classes
21 | names:
22 | 0: person
23 |
24 | # Download script/URL (optional)
25 | download: https://ultralytics.com/assets/coco8-pose.zip
26 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco8-seg.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics
3 | # Example usage: yolo train data=coco8-seg.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco8-seg ← downloads here (1 MB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco8-seg # dataset root dir
12 | train: images/train # train images (relative to 'path') 4 images
13 | val: images/val # val images (relative to 'path') 4 images
14 | test: # test images (optional)
15 |
16 | # Classes
17 | names:
18 | 0: person
19 | 1: bicycle
20 | 2: car
21 | 3: motorcycle
22 | 4: airplane
23 | 5: bus
24 | 6: train
25 | 7: truck
26 | 8: boat
27 | 9: traffic light
28 | 10: fire hydrant
29 | 11: stop sign
30 | 12: parking meter
31 | 13: bench
32 | 14: bird
33 | 15: cat
34 | 16: dog
35 | 17: horse
36 | 18: sheep
37 | 19: cow
38 | 20: elephant
39 | 21: bear
40 | 22: zebra
41 | 23: giraffe
42 | 24: backpack
43 | 25: umbrella
44 | 26: handbag
45 | 27: tie
46 | 28: suitcase
47 | 29: frisbee
48 | 30: skis
49 | 31: snowboard
50 | 32: sports ball
51 | 33: kite
52 | 34: baseball bat
53 | 35: baseball glove
54 | 36: skateboard
55 | 37: surfboard
56 | 38: tennis racket
57 | 39: bottle
58 | 40: wine glass
59 | 41: cup
60 | 42: fork
61 | 43: knife
62 | 44: spoon
63 | 45: bowl
64 | 46: banana
65 | 47: apple
66 | 48: sandwich
67 | 49: orange
68 | 50: broccoli
69 | 51: carrot
70 | 52: hot dog
71 | 53: pizza
72 | 54: donut
73 | 55: cake
74 | 56: chair
75 | 57: couch
76 | 58: potted plant
77 | 59: bed
78 | 60: dining table
79 | 61: toilet
80 | 62: tv
81 | 63: laptop
82 | 64: mouse
83 | 65: remote
84 | 66: keyboard
85 | 67: cell phone
86 | 68: microwave
87 | 69: oven
88 | 70: toaster
89 | 71: sink
90 | 72: refrigerator
91 | 73: book
92 | 74: clock
93 | 75: vase
94 | 76: scissors
95 | 77: teddy bear
96 | 78: hair drier
97 | 79: toothbrush
98 |
99 |
100 | # Download script/URL (optional)
101 | download: https://ultralytics.com/assets/coco8-seg.zip
102 |
--------------------------------------------------------------------------------
/ultralytics/datasets/coco8.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
3 | # Example usage: yolo train data=coco8.yaml
4 | # parent
5 | # ├── ultralytics
6 | # └── datasets
7 | # └── coco8 ← downloads here (1 MB)
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ../datasets/coco8 # dataset root dir
12 | train: images/train # train images (relative to 'path') 4 images
13 | val: images/val # val images (relative to 'path') 4 images
14 | test: # test images (optional)
15 |
16 | # Classes
17 | names:
18 | 0: person
19 | 1: bicycle
20 | 2: car
21 | 3: motorcycle
22 | 4: airplane
23 | 5: bus
24 | 6: train
25 | 7: truck
26 | 8: boat
27 | 9: traffic light
28 | 10: fire hydrant
29 | 11: stop sign
30 | 12: parking meter
31 | 13: bench
32 | 14: bird
33 | 15: cat
34 | 16: dog
35 | 17: horse
36 | 18: sheep
37 | 19: cow
38 | 20: elephant
39 | 21: bear
40 | 22: zebra
41 | 23: giraffe
42 | 24: backpack
43 | 25: umbrella
44 | 26: handbag
45 | 27: tie
46 | 28: suitcase
47 | 29: frisbee
48 | 30: skis
49 | 31: snowboard
50 | 32: sports ball
51 | 33: kite
52 | 34: baseball bat
53 | 35: baseball glove
54 | 36: skateboard
55 | 37: surfboard
56 | 38: tennis racket
57 | 39: bottle
58 | 40: wine glass
59 | 41: cup
60 | 42: fork
61 | 43: knife
62 | 44: spoon
63 | 45: bowl
64 | 46: banana
65 | 47: apple
66 | 48: sandwich
67 | 49: orange
68 | 50: broccoli
69 | 51: carrot
70 | 52: hot dog
71 | 53: pizza
72 | 54: donut
73 | 55: cake
74 | 56: chair
75 | 57: couch
76 | 58: potted plant
77 | 59: bed
78 | 60: dining table
79 | 61: toilet
80 | 62: tv
81 | 63: laptop
82 | 64: mouse
83 | 65: remote
84 | 66: keyboard
85 | 67: cell phone
86 | 68: microwave
87 | 69: oven
88 | 70: toaster
89 | 71: sink
90 | 72: refrigerator
91 | 73: book
92 | 74: clock
93 | 75: vase
94 | 76: scissors
95 | 77: teddy bear
96 | 78: hair drier
97 | 79: toothbrush
98 |
99 |
100 | # Download script/URL (optional)
101 | download: https://ultralytics.com/assets/coco8.zip
102 |
--------------------------------------------------------------------------------
/ultralytics/hub/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import requests
4 |
5 | from ultralytics.hub.auth import Auth
6 | from ultralytics.hub.utils import PREFIX
7 | from ultralytics.yolo.data.utils import HUBDatasetStats
8 | from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
9 |
10 |
11 | def login(api_key=''):
12 | """
13 | Log in to the Ultralytics HUB API using the provided API key.
14 |
15 | Args:
16 | api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
17 |
18 | Example:
19 | from ultralytics import hub
20 | hub.login('API_KEY')
21 | """
22 | Auth(api_key, verbose=True)
23 |
24 |
25 | def logout():
26 | """
27 | Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
28 |
29 | Example:
30 | from ultralytics import hub
31 | hub.logout()
32 | """
33 | SETTINGS['api_key'] = ''
34 | yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
35 | LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
36 |
37 |
38 | def start(key=''):
39 | """
40 | Start training models with Ultralytics HUB (DEPRECATED).
41 |
42 | Args:
43 | key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
44 | or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
45 | """
46 | api_key, model_id = key.split('_')
47 | LOGGER.warning(f"""
48 | WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
49 |
50 | from ultralytics import YOLO, hub
51 |
52 | hub.login('{api_key}')
53 | model = YOLO('https://hub.ultralytics.com/models/{model_id}')
54 | model.train()""")
55 |
56 |
57 | def reset_model(model_id=''):
58 | """Reset a trained model to an untrained state."""
59 | r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
60 | if r.status_code == 200:
61 | LOGGER.info(f'{PREFIX}Model reset successfully')
62 | return
63 | LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
64 |
65 |
66 | def export_fmts_hub():
67 | """Returns a list of HUB-supported export formats."""
68 | from ultralytics.yolo.engine.exporter import export_formats
69 | return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
70 |
71 |
72 | def export_model(model_id='', format='torchscript'):
73 | """Export a model to all formats."""
74 | assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
75 | r = requests.post(f'https://api.ultralytics.com/v1/models/{model_id}/export',
76 | json={'format': format},
77 | headers={'x-api-key': Auth().api_key})
78 | assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
79 | LOGGER.info(f'{PREFIX}{format} export started ✅')
80 |
81 |
82 | def get_export(model_id='', format='torchscript'):
83 | """Get an exported model dictionary with download URL."""
84 | assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
85 | r = requests.post('https://api.ultralytics.com/get-export',
86 | json={
87 | 'apiKey': Auth().api_key,
88 | 'modelId': model_id,
89 | 'format': format})
90 | assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
91 | return r.json()
92 |
93 |
94 | def check_dataset(path='', task='detect'):
95 | """
96 | Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
97 | uploaded to the HUB. Usage examples are given below.
98 |
99 | Args:
100 | path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
101 | task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
102 |
103 | Example:
104 | ```python
105 | from ultralytics.hub import check_dataset
106 |
107 | check_dataset('path/to/coco8.zip', task='detect') # detect dataset
108 | check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
109 | check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
110 | ```
111 | """
112 | HUBDatasetStats(path=path, task=task).get_json()
113 | LOGGER.info('Checks completed correctly ✅. Upload this dataset to https://hub.ultralytics.com/datasets/.')
114 |
115 |
116 | if __name__ == '__main__':
117 | start()
118 |
--------------------------------------------------------------------------------
/ultralytics/models/README.md:
--------------------------------------------------------------------------------
1 | ## Models
2 |
3 | Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
4 | files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
5 | and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
6 | segmentation tasks.
7 |
8 | These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
9 | instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
10 | from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
11 | directory provides a great starting point for your custom model development needs.
12 |
13 | To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
14 | selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
15 | details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
16 | to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
17 |
18 | ### Usage
19 |
20 | Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
21 |
22 | ```bash
23 | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
24 | ```
25 |
26 | They may also be used directly in a Python environment, and accepts the same
27 | [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
28 |
29 | ```python
30 | from ultralytics import YOLO
31 |
32 | model = YOLO("model.yaml") # build a YOLOv8n model from scratch
33 | # YOLO("model.pt") use pre-trained model if available
34 | model.info() # display model information
35 | model.train(data="coco128.yaml", epochs=100) # train the model
36 | ```
37 |
38 | ## Pre-trained Model Architectures
39 |
40 | Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information
41 | and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
42 |
43 | ## Contributing New Models
44 |
45 | If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
46 |
--------------------------------------------------------------------------------
/ultralytics/models/rt-detr/rtdetr-l.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | l: [1.00, 1.00, 1024]
9 |
10 | backbone:
11 | # [from, repeats, module, args]
12 | - [-1, 1, HGStem, [32, 48]] # 0-P2/4
13 | - [-1, 6, HGBlock, [48, 128, 3]] # stage 1
14 |
15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
16 | - [-1, 6, HGBlock, [96, 512, 3]] # stage 2
17 |
18 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
19 | - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
20 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
21 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
22 |
23 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
24 | - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
25 |
26 | head:
27 | - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
28 | - [-1, 1, AIFI, [1024, 8]]
29 | - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
30 |
31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
32 | - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
33 | - [[-2, -1], 1, Concat, [1]]
34 | - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
35 | - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
36 |
37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
38 | - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
39 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4
40 | - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
41 |
42 | - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
43 | - [[-1, 17], 1, Concat, [1]] # cat Y4
44 | - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
45 |
46 | - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
47 | - [[-1, 12], 1, Concat, [1]] # cat Y5
48 | - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
49 |
50 | - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
51 |
--------------------------------------------------------------------------------
/ultralytics/models/rt-detr/rtdetr-x.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | x: [1.00, 1.00, 2048]
9 |
10 | backbone:
11 | # [from, repeats, module, args]
12 | - [-1, 1, HGStem, [32, 64]] # 0-P2/4
13 | - [-1, 6, HGBlock, [64, 128, 3]] # stage 1
14 |
15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
16 | - [-1, 6, HGBlock, [128, 512, 3]]
17 | - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2
18 |
19 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16
20 | - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut
21 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
22 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
23 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
24 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3
25 |
26 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32
27 | - [-1, 6, HGBlock, [512, 2048, 5, True, False]]
28 | - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4
29 |
30 | head:
31 | - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2
32 | - [-1, 1, AIFI, [2048, 8]]
33 | - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0
34 |
35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36 | - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1
37 | - [[-2, -1], 1, Concat, [1]]
38 | - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0
39 | - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1
40 |
41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
42 | - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0
43 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4
44 | - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1
45 |
46 | - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0
47 | - [[-1, 21], 1, Concat, [1]] # cat Y4
48 | - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0
49 |
50 | - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1
51 | - [[-1, 16], 1, Concat, [1]] # cat Y5
52 | - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1
53 |
54 | - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
55 |
--------------------------------------------------------------------------------
/ultralytics/models/v3/yolov3-spp.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | depth_multiple: 1.0 # model depth multiple
7 | width_multiple: 1.0 # layer channel multiple
8 |
9 | # darknet53 backbone
10 | backbone:
11 | # [from, number, module, args]
12 | [[-1, 1, Conv, [32, 3, 1]], # 0
13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14 | [-1, 1, Bottleneck, [64]],
15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
16 | [-1, 2, Bottleneck, [128]],
17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
18 | [-1, 8, Bottleneck, [256]],
19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
20 | [-1, 8, Bottleneck, [512]],
21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
22 | [-1, 4, Bottleneck, [1024]], # 10
23 | ]
24 |
25 | # YOLOv3-SPP head
26 | head:
27 | [[-1, 1, Bottleneck, [1024, False]],
28 | [-1, 1, SPP, [512, [5, 9, 13]]],
29 | [-1, 1, Conv, [1024, 3, 1]],
30 | [-1, 1, Conv, [512, 1, 1]],
31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
32 |
33 | [-2, 1, Conv, [256, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4
36 | [-1, 1, Bottleneck, [512, False]],
37 | [-1, 1, Bottleneck, [512, False]],
38 | [-1, 1, Conv, [256, 1, 1]],
39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
40 |
41 | [-2, 1, Conv, [128, 1, 1]],
42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3
44 | [-1, 1, Bottleneck, [256, False]],
45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
46 |
47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
48 | ]
49 |
--------------------------------------------------------------------------------
/ultralytics/models/v3/yolov3-tiny.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | depth_multiple: 1.0 # model depth multiple
7 | width_multiple: 1.0 # layer channel multiple
8 |
9 | # YOLOv3-tiny backbone
10 | backbone:
11 | # [from, number, module, args]
12 | [[-1, 1, Conv, [16, 3, 1]], # 0
13 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
14 | [-1, 1, Conv, [32, 3, 1]],
15 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
16 | [-1, 1, Conv, [64, 3, 1]],
17 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
18 | [-1, 1, Conv, [128, 3, 1]],
19 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
20 | [-1, 1, Conv, [256, 3, 1]],
21 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
22 | [-1, 1, Conv, [512, 3, 1]],
23 | [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
24 | [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
25 | ]
26 |
27 | # YOLOv3-tiny head
28 | head:
29 | [[-1, 1, Conv, [1024, 3, 1]],
30 | [-1, 1, Conv, [256, 1, 1]],
31 | [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
32 |
33 | [-2, 1, Conv, [128, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4
36 | [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
37 |
38 | [[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
39 | ]
40 |
--------------------------------------------------------------------------------
/ultralytics/models/v3/yolov3.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | depth_multiple: 1.0 # model depth multiple
7 | width_multiple: 1.0 # layer channel multiple
8 |
9 | # darknet53 backbone
10 | backbone:
11 | # [from, number, module, args]
12 | [[-1, 1, Conv, [32, 3, 1]], # 0
13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14 | [-1, 1, Bottleneck, [64]],
15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
16 | [-1, 2, Bottleneck, [128]],
17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
18 | [-1, 8, Bottleneck, [256]],
19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
20 | [-1, 8, Bottleneck, [512]],
21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
22 | [-1, 4, Bottleneck, [1024]], # 10
23 | ]
24 |
25 | # YOLOv3 head
26 | head:
27 | [[-1, 1, Bottleneck, [1024, False]],
28 | [-1, 1, Conv, [512, 1, 1]],
29 | [-1, 1, Conv, [1024, 3, 1]],
30 | [-1, 1, Conv, [512, 1, 1]],
31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
32 |
33 | [-2, 1, Conv, [256, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4
36 | [-1, 1, Bottleneck, [512, False]],
37 | [-1, 1, Bottleneck, [512, False]],
38 | [-1, 1, Conv, [256, 1, 1]],
39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
40 |
41 | [-2, 1, Conv, [128, 1, 1]],
42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3
44 | [-1, 1, Bottleneck, [256, False]],
45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
46 |
47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
48 | ]
49 |
--------------------------------------------------------------------------------
/ultralytics/models/v5/yolov5-p6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 1024]
11 | l: [1.00, 1.00, 1024]
12 | x: [1.33, 1.25, 1024]
13 |
14 | # YOLOv5 v6.0 backbone
15 | backbone:
16 | # [from, number, module, args]
17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
19 | [-1, 3, C3, [128]],
20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
21 | [-1, 6, C3, [256]],
22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
23 | [-1, 9, C3, [512]],
24 | [-1, 1, Conv, [768, 3, 2]], # 7-P5/32
25 | [-1, 3, C3, [768]],
26 | [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
27 | [-1, 3, C3, [1024]],
28 | [-1, 1, SPPF, [1024, 5]], # 11
29 | ]
30 |
31 | # YOLOv5 v6.0 head
32 | head:
33 | [[-1, 1, Conv, [768, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 8], 1, Concat, [1]], # cat backbone P5
36 | [-1, 3, C3, [768, False]], # 15
37 |
38 | [-1, 1, Conv, [512, 1, 1]],
39 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
40 | [[-1, 6], 1, Concat, [1]], # cat backbone P4
41 | [-1, 3, C3, [512, False]], # 19
42 |
43 | [-1, 1, Conv, [256, 1, 1]],
44 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
45 | [[-1, 4], 1, Concat, [1]], # cat backbone P3
46 | [-1, 3, C3, [256, False]], # 23 (P3/8-small)
47 |
48 | [-1, 1, Conv, [256, 3, 2]],
49 | [[-1, 20], 1, Concat, [1]], # cat head P4
50 | [-1, 3, C3, [512, False]], # 26 (P4/16-medium)
51 |
52 | [-1, 1, Conv, [512, 3, 2]],
53 | [[-1, 16], 1, Concat, [1]], # cat head P5
54 | [-1, 3, C3, [768, False]], # 29 (P5/32-large)
55 |
56 | [-1, 1, Conv, [768, 3, 2]],
57 | [[-1, 12], 1, Concat, [1]], # cat head P6
58 | [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
59 |
60 | [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
61 | ]
62 |
--------------------------------------------------------------------------------
/ultralytics/models/v5/yolov5.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 1024]
11 | l: [1.00, 1.00, 1024]
12 | x: [1.33, 1.25, 1024]
13 |
14 | # YOLOv5 v6.0 backbone
15 | backbone:
16 | # [from, number, module, args]
17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
19 | [-1, 3, C3, [128]],
20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
21 | [-1, 6, C3, [256]],
22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
23 | [-1, 9, C3, [512]],
24 | [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
25 | [-1, 3, C3, [1024]],
26 | [-1, 1, SPPF, [1024, 5]], # 9
27 | ]
28 |
29 | # YOLOv5 v6.0 head
30 | head:
31 | [[-1, 1, Conv, [512, 1, 1]],
32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
33 | [[-1, 6], 1, Concat, [1]], # cat backbone P4
34 | [-1, 3, C3, [512, False]], # 13
35 |
36 | [-1, 1, Conv, [256, 1, 1]],
37 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38 | [[-1, 4], 1, Concat, [1]], # cat backbone P3
39 | [-1, 3, C3, [256, False]], # 17 (P3/8-small)
40 |
41 | [-1, 1, Conv, [256, 3, 2]],
42 | [[-1, 14], 1, Concat, [1]], # cat head P4
43 | [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
44 |
45 | [-1, 1, Conv, [512, 3, 2]],
46 | [[-1, 10], 1, Concat, [1]], # cat head P5
47 | [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
48 |
49 | [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
50 | ]
51 |
--------------------------------------------------------------------------------
/ultralytics/models/v6/yolov6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | activation: nn.ReLU() # (optional) model default activation function
7 | scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
8 | # [depth, width, max_channels]
9 | n: [0.33, 0.25, 1024]
10 | s: [0.33, 0.50, 1024]
11 | m: [0.67, 0.75, 768]
12 | l: [1.00, 1.00, 512]
13 | x: [1.00, 1.25, 512]
14 |
15 | # YOLOv6-3.0s backbone
16 | backbone:
17 | # [from, repeats, module, args]
18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20 | - [-1, 6, Conv, [128, 3, 1]]
21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22 | - [-1, 12, Conv, [256, 3, 1]]
23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24 | - [-1, 18, Conv, [512, 3, 1]]
25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26 | - [-1, 6, Conv, [1024, 3, 1]]
27 | - [-1, 1, SPPF, [1024, 5]] # 9
28 |
29 | # YOLOv6-3.0s head
30 | head:
31 | - [-1, 1, Conv, [256, 1, 1]]
32 | - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
33 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
34 | - [-1, 1, Conv, [256, 3, 1]]
35 | - [-1, 9, Conv, [256, 3, 1]] # 14
36 |
37 | - [-1, 1, Conv, [128, 1, 1]]
38 | - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
39 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
40 | - [-1, 1, Conv, [128, 3, 1]]
41 | - [-1, 9, Conv, [128, 3, 1]] # 19
42 |
43 | - [-1, 1, Conv, [128, 3, 2]]
44 | - [[-1, 15], 1, Concat, [1]] # cat head P4
45 | - [-1, 1, Conv, [256, 3, 1]]
46 | - [-1, 9, Conv, [256, 3, 1]] # 23
47 |
48 | - [-1, 1, Conv, [256, 3, 2]]
49 | - [[-1, 10], 1, Concat, [1]] # cat head P5
50 | - [-1, 1, Conv, [512, 3, 1]]
51 | - [-1, 9, Conv, [512, 3, 1]] # 27
52 |
53 | - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)
54 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-cls.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
3 |
4 | # Parameters
5 | nc: 1000 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 1024]
11 | l: [1.00, 1.00, 1024]
12 | x: [1.00, 1.25, 1024]
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 |
27 | # YOLOv8.0n head
28 | head:
29 | - [-1, 1, Classify, [nc]] # Classify
30 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-p2.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 768]
11 | l: [1.00, 1.00, 512]
12 | x: [1.00, 1.25, 512]
13 |
14 | # YOLOv8.0 backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0-p2 head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
39 | - [[-1, 2], 1, Concat, [1]] # cat backbone P2
40 | - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall)
41 |
42 | - [-1, 1, Conv, [128, 3, 2]]
43 | - [[-1, 15], 1, Concat, [1]] # cat head P3
44 | - [-1, 3, C2f, [256]] # 21 (P3/8-small)
45 |
46 | - [-1, 1, Conv, [256, 3, 2]]
47 | - [[-1, 12], 1, Concat, [1]] # cat head P4
48 | - [-1, 3, C2f, [512]] # 24 (P4/16-medium)
49 |
50 | - [-1, 1, Conv, [512, 3, 2]]
51 | - [[-1, 9], 1, Concat, [1]] # cat head P5
52 | - [-1, 3, C2f, [1024]] # 27 (P5/32-large)
53 |
54 | - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5)
55 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-p6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 768]
11 | l: [1.00, 1.00, 512]
12 | x: [1.00, 1.25, 512]
13 |
14 | # YOLOv8.0x6 backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [768, True]]
26 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
27 | - [-1, 3, C2f, [1024, True]]
28 | - [-1, 1, SPPF, [1024, 5]] # 11
29 |
30 | # YOLOv8.0x6 head
31 | head:
32 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
33 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5
34 | - [-1, 3, C2, [768, False]] # 14
35 |
36 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
37 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
38 | - [-1, 3, C2, [512, False]] # 17
39 |
40 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
41 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
42 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
43 |
44 | - [-1, 1, Conv, [256, 3, 2]]
45 | - [[-1, 17], 1, Concat, [1]] # cat head P4
46 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
47 |
48 | - [-1, 1, Conv, [512, 3, 2]]
49 | - [[-1, 14], 1, Concat, [1]] # cat head P5
50 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large)
51 |
52 | - [-1, 1, Conv, [768, 3, 2]]
53 | - [[-1, 11], 1, Concat, [1]] # cat head P6
54 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
55 |
56 | - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)
57 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-pose-p6.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
3 |
4 | # Parameters
5 | nc: 1 # number of classes
6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
8 | # [depth, width, max_channels]
9 | n: [0.33, 0.25, 1024]
10 | s: [0.33, 0.50, 1024]
11 | m: [0.67, 0.75, 768]
12 | l: [1.00, 1.00, 512]
13 | x: [1.00, 1.25, 512]
14 |
15 | # YOLOv8.0x6 backbone
16 | backbone:
17 | # [from, repeats, module, args]
18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20 | - [-1, 3, C2f, [128, True]]
21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22 | - [-1, 6, C2f, [256, True]]
23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24 | - [-1, 6, C2f, [512, True]]
25 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
26 | - [-1, 3, C2f, [768, True]]
27 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
28 | - [-1, 3, C2f, [1024, True]]
29 | - [-1, 1, SPPF, [1024, 5]] # 11
30 |
31 | # YOLOv8.0x6 head
32 | head:
33 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
34 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5
35 | - [-1, 3, C2, [768, False]] # 14
36 |
37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
38 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
39 | - [-1, 3, C2, [512, False]] # 17
40 |
41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
42 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
43 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
44 |
45 | - [-1, 1, Conv, [256, 3, 2]]
46 | - [[-1, 17], 1, Concat, [1]] # cat head P4
47 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
48 |
49 | - [-1, 1, Conv, [512, 3, 2]]
50 | - [[-1, 14], 1, Concat, [1]] # cat head P5
51 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large)
52 |
53 | - [-1, 1, Conv, [768, 3, 2]]
54 | - [[-1, 11], 1, Concat, [1]] # cat head P6
55 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
56 |
57 | - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6)
58 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-pose.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
3 |
4 | # Parameters
5 | nc: 1 # number of classes
6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
8 | # [depth, width, max_channels]
9 | n: [0.33, 0.25, 1024]
10 | s: [0.33, 0.50, 1024]
11 | m: [0.67, 0.75, 768]
12 | l: [1.00, 1.00, 512]
13 | x: [1.00, 1.25, 512]
14 |
15 | # YOLOv8.0n backbone
16 | backbone:
17 | # [from, repeats, module, args]
18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20 | - [-1, 3, C2f, [128, True]]
21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22 | - [-1, 6, C2f, [256, True]]
23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24 | - [-1, 6, C2f, [512, True]]
25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26 | - [-1, 3, C2f, [1024, True]]
27 | - [-1, 1, SPPF, [1024, 5]] # 9
28 |
29 | # YOLOv8.0n head
30 | head:
31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
32 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
33 | - [-1, 3, C2f, [512]] # 12
34 |
35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
38 |
39 | - [-1, 1, Conv, [256, 3, 2]]
40 | - [[-1, 12], 1, Concat, [1]] # cat head P4
41 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
42 |
43 | - [-1, 1, Conv, [512, 3, 2]]
44 | - [[-1, 9], 1, Concat, [1]] # cat head P5
45 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
46 |
47 | - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)
48 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-rtdetr.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0n head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, Conv, [256, 3, 2]]
39 | - [[-1, 12], 1, Concat, [1]] # cat head P4
40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41 |
42 | - [-1, 1, Conv, [512, 3, 2]]
43 | - [[-1, 9], 1, Concat, [1]] # cat head P5
44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45 |
46 | - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
47 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8-seg.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024]
9 | s: [0.33, 0.50, 1024]
10 | m: [0.67, 0.75, 768]
11 | l: [1.00, 1.00, 512]
12 | x: [1.00, 1.25, 512]
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0n head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, Conv, [256, 3, 2]]
39 | - [[-1, 12], 1, Concat, [1]] # cat head P4
40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41 |
42 | - [-1, 1, Conv, [512, 3, 2]]
43 | - [[-1, 9], 1, Concat, [1]] # cat head P5
44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45 |
46 | - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
47 |
--------------------------------------------------------------------------------
/ultralytics/models/v8/yolov8.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3 |
4 | # Parameters
5 | nc: 80 # number of classes
6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7 | # [depth, width, max_channels]
8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13 |
14 | # YOLOv8.0n backbone
15 | backbone:
16 | # [from, repeats, module, args]
17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19 | - [-1, 3, C2f, [128, True]]
20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21 | - [-1, 6, C2f, [256, True]]
22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23 | - [-1, 6, C2f, [512, True]]
24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25 | - [-1, 3, C2f, [1024, True]]
26 | - [-1, 1, SPPF, [1024, 5]] # 9
27 |
28 | # YOLOv8.0n head
29 | head:
30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32 | - [-1, 3, C2f, [512]] # 12
33 |
34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37 |
38 | - [-1, 1, Conv, [256, 3, 2]]
39 | - [[-1, 12], 1, Concat, [1]] # cat head P4
40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41 |
42 | - [-1, 1, Conv, [512, 3, 2]]
43 | - [[-1, 9], 1, Concat, [1]] # cat head P5
44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45 |
46 | - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
47 |
--------------------------------------------------------------------------------
/ultralytics/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
4 | attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
5 | yaml_model_load)
6 |
7 | __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
8 | 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
9 | 'BaseModel')
10 |
--------------------------------------------------------------------------------
/ultralytics/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Ultralytics modules. Visualize with:
4 |
5 | from ultralytics.nn.modules import *
6 | import torch
7 | import os
8 |
9 | x = torch.ones(1, 128, 40, 40)
10 | m = Conv(128, 128)
11 | f = f'{m._get_name()}.onnx'
12 | torch.onnx.export(m, x, f)
13 | os.system(f'onnxsim {f} {f} && open {f}')
14 | """
15 |
16 | from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
17 | HGBlock, HGStem, Proto, RepC3)
18 | from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
19 | GhostConv, LightConv, RepConv, SpatialAttention)
20 | from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
21 | from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
22 | MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
23 |
24 | __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
25 | 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
26 | 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
27 | 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
28 | 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
29 | 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
30 |
--------------------------------------------------------------------------------
/ultralytics/nn/modules/utils.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Module utils
4 | """
5 |
6 | import copy
7 | import math
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn.init import uniform_
14 |
15 | __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
16 |
17 |
18 | def _get_clones(module, n):
19 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
20 |
21 |
22 | def bias_init_with_prob(prior_prob=0.01):
23 | """initialize conv/fc bias value according to a given probability value."""
24 | return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
25 |
26 |
27 | def linear_init_(module):
28 | bound = 1 / math.sqrt(module.weight.shape[0])
29 | uniform_(module.weight, -bound, bound)
30 | if hasattr(module, 'bias') and module.bias is not None:
31 | uniform_(module.bias, -bound, bound)
32 |
33 |
34 | def inverse_sigmoid(x, eps=1e-5):
35 | x = x.clamp(min=0, max=1)
36 | x1 = x.clamp(min=eps)
37 | x2 = (1 - x).clamp(min=eps)
38 | return torch.log(x1 / x2)
39 |
40 |
41 | def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
42 | sampling_locations: torch.Tensor,
43 | attention_weights: torch.Tensor) -> torch.Tensor:
44 | """
45 | Multi-scale deformable attention.
46 | https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
47 | """
48 |
49 | bs, _, num_heads, embed_dims = value.shape
50 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
51 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
52 | sampling_grids = 2 * sampling_locations - 1
53 | sampling_value_list = []
54 | for level, (H_, W_) in enumerate(value_spatial_shapes):
55 | # bs, H_*W_, num_heads, embed_dims ->
56 | # bs, H_*W_, num_heads*embed_dims ->
57 | # bs, num_heads*embed_dims, H_*W_ ->
58 | # bs*num_heads, embed_dims, H_, W_
59 | value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
60 | # bs, num_queries, num_heads, num_points, 2 ->
61 | # bs, num_heads, num_queries, num_points, 2 ->
62 | # bs*num_heads, num_queries, num_points, 2
63 | sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
64 | # bs*num_heads, embed_dims, num_queries, num_points
65 | sampling_value_l_ = F.grid_sample(value_l_,
66 | sampling_grid_l_,
67 | mode='bilinear',
68 | padding_mode='zeros',
69 | align_corners=False)
70 | sampling_value_list.append(sampling_value_l_)
71 | # (bs, num_queries, num_heads, num_levels, num_points) ->
72 | # (bs, num_heads, num_queries, num_levels, num_points) ->
73 | # (bs, num_heads, 1, num_queries, num_levels*num_points)
74 | attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
75 | num_levels * num_points)
76 | output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
77 | bs, num_heads * embed_dims, num_queries))
78 | return output.transpose(1, 2).contiguous()
79 |
--------------------------------------------------------------------------------
/ultralytics/tracker/README.md:
--------------------------------------------------------------------------------
1 | # Tracker
2 |
3 | ## Supported Trackers
4 |
5 | - [x] ByteTracker
6 | - [x] BoT-SORT
7 |
8 | ## Usage
9 |
10 | ### python interface:
11 |
12 | You can use the Python interface to track objects using the YOLO model.
13 |
14 | ```python
15 | from ultralytics import YOLO
16 |
17 | model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
18 | model.track(
19 | source="video/streams",
20 | stream=True,
21 | tracker="botsort.yaml", # or 'bytetrack.yaml'
22 | show=True,
23 | )
24 | ```
25 |
26 | You can get the IDs of the tracked objects using the following code:
27 |
28 | ```python
29 | from ultralytics import YOLO
30 |
31 | model = YOLO("yolov8n.pt")
32 |
33 | for result in model.track(source="video.mp4"):
34 | print(
35 | result.boxes.id.cpu().numpy().astype(int)
36 | ) # this will print the IDs of the tracked objects in the frame
37 | ```
38 |
39 | If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking.
40 |
41 | ```python
42 | import cv2
43 | from ultralytics import YOLO
44 |
45 | cap = cv2.VideoCapture("video.mp4")
46 | model = YOLO("yolov8n.pt")
47 | while True:
48 | ret, frame = cap.read()
49 | if not ret:
50 | break
51 | results = model.track(frame, persist=True)
52 | boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
53 | ids = results[0].boxes.id.cpu().numpy().astype(int)
54 | for box, id in zip(boxes, ids):
55 | cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
56 | cv2.putText(
57 | frame,
58 | f"Id {id}",
59 | (box[0], box[1]),
60 | cv2.FONT_HERSHEY_SIMPLEX,
61 | 1,
62 | (0, 0, 255),
63 | 2,
64 | )
65 | cv2.imshow("frame", frame)
66 | if cv2.waitKey(1) & 0xFF == ord("q"):
67 | break
68 | ```
69 |
70 | ## Change tracker parameters
71 |
72 | You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/tracker/cfg folder.
73 |
74 | ## Command Line Interface (CLI)
75 |
76 | You can also use the command line interface to track objects using the YOLO model.
77 |
78 | ```bash
79 | yolo detect track source=... tracker=...
80 | yolo segment track source=... tracker=...
81 | yolo pose track source=... tracker=...
82 | ```
83 |
84 | By default, trackers will use the configuration in `ultralytics/tracker/cfg`.
85 | We also support using a modified tracker config file. Please refer to the tracker config files
86 | in `ultralytics/tracker/cfg`.
87 |
--------------------------------------------------------------------------------
/ultralytics/tracker/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .track import register_tracker
4 | from .trackers import BOTSORT, BYTETracker
5 |
6 | __all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
7 |
--------------------------------------------------------------------------------
/ultralytics/tracker/cfg/botsort.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT
3 |
4 | tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
5 | track_high_thresh: 0.5 # threshold for the first association
6 | track_low_thresh: 0.1 # threshold for the second association
7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8 | track_buffer: 30 # buffer to calculate the time when to remove tracks
9 | match_thresh: 0.8 # threshold for matching tracks
10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11 | # mot20: False # for tracker evaluation(not used for now)
12 |
13 | # BoT-SORT settings
14 | cmc_method: sparseOptFlow # method of global motion compensation
15 | # ReID model related thresh (not supported yet)
16 | proximity_thresh: 0.5
17 | appearance_thresh: 0.25
18 | with_reid: False
19 |
--------------------------------------------------------------------------------
/ultralytics/tracker/cfg/bytetrack.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
3 |
4 | tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
5 | track_high_thresh: 0.5 # threshold for the first association
6 | track_low_thresh: 0.1 # threshold for the second association
7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8 | track_buffer: 30 # buffer to calculate the time when to remove tracks
9 | match_thresh: 0.8 # threshold for matching tracks
10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11 | # mot20: False # for tracker evaluation(not used for now)
12 |
--------------------------------------------------------------------------------
/ultralytics/tracker/track.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from functools import partial
4 |
5 | import torch
6 |
7 | from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
8 | from ultralytics.yolo.utils.checks import check_yaml
9 |
10 | from .trackers import BOTSORT, BYTETracker
11 |
12 | TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
13 |
14 |
15 | def on_predict_start(predictor, persist=False):
16 | """
17 | Initialize trackers for object tracking during prediction.
18 |
19 | Args:
20 | predictor (object): The predictor object to initialize trackers for.
21 | persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
22 |
23 | Raises:
24 | AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
25 | """
26 | if hasattr(predictor, 'trackers') and persist:
27 | return
28 | tracker = check_yaml(predictor.args.tracker)
29 | cfg = IterableSimpleNamespace(**yaml_load(tracker))
30 | assert cfg.tracker_type in ['bytetrack', 'botsort'], \
31 | f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
32 | trackers = []
33 | for _ in range(predictor.dataset.bs):
34 | tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
35 | trackers.append(tracker)
36 | predictor.trackers = trackers
37 |
38 |
39 | def on_predict_postprocess_end(predictor):
40 | """Postprocess detected boxes and update with object tracking."""
41 | bs = predictor.dataset.bs
42 | im0s = predictor.batch[1]
43 | for i in range(bs):
44 | det = predictor.results[i].boxes.cpu().numpy()
45 | if len(det) == 0:
46 | continue
47 | tracks = predictor.trackers[i].update(det, im0s[i])
48 | if len(tracks) == 0:
49 | continue
50 | idx = tracks[:, -1].astype(int)
51 | predictor.results[i] = predictor.results[i][idx]
52 | predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
53 |
54 |
55 | def register_tracker(model, persist):
56 | """
57 | Register tracking callbacks to the model for object tracking during prediction.
58 |
59 | Args:
60 | model (object): The model object to register tracking callbacks for.
61 | persist (bool): Whether to persist the trackers if they already exist.
62 |
63 | """
64 | model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
65 | model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
66 |
--------------------------------------------------------------------------------
/ultralytics/tracker/trackers/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .bot_sort import BOTSORT
4 | from .byte_tracker import BYTETracker
5 |
6 | __all__ = 'BOTSORT', 'BYTETracker' # allow simpler import
7 |
--------------------------------------------------------------------------------
/ultralytics/tracker/trackers/basetrack.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 |
7 |
8 | class TrackState:
9 | """Enumeration of possible object tracking states."""
10 |
11 | New = 0
12 | Tracked = 1
13 | Lost = 2
14 | Removed = 3
15 |
16 |
17 | class BaseTrack:
18 | """Base class for object tracking, handling basic track attributes and operations."""
19 |
20 | _count = 0
21 |
22 | track_id = 0
23 | is_activated = False
24 | state = TrackState.New
25 |
26 | history = OrderedDict()
27 | features = []
28 | curr_feature = None
29 | score = 0
30 | start_frame = 0
31 | frame_id = 0
32 | time_since_update = 0
33 |
34 | # Multi-camera
35 | location = (np.inf, np.inf)
36 |
37 | @property
38 | def end_frame(self):
39 | """Return the last frame ID of the track."""
40 | return self.frame_id
41 |
42 | @staticmethod
43 | def next_id():
44 | """Increment and return the global track ID counter."""
45 | BaseTrack._count += 1
46 | return BaseTrack._count
47 |
48 | def activate(self, *args):
49 | """Activate the track with the provided arguments."""
50 | raise NotImplementedError
51 |
52 | def predict(self):
53 | """Predict the next state of the track."""
54 | raise NotImplementedError
55 |
56 | def update(self, *args, **kwargs):
57 | """Update the track with new observations."""
58 | raise NotImplementedError
59 |
60 | def mark_lost(self):
61 | """Mark the track as lost."""
62 | self.state = TrackState.Lost
63 |
64 | def mark_removed(self):
65 | """Mark the track as removed."""
66 | self.state = TrackState.Removed
67 |
68 | @staticmethod
69 | def reset_id():
70 | """Reset the global track ID counter."""
71 | BaseTrack._count = 0
72 |
--------------------------------------------------------------------------------
/ultralytics/tracker/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/tracker/utils/__init__.py
--------------------------------------------------------------------------------
/ultralytics/vit/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .rtdetr import RTDETR
4 | from .sam import SAM
5 |
6 | __all__ = 'RTDETR', 'SAM' # allow simpler import
7 |
--------------------------------------------------------------------------------
/ultralytics/vit/rtdetr/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import RTDETR
4 | from .predict import RTDETRPredictor
5 | from .val import RTDETRValidator
6 |
7 | __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
8 |
--------------------------------------------------------------------------------
/ultralytics/vit/rtdetr/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.data.augment import LetterBox
6 | from ultralytics.yolo.engine.predictor import BasePredictor
7 | from ultralytics.yolo.engine.results import Results
8 | from ultralytics.yolo.utils import ops
9 |
10 |
11 | class RTDETRPredictor(BasePredictor):
12 |
13 | def postprocess(self, preds, img, orig_imgs):
14 | """Postprocess predictions and returns a list of Results objects."""
15 | bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
16 | bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0)
17 | results = []
18 | for i, bbox in enumerate(bboxes): # (300, 4)
19 | bbox = ops.xywh2xyxy(bbox)
20 | score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
21 | idx = score.squeeze(-1) > self.args.conf # (300, )
22 | if self.args.classes is not None:
23 | idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
24 | pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
26 | oh, ow = orig_img.shape[:2]
27 | if not isinstance(orig_imgs, torch.Tensor):
28 | pred[..., [0, 2]] *= ow
29 | pred[..., [1, 3]] *= oh
30 | path = self.batch[0]
31 | img_path = path[i] if isinstance(path, list) else path
32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
33 | return results
34 |
35 | def pre_transform(self, im):
36 | """Pre-transform input image before inference.
37 |
38 | Args:
39 | im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
40 |
41 | Return: A list of transformed imgs.
42 | """
43 | # The size must be square(640) and scaleFilled.
44 | return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im]
45 |
--------------------------------------------------------------------------------
/ultralytics/vit/rtdetr/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from copy import copy
4 |
5 | import torch
6 |
7 | from ultralytics.nn.tasks import RTDETRDetectionModel
8 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
9 | from ultralytics.yolo.v8.detect import DetectionTrainer
10 |
11 | from .val import RTDETRDataset, RTDETRValidator
12 |
13 |
14 | class RTDETRTrainer(DetectionTrainer):
15 |
16 | def get_model(self, cfg=None, weights=None, verbose=True):
17 | """Return a YOLO detection model."""
18 | model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
19 | if weights:
20 | model.load(weights)
21 | return model
22 |
23 | def build_dataset(self, img_path, mode='val', batch=None):
24 | """Build RTDETR Dataset
25 |
26 | Args:
27 | img_path (str): Path to the folder containing images.
28 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
29 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
30 | """
31 | return RTDETRDataset(
32 | img_path=img_path,
33 | imgsz=self.args.imgsz,
34 | batch_size=batch,
35 | augment=mode == 'train', # no augmentation
36 | hyp=self.args,
37 | rect=False, # no rect
38 | cache=self.args.cache or None,
39 | prefix=colorstr(f'{mode}: '),
40 | data=self.data)
41 |
42 | def get_validator(self):
43 | """Returns a DetectionValidator for RTDETR model validation."""
44 | self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
45 | return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
46 |
47 | def preprocess_batch(self, batch):
48 | """Preprocesses a batch of images by scaling and converting to float."""
49 | batch = super().preprocess_batch(batch)
50 | bs = len(batch['img'])
51 | batch_idx = batch['batch_idx']
52 | gt_bbox, gt_class = [], []
53 | for i in range(bs):
54 | gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
55 | gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
56 | return batch
57 |
58 |
59 | def train(cfg=DEFAULT_CFG, use_python=False):
60 | """Train and optimize RTDETR model given training data and device."""
61 | model = 'rtdetr-l.yaml'
62 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
63 | device = cfg.device if cfg.device is not None else ''
64 |
65 | # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
66 | # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
67 | args = dict(model=model,
68 | data=data,
69 | device=device,
70 | imgsz=640,
71 | exist_ok=True,
72 | batch=4,
73 | deterministic=False,
74 | amp=False)
75 | trainer = RTDETRTrainer(overrides=args)
76 | trainer.train()
77 |
78 |
79 | if __name__ == '__main__':
80 | train()
81 |
--------------------------------------------------------------------------------
/ultralytics/vit/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .build import build_sam # noqa
4 | from .model import SAM # noqa
5 | from .modules.prompt_predictor import PromptPredictor # noqa
6 |
--------------------------------------------------------------------------------
/ultralytics/vit/sam/autosize.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from copy import deepcopy
10 | from typing import Tuple
11 |
12 | import numpy as np
13 | import torch
14 | from torch.nn import functional as F
15 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
16 |
17 |
18 | class ResizeLongestSide:
19 | """
20 | Resizes images to the longest side 'target_length', as well as provides
21 | methods for resizing coordinates and boxes. Provides methods for
22 | transforming both numpy array and batched torch tensors.
23 | """
24 |
25 | def __init__(self, target_length: int) -> None:
26 | self.target_length = target_length
27 |
28 | def apply_image(self, image: np.ndarray) -> np.ndarray:
29 | """
30 | Expects a numpy array with shape HxWxC in uint8 format.
31 | """
32 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
33 | return np.array(resize(to_pil_image(image), target_size))
34 |
35 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
36 | """
37 | Expects a numpy array of length 2 in the final dimension. Requires the
38 | original image size in (H, W) format.
39 | """
40 | old_h, old_w = original_size
41 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(image, target_size, mode='bilinear', align_corners=False, antialias=True)
64 |
65 | def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
66 | """
67 | Expects a torch tensor with length 2 in the last dimension. Requires the
68 | original image size in (H, W) format.
69 | """
70 | old_h, old_w = original_size
71 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
72 | coords = deepcopy(coords).to(torch.float)
73 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
74 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
75 | return coords
76 |
77 | def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
78 | """
79 | Expects a torch tensor with shape Bx4. Requires the original image
80 | size in (H, W) format.
81 | """
82 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
83 | return boxes.reshape(-1, 4)
84 |
85 | @staticmethod
86 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
87 | """
88 | Compute the output size given input size and target long side length.
89 | """
90 | scale = long_side_length * 1.0 / max(oldh, oldw)
91 | newh, neww = oldh * scale, oldw * scale
92 | neww = int(neww + 0.5)
93 | newh = int(newh + 0.5)
94 | return (newh, neww)
95 |
--------------------------------------------------------------------------------
/ultralytics/vit/sam/build.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from functools import partial
10 |
11 | import torch
12 |
13 | from ...yolo.utils.downloads import attempt_download_asset
14 | from .modules.decoders import MaskDecoder
15 | from .modules.encoders import ImageEncoderViT, PromptEncoder
16 | from .modules.sam import Sam
17 | from .modules.transformer import TwoWayTransformer
18 |
19 |
20 | def build_sam_vit_h(checkpoint=None):
21 | """Build and return a Segment Anything Model (SAM) h-size model."""
22 | return _build_sam(
23 | encoder_embed_dim=1280,
24 | encoder_depth=32,
25 | encoder_num_heads=16,
26 | encoder_global_attn_indexes=[7, 15, 23, 31],
27 | checkpoint=checkpoint,
28 | )
29 |
30 |
31 | def build_sam_vit_l(checkpoint=None):
32 | """Build and return a Segment Anything Model (SAM) l-size model."""
33 | return _build_sam(
34 | encoder_embed_dim=1024,
35 | encoder_depth=24,
36 | encoder_num_heads=16,
37 | encoder_global_attn_indexes=[5, 11, 17, 23],
38 | checkpoint=checkpoint,
39 | )
40 |
41 |
42 | def build_sam_vit_b(checkpoint=None):
43 | """Build and return a Segment Anything Model (SAM) b-size model."""
44 | return _build_sam(
45 | encoder_embed_dim=768,
46 | encoder_depth=12,
47 | encoder_num_heads=12,
48 | encoder_global_attn_indexes=[2, 5, 8, 11],
49 | checkpoint=checkpoint,
50 | )
51 |
52 |
53 | def _build_sam(
54 | encoder_embed_dim,
55 | encoder_depth,
56 | encoder_num_heads,
57 | encoder_global_attn_indexes,
58 | checkpoint=None,
59 | ):
60 | """Builds the selected SAM model architecture."""
61 | prompt_embed_dim = 256
62 | image_size = 1024
63 | vit_patch_size = 16
64 | image_embedding_size = image_size // vit_patch_size
65 | sam = Sam(
66 | image_encoder=ImageEncoderViT(
67 | depth=encoder_depth,
68 | embed_dim=encoder_embed_dim,
69 | img_size=image_size,
70 | mlp_ratio=4,
71 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
72 | num_heads=encoder_num_heads,
73 | patch_size=vit_patch_size,
74 | qkv_bias=True,
75 | use_rel_pos=True,
76 | global_attn_indexes=encoder_global_attn_indexes,
77 | window_size=14,
78 | out_chans=prompt_embed_dim,
79 | ),
80 | prompt_encoder=PromptEncoder(
81 | embed_dim=prompt_embed_dim,
82 | image_embedding_size=(image_embedding_size, image_embedding_size),
83 | input_image_size=(image_size, image_size),
84 | mask_in_chans=16,
85 | ),
86 | mask_decoder=MaskDecoder(
87 | num_multimask_outputs=3,
88 | transformer=TwoWayTransformer(
89 | depth=2,
90 | embedding_dim=prompt_embed_dim,
91 | mlp_dim=2048,
92 | num_heads=8,
93 | ),
94 | transformer_dim=prompt_embed_dim,
95 | iou_head_depth=3,
96 | iou_head_hidden_dim=256,
97 | ),
98 | pixel_mean=[123.675, 116.28, 103.53],
99 | pixel_std=[58.395, 57.12, 57.375],
100 | )
101 | sam.eval()
102 | if checkpoint is not None:
103 | attempt_download_asset(checkpoint)
104 | with open(checkpoint, 'rb') as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
109 |
110 | sam_model_map = {
111 | # "default": build_sam_vit_h,
112 | 'sam_h.pt': build_sam_vit_h,
113 | 'sam_l.pt': build_sam_vit_l,
114 | 'sam_b.pt': build_sam_vit_b, }
115 |
116 |
117 | def build_sam(ckpt='sam_b.pt'):
118 | """Build a SAM model specified by ckpt."""
119 | model_builder = None
120 | for k in sam_model_map.keys():
121 | if ckpt.endswith(k):
122 | model_builder = sam_model_map.get(k)
123 |
124 | if not model_builder:
125 | raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
126 |
127 | return model_builder(ckpt)
128 |
--------------------------------------------------------------------------------
/ultralytics/vit/sam/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | SAM model interface
4 | """
5 |
6 | from ultralytics.yolo.cfg import get_cfg
7 |
8 | from ...yolo.utils.torch_utils import model_info
9 | from .build import build_sam
10 | from .predict import Predictor
11 |
12 |
13 | class SAM:
14 |
15 | def __init__(self, model='sam_b.pt') -> None:
16 | if model and not model.endswith('.pt') and not model.endswith('.pth'):
17 | # Should raise AssertionError instead?
18 | raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
19 | self.model = build_sam(model)
20 | self.task = 'segment' # required
21 | self.predictor = None # reuse predictor
22 |
23 | def predict(self, source, stream=False, **kwargs):
24 | """Predicts and returns segmentation masks for given image or video source."""
25 | overrides = dict(conf=0.25, task='segment', mode='predict')
26 | overrides.update(kwargs) # prefer kwargs
27 | if not self.predictor:
28 | self.predictor = Predictor(overrides=overrides)
29 | self.predictor.setup_model(model=self.model)
30 | else: # only update args if predictor is already setup
31 | self.predictor.args = get_cfg(self.predictor.args, overrides)
32 | return self.predictor(source, stream=stream)
33 |
34 | def train(self, **kwargs):
35 | """Function trains models but raises an error as SAM models do not support training."""
36 | raise NotImplementedError("SAM models don't support training")
37 |
38 | def val(self, **kwargs):
39 | """Run validation given dataset."""
40 | raise NotImplementedError("SAM models don't support validation")
41 |
42 | def __call__(self, source=None, stream=False, **kwargs):
43 | """Calls the 'predict' function with given arguments to perform object detection."""
44 | return self.predict(source, stream, **kwargs)
45 |
46 | def __getattr__(self, attr):
47 | """Raises error if object has no requested attribute."""
48 | name = self.__class__.__name__
49 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
50 |
51 | def info(self, detailed=False, verbose=True):
52 | """
53 | Logs model info.
54 |
55 | Args:
56 | detailed (bool): Show detailed information about model.
57 | verbose (bool): Controls verbosity.
58 | """
59 | return model_info(self.model, detailed=detailed, verbose=verbose)
60 |
--------------------------------------------------------------------------------
/ultralytics/vit/sam/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
--------------------------------------------------------------------------------
/ultralytics/vit/sam/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from ultralytics.yolo.engine.predictor import BasePredictor
7 | from ultralytics.yolo.engine.results import Results
8 | from ultralytics.yolo.utils.torch_utils import select_device
9 |
10 | from .modules.mask_generator import SamAutomaticMaskGenerator
11 |
12 |
13 | class Predictor(BasePredictor):
14 |
15 | def preprocess(self, im):
16 | """Prepares input image for inference."""
17 | # TODO: Only support bs=1 for now
18 | # im = ResizeLongestSide(1024).apply_image(im[0])
19 | # im = torch.as_tensor(im, device=self.device)
20 | # im = im.permute(2, 0, 1).contiguous()[None, :, :, :]
21 | return im[0]
22 |
23 | def setup_model(self, model):
24 | """Set up YOLO model with specified thresholds and device."""
25 | device = select_device(self.args.device)
26 | model.eval()
27 | self.model = SamAutomaticMaskGenerator(model.to(device),
28 | pred_iou_thresh=self.args.conf,
29 | box_nms_thresh=self.args.iou)
30 | self.device = device
31 | # TODO: Temporary settings for compatibility
32 | self.model.pt = False
33 | self.model.triton = False
34 | self.model.stride = 32
35 | self.model.fp16 = False
36 | self.done_warmup = True
37 |
38 | def postprocess(self, preds, path, orig_imgs):
39 | """Postprocesses inference output predictions to create detection masks for objects."""
40 | names = dict(enumerate(list(range(len(preds)))))
41 | results = []
42 | # TODO
43 | for i, pred in enumerate([preds]):
44 | masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0))
45 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
46 | path = self.batch[0]
47 | img_path = path[i] if isinstance(path, list) else path
48 | results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks))
49 | return results
50 |
51 | # def __call__(self, source=None, model=None, stream=False):
52 | # frame = cv2.imread(source)
53 | # preds = self.model.generate(frame)
54 | # return self.postprocess(preds, source, frame)
55 |
--------------------------------------------------------------------------------
/ultralytics/vit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
--------------------------------------------------------------------------------
/ultralytics/yolo/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from . import v8
4 |
5 | __all__ = 'v8', # tuple or list
6 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .base import BaseDataset
4 | from .build import build_dataloader, build_yolo_dataset, load_inference_source
5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
6 | from .dataset_wrappers import MixAndRectDataset
7 |
8 | __all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset',
9 | 'build_yolo_dataset', 'build_dataloader', 'load_inference_source')
10 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/annotator.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from ultralytics import YOLO
4 | from ultralytics.vit.sam import PromptPredictor, build_sam
5 | from ultralytics.yolo.utils.torch_utils import select_device
6 |
7 |
8 | def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
9 | """
10 | Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
11 | Args:
12 | data (str): Path to a folder containing images to be annotated.
13 | det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
14 | sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
15 | device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
16 | output_dir (str | None | optional): Directory to save the annotated results.
17 | Defaults to a 'labels' folder in the same directory as 'data'.
18 | """
19 | device = select_device(device)
20 | det_model = YOLO(det_model)
21 | sam_model = build_sam(sam_model)
22 | det_model.to(device)
23 | sam_model.to(device)
24 |
25 | if not output_dir:
26 | output_dir = Path(str(data)).parent / 'labels'
27 | Path(output_dir).mkdir(exist_ok=True, parents=True)
28 |
29 | prompt_predictor = PromptPredictor(sam_model)
30 | det_results = det_model(data, stream=True)
31 |
32 | for result in det_results:
33 | boxes = result.boxes.xyxy # Boxes object for bbox outputs
34 | class_ids = result.boxes.cls.int().tolist() # noqa
35 | if len(class_ids):
36 | prompt_predictor.set_image(result.orig_img)
37 | masks, _, _ = prompt_predictor.predict_torch(
38 | point_coords=None,
39 | point_labels=None,
40 | boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]),
41 | multimask_output=False,
42 | )
43 |
44 | result.update(masks=masks.squeeze(1))
45 | segments = result.masks.xyn # noqa
46 |
47 | with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
48 | for i in range(len(segments)):
49 | s = segments[i]
50 | if len(s) == 0:
51 | continue
52 | segment = map(str, segments[i].reshape(-1).tolist())
53 | f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
54 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/yolo/data/dataloaders/__init__.py
--------------------------------------------------------------------------------
/ultralytics/yolo/data/dataset_wrappers.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import collections
4 | from copy import deepcopy
5 |
6 | from .augment import LetterBox
7 |
8 |
9 | class MixAndRectDataset:
10 | """
11 | A dataset class that applies mosaic and mixup transformations as well as rectangular training.
12 |
13 | Attributes:
14 | dataset: The base dataset.
15 | imgsz: The size of the images in the dataset.
16 | """
17 |
18 | def __init__(self, dataset):
19 | """
20 | Args:
21 | dataset (BaseDataset): The base dataset to apply transformations to.
22 | """
23 | self.dataset = dataset
24 | self.imgsz = dataset.imgsz
25 |
26 | def __len__(self):
27 | """Returns the number of items in the dataset."""
28 | return len(self.dataset)
29 |
30 | def __getitem__(self, index):
31 | """
32 | Applies mosaic, mixup and rectangular training transformations to an item in the dataset.
33 |
34 | Args:
35 | index (int): Index of the item in the dataset.
36 |
37 | Returns:
38 | (dict): A dictionary containing the transformed item data.
39 | """
40 | labels = deepcopy(self.dataset[index])
41 | for transform in self.dataset.transforms.tolist():
42 | # Mosaic and mixup
43 | if hasattr(transform, 'get_indexes'):
44 | indexes = transform.get_indexes(self.dataset)
45 | if not isinstance(indexes, collections.abc.Sequence):
46 | indexes = [indexes]
47 | labels['mix_labels'] = [deepcopy(self.dataset[index]) for index in indexes]
48 | if self.dataset.rect and isinstance(transform, LetterBox):
49 | transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
50 | labels = transform(labels)
51 | if 'mix_labels' in labels:
52 | labels.pop('mix_labels')
53 | return labels
54 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/scripts/download_weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Ultralytics YOLO 🚀, AGPL-3.0 license
3 | # Download latest models from https://github.com/ultralytics/assets/releases
4 | # Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh
5 | # parent
6 | # └── weights
7 | # ├── yolov8n.pt ← downloads here
8 | # ├── yolov8s.pt
9 | # └── ...
10 |
11 | python - < None:
29 | # Load or create new NAS model
30 | import super_gradients
31 |
32 | self.predictor = None
33 | suffix = Path(model).suffix
34 | if suffix == '.pt':
35 | self._load(model)
36 | elif suffix == '':
37 | self.model = super_gradients.training.models.get(model, pretrained_weights='coco')
38 | self.task = 'detect'
39 | self.model.args = DEFAULT_CFG_DICT # attach args to model
40 |
41 | # Standardize model
42 | self.model.fuse = lambda verbose=True: self.model
43 | self.model.stride = torch.tensor([32])
44 | self.model.names = dict(enumerate(self.model._class_names))
45 | self.model.is_fused = lambda: False # for info()
46 | self.model.yaml = {} # for info()
47 | self.model.pt_path = model # for export()
48 | self.model.task = 'detect' # for export()
49 | self.info()
50 |
51 | @smart_inference_mode()
52 | def _load(self, weights: str):
53 | self.model = torch.load(weights)
54 |
55 | @smart_inference_mode()
56 | def predict(self, source=None, stream=False, **kwargs):
57 | """
58 | Perform prediction using the YOLO model.
59 |
60 | Args:
61 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
62 | Accepts all source types accepted by the YOLO model.
63 | stream (bool): Whether to stream the predictions or not. Defaults to False.
64 | **kwargs : Additional keyword arguments passed to the predictor.
65 | Check the 'configuration' section in the documentation for all available options.
66 |
67 | Returns:
68 | (List[ultralytics.yolo.engine.results.Results]): The prediction results.
69 | """
70 | if source is None:
71 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
72 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
73 | overrides = dict(conf=0.25, task='detect', mode='predict')
74 | overrides.update(kwargs) # prefer kwargs
75 | if not self.predictor:
76 | self.predictor = NASPredictor(overrides=overrides)
77 | self.predictor.setup_model(model=self.model)
78 | else: # only update args if predictor is already setup
79 | self.predictor.args = get_cfg(self.predictor.args, overrides)
80 | return self.predictor(source, stream=stream)
81 |
82 | def train(self, **kwargs):
83 | """Function trains models but raises an error as NAS models do not support training."""
84 | raise NotImplementedError("NAS models don't support training")
85 |
86 | def val(self, **kwargs):
87 | """Run validation given dataset."""
88 | overrides = dict(task='detect', mode='val')
89 | overrides.update(kwargs) # prefer kwargs
90 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
91 | args.imgsz = check_imgsz(args.imgsz, max_dim=1)
92 | validator = NASValidator(args=args)
93 | validator(model=self.model)
94 | self.metrics = validator.metrics
95 | return validator.metrics
96 |
97 | @smart_inference_mode()
98 | def export(self, **kwargs):
99 | """
100 | Export model.
101 |
102 | Args:
103 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
104 | """
105 | overrides = dict(task='detect')
106 | overrides.update(kwargs)
107 | overrides['mode'] = 'export'
108 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
109 | args.task = self.task
110 | if args.imgsz == DEFAULT_CFG.imgsz:
111 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
112 | if args.batch == DEFAULT_CFG.batch:
113 | args.batch = 1 # default to 1 if not modified
114 | return Exporter(overrides=args)(model=self.model)
115 |
116 | def info(self, detailed=False, verbose=True):
117 | """
118 | Logs model info.
119 |
120 | Args:
121 | detailed (bool): Show detailed information about model.
122 | verbose (bool): Controls verbosity.
123 | """
124 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
125 |
126 | def __call__(self, source=None, stream=False, **kwargs):
127 | """Calls the 'predict' function with given arguments to perform object detection."""
128 | return self.predict(source, stream, **kwargs)
129 |
130 | def __getattr__(self, attr):
131 | """Raises error if object has no requested attribute."""
132 | name = self.__class__.__name__
133 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
134 |
--------------------------------------------------------------------------------
/ultralytics/yolo/nas/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import ops
8 | from ultralytics.yolo.utils.ops import xyxy2xywh
9 |
10 |
11 | class NASPredictor(BasePredictor):
12 |
13 | def postprocess(self, preds_in, img, orig_imgs):
14 | """Postprocesses predictions and returns a list of Results objects."""
15 |
16 | # Cat boxes and class scores
17 | boxes = xyxy2xywh(preds_in[0][0])
18 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
19 |
20 | preds = ops.non_max_suppression(preds,
21 | self.args.conf,
22 | self.args.iou,
23 | agnostic=self.args.agnostic_nms,
24 | max_det=self.args.max_det,
25 | classes=self.args.classes)
26 |
27 | results = []
28 | for i, pred in enumerate(preds):
29 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
30 | if not isinstance(orig_imgs, torch.Tensor):
31 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
32 | path = self.batch[0]
33 | img_path = path[i] if isinstance(path, list) else path
34 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
35 | return results
36 |
--------------------------------------------------------------------------------
/ultralytics/yolo/nas/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.utils import ops
6 | from ultralytics.yolo.utils.ops import xyxy2xywh
7 | from ultralytics.yolo.v8.detect import DetectionValidator
8 |
9 | __all__ = ['NASValidator']
10 |
11 |
12 | class NASValidator(DetectionValidator):
13 |
14 | def postprocess(self, preds_in):
15 | """Apply Non-maximum suppression to prediction outputs."""
16 | boxes = xyxy2xywh(preds_in[0][0])
17 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
18 | return ops.non_max_suppression(preds,
19 | self.args.conf,
20 | self.args.iou,
21 | labels=self.lb,
22 | multi_label=False,
23 | agnostic=self.args.single_cls,
24 | max_det=self.args.max_det,
25 | max_time_img=0.5)
26 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/autobatch.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
4 | """
5 |
6 | from copy import deepcopy
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr
12 | from ultralytics.yolo.utils.torch_utils import profile
13 |
14 |
15 | def check_train_batch_size(model, imgsz=640, amp=True):
16 | """
17 | Check YOLO training batch size using the autobatch() function.
18 |
19 | Args:
20 | model (torch.nn.Module): YOLO model to check batch size for.
21 | imgsz (int): Image size used for training.
22 | amp (bool): If True, use automatic mixed precision (AMP) for training.
23 |
24 | Returns:
25 | (int): Optimal batch size computed using the autobatch() function.
26 | """
27 |
28 | with torch.cuda.amp.autocast(amp):
29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
30 |
31 |
32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=DEFAULT_CFG.batch):
33 | """
34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
35 |
36 | Args:
37 | model (torch.nn.module): YOLO model to compute batch size for.
38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
41 |
42 | Returns:
43 | (int): The optimal batch size.
44 | """
45 |
46 | # Check device
47 | prefix = colorstr('AutoBatch: ')
48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
49 | device = next(model.parameters()).device # get model device
50 | if device.type == 'cpu':
51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
52 | return batch_size
53 | if torch.backends.cudnn.benchmark:
54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
55 | return batch_size
56 |
57 | # Inspect CUDA memory
58 | gb = 1 << 30 # bytes to GiB (1024 ** 3)
59 | d = str(device).upper() # 'CUDA:0'
60 | properties = torch.cuda.get_device_properties(device) # device properties
61 | t = properties.total_memory / gb # GiB total
62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved
63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated
64 | f = t - (r + a) # GiB free
65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
66 |
67 | # Profile batch sizes
68 | batch_sizes = [1, 2, 4, 8, 16]
69 | try:
70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
71 | results = profile(img, model, n=3, device=device)
72 |
73 | # Fit a solution
74 | y = [x[2] for x in results if x] # memory [2]
75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
77 | if None in results: # some sizes failed
78 | i = results.index(None) # first fail index
79 | if b >= batch_sizes[i]: # y intercept above failure point
80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point
81 | if b < 1 or b > 1024: # b outside of safe range
82 | b = batch_size
83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
84 |
85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
87 | return b
88 | except Exception as e:
89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
90 | return batch_size
91 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
4 |
5 | __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
6 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/dvc.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | import os
3 |
4 | import pkg_resources as pkg
5 |
6 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
7 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
8 |
9 | try:
10 | from importlib.metadata import version
11 |
12 | import dvclive
13 |
14 | assert not TESTS_RUNNING # do not log pytest
15 |
16 | ver = version('dvclive')
17 | if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
18 | LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
19 | dvclive = None # noqa: F811
20 | except (ImportError, AssertionError, TypeError):
21 | dvclive = None
22 |
23 | # DVCLive logger instance
24 | live = None
25 | _processed_plots = {}
26 |
27 | # `on_fit_epoch_end` is called on final validation (probably need to be fixed)
28 | # for now this is the way we distinguish final evaluation of the best model vs
29 | # last epoch validation
30 | _training_epoch = False
31 |
32 |
33 | def _logger_disabled():
34 | return os.getenv('ULTRALYTICS_DVC_DISABLED', 'false').lower() == 'true'
35 |
36 |
37 | def _log_images(image_path, prefix=''):
38 | if live:
39 | live.log_image(os.path.join(prefix, image_path.name), image_path)
40 |
41 |
42 | def _log_plots(plots, prefix=''):
43 | for name, params in plots.items():
44 | timestamp = params['timestamp']
45 | if _processed_plots.get(name) != timestamp:
46 | _log_images(name, prefix)
47 | _processed_plots[name] = timestamp
48 |
49 |
50 | def _log_confusion_matrix(validator):
51 | targets = []
52 | preds = []
53 | matrix = validator.confusion_matrix.matrix
54 | names = list(validator.names.values())
55 | if validator.confusion_matrix.task == 'detect':
56 | names += ['background']
57 |
58 | for ti, pred in enumerate(matrix.T.astype(int)):
59 | for pi, num in enumerate(pred):
60 | targets.extend([names[ti]] * num)
61 | preds.extend([names[pi]] * num)
62 |
63 | live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
64 |
65 |
66 | def on_pretrain_routine_start(trainer):
67 | try:
68 | global live
69 | if not _logger_disabled():
70 | live = dvclive.Live(save_dvc_exp=True)
71 | LOGGER.info(
72 | 'DVCLive is detected and auto logging is enabled (can be disabled with `ULTRALYTICS_DVC_DISABLED=true`).'
73 | )
74 | else:
75 | LOGGER.debug('DVCLive is detected and auto logging is disabled via `ULTRALYTICS_DVC_DISABLED`.')
76 | live = None
77 | except Exception as e:
78 | LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
79 |
80 |
81 | def on_pretrain_routine_end(trainer):
82 | _log_plots(trainer.plots, 'train')
83 |
84 |
85 | def on_train_start(trainer):
86 | if live:
87 | live.log_params(trainer.args)
88 |
89 |
90 | def on_train_epoch_start(trainer):
91 | global _training_epoch
92 | _training_epoch = True
93 |
94 |
95 | def on_fit_epoch_end(trainer):
96 | global _training_epoch
97 | if live and _training_epoch:
98 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
99 | for metric, value in all_metrics.items():
100 | live.log_metric(metric, value)
101 |
102 | if trainer.epoch == 0:
103 | for metric, value in model_info_for_loggers(trainer).items():
104 | live.log_metric(metric, value, plot=False)
105 |
106 | _log_plots(trainer.plots, 'train')
107 | _log_plots(trainer.validator.plots, 'val')
108 |
109 | live.next_step()
110 | _training_epoch = False
111 |
112 |
113 | def on_train_end(trainer):
114 | if live:
115 | # At the end log the best metrics. It runs validator on the best model internally.
116 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
117 | for metric, value in all_metrics.items():
118 | live.log_metric(metric, value, plot=False)
119 |
120 | _log_plots(trainer.plots, 'eval')
121 | _log_plots(trainer.validator.plots, 'eval')
122 | _log_confusion_matrix(trainer.validator)
123 |
124 | if trainer.best.exists():
125 | live.log_artifact(trainer.best, copy=True)
126 |
127 | live.end()
128 |
129 |
130 | callbacks = {
131 | 'on_pretrain_routine_start': on_pretrain_routine_start,
132 | 'on_pretrain_routine_end': on_pretrain_routine_end,
133 | 'on_train_start': on_train_start,
134 | 'on_train_epoch_start': on_train_epoch_start,
135 | 'on_fit_epoch_end': on_fit_epoch_end,
136 | 'on_train_end': on_train_end} if dvclive else {}
137 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/hub.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import json
4 | from time import time
5 |
6 | from ultralytics.hub.utils import PREFIX, events
7 | from ultralytics.yolo.utils import LOGGER
8 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
9 |
10 |
11 | def on_pretrain_routine_end(trainer):
12 | """Logs info before starting timer for upload rate limit."""
13 | session = getattr(trainer, 'hub_session', None)
14 | if session:
15 | # Start timer for upload rate limit
16 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
17 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
18 |
19 |
20 | def on_fit_epoch_end(trainer):
21 | """Uploads training progress metrics at the end of each epoch."""
22 | session = getattr(trainer, 'hub_session', None)
23 | if session:
24 | # Upload metrics after val end
25 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
26 | if trainer.epoch == 0:
27 | all_plots = {**all_plots, **model_info_for_loggers(trainer)}
28 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
29 | if time() - session.timers['metrics'] > session.rate_limits['metrics']:
30 | session.upload_metrics()
31 | session.timers['metrics'] = time() # reset timer
32 | session.metrics_queue = {} # reset queue
33 |
34 |
35 | def on_model_save(trainer):
36 | """Saves checkpoints to Ultralytics HUB with rate limiting."""
37 | session = getattr(trainer, 'hub_session', None)
38 | if session:
39 | # Upload checkpoints with rate limiting
40 | is_best = trainer.best_fitness == trainer.fitness
41 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
42 | LOGGER.info(f'{PREFIX}Uploading checkpoint https://hub.ultralytics.com/models/{session.model_id}')
43 | session.upload_model(trainer.epoch, trainer.last, is_best)
44 | session.timers['ckpt'] = time() # reset timer
45 |
46 |
47 | def on_train_end(trainer):
48 | """Upload final model and metrics to Ultralytics HUB at the end of training."""
49 | session = getattr(trainer, 'hub_session', None)
50 | if session:
51 | # Upload final model and metrics with exponential standoff
52 | LOGGER.info(f'{PREFIX}Syncing final model...')
53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
54 | session.alive = False # stop heartbeats
55 | LOGGER.info(f'{PREFIX}Done ✅\n'
56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
57 |
58 |
59 | def on_train_start(trainer):
60 | """Run events on train start."""
61 | events(trainer.args)
62 |
63 |
64 | def on_val_start(validator):
65 | """Runs events on validation start."""
66 | events(validator.args)
67 |
68 |
69 | def on_predict_start(predictor):
70 | """Run events on predict start."""
71 | events(predictor.args)
72 |
73 |
74 | def on_export_start(exporter):
75 | """Run events on export start."""
76 | events(exporter.args)
77 |
78 |
79 | callbacks = {
80 | 'on_pretrain_routine_end': on_pretrain_routine_end,
81 | 'on_fit_epoch_end': on_fit_epoch_end,
82 | 'on_model_save': on_model_save,
83 | 'on_train_end': on_train_end,
84 | 'on_train_start': on_train_start,
85 | 'on_val_start': on_val_start,
86 | 'on_predict_start': on_predict_start,
87 | 'on_export_start': on_export_start}
88 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/mlflow.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import os
4 | import re
5 | from pathlib import Path
6 |
7 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
8 |
9 | try:
10 | import mlflow
11 |
12 | assert not TESTS_RUNNING # do not log pytest
13 | assert hasattr(mlflow, '__version__') # verify package is not directory
14 | except (ImportError, AssertionError):
15 | mlflow = None
16 |
17 |
18 | def on_pretrain_routine_end(trainer):
19 | """Logs training parameters to MLflow."""
20 | global mlflow, run, run_id, experiment_name
21 |
22 | if os.environ.get('MLFLOW_TRACKING_URI') is None:
23 | mlflow = None
24 |
25 | if mlflow:
26 | mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
27 | mlflow.set_tracking_uri(mlflow_location)
28 |
29 | experiment_name = os.environ.get('MLFLOW_EXPERIMENT') or trainer.args.project or '/Shared/YOLOv8'
30 | experiment = mlflow.get_experiment_by_name(experiment_name)
31 | if experiment is None:
32 | mlflow.create_experiment(experiment_name)
33 | mlflow.set_experiment(experiment_name)
34 |
35 | prefix = colorstr('MLFlow: ')
36 | try:
37 | run, active_run = mlflow, mlflow.active_run()
38 | if not active_run:
39 | active_run = mlflow.start_run(experiment_id=experiment.experiment_id)
40 | run_id = active_run.info.run_id
41 | LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
42 | run.log_params(vars(trainer.model.args))
43 | except Exception as err:
44 | LOGGER.error(f'{prefix}Failing init - {repr(err)}')
45 | LOGGER.warning(f'{prefix}Continuing without Mlflow')
46 |
47 |
48 | def on_fit_epoch_end(trainer):
49 | """Logs training metrics to Mlflow."""
50 | if mlflow:
51 | metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
52 | run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
53 |
54 |
55 | def on_train_end(trainer):
56 | """Called at end of train loop to log model artifact info."""
57 | if mlflow:
58 | root_dir = Path(__file__).resolve().parents[3]
59 | run.log_artifact(trainer.last)
60 | run.log_artifact(trainer.best)
61 | run.pyfunc.log_model(artifact_path=experiment_name,
62 | code_path=[str(root_dir)],
63 | artifacts={'model_path': str(trainer.save_dir)},
64 | python_model=run.pyfunc.PythonModel())
65 |
66 |
67 | callbacks = {
68 | 'on_pretrain_routine_end': on_pretrain_routine_end,
69 | 'on_fit_epoch_end': on_fit_epoch_end,
70 | 'on_train_end': on_train_end} if mlflow else {}
71 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/neptune.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import matplotlib.image as mpimg
4 | import matplotlib.pyplot as plt
5 |
6 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
7 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
8 |
9 | try:
10 | import neptune
11 | from neptune.types import File
12 |
13 | assert not TESTS_RUNNING # do not log pytest
14 | assert hasattr(neptune, '__version__')
15 | except (ImportError, AssertionError):
16 | neptune = None
17 |
18 | run = None # NeptuneAI experiment logger instance
19 |
20 |
21 | def _log_scalars(scalars, step=0):
22 | """Log scalars to the NeptuneAI experiment logger."""
23 | if run:
24 | for k, v in scalars.items():
25 | run[k].append(value=v, step=step)
26 |
27 |
28 | def _log_images(imgs_dict, group=''):
29 | """Log scalars to the NeptuneAI experiment logger."""
30 | if run:
31 | for k, v in imgs_dict.items():
32 | run[f'{group}/{k}'].upload(File(v))
33 |
34 |
35 | def _log_plot(title, plot_path):
36 | """Log plots to the NeptuneAI experiment logger."""
37 | """
38 | Log image as plot in the plot section of NeptuneAI
39 |
40 | arguments:
41 | title (str) Title of the plot
42 | plot_path (PosixPath or str) Path to the saved image file
43 | """
44 | img = mpimg.imread(plot_path)
45 | fig = plt.figure()
46 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
47 | ax.imshow(img)
48 | run[f'Plots/{title}'].upload(fig)
49 |
50 |
51 | def on_pretrain_routine_start(trainer):
52 | """Callback function called before the training routine starts."""
53 | try:
54 | global run
55 | run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
56 | run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
57 | except Exception as e:
58 | LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
59 |
60 |
61 | def on_train_epoch_end(trainer):
62 | """Callback function called at end of each training epoch."""
63 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
64 | _log_scalars(trainer.lr, trainer.epoch + 1)
65 | if trainer.epoch == 1:
66 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
67 |
68 |
69 | def on_fit_epoch_end(trainer):
70 | """Callback function called at end of each fit (train+val) epoch."""
71 | if run and trainer.epoch == 0:
72 | run['Configuration/Model'] = model_info_for_loggers(trainer)
73 | _log_scalars(trainer.metrics, trainer.epoch + 1)
74 |
75 |
76 | def on_val_end(validator):
77 | """Callback function called at end of each validation."""
78 | if run:
79 | # Log val_labels and val_pred
80 | _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
81 |
82 |
83 | def on_train_end(trainer):
84 | """Callback function called at end of training."""
85 | if run:
86 | # Log final results, CM matrix + PR plots
87 | files = [
88 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
89 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
90 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
91 | for f in files:
92 | _log_plot(title=f.stem, plot_path=f)
93 | # Log the final model
94 | run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
95 | trainer.best)))
96 |
97 |
98 | callbacks = {
99 | 'on_pretrain_routine_start': on_pretrain_routine_start,
100 | 'on_train_epoch_end': on_train_epoch_end,
101 | 'on_fit_epoch_end': on_fit_epoch_end,
102 | 'on_val_end': on_val_end,
103 | 'on_train_end': on_train_end} if neptune else {}
104 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/raytune.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | try:
4 | import ray
5 | from ray import tune
6 | from ray.air import session
7 | except (ImportError, AssertionError):
8 | tune = None
9 |
10 |
11 | def on_fit_epoch_end(trainer):
12 | """Sends training metrics to Ray Tune at end of each epoch."""
13 | if ray.tune.is_session_enabled():
14 | metrics = trainer.metrics
15 | metrics['epoch'] = trainer.epoch
16 | session.report(metrics)
17 |
18 |
19 | callbacks = {
20 | 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
21 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/tensorboard.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
4 |
5 | try:
6 | from torch.utils.tensorboard import SummaryWriter
7 |
8 | assert not TESTS_RUNNING # do not log pytest
9 | except (ImportError, AssertionError):
10 | SummaryWriter = None
11 |
12 | writer = None # TensorBoard SummaryWriter instance
13 |
14 |
15 | def _log_scalars(scalars, step=0):
16 | """Logs scalar values to TensorBoard."""
17 | if writer:
18 | for k, v in scalars.items():
19 | writer.add_scalar(k, v, step)
20 |
21 |
22 | def on_pretrain_routine_start(trainer):
23 | """Initialize TensorBoard logging with SummaryWriter."""
24 | if SummaryWriter:
25 | try:
26 | global writer
27 | writer = SummaryWriter(str(trainer.save_dir))
28 | prefix = colorstr('TensorBoard: ')
29 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
30 | except Exception as e:
31 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
32 |
33 |
34 | def on_batch_end(trainer):
35 | """Logs scalar statistics at the end of a training batch."""
36 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
37 |
38 |
39 | def on_fit_epoch_end(trainer):
40 | """Logs epoch metrics at end of training epoch."""
41 | _log_scalars(trainer.metrics, trainer.epoch + 1)
42 |
43 |
44 | callbacks = {
45 | 'on_pretrain_routine_start': on_pretrain_routine_start,
46 | 'on_fit_epoch_end': on_fit_epoch_end,
47 | 'on_batch_end': on_batch_end}
48 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/wb.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | from ultralytics.yolo.utils import TESTS_RUNNING
3 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
4 |
5 | try:
6 | import wandb as wb
7 |
8 | assert hasattr(wb, '__version__')
9 | assert not TESTS_RUNNING # do not log pytest
10 | except (ImportError, AssertionError):
11 | wb = None
12 |
13 | _processed_plots = {}
14 |
15 |
16 | def _log_plots(plots, step):
17 | for name, params in plots.items():
18 | timestamp = params['timestamp']
19 | if _processed_plots.get(name, None) != timestamp:
20 | wb.run.log({name.stem: wb.Image(str(name))}, step=step)
21 | _processed_plots[name] = timestamp
22 |
23 |
24 | def on_pretrain_routine_start(trainer):
25 | """Initiate and start project if module is present."""
26 | wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
27 |
28 |
29 | def on_fit_epoch_end(trainer):
30 | """Logs training metrics and model information at the end of an epoch."""
31 | wb.run.log(trainer.metrics, step=trainer.epoch + 1)
32 | _log_plots(trainer.plots, step=trainer.epoch + 1)
33 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
34 | if trainer.epoch == 0:
35 | wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
36 |
37 |
38 | def on_train_epoch_end(trainer):
39 | """Log metrics and save images at the end of each training epoch."""
40 | wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
41 | wb.run.log(trainer.lr, step=trainer.epoch + 1)
42 | if trainer.epoch == 1:
43 | _log_plots(trainer.plots, step=trainer.epoch + 1)
44 |
45 |
46 | def on_train_end(trainer):
47 | """Save the best model as an artifact at end of training."""
48 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
49 | _log_plots(trainer.plots, step=trainer.epoch + 1)
50 | art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
51 | if trainer.best.exists():
52 | art.add_file(trainer.best)
53 | wb.run.log_artifact(art)
54 |
55 |
56 | callbacks = {
57 | 'on_pretrain_routine_start': on_pretrain_routine_start,
58 | 'on_train_epoch_end': on_train_epoch_end,
59 | 'on_fit_epoch_end': on_fit_epoch_end,
60 | 'on_train_end': on_train_end} if wb else {}
61 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import os
4 | import re
5 | import shutil
6 | import socket
7 | import sys
8 | import tempfile
9 | from pathlib import Path
10 |
11 | from . import USER_CONFIG_DIR
12 | from .torch_utils import TORCH_1_9
13 |
14 |
15 | def find_free_network_port() -> int:
16 | """Finds a free port on localhost.
17 |
18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the
19 | `MASTER_PORT` environment variable.
20 | """
21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
22 | s.bind(('127.0.0.1', 0))
23 | return s.getsockname()[1] # port
24 |
25 |
26 | def generate_ddp_file(trainer):
27 | """Generates a DDP file and returns its file name."""
28 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
29 |
30 | content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
31 | from {module} import {name}
32 | from ultralytics.yolo.utils import DEFAULT_CFG_DICT
33 |
34 | cfg = DEFAULT_CFG_DICT.copy()
35 | cfg.update(save_dir='') # handle the extra key 'save_dir'
36 | trainer = {name}(cfg=cfg, overrides=overrides)
37 | trainer.train()'''
38 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
39 | with tempfile.NamedTemporaryFile(prefix='_temp_',
40 | suffix=f'{id(trainer)}.py',
41 | mode='w+',
42 | encoding='utf-8',
43 | dir=USER_CONFIG_DIR / 'DDP',
44 | delete=False) as file:
45 | file.write(content)
46 | return file.name
47 |
48 |
49 | def generate_ddp_command(world_size, trainer):
50 | """Generates and returns command for distributed training."""
51 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
52 | if not trainer.resume:
53 | shutil.rmtree(trainer.save_dir) # remove the save_dir
54 | file = str(Path(sys.argv[0]).resolve())
55 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
56 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
57 | file = generate_ddp_file(trainer)
58 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
59 | port = find_free_network_port()
60 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
61 | return cmd, file
62 |
63 |
64 | def ddp_cleanup(trainer, file):
65 | """Delete temp file if created."""
66 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file
67 | os.remove(file)
68 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/errors.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.yolo.utils import emojis
4 |
5 |
6 | class HUBModelError(Exception):
7 |
8 | def __init__(self, message='Model not found. Please check model URL and try again.'):
9 | """Create an exception for when a model is not found."""
10 | super().__init__(emojis(message))
11 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/files.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import contextlib
4 | import glob
5 | import os
6 | import shutil
7 | from datetime import datetime
8 | from pathlib import Path
9 |
10 |
11 | class WorkingDirectory(contextlib.ContextDecorator):
12 | """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
13 |
14 | def __init__(self, new_dir):
15 | """Sets the working directory to 'new_dir' upon instantiation."""
16 | self.dir = new_dir # new dir
17 | self.cwd = Path.cwd().resolve() # current dir
18 |
19 | def __enter__(self):
20 | """Changes the current directory to the specified directory."""
21 | os.chdir(self.dir)
22 |
23 | def __exit__(self, exc_type, exc_val, exc_tb):
24 | """Restore the current working directory on context exit."""
25 | os.chdir(self.cwd)
26 |
27 |
28 | def increment_path(path, exist_ok=False, sep='', mkdir=False):
29 | """
30 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
31 |
32 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
33 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
34 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
35 | directory if it does not already exist.
36 |
37 | Args:
38 | path (str, pathlib.Path): Path to increment.
39 | exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
40 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
41 | mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
42 |
43 | Returns:
44 | (pathlib.Path): Incremented path.
45 | """
46 | path = Path(path) # os-agnostic
47 | if path.exists() and not exist_ok:
48 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
49 |
50 | # Method 1
51 | for n in range(2, 9999):
52 | p = f'{path}{sep}{n}{suffix}' # increment path
53 | if not os.path.exists(p): #
54 | break
55 | path = Path(p)
56 |
57 | if mkdir:
58 | path.mkdir(parents=True, exist_ok=True) # make directory
59 |
60 | return path
61 |
62 |
63 | def file_age(path=__file__):
64 | """Return days since last file update."""
65 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
66 | return dt.days # + dt.seconds / 86400 # fractional days
67 |
68 |
69 | def file_date(path=__file__):
70 | """Return human-readable file modification date, i.e. '2021-3-26'."""
71 | t = datetime.fromtimestamp(Path(path).stat().st_mtime)
72 | return f'{t.year}-{t.month}-{t.day}'
73 |
74 |
75 | def file_size(path):
76 | """Return file/dir size (MB)."""
77 | if isinstance(path, (str, Path)):
78 | mb = 1 << 20 # bytes to MiB (1024 ** 2)
79 | path = Path(path)
80 | if path.is_file():
81 | return path.stat().st_size / mb
82 | elif path.is_dir():
83 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
84 | return 0.0
85 |
86 |
87 | def get_latest_run(search_dir='.'):
88 | """Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
89 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
90 | return max(last_list, key=os.path.getctime) if last_list else ''
91 |
92 |
93 | def make_dirs(dir='new_dir/'):
94 | # Create folders
95 | dir = Path(dir)
96 | if dir.exists():
97 | shutil.rmtree(dir) # delete dir
98 | for p in dir, dir / 'labels', dir / 'images':
99 | p.mkdir(parents=True, exist_ok=True) # make dir
100 | return dir
101 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/patches.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | Monkey patches to update/extend functionality of existing functions
4 | """
5 |
6 | from pathlib import Path
7 |
8 | import cv2
9 | import numpy as np
10 | import torch
11 |
12 | # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
13 | _imshow = cv2.imshow # copy to avoid recursion errors
14 |
15 |
16 | def imread(filename, flags=cv2.IMREAD_COLOR):
17 | return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
18 |
19 |
20 | def imwrite(filename, img):
21 | try:
22 | cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
23 | return True
24 | except Exception:
25 | return False
26 |
27 |
28 | def imshow(path, im):
29 | _imshow(path.encode('unicode_escape').decode(), im)
30 |
31 |
32 | # PyTorch functions ----------------------------------------------------------------------------------------------------
33 | _torch_save = torch.save # copy to avoid recursion errors
34 |
35 |
36 | def torch_save(*args, **kwargs):
37 | # Use dill (if exists) to serialize the lambda functions where pickle does not do this
38 | try:
39 | import dill as pickle
40 | except ImportError:
41 | import pickle
42 |
43 | if 'pickle_module' not in kwargs:
44 | kwargs['pickle_module'] = pickle
45 | return _torch_save(*args, **kwargs)
46 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/tuner.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.yolo.utils import LOGGER
4 |
5 | try:
6 | from ray import tune
7 | from ray.air import RunConfig, session # noqa
8 | from ray.air.integrations.wandb import WandbLoggerCallback # noqa
9 | from ray.tune.schedulers import ASHAScheduler # noqa
10 | from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa
11 |
12 | except ImportError:
13 | LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`")
14 | tune = None
15 |
16 | default_space = {
17 | # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
18 | 'lr0': tune.uniform(1e-5, 1e-1),
19 | 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
20 | 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
21 | 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
22 | 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
23 | 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
24 | 'box': tune.uniform(0.02, 0.2), # box loss gain
25 | 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
26 | 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
27 | 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
28 | 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
29 | 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
30 | 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
31 | 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
32 | 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
33 | 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
34 | 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
35 | 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
36 | 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
37 | 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
38 | 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
39 |
40 | task_metric_map = {
41 | 'detect': 'metrics/mAP50-95(B)',
42 | 'segment': 'metrics/mAP50-95(M)',
43 | 'classify': 'metrics/accuracy_top1',
44 | 'pose': 'metrics/mAP50-95(P)'}
45 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.yolo.v8 import classify, detect, pose, segment
4 |
5 | __all__ = 'classify', 'segment', 'detect', 'pose'
6 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict
4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
6 |
7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
8 |
9 |
10 | class ClassificationPredictor(BasePredictor):
11 |
12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
13 | super().__init__(cfg, overrides, _callbacks)
14 | self.args.task = 'classify'
15 |
16 | def preprocess(self, img):
17 | """Converts input image to model-compatible data type."""
18 | if not isinstance(img, torch.Tensor):
19 | img = torch.stack([self.transforms(im) for im in img], dim=0)
20 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
21 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
22 |
23 | def postprocess(self, preds, img, orig_imgs):
24 | """Postprocesses predictions to return Results objects."""
25 | results = []
26 | for i, pred in enumerate(preds):
27 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
28 | path = self.batch[0]
29 | img_path = path[i] if isinstance(path, list) else path
30 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
31 |
32 | return results
33 |
34 |
35 | def predict(cfg=DEFAULT_CFG, use_python=False):
36 | """Run YOLO model predictions on input images/videos."""
37 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
38 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
39 | else 'https://ultralytics.com/images/bus.jpg'
40 |
41 | args = dict(model=model, source=source)
42 | if use_python:
43 | from ultralytics import YOLO
44 | YOLO(model)(**args)
45 | else:
46 | predictor = ClassificationPredictor(overrides=args)
47 | predictor.predict_cli()
48 |
49 |
50 | if __name__ == '__main__':
51 | predict()
52 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.data import ClassificationDataset, build_dataloader
6 | from ultralytics.yolo.engine.validator import BaseValidator
7 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
8 | from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
9 | from ultralytics.yolo.utils.plotting import plot_images
10 |
11 |
12 | class ClassificationValidator(BaseValidator):
13 |
14 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
15 | """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
16 | super().__init__(dataloader, save_dir, pbar, args, _callbacks)
17 | self.args.task = 'classify'
18 | self.metrics = ClassifyMetrics()
19 |
20 | def get_desc(self):
21 | """Returns a formatted string summarizing classification metrics."""
22 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
23 |
24 | def init_metrics(self, model):
25 | """Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
26 | self.names = model.names
27 | self.nc = len(model.names)
28 | self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
29 | self.pred = []
30 | self.targets = []
31 |
32 | def preprocess(self, batch):
33 | """Preprocesses input batch and returns it."""
34 | batch['img'] = batch['img'].to(self.device, non_blocking=True)
35 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
36 | batch['cls'] = batch['cls'].to(self.device)
37 | return batch
38 |
39 | def update_metrics(self, preds, batch):
40 | """Updates running metrics with model predictions and batch targets."""
41 | n5 = min(len(self.model.names), 5)
42 | self.pred.append(preds.argsort(1, descending=True)[:, :n5])
43 | self.targets.append(batch['cls'])
44 |
45 | def finalize_metrics(self, *args, **kwargs):
46 | """Finalizes metrics of the model such as confusion_matrix and speed."""
47 | self.confusion_matrix.process_cls_preds(self.pred, self.targets)
48 | if self.args.plots:
49 | for normalize in True, False:
50 | self.confusion_matrix.plot(save_dir=self.save_dir,
51 | names=self.names.values(),
52 | normalize=normalize,
53 | on_plot=self.on_plot)
54 | self.metrics.speed = self.speed
55 | self.metrics.confusion_matrix = self.confusion_matrix
56 |
57 | def get_stats(self):
58 | """Returns a dictionary of metrics obtained by processing targets and predictions."""
59 | self.metrics.process(self.targets, self.pred)
60 | return self.metrics.results_dict
61 |
62 | def build_dataset(self, img_path):
63 | return ClassificationDataset(root=img_path, args=self.args, augment=False)
64 |
65 | def get_dataloader(self, dataset_path, batch_size):
66 | """Builds and returns a data loader for classification tasks with given parameters."""
67 | dataset = self.build_dataset(dataset_path)
68 | return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
69 |
70 | def print_results(self):
71 | """Prints evaluation metrics for YOLO object detection model."""
72 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
73 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
74 |
75 | def plot_val_samples(self, batch, ni):
76 | """Plot validation image samples."""
77 | plot_images(images=batch['img'],
78 | batch_idx=torch.arange(len(batch['img'])),
79 | cls=batch['cls'].squeeze(-1),
80 | fname=self.save_dir / f'val_batch{ni}_labels.jpg',
81 | names=self.names,
82 | on_plot=self.on_plot)
83 |
84 | def plot_predictions(self, batch, preds, ni):
85 | """Plots predicted bounding boxes on input images and saves the result."""
86 | plot_images(batch['img'],
87 | batch_idx=torch.arange(len(batch['img'])),
88 | cls=torch.argmax(preds, dim=1),
89 | fname=self.save_dir / f'val_batch{ni}_pred.jpg',
90 | names=self.names,
91 | on_plot=self.on_plot) # pred
92 |
93 |
94 | def val(cfg=DEFAULT_CFG, use_python=False):
95 | """Validate YOLO model using custom data."""
96 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
97 | data = cfg.data or 'mnist160'
98 |
99 | args = dict(model=model, data=data)
100 | if use_python:
101 | from ultralytics import YOLO
102 | YOLO(model).val(**args)
103 | else:
104 | validator = ClassificationValidator(args=args)
105 | validator(model=args['model'])
106 |
107 |
108 | if __name__ == '__main__':
109 | val()
110 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/detect/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .predict import DetectionPredictor, predict
4 | from .train import DetectionTrainer, train
5 | from .val import DetectionValidator, val
6 |
7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/detect/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
8 |
9 |
10 | class DetectionPredictor(BasePredictor):
11 |
12 | def postprocess(self, preds, img, orig_imgs):
13 | """Postprocesses predictions and returns a list of Results objects."""
14 | preds = ops.non_max_suppression(preds,
15 | self.args.conf,
16 | self.args.iou,
17 | agnostic=self.args.agnostic_nms,
18 | max_det=self.args.max_det,
19 | classes=self.args.classes)
20 |
21 | results = []
22 | for i, pred in enumerate(preds):
23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
24 | if not isinstance(orig_imgs, torch.Tensor):
25 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
26 | path = self.batch[0]
27 | img_path = path[i] if isinstance(path, list) else path
28 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
29 | return results
30 |
31 |
32 | def predict(cfg=DEFAULT_CFG, use_python=False):
33 | """Runs YOLO model inference on input image(s)."""
34 | model = cfg.model or 'yolov8n.pt'
35 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
36 | else 'https://ultralytics.com/images/bus.jpg'
37 |
38 | args = dict(model=model, source=source)
39 | if use_python:
40 | from ultralytics import YOLO
41 | YOLO(model)(**args)
42 | else:
43 | predictor = DetectionPredictor(overrides=args)
44 | predictor.predict_cli()
45 |
46 |
47 | if __name__ == '__main__':
48 | predict()
49 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/pose/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .predict import PosePredictor, predict
4 | from .train import PoseTrainer, train
5 | from .val import PoseValidator, val
6 |
7 | __all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/pose/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from ultralytics.yolo.engine.results import Results
4 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
6 |
7 |
8 | class PosePredictor(DetectionPredictor):
9 |
10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11 | super().__init__(cfg, overrides, _callbacks)
12 | self.args.task = 'pose'
13 |
14 | def postprocess(self, preds, img, orig_imgs):
15 | """Return detection results for a given input image or list of images."""
16 | preds = ops.non_max_suppression(preds,
17 | self.args.conf,
18 | self.args.iou,
19 | agnostic=self.args.agnostic_nms,
20 | max_det=self.args.max_det,
21 | classes=self.args.classes,
22 | nc=len(self.model.names))
23 |
24 | results = []
25 | for i, pred in enumerate(preds):
26 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
27 | shape = orig_img.shape
28 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
29 | pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
30 | pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape)
31 | path = self.batch[0]
32 | img_path = path[i] if isinstance(path, list) else path
33 | results.append(
34 | Results(orig_img=orig_img,
35 | path=img_path,
36 | names=self.model.names,
37 | boxes=pred[:, :6],
38 | keypoints=pred_kpts))
39 | return results
40 |
41 |
42 | def predict(cfg=DEFAULT_CFG, use_python=False):
43 | """Runs YOLO to predict objects in an image or video."""
44 | model = cfg.model or 'yolov8n-pose.pt'
45 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
46 | else 'https://ultralytics.com/images/bus.jpg'
47 |
48 | args = dict(model=model, source=source)
49 | if use_python:
50 | from ultralytics import YOLO
51 | YOLO(model)(**args)
52 | else:
53 | predictor = PosePredictor(overrides=args)
54 | predictor.predict_cli()
55 |
56 |
57 | if __name__ == '__main__':
58 | predict()
59 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/pose/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from copy import copy
4 |
5 | from ultralytics.nn.tasks import PoseModel
6 | from ultralytics.yolo import v8
7 | from ultralytics.yolo.utils import DEFAULT_CFG
8 | from ultralytics.yolo.utils.plotting import plot_images, plot_results
9 |
10 |
11 | # BaseTrainer python usage
12 | class PoseTrainer(v8.detect.DetectionTrainer):
13 |
14 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
15 | """Initialize a PoseTrainer object with specified configurations and overrides."""
16 | if overrides is None:
17 | overrides = {}
18 | overrides['task'] = 'pose'
19 | super().__init__(cfg, overrides, _callbacks)
20 |
21 | def get_model(self, cfg=None, weights=None, verbose=True):
22 | """Get pose estimation model with specified configuration and weights."""
23 | model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
24 | if weights:
25 | model.load(weights)
26 |
27 | return model
28 |
29 | def set_model_attributes(self):
30 | """Sets keypoints shape attribute of PoseModel."""
31 | super().set_model_attributes()
32 | self.model.kpt_shape = self.data['kpt_shape']
33 |
34 | def get_validator(self):
35 | """Returns an instance of the PoseValidator class for validation."""
36 | self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
37 | return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
38 |
39 | def plot_training_samples(self, batch, ni):
40 | """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
41 | images = batch['img']
42 | kpts = batch['keypoints']
43 | cls = batch['cls'].squeeze(-1)
44 | bboxes = batch['bboxes']
45 | paths = batch['im_file']
46 | batch_idx = batch['batch_idx']
47 | plot_images(images,
48 | batch_idx,
49 | cls,
50 | bboxes,
51 | kpts=kpts,
52 | paths=paths,
53 | fname=self.save_dir / f'train_batch{ni}.jpg',
54 | on_plot=self.on_plot)
55 |
56 | def plot_metrics(self):
57 | """Plots training/val metrics."""
58 | plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
59 |
60 |
61 | def train(cfg=DEFAULT_CFG, use_python=False):
62 | """Train the YOLO model on the given data and device."""
63 | model = cfg.model or 'yolov8n-pose.yaml'
64 | data = cfg.data or 'coco8-pose.yaml'
65 | device = cfg.device if cfg.device is not None else ''
66 |
67 | args = dict(model=model, data=data, device=device)
68 | if use_python:
69 | from ultralytics import YOLO
70 | YOLO(model).train(**args)
71 | else:
72 | trainer = PoseTrainer(overrides=args)
73 | trainer.train()
74 |
75 |
76 | if __name__ == '__main__':
77 | train()
78 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/segment/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .predict import SegmentationPredictor, predict
4 | from .train import SegmentationTrainer, train
5 | from .val import SegmentationValidator, val
6 |
7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/segment/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.results import Results
6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
7 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
8 |
9 |
10 | class SegmentationPredictor(DetectionPredictor):
11 |
12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
13 | super().__init__(cfg, overrides, _callbacks)
14 | self.args.task = 'segment'
15 |
16 | def postprocess(self, preds, img, orig_imgs):
17 | """TODO: filter by classes."""
18 | p = ops.non_max_suppression(preds[0],
19 | self.args.conf,
20 | self.args.iou,
21 | agnostic=self.args.agnostic_nms,
22 | max_det=self.args.max_det,
23 | nc=len(self.model.names),
24 | classes=self.args.classes)
25 | results = []
26 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
27 | for i, pred in enumerate(p):
28 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
29 | path = self.batch[0]
30 | img_path = path[i] if isinstance(path, list) else path
31 | if not len(pred): # save empty boxes
32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
33 | continue
34 | if self.args.retina_masks:
35 | if not isinstance(orig_imgs, torch.Tensor):
36 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
37 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
38 | else:
39 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
40 | if not isinstance(orig_imgs, torch.Tensor):
41 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
42 | results.append(
43 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
44 | return results
45 |
46 |
47 | def predict(cfg=DEFAULT_CFG, use_python=False):
48 | """Runs YOLO object detection on an image or video source."""
49 | model = cfg.model or 'yolov8n-seg.pt'
50 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
51 | else 'https://ultralytics.com/images/bus.jpg'
52 |
53 | args = dict(model=model, source=source)
54 | if use_python:
55 | from ultralytics import YOLO
56 | YOLO(model)(**args)
57 | else:
58 | predictor = SegmentationPredictor(overrides=args)
59 | predictor.predict_cli()
60 |
61 |
62 | if __name__ == '__main__':
63 | predict()
64 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/segment/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | from copy import copy
3 |
4 | from ultralytics.nn.tasks import SegmentationModel
5 | from ultralytics.yolo import v8
6 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK
7 | from ultralytics.yolo.utils.plotting import plot_images, plot_results
8 |
9 |
10 | # BaseTrainer python usage
11 | class SegmentationTrainer(v8.detect.DetectionTrainer):
12 |
13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
14 | """Initialize a SegmentationTrainer object with given arguments."""
15 | if overrides is None:
16 | overrides = {}
17 | overrides['task'] = 'segment'
18 | super().__init__(cfg, overrides, _callbacks)
19 |
20 | def get_model(self, cfg=None, weights=None, verbose=True):
21 | """Return SegmentationModel initialized with specified config and weights."""
22 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
23 | if weights:
24 | model.load(weights)
25 |
26 | return model
27 |
28 | def get_validator(self):
29 | """Return an instance of SegmentationValidator for validation of YOLO model."""
30 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
31 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
32 |
33 | def plot_training_samples(self, batch, ni):
34 | """Creates a plot of training sample images with labels and box coordinates."""
35 | plot_images(batch['img'],
36 | batch['batch_idx'],
37 | batch['cls'].squeeze(-1),
38 | batch['bboxes'],
39 | batch['masks'],
40 | paths=batch['im_file'],
41 | fname=self.save_dir / f'train_batch{ni}.jpg',
42 | on_plot=self.on_plot)
43 |
44 | def plot_metrics(self):
45 | """Plots training/val metrics."""
46 | plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
47 |
48 |
49 | def train(cfg=DEFAULT_CFG, use_python=False):
50 | """Train a YOLO segmentation model based on passed arguments."""
51 | model = cfg.model or 'yolov8n-seg.pt'
52 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
53 | device = cfg.device if cfg.device is not None else ''
54 |
55 | args = dict(model=model, data=data, device=device)
56 | if use_python:
57 | from ultralytics import YOLO
58 | YOLO(model).train(**args)
59 | else:
60 | trainer = SegmentationTrainer(overrides=args)
61 | trainer.train()
62 |
63 |
64 | if __name__ == '__main__':
65 | train()
66 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/utils/__init__.py
--------------------------------------------------------------------------------