├── .gitignore ├── Inference.py ├── LICENSE ├── MORE_USAGES.md ├── README.md ├── app_gradio.py ├── assets ├── Overview.png ├── anomaly.png ├── building.png ├── dog_clip.png ├── eightpic.pdf ├── eightpic.png ├── head_fig.png ├── hf_everything_mode.png ├── hf_points_mode.png ├── logo.png ├── more_usages │ ├── box_prompt.png │ ├── draw_edge.png │ ├── everything_mode.png │ ├── everything_mode_without_retina.png │ ├── more_points.png │ └── text_prompt_cat.png ├── replicate-1.png ├── replicate-2.png ├── replicate-3.png └── salient.png ├── cog.yaml ├── examples ├── dogs.jpg ├── sa_10039.jpg ├── sa_11025.jpg ├── sa_1309.jpg ├── sa_192.jpg ├── sa_414.jpg ├── sa_561.jpg ├── sa_862.jpg └── sa_8776.jpg ├── fastsam ├── __init__.py ├── decoder.py ├── model.py ├── predict.py ├── prompt.py └── utils.py ├── images ├── cat.jpg └── dogs.jpg ├── output ├── cat.jpg └── dogs.jpg ├── predict.py ├── requirements.txt ├── segpredict.py ├── setup.py ├── ultralytics ├── .pre-commit-config.yaml ├── __init__.py ├── assets │ ├── bus.jpg │ └── zidane.jpg ├── datasets │ ├── Argoverse.yaml │ ├── GlobalWheat2020.yaml │ ├── ImageNet.yaml │ ├── Objects365.yaml │ ├── SKU-110K.yaml │ ├── VOC.yaml │ ├── VisDrone.yaml │ ├── coco-pose.yaml │ ├── coco.yaml │ ├── coco128-seg.yaml │ ├── coco128.yaml │ ├── coco8-pose.yaml │ ├── coco8-seg.yaml │ ├── coco8.yaml │ └── xView.yaml ├── hub │ ├── __init__.py │ ├── auth.py │ ├── session.py │ └── utils.py ├── models │ ├── README.md │ ├── rt-detr │ │ ├── rtdetr-l.yaml │ │ └── rtdetr-x.yaml │ ├── v3 │ │ ├── yolov3-spp.yaml │ │ ├── yolov3-tiny.yaml │ │ └── yolov3.yaml │ ├── v5 │ │ ├── yolov5-p6.yaml │ │ └── yolov5.yaml │ ├── v6 │ │ └── yolov6.yaml │ └── v8 │ │ ├── yolov8-cls.yaml │ │ ├── yolov8-p2.yaml │ │ ├── yolov8-p6.yaml │ │ ├── yolov8-pose-p6.yaml │ │ ├── yolov8-pose.yaml │ │ ├── yolov8-rtdetr.yaml │ │ ├── yolov8-seg.yaml │ │ └── yolov8.yaml ├── nn │ ├── __init__.py │ ├── autobackend.py │ ├── autoshape.py │ ├── modules │ │ ├── __init__.py │ │ ├── block.py │ │ ├── conv.py │ │ ├── head.py │ │ ├── transformer.py │ │ └── utils.py │ └── tasks.py ├── tracker │ ├── README.md │ ├── __init__.py │ ├── cfg │ │ ├── botsort.yaml │ │ └── bytetrack.yaml │ ├── track.py │ ├── trackers │ │ ├── __init__.py │ │ ├── basetrack.py │ │ ├── bot_sort.py │ │ └── byte_tracker.py │ └── utils │ │ ├── __init__.py │ │ ├── gmc.py │ │ ├── kalman_filter.py │ │ └── matching.py ├── vit │ ├── __init__.py │ ├── rtdetr │ │ ├── __init__.py │ │ ├── model.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py │ ├── sam │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── autosize.py │ │ ├── build.py │ │ ├── model.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── decoders.py │ │ │ ├── encoders.py │ │ │ ├── mask_generator.py │ │ │ ├── prompt_predictor.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ └── predict.py │ └── utils │ │ ├── __init__.py │ │ ├── loss.py │ │ └── ops.py └── yolo │ ├── __init__.py │ ├── cfg │ ├── __init__.py │ └── default.yaml │ ├── data │ ├── __init__.py │ ├── annotator.py │ ├── augment.py │ ├── base.py │ ├── build.py │ ├── converter.py │ ├── dataloaders │ │ ├── __init__.py │ │ ├── stream_loaders.py │ │ ├── v5augmentations.py │ │ └── v5loader.py │ ├── dataset.py │ ├── dataset_wrappers.py │ ├── scripts │ │ ├── download_weights.sh │ │ ├── get_coco.sh │ │ ├── get_coco128.sh │ │ └── get_imagenet.sh │ └── utils.py │ ├── engine │ ├── __init__.py │ ├── exporter.py │ ├── model.py │ ├── predictor.py │ ├── results.py │ ├── trainer.py │ └── validator.py │ ├── nas │ ├── __init__.py │ ├── model.py │ ├── predict.py │ └── val.py │ ├── utils │ ├── __init__.py │ ├── autobatch.py │ ├── benchmarks.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── clearml.py │ │ ├── comet.py │ │ ├── dvc.py │ │ ├── hub.py │ │ ├── mlflow.py │ │ ├── neptune.py │ │ ├── raytune.py │ │ ├── tensorboard.py │ │ └── wb.py │ ├── checks.py │ ├── dist.py │ ├── downloads.py │ ├── errors.py │ ├── files.py │ ├── instance.py │ ├── loss.py │ ├── metrics.py │ ├── ops.py │ ├── patches.py │ ├── plotting.py │ ├── tal.py │ ├── torch_utils.py │ └── tuner.py │ └── v8 │ ├── __init__.py │ ├── classify │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ ├── detect │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ ├── pose │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ └── segment │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py └── utils ├── __init__.py ├── tools.py └── tools_gradio.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | *.pyd 4 | .DS_Store 5 | .idea 6 | weights 7 | build/ 8 | *.egg-info/ 9 | gradio_cached_examples -------------------------------------------------------------------------------- /Inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fastsam import FastSAM, FastSAMPrompt 3 | import ast 4 | import torch 5 | from PIL import Image 6 | from utils.tools import convert_box_xywh_to_xyxy 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--model_path", type=str, default="./weights/FastSAM.pt", help="model" 13 | ) 14 | parser.add_argument( 15 | "--img_path", type=str, default="./images/dogs.jpg", help="path to image file" 16 | ) 17 | parser.add_argument("--imgsz", type=int, default=1024, help="image size") 18 | parser.add_argument( 19 | "--iou", 20 | type=float, 21 | default=0.9, 22 | help="iou threshold for filtering the annotations", 23 | ) 24 | parser.add_argument( 25 | "--text_prompt", type=str, default=None, help='use text prompt eg: "a dog"' 26 | ) 27 | parser.add_argument( 28 | "--conf", type=float, default=0.4, help="object confidence threshold" 29 | ) 30 | parser.add_argument( 31 | "--output", type=str, default="./output/", help="image save path" 32 | ) 33 | parser.add_argument( 34 | "--randomcolor", type=bool, default=True, help="mask random color" 35 | ) 36 | parser.add_argument( 37 | "--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]" 38 | ) 39 | parser.add_argument( 40 | "--point_label", 41 | type=str, 42 | default="[0]", 43 | help="[1,0] 0:background, 1:foreground", 44 | ) 45 | parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes") 46 | parser.add_argument( 47 | "--better_quality", 48 | type=str, 49 | default=False, 50 | help="better quality using morphologyEx", 51 | ) 52 | device = torch.device( 53 | "cuda" 54 | if torch.cuda.is_available() 55 | else "mps" 56 | if torch.backends.mps.is_available() 57 | else "cpu" 58 | ) 59 | parser.add_argument( 60 | "--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" 61 | ) 62 | parser.add_argument( 63 | "--retina", 64 | type=bool, 65 | default=True, 66 | help="draw high-resolution segmentation masks", 67 | ) 68 | parser.add_argument( 69 | "--withContours", type=bool, default=False, help="draw the edges of the masks" 70 | ) 71 | return parser.parse_args() 72 | 73 | 74 | def main(args): 75 | # load model 76 | model = FastSAM(args.model_path) 77 | args.point_prompt = ast.literal_eval(args.point_prompt) 78 | args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt)) 79 | args.point_label = ast.literal_eval(args.point_label) 80 | input = Image.open(args.img_path) 81 | input = input.convert("RGB") 82 | everything_results = model( 83 | input, 84 | device=args.device, 85 | retina_masks=args.retina, 86 | imgsz=args.imgsz, 87 | conf=args.conf, 88 | iou=args.iou 89 | ) 90 | bboxes = None 91 | points = None 92 | point_label = None 93 | prompt_process = FastSAMPrompt(input, everything_results, device=args.device) 94 | if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: 95 | ann = prompt_process.box_prompt(bboxes=args.box_prompt) 96 | bboxes = args.box_prompt 97 | elif args.text_prompt != None: 98 | ann = prompt_process.text_prompt(text=args.text_prompt) 99 | elif args.point_prompt[0] != [0, 0]: 100 | ann = prompt_process.point_prompt( 101 | points=args.point_prompt, pointlabel=args.point_label 102 | ) 103 | points = args.point_prompt 104 | point_label = args.point_label 105 | else: 106 | ann = prompt_process.everything_prompt() 107 | prompt_process.plot( 108 | annotations=ann, 109 | output_path=args.output+args.img_path.split("/")[-1], 110 | bboxes = bboxes, 111 | points = points, 112 | point_label = point_label, 113 | withContours=args.withContours, 114 | better_quality=args.better_quality, 115 | ) 116 | 117 | 118 | 119 | 120 | if __name__ == "__main__": 121 | args = parse_args() 122 | main(args) 123 | -------------------------------------------------------------------------------- /MORE_USAGES.md: -------------------------------------------------------------------------------- 1 | # MORE_USAGES 2 | 3 | 4 | 5 | ### Everything mode 6 | Use --imgsz to change different input sizes. 7 | 8 | ```shell 9 | python Inference.py --model_path ./weights/FastSAM.pt \ 10 | --img_path ./images/dogs.jpg \ 11 | --imgsz 720 \ 12 | ``` 13 | ![everything mode](assets/more_usages/everything_mode.png) 14 | 15 | 16 | 17 | ### Use more points 18 | p 19 | ```shell 20 | python Inference.py --model_path ./weights/FastSAM.pt \ 21 | --img_path ./images/dogs.jpg \ 22 | --point_prompt "[[520,360],[620,300],[520,300],[620,360]]" \ 23 | --point_label "[1,0,1,0]" 24 | ``` 25 | ![points prompt](assets/more_usages/more_points.png) 26 | ### draw mask edge 27 | Use `--withContours True` to draw the edge of the mask. 28 | 29 | When `--better_quality True` is set, the edge will be more smooth. 30 | 31 | ```shell 32 | python Inference.py --model_path ./weights/FastSAM.pt \ 33 | --img_path ./images/dogs.jpg \ 34 | --point_prompt "[[620,360]]" \ 35 | --point_label "[1]" \ 36 | --withContours True \ 37 | --better_quality True 38 | ``` 39 | 40 | ![Draw Edge](assets/more_usages/draw_edge.png) 41 | ### use box prompt 42 | Use `--box_prompt [x,y,w,h]` to specify the bounding box of the foreground object 43 | ```shell 44 | python Inference.py --model_path ./weights/FastSAM.pt \ 45 | --img_path ./images/dogs.jpg \ 46 | --box_prompt "[[570,200,230,400]]" 47 | ``` 48 | ![box prompt](assets/more_usages/box_prompt.png) 49 | 50 | ### use text prompt 51 | Use `--text_prompt "text"` to specify the text prompt 52 | ```shell 53 | python Inference.py --model_path ./weights/FastSAM.pt \ 54 | --img_path ./images/cat.jpg \ 55 | --text_prompt "cat" \ 56 | --better_quality True \ 57 | --withContours True 58 | ``` 59 | ![text prompt](assets/more_usages/text_prompt_cat.png) 60 | -------------------------------------------------------------------------------- /assets/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/Overview.png -------------------------------------------------------------------------------- /assets/anomaly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/anomaly.png -------------------------------------------------------------------------------- /assets/building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/building.png -------------------------------------------------------------------------------- /assets/dog_clip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/dog_clip.png -------------------------------------------------------------------------------- /assets/eightpic.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/eightpic.pdf -------------------------------------------------------------------------------- /assets/eightpic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/eightpic.png -------------------------------------------------------------------------------- /assets/head_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/head_fig.png -------------------------------------------------------------------------------- /assets/hf_everything_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/hf_everything_mode.png -------------------------------------------------------------------------------- /assets/hf_points_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/hf_points_mode.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/logo.png -------------------------------------------------------------------------------- /assets/more_usages/box_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/box_prompt.png -------------------------------------------------------------------------------- /assets/more_usages/draw_edge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/draw_edge.png -------------------------------------------------------------------------------- /assets/more_usages/everything_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/everything_mode.png -------------------------------------------------------------------------------- /assets/more_usages/everything_mode_without_retina.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/everything_mode_without_retina.png -------------------------------------------------------------------------------- /assets/more_usages/more_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/more_points.png -------------------------------------------------------------------------------- /assets/more_usages/text_prompt_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/more_usages/text_prompt_cat.png -------------------------------------------------------------------------------- /assets/replicate-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/replicate-1.png -------------------------------------------------------------------------------- /assets/replicate-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/replicate-2.png -------------------------------------------------------------------------------- /assets/replicate-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/replicate-3.png -------------------------------------------------------------------------------- /assets/salient.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/assets/salient.png -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | # Thanks for chenxwh. 4 | 5 | build: 6 | # set to true if your model requires a GPU 7 | gpu: true 8 | cuda: "11.7" 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | python_version: "3.8" 13 | python_packages: 14 | - "matplotlib==3.7.1" 15 | - "opencv-python==4.7.0.72" 16 | - "Pillow==9.5.0" 17 | - "PyYAML==6.0" 18 | - "requests==2.31.0" 19 | - "scipy==1.10.1" 20 | - "torch==2.0.1" 21 | - "torchvision==0.15.2" 22 | - "tqdm==4.65.0" 23 | - "pandas==2.0.2" 24 | - "seaborn==0.12.0" 25 | - "ultralytics==8.0.121" 26 | - git+https://github.com/openai/CLIP.git 27 | predict: "predict.py:Predictor" 28 | -------------------------------------------------------------------------------- /examples/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/dogs.jpg -------------------------------------------------------------------------------- /examples/sa_10039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_10039.jpg -------------------------------------------------------------------------------- /examples/sa_11025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_11025.jpg -------------------------------------------------------------------------------- /examples/sa_1309.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_1309.jpg -------------------------------------------------------------------------------- /examples/sa_192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_192.jpg -------------------------------------------------------------------------------- /examples/sa_414.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_414.jpg -------------------------------------------------------------------------------- /examples/sa_561.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_561.jpg -------------------------------------------------------------------------------- /examples/sa_862.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_862.jpg -------------------------------------------------------------------------------- /examples/sa_8776.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/examples/sa_8776.jpg -------------------------------------------------------------------------------- /fastsam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import FastSAM 4 | from .predict import FastSAMPredictor 5 | from .prompt import FastSAMPrompt 6 | # from .val import FastSAMValidator 7 | from .decoder import FastSAMDecoder 8 | 9 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder' 10 | -------------------------------------------------------------------------------- /fastsam/decoder.py: -------------------------------------------------------------------------------- 1 | from .model import FastSAM 2 | import numpy as np 3 | from PIL import Image 4 | from typing import Optional, List, Tuple, Union 5 | 6 | 7 | class FastSAMDecoder: 8 | def __init__( 9 | self, 10 | model: FastSAM, 11 | device: str='cpu', 12 | conf: float=0.4, 13 | iou: float=0.9, 14 | imgsz: int=1024, 15 | retina_masks: bool=True, 16 | ): 17 | self.model = model 18 | self.device = device 19 | self.retina_masks = retina_masks 20 | self.imgsz = imgsz 21 | self.conf = conf 22 | self.iou = iou 23 | self.image = None 24 | self.image_embedding = None 25 | 26 | def run_encoder(self, image): 27 | if isinstance(image,str): 28 | image = np.array(Image.open(image)) 29 | self.image = image 30 | image_embedding = self.model( 31 | self.image, 32 | device=self.device, 33 | retina_masks=self.retina_masks, 34 | imgsz=self.imgsz, 35 | conf=self.conf, 36 | iou=self.iou 37 | ) 38 | return image_embedding[0].numpy() 39 | 40 | def run_decoder( 41 | self, 42 | image_embedding, 43 | point_prompt: Optional[np.ndarray]=None, 44 | point_label: Optional[np.ndarray]=None, 45 | box_prompt: Optional[np.ndarray]=None, 46 | text_prompt: Optional[str]=None, 47 | )->np.ndarray: 48 | self.image_embedding = image_embedding 49 | if point_prompt is not None: 50 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label) 51 | return ann 52 | elif box_prompt is not None: 53 | ann = self.box_prompt(bbox=box_prompt) 54 | return ann 55 | elif text_prompt is not None: 56 | ann = self.text_prompt(text=text_prompt) 57 | return ann 58 | else: 59 | return None 60 | 61 | def box_prompt(self, bbox): 62 | assert (bbox[2] != 0 and bbox[3] != 0) 63 | masks = self.image_embedding.masks.data 64 | target_height = self.image.shape[0] 65 | target_width = self.image.shape[1] 66 | h = masks.shape[1] 67 | w = masks.shape[2] 68 | if h != target_height or w != target_width: 69 | bbox = [ 70 | int(bbox[0] * w / target_width), 71 | int(bbox[1] * h / target_height), 72 | int(bbox[2] * w / target_width), 73 | int(bbox[3] * h / target_height), ] 74 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 75 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 76 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 77 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 78 | 79 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 80 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 81 | 82 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) 83 | orig_masks_area = np.sum(masks, axis=(1, 2)) 84 | 85 | union = bbox_area + orig_masks_area - masks_area 86 | IoUs = masks_area / union 87 | max_iou_index = np.argmax(IoUs) 88 | 89 | return np.array([masks[max_iou_index].cpu().numpy()]) 90 | 91 | def point_prompt(self, points, pointlabel): # numpy 92 | 93 | masks = self._format_results(self.image_embedding[0], 0) 94 | target_height = self.image.shape[0] 95 | target_width = self.image.shape[1] 96 | h = masks[0]['segmentation'].shape[0] 97 | w = masks[0]['segmentation'].shape[1] 98 | if h != target_height or w != target_width: 99 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] 100 | onemask = np.zeros((h, w)) 101 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 102 | for i, annotation in enumerate(masks): 103 | if type(annotation) == dict: 104 | mask = annotation['segmentation'] 105 | else: 106 | mask = annotation 107 | for i, point in enumerate(points): 108 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: 109 | onemask[mask] = 1 110 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: 111 | onemask[mask] = 0 112 | onemask = onemask >= 1 113 | return np.array([onemask]) 114 | 115 | def _format_results(self, result, filter=0): 116 | annotations = [] 117 | n = len(result.masks.data) 118 | for i in range(n): 119 | annotation = {} 120 | mask = result.masks.data[i] == 1.0 121 | 122 | if np.sum(mask) < filter: 123 | continue 124 | annotation['id'] = i 125 | annotation['segmentation'] = mask 126 | annotation['bbox'] = result.boxes.data[i] 127 | annotation['score'] = result.boxes.conf[i] 128 | annotation['area'] = annotation['segmentation'].sum() 129 | annotations.append(annotation) 130 | return annotations 131 | -------------------------------------------------------------------------------- /fastsam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | FastSAM model interface. 4 | 5 | Usage - Predict: 6 | from ultralytics import FastSAM 7 | 8 | model = FastSAM('last.pt') 9 | results = model.predict('ultralytics/assets/bus.jpg') 10 | """ 11 | 12 | from ultralytics.yolo.cfg import get_cfg 13 | from ultralytics.yolo.engine.exporter import Exporter 14 | from ultralytics.yolo.engine.model import YOLO 15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir 16 | from ultralytics.yolo.utils.checks import check_imgsz 17 | 18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode 19 | from .predict import FastSAMPredictor 20 | 21 | 22 | class FastSAM(YOLO): 23 | 24 | @smart_inference_mode() 25 | def predict(self, source=None, stream=False, **kwargs): 26 | """ 27 | Perform prediction using the YOLO model. 28 | 29 | Args: 30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on. 31 | Accepts all source types accepted by the YOLO model. 32 | stream (bool): Whether to stream the predictions or not. Defaults to False. 33 | **kwargs : Additional keyword arguments passed to the predictor. 34 | Check the 'configuration' section in the documentation for all available options. 35 | 36 | Returns: 37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results. 38 | """ 39 | if source is None: 40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' 41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") 42 | overrides = self.overrides.copy() 43 | overrides['conf'] = 0.25 44 | overrides.update(kwargs) # prefer kwargs 45 | overrides['mode'] = kwargs.get('mode', 'predict') 46 | assert overrides['mode'] in ['track', 'predict'] 47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python 48 | self.predictor = FastSAMPredictor(overrides=overrides) 49 | self.predictor.setup_model(model=self.model, verbose=False) 50 | try: 51 | return self.predictor(source, stream=stream) 52 | except Exception as e: 53 | return None 54 | 55 | def train(self, **kwargs): 56 | """Function trains models but raises an error as FastSAM models do not support training.""" 57 | raise NotImplementedError("Currently, the training codes are on the way.") 58 | 59 | def val(self, **kwargs): 60 | """Run validation given dataset.""" 61 | overrides = dict(task='segment', mode='val') 62 | overrides.update(kwargs) # prefer kwargs 63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1) 65 | validator = FastSAM(args=args) 66 | validator(model=self.model) 67 | self.metrics = validator.metrics 68 | return validator.metrics 69 | 70 | @smart_inference_mode() 71 | def export(self, **kwargs): 72 | """ 73 | Export model. 74 | 75 | Args: 76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs 77 | """ 78 | overrides = dict(task='detect') 79 | overrides.update(kwargs) 80 | overrides['mode'] = 'export' 81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 82 | args.task = self.task 83 | if args.imgsz == DEFAULT_CFG.imgsz: 84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed 85 | if args.batch == DEFAULT_CFG.batch: 86 | args.batch = 1 # default to 1 if not modified 87 | return Exporter(overrides=args)(model=self.model) 88 | 89 | def info(self, detailed=False, verbose=True): 90 | """ 91 | Logs model info. 92 | 93 | Args: 94 | detailed (bool): Show detailed information about model. 95 | verbose (bool): Controls verbosity. 96 | """ 97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 98 | 99 | def __call__(self, source=None, stream=False, **kwargs): 100 | """Calls the 'predict' function with given arguments to perform object detection.""" 101 | return self.predict(source, stream, **kwargs) 102 | 103 | def __getattr__(self, attr): 104 | """Raises error if object has no requested attribute.""" 105 | name = self.__class__.__name__ 106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 107 | -------------------------------------------------------------------------------- /fastsam/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ultralytics.yolo.engine.results import Results 4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops 5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 6 | from .utils import bbox_iou 7 | 8 | class FastSAMPredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'segment' 13 | 14 | def postprocess(self, preds, img, orig_imgs): 15 | """TODO: filter by classes.""" 16 | p = ops.non_max_suppression(preds[0], 17 | self.args.conf, 18 | self.args.iou, 19 | agnostic=self.args.agnostic_nms, 20 | max_det=self.args.max_det, 21 | nc=len(self.model.names), 22 | classes=self.args.classes) 23 | 24 | results = [] 25 | if len(p) == 0 or len(p[0]) == 0: 26 | print("No object detected.") 27 | return results 28 | 29 | full_box = torch.zeros_like(p[0][0]) 30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 31 | full_box = full_box.view(1, -1) 32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) 33 | if critical_iou_index.numel() != 0: 34 | full_box[0][4] = p[0][critical_iou_index][:,4] 35 | full_box[0][6:] = p[0][critical_iou_index][:,6:] 36 | p[0][critical_iou_index] = full_box 37 | 38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 39 | for i, pred in enumerate(p): 40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 41 | path = self.batch[0] 42 | img_path = path[i] if isinstance(path, list) else path 43 | if not len(pred): # save empty boxes 44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 45 | continue 46 | if self.args.retina_masks: 47 | if not isinstance(orig_imgs, torch.Tensor): 48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 50 | else: 51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 52 | if not isinstance(orig_imgs, torch.Tensor): 53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 54 | results.append( 55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 56 | return results 57 | -------------------------------------------------------------------------------- /fastsam/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): 7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold. 8 | Args: 9 | boxes: (n, 4) 10 | image_shape: (height, width) 11 | threshold: pixel threshold 12 | Returns: 13 | adjusted_boxes: adjusted bounding boxes 14 | ''' 15 | 16 | # Image dimensions 17 | h, w = image_shape 18 | 19 | # Adjust boxes 20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor( 21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1 22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor( 23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1 24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor( 25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2 26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor( 27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2 28 | 29 | return boxes 30 | 31 | 32 | 33 | def convert_box_xywh_to_xyxy(box): 34 | x1 = box[0] 35 | y1 = box[1] 36 | x2 = box[0] + box[2] 37 | y2 = box[1] + box[3] 38 | return [x1, y1, x2, y2] 39 | 40 | 41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): 42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. 43 | Args: 44 | box1: (4, ) 45 | boxes: (n, 4) 46 | Returns: 47 | high_iou_indices: Indices of boxes with IoU > thres 48 | ''' 49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape) 50 | # obtain coordinates for intersections 51 | x1 = torch.max(box1[0], boxes[:, 0]) 52 | y1 = torch.max(box1[1], boxes[:, 1]) 53 | x2 = torch.min(box1[2], boxes[:, 2]) 54 | y2 = torch.min(box1[3], boxes[:, 3]) 55 | 56 | # compute the area of intersection 57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) 58 | 59 | # compute the area of both individual boxes 60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 62 | 63 | # compute the area of union 64 | union = box1_area + box2_area - intersection 65 | 66 | # compute the IoU 67 | iou = intersection / union # Should be shape (n, ) 68 | if raw_output: 69 | if iou.numel() == 0: 70 | return 0 71 | return iou 72 | 73 | # get indices of boxes with IoU > thres 74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten() 75 | 76 | return high_iou_indices 77 | 78 | 79 | def image_to_np_ndarray(image): 80 | if type(image) is str: 81 | return np.array(Image.open(image)) 82 | elif issubclass(type(image), Image.Image): 83 | return np.array(image) 84 | elif type(image) is np.ndarray: 85 | return image 86 | return None 87 | -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/images/cat.jpg -------------------------------------------------------------------------------- /images/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/images/dogs.jpg -------------------------------------------------------------------------------- /output/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/output/cat.jpg -------------------------------------------------------------------------------- /output/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/output/dogs.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Base----------------------------------- 2 | matplotlib>=3.2.2 3 | opencv-python>=4.6.0 4 | Pillow>=7.1.2 5 | PyYAML>=5.3.1 6 | requests>=2.23.0 7 | scipy>=1.4.1 8 | torch>=1.7.0 9 | torchvision>=0.8.1 10 | tqdm>=4.64.0 11 | 12 | pandas>=1.1.4 13 | seaborn>=0.11.0 14 | 15 | gradio==3.35.2 16 | 17 | # Ultralytics----------------------------------- 18 | # ultralytics == 8.0.120 19 | 20 | -------------------------------------------------------------------------------- /segpredict.py: -------------------------------------------------------------------------------- 1 | from fastsam import FastSAM, FastSAMPrompt 2 | import torch 3 | 4 | model = FastSAM('FastSAM.pt') 5 | IMAGE_PATH = './images/dogs.jpg' 6 | DEVICE = torch.device( 7 | "cuda" 8 | if torch.cuda.is_available() 9 | else "mps" 10 | if torch.backends.mps.is_available() 11 | else "cpu" 12 | ) 13 | everything_results = model( 14 | IMAGE_PATH, 15 | device=DEVICE, 16 | retina_masks=True, 17 | imgsz=1024, 18 | conf=0.4, 19 | iou=0.9, 20 | ) 21 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE) 22 | 23 | # # everything prompt 24 | ann = prompt_process.everything_prompt() 25 | 26 | # # bbox prompt 27 | # # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2] 28 | # bboxes default shape [[0,0,0,0]] -> [[x1,y1,x2,y2]] 29 | # ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300]) 30 | # ann = prompt_process.box_prompt(bboxes=[[200, 200, 300, 300], [500, 500, 600, 600]]) 31 | 32 | # # text prompt 33 | # ann = prompt_process.text_prompt(text='a photo of a dog') 34 | 35 | # # point prompt 36 | # # points default [[0,0]] [[x1,y1],[x2,y2]] 37 | # # point_label default [0] [1,0] 0:background, 1:foreground 38 | # ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 39 | 40 | # point prompt 41 | # points default [[0,0]] [[x1,y1],[x2,y2]] 42 | # point_label default [0] [1,0] 0:background, 1:foreground 43 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 44 | 45 | prompt_process.plot( 46 | annotations=ann, 47 | output='./output/', 48 | mask_random_color=True, 49 | better_quality=True, 50 | retina=False, 51 | withContours=True, 52 | ) 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | REQUIREMENTS = [i.strip() for i in open("requirements.txt").readlines()] 7 | REQUIREMENTS += [ 8 | "CLIP @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=CLIP" 9 | ] 10 | 11 | setup( 12 | name="fastsam", 13 | version="0.1.1", 14 | install_requires=REQUIREMENTS, 15 | packages=["fastsam", "fastsam_tools"], 16 | package_dir= { 17 | "fastsam": "fastsam", 18 | "fastsam_tools": "utils", 19 | }, 20 | url="https://github.com/CASIA-IVA-Lab/FastSAM" 21 | ) 22 | -------------------------------------------------------------------------------- /ultralytics/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md 3 | 4 | exclude: 'docs/' 5 | # Define bot property if installed via https://github.com/marketplace/pre-commit-ci 6 | ci: 7 | autofix_prs: true 8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 9 | autoupdate_schedule: monthly 10 | # submodules: true 11 | 12 | repos: 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v4.4.0 15 | hooks: 16 | - id: end-of-file-fixer 17 | - id: trailing-whitespace 18 | - id: check-case-conflict 19 | # - id: check-yaml 20 | - id: check-docstring-first 21 | - id: double-quote-string-fixer 22 | - id: detect-private-key 23 | 24 | - repo: https://github.com/asottile/pyupgrade 25 | rev: v3.4.0 26 | hooks: 27 | - id: pyupgrade 28 | name: Upgrade code 29 | 30 | - repo: https://github.com/PyCQA/isort 31 | rev: 5.12.0 32 | hooks: 33 | - id: isort 34 | name: Sort imports 35 | 36 | - repo: https://github.com/google/yapf 37 | rev: v0.33.0 38 | hooks: 39 | - id: yapf 40 | name: YAPF formatting 41 | 42 | - repo: https://github.com/executablebooks/mdformat 43 | rev: 0.7.16 44 | hooks: 45 | - id: mdformat 46 | name: MD formatting 47 | additional_dependencies: 48 | - mdformat-gfm 49 | - mdformat-black 50 | # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md" 51 | 52 | - repo: https://github.com/PyCQA/flake8 53 | rev: 6.0.0 54 | hooks: 55 | - id: flake8 56 | name: PEP8 57 | 58 | - repo: https://github.com/codespell-project/codespell 59 | rev: v2.2.4 60 | hooks: 61 | - id: codespell 62 | args: 63 | - --ignore-words-list=crate,nd,strack,dota 64 | 65 | # - repo: https://github.com/asottile/yesqa 66 | # rev: v1.4.0 67 | # hooks: 68 | # - id: yesqa 69 | 70 | # - repo: https://github.com/asottile/dead 71 | # rev: v1.5.0 72 | # hooks: 73 | # - id: dead 74 | -------------------------------------------------------------------------------- /ultralytics/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | __version__ = '8.0.120' 4 | 5 | from ultralytics.hub import start 6 | from ultralytics.vit.rtdetr import RTDETR 7 | from ultralytics.vit.sam import SAM 8 | from ultralytics.yolo.engine.model import YOLO 9 | from ultralytics.yolo.nas import NAS 10 | from ultralytics.yolo.utils.checks import check_yolo as checks 11 | 12 | __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start' # allow simpler import 13 | -------------------------------------------------------------------------------- /ultralytics/assets/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/assets/bus.jpg -------------------------------------------------------------------------------- /ultralytics/assets/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/assets/zidane.jpg -------------------------------------------------------------------------------- /ultralytics/datasets/Argoverse.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI 3 | # Example usage: yolo train data=Argoverse.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── Argoverse ← downloads here (31.3 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/Argoverse # dataset root dir 12 | train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images 13 | val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images 14 | test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview 15 | 16 | # Classes 17 | names: 18 | 0: person 19 | 1: bicycle 20 | 2: car 21 | 3: motorcycle 22 | 4: bus 23 | 5: truck 24 | 6: traffic_light 25 | 7: stop_sign 26 | 27 | 28 | # Download script/URL (optional) --------------------------------------------------------------------------------------- 29 | download: | 30 | import json 31 | from tqdm import tqdm 32 | from ultralytics.yolo.utils.downloads import download 33 | from pathlib import Path 34 | 35 | def argoverse2yolo(set): 36 | labels = {} 37 | a = json.load(open(set, "rb")) 38 | for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."): 39 | img_id = annot['image_id'] 40 | img_name = a['images'][img_id]['name'] 41 | img_label_name = f'{img_name[:-3]}txt' 42 | 43 | cls = annot['category_id'] # instance class id 44 | x_center, y_center, width, height = annot['bbox'] 45 | x_center = (x_center + width / 2) / 1920.0 # offset and scale 46 | y_center = (y_center + height / 2) / 1200.0 # offset and scale 47 | width /= 1920.0 # scale 48 | height /= 1200.0 # scale 49 | 50 | img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']] 51 | if not img_dir.exists(): 52 | img_dir.mkdir(parents=True, exist_ok=True) 53 | 54 | k = str(img_dir / img_label_name) 55 | if k not in labels: 56 | labels[k] = [] 57 | labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n") 58 | 59 | for k in labels: 60 | with open(k, "w") as f: 61 | f.writelines(labels[k]) 62 | 63 | 64 | # Download 65 | dir = Path(yaml['path']) # dataset root dir 66 | urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip'] 67 | download(urls, dir=dir) 68 | 69 | # Convert 70 | annotations_dir = 'Argoverse-HD/annotations/' 71 | (dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images' 72 | for d in "train.json", "val.json": 73 | argoverse2yolo(dir / annotations_dir / d) # convert VisDrone annotations to YOLO labels 74 | -------------------------------------------------------------------------------- /ultralytics/datasets/GlobalWheat2020.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan 3 | # Example usage: yolo train data=GlobalWheat2020.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── GlobalWheat2020 ← downloads here (7.0 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/GlobalWheat2020 # dataset root dir 12 | train: # train images (relative to 'path') 3422 images 13 | - images/arvalis_1 14 | - images/arvalis_2 15 | - images/arvalis_3 16 | - images/ethz_1 17 | - images/rres_1 18 | - images/inrae_1 19 | - images/usask_1 20 | val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1) 21 | - images/ethz_1 22 | test: # test images (optional) 1276 images 23 | - images/utokyo_1 24 | - images/utokyo_2 25 | - images/nau_1 26 | - images/uq_1 27 | 28 | # Classes 29 | names: 30 | 0: wheat_head 31 | 32 | 33 | # Download script/URL (optional) --------------------------------------------------------------------------------------- 34 | download: | 35 | from ultralytics.yolo.utils.downloads import download 36 | from pathlib import Path 37 | 38 | # Download 39 | dir = Path(yaml['path']) # dataset root dir 40 | urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip', 41 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip'] 42 | download(urls, dir=dir) 43 | 44 | # Make Directories 45 | for p in 'annotations', 'images', 'labels': 46 | (dir / p).mkdir(parents=True, exist_ok=True) 47 | 48 | # Move 49 | for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \ 50 | 'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1': 51 | (dir / 'global-wheat-codalab-official' / p).rename(dir / 'images' / p) # move to /images 52 | f = (dir / 'global-wheat-codalab-official' / p).with_suffix('.json') # json file 53 | if f.exists(): 54 | f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations 55 | -------------------------------------------------------------------------------- /ultralytics/datasets/SKU-110K.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail 3 | # Example usage: yolo train data=SKU-110K.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── SKU-110K ← downloads here (13.6 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/SKU-110K # dataset root dir 12 | train: train.txt # train images (relative to 'path') 8219 images 13 | val: val.txt # val images (relative to 'path') 588 images 14 | test: test.txt # test images (optional) 2936 images 15 | 16 | # Classes 17 | names: 18 | 0: object 19 | 20 | 21 | # Download script/URL (optional) --------------------------------------------------------------------------------------- 22 | download: | 23 | import shutil 24 | from pathlib import Path 25 | 26 | import numpy as np 27 | import pandas as pd 28 | from tqdm import tqdm 29 | 30 | from ultralytics.yolo.utils.downloads import download 31 | from ultralytics.yolo.utils.ops import xyxy2xywh 32 | 33 | # Download 34 | dir = Path(yaml['path']) # dataset root dir 35 | parent = Path(dir.parent) # download dir 36 | urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz'] 37 | download(urls, dir=parent) 38 | 39 | # Rename directories 40 | if dir.exists(): 41 | shutil.rmtree(dir) 42 | (parent / 'SKU110K_fixed').rename(dir) # rename dir 43 | (dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir 44 | 45 | # Convert labels 46 | names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names 47 | for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv': 48 | x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations 49 | images, unique_images = x[:, 0], np.unique(x[:, 0]) 50 | with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f: 51 | f.writelines(f'./images/{s}\n' for s in unique_images) 52 | for im in tqdm(unique_images, desc=f'Converting {dir / d}'): 53 | cls = 0 # single-class dataset 54 | with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f: 55 | for r in x[images == im]: 56 | w, h = r[6], r[7] # image width, height 57 | xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance 58 | f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label 59 | -------------------------------------------------------------------------------- /ultralytics/datasets/VOC.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford 3 | # Example usage: yolo train data=VOC.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── VOC ← downloads here (2.8 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/VOC 12 | train: # train images (relative to 'path') 16551 images 13 | - images/train2012 14 | - images/train2007 15 | - images/val2012 16 | - images/val2007 17 | val: # val images (relative to 'path') 4952 images 18 | - images/test2007 19 | test: # test images (optional) 20 | - images/test2007 21 | 22 | # Classes 23 | names: 24 | 0: aeroplane 25 | 1: bicycle 26 | 2: bird 27 | 3: boat 28 | 4: bottle 29 | 5: bus 30 | 6: car 31 | 7: cat 32 | 8: chair 33 | 9: cow 34 | 10: diningtable 35 | 11: dog 36 | 12: horse 37 | 13: motorbike 38 | 14: person 39 | 15: pottedplant 40 | 16: sheep 41 | 17: sofa 42 | 18: train 43 | 19: tvmonitor 44 | 45 | 46 | # Download script/URL (optional) --------------------------------------------------------------------------------------- 47 | download: | 48 | import xml.etree.ElementTree as ET 49 | 50 | from tqdm import tqdm 51 | from ultralytics.yolo.utils.downloads import download 52 | from pathlib import Path 53 | 54 | def convert_label(path, lb_path, year, image_id): 55 | def convert_box(size, box): 56 | dw, dh = 1. / size[0], 1. / size[1] 57 | x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2] 58 | return x * dw, y * dh, w * dw, h * dh 59 | 60 | in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml') 61 | out_file = open(lb_path, 'w') 62 | tree = ET.parse(in_file) 63 | root = tree.getroot() 64 | size = root.find('size') 65 | w = int(size.find('width').text) 66 | h = int(size.find('height').text) 67 | 68 | names = list(yaml['names'].values()) # names list 69 | for obj in root.iter('object'): 70 | cls = obj.find('name').text 71 | if cls in names and int(obj.find('difficult').text) != 1: 72 | xmlbox = obj.find('bndbox') 73 | bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')]) 74 | cls_id = names.index(cls) # class id 75 | out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n') 76 | 77 | 78 | # Download 79 | dir = Path(yaml['path']) # dataset root dir 80 | url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/' 81 | urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images 82 | f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images 83 | f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images 84 | download(urls, dir=dir / 'images', curl=True, threads=3) 85 | 86 | # Convert 87 | path = dir / 'images/VOCdevkit' 88 | for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'): 89 | imgs_path = dir / 'images' / f'{image_set}{year}' 90 | lbs_path = dir / 'labels' / f'{image_set}{year}' 91 | imgs_path.mkdir(exist_ok=True, parents=True) 92 | lbs_path.mkdir(exist_ok=True, parents=True) 93 | 94 | with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f: 95 | image_ids = f.read().strip().split() 96 | for id in tqdm(image_ids, desc=f'{image_set}{year}'): 97 | f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path 98 | lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path 99 | f.rename(imgs_path / f.name) # move image 100 | convert_label(path, lb_path, year, id) # convert labels to YOLO format 101 | -------------------------------------------------------------------------------- /ultralytics/datasets/VisDrone.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University 3 | # Example usage: yolo train data=VisDrone.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── VisDrone ← downloads here (2.3 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/VisDrone # dataset root dir 12 | train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images 13 | val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images 14 | test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images 15 | 16 | # Classes 17 | names: 18 | 0: pedestrian 19 | 1: people 20 | 2: bicycle 21 | 3: car 22 | 4: van 23 | 5: truck 24 | 6: tricycle 25 | 7: awning-tricycle 26 | 8: bus 27 | 9: motor 28 | 29 | 30 | # Download script/URL (optional) --------------------------------------------------------------------------------------- 31 | download: | 32 | import os 33 | from pathlib import Path 34 | 35 | from ultralytics.yolo.utils.downloads import download 36 | 37 | def visdrone2yolo(dir): 38 | from PIL import Image 39 | from tqdm import tqdm 40 | 41 | def convert_box(size, box): 42 | # Convert VisDrone box to YOLO xywh box 43 | dw = 1. / size[0] 44 | dh = 1. / size[1] 45 | return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh 46 | 47 | (dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory 48 | pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}') 49 | for f in pbar: 50 | img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size 51 | lines = [] 52 | with open(f, 'r') as file: # read annotation.txt 53 | for row in [x.split(',') for x in file.read().strip().splitlines()]: 54 | if row[4] == '0': # VisDrone 'ignored regions' class 0 55 | continue 56 | cls = int(row[5]) - 1 57 | box = convert_box(img_size, tuple(map(int, row[:4]))) 58 | lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n") 59 | with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl: 60 | fl.writelines(lines) # write label.txt 61 | 62 | 63 | # Download 64 | dir = Path(yaml['path']) # dataset root dir 65 | urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip', 66 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip', 67 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip', 68 | 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip'] 69 | download(urls, dir=dir, curl=True, threads=4) 70 | 71 | # Convert 72 | for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev': 73 | visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels 74 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco-pose.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO 2017 dataset http://cocodataset.org by Microsoft 3 | # Example usage: yolo train data=coco-pose.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco-pose ← downloads here (20.1 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco-pose # dataset root dir 12 | train: train2017.txt # train images (relative to 'path') 118287 images 13 | val: val2017.txt # val images (relative to 'path') 5000 images 14 | test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 15 | 16 | # Keypoints 17 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 18 | flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 19 | 20 | # Classes 21 | names: 22 | 0: person 23 | 24 | # Download script/URL (optional) 25 | download: | 26 | from ultralytics.yolo.utils.downloads import download 27 | from pathlib import Path 28 | 29 | # Download labels 30 | dir = Path(yaml['path']) # dataset root dir 31 | url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/' 32 | urls = [url + 'coco2017labels-pose.zip'] # labels 33 | download(urls, dir=dir.parent) 34 | # Download data 35 | urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images 36 | 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images 37 | 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) 38 | download(urls, dir=dir / 'images', threads=3) 39 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO 2017 dataset http://cocodataset.org by Microsoft 3 | # Example usage: yolo train data=coco.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco ← downloads here (20.1 GB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco # dataset root dir 12 | train: train2017.txt # train images (relative to 'path') 118287 images 13 | val: val2017.txt # val images (relative to 'path') 5000 images 14 | test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 15 | 16 | # Classes 17 | names: 18 | 0: person 19 | 1: bicycle 20 | 2: car 21 | 3: motorcycle 22 | 4: airplane 23 | 5: bus 24 | 6: train 25 | 7: truck 26 | 8: boat 27 | 9: traffic light 28 | 10: fire hydrant 29 | 11: stop sign 30 | 12: parking meter 31 | 13: bench 32 | 14: bird 33 | 15: cat 34 | 16: dog 35 | 17: horse 36 | 18: sheep 37 | 19: cow 38 | 20: elephant 39 | 21: bear 40 | 22: zebra 41 | 23: giraffe 42 | 24: backpack 43 | 25: umbrella 44 | 26: handbag 45 | 27: tie 46 | 28: suitcase 47 | 29: frisbee 48 | 30: skis 49 | 31: snowboard 50 | 32: sports ball 51 | 33: kite 52 | 34: baseball bat 53 | 35: baseball glove 54 | 36: skateboard 55 | 37: surfboard 56 | 38: tennis racket 57 | 39: bottle 58 | 40: wine glass 59 | 41: cup 60 | 42: fork 61 | 43: knife 62 | 44: spoon 63 | 45: bowl 64 | 46: banana 65 | 47: apple 66 | 48: sandwich 67 | 49: orange 68 | 50: broccoli 69 | 51: carrot 70 | 52: hot dog 71 | 53: pizza 72 | 54: donut 73 | 55: cake 74 | 56: chair 75 | 57: couch 76 | 58: potted plant 77 | 59: bed 78 | 60: dining table 79 | 61: toilet 80 | 62: tv 81 | 63: laptop 82 | 64: mouse 83 | 65: remote 84 | 66: keyboard 85 | 67: cell phone 86 | 68: microwave 87 | 69: oven 88 | 70: toaster 89 | 71: sink 90 | 72: refrigerator 91 | 73: book 92 | 74: clock 93 | 75: vase 94 | 76: scissors 95 | 77: teddy bear 96 | 78: hair drier 97 | 79: toothbrush 98 | 99 | 100 | # Download script/URL (optional) 101 | download: | 102 | from ultralytics.yolo.utils.downloads import download 103 | from pathlib import Path 104 | 105 | # Download labels 106 | segments = True # segment or box labels 107 | dir = Path(yaml['path']) # dataset root dir 108 | url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/' 109 | urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels 110 | download(urls, dir=dir.parent) 111 | # Download data 112 | urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images 113 | 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images 114 | 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) 115 | download(urls, dir=dir / 'images', threads=3) 116 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco128-seg.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics 3 | # Example usage: yolo train data=coco128.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco128-seg ← downloads here (7 MB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco128-seg # dataset root dir 12 | train: images/train2017 # train images (relative to 'path') 128 images 13 | val: images/train2017 # val images (relative to 'path') 128 images 14 | test: # test images (optional) 15 | 16 | # Classes 17 | names: 18 | 0: person 19 | 1: bicycle 20 | 2: car 21 | 3: motorcycle 22 | 4: airplane 23 | 5: bus 24 | 6: train 25 | 7: truck 26 | 8: boat 27 | 9: traffic light 28 | 10: fire hydrant 29 | 11: stop sign 30 | 12: parking meter 31 | 13: bench 32 | 14: bird 33 | 15: cat 34 | 16: dog 35 | 17: horse 36 | 18: sheep 37 | 19: cow 38 | 20: elephant 39 | 21: bear 40 | 22: zebra 41 | 23: giraffe 42 | 24: backpack 43 | 25: umbrella 44 | 26: handbag 45 | 27: tie 46 | 28: suitcase 47 | 29: frisbee 48 | 30: skis 49 | 31: snowboard 50 | 32: sports ball 51 | 33: kite 52 | 34: baseball bat 53 | 35: baseball glove 54 | 36: skateboard 55 | 37: surfboard 56 | 38: tennis racket 57 | 39: bottle 58 | 40: wine glass 59 | 41: cup 60 | 42: fork 61 | 43: knife 62 | 44: spoon 63 | 45: bowl 64 | 46: banana 65 | 47: apple 66 | 48: sandwich 67 | 49: orange 68 | 50: broccoli 69 | 51: carrot 70 | 52: hot dog 71 | 53: pizza 72 | 54: donut 73 | 55: cake 74 | 56: chair 75 | 57: couch 76 | 58: potted plant 77 | 59: bed 78 | 60: dining table 79 | 61: toilet 80 | 62: tv 81 | 63: laptop 82 | 64: mouse 83 | 65: remote 84 | 66: keyboard 85 | 67: cell phone 86 | 68: microwave 87 | 69: oven 88 | 70: toaster 89 | 71: sink 90 | 72: refrigerator 91 | 73: book 92 | 74: clock 93 | 75: vase 94 | 76: scissors 95 | 77: teddy bear 96 | 78: hair drier 97 | 79: toothbrush 98 | 99 | 100 | # Download script/URL (optional) 101 | download: https://ultralytics.com/assets/coco128-seg.zip 102 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco128.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics 3 | # Example usage: yolo train data=coco128.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco128 ← downloads here (7 MB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco128 # dataset root dir 12 | train: images/train2017 # train images (relative to 'path') 128 images 13 | val: images/train2017 # val images (relative to 'path') 128 images 14 | test: # test images (optional) 15 | 16 | # Classes 17 | names: 18 | 0: person 19 | 1: bicycle 20 | 2: car 21 | 3: motorcycle 22 | 4: airplane 23 | 5: bus 24 | 6: train 25 | 7: truck 26 | 8: boat 27 | 9: traffic light 28 | 10: fire hydrant 29 | 11: stop sign 30 | 12: parking meter 31 | 13: bench 32 | 14: bird 33 | 15: cat 34 | 16: dog 35 | 17: horse 36 | 18: sheep 37 | 19: cow 38 | 20: elephant 39 | 21: bear 40 | 22: zebra 41 | 23: giraffe 42 | 24: backpack 43 | 25: umbrella 44 | 26: handbag 45 | 27: tie 46 | 28: suitcase 47 | 29: frisbee 48 | 30: skis 49 | 31: snowboard 50 | 32: sports ball 51 | 33: kite 52 | 34: baseball bat 53 | 35: baseball glove 54 | 36: skateboard 55 | 37: surfboard 56 | 38: tennis racket 57 | 39: bottle 58 | 40: wine glass 59 | 41: cup 60 | 42: fork 61 | 43: knife 62 | 44: spoon 63 | 45: bowl 64 | 46: banana 65 | 47: apple 66 | 48: sandwich 67 | 49: orange 68 | 50: broccoli 69 | 51: carrot 70 | 52: hot dog 71 | 53: pizza 72 | 54: donut 73 | 55: cake 74 | 56: chair 75 | 57: couch 76 | 58: potted plant 77 | 59: bed 78 | 60: dining table 79 | 61: toilet 80 | 62: tv 81 | 63: laptop 82 | 64: mouse 83 | 65: remote 84 | 66: keyboard 85 | 67: cell phone 86 | 68: microwave 87 | 69: oven 88 | 70: toaster 89 | 71: sink 90 | 72: refrigerator 91 | 73: book 92 | 74: clock 93 | 75: vase 94 | 76: scissors 95 | 77: teddy bear 96 | 78: hair drier 97 | 79: toothbrush 98 | 99 | 100 | # Download script/URL (optional) 101 | download: https://ultralytics.com/assets/coco128.zip 102 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco8-pose.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics 3 | # Example usage: yolo train data=coco8-pose.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco8-pose ← downloads here (1 MB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco8-pose # dataset root dir 12 | train: images/train # train images (relative to 'path') 4 images 13 | val: images/val # val images (relative to 'path') 4 images 14 | test: # test images (optional) 15 | 16 | # Keypoints 17 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 18 | flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 19 | 20 | # Classes 21 | names: 22 | 0: person 23 | 24 | # Download script/URL (optional) 25 | download: https://ultralytics.com/assets/coco8-pose.zip 26 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco8-seg.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics 3 | # Example usage: yolo train data=coco8-seg.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco8-seg ← downloads here (1 MB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco8-seg # dataset root dir 12 | train: images/train # train images (relative to 'path') 4 images 13 | val: images/val # val images (relative to 'path') 4 images 14 | test: # test images (optional) 15 | 16 | # Classes 17 | names: 18 | 0: person 19 | 1: bicycle 20 | 2: car 21 | 3: motorcycle 22 | 4: airplane 23 | 5: bus 24 | 6: train 25 | 7: truck 26 | 8: boat 27 | 9: traffic light 28 | 10: fire hydrant 29 | 11: stop sign 30 | 12: parking meter 31 | 13: bench 32 | 14: bird 33 | 15: cat 34 | 16: dog 35 | 17: horse 36 | 18: sheep 37 | 19: cow 38 | 20: elephant 39 | 21: bear 40 | 22: zebra 41 | 23: giraffe 42 | 24: backpack 43 | 25: umbrella 44 | 26: handbag 45 | 27: tie 46 | 28: suitcase 47 | 29: frisbee 48 | 30: skis 49 | 31: snowboard 50 | 32: sports ball 51 | 33: kite 52 | 34: baseball bat 53 | 35: baseball glove 54 | 36: skateboard 55 | 37: surfboard 56 | 38: tennis racket 57 | 39: bottle 58 | 40: wine glass 59 | 41: cup 60 | 42: fork 61 | 43: knife 62 | 44: spoon 63 | 45: bowl 64 | 46: banana 65 | 47: apple 66 | 48: sandwich 67 | 49: orange 68 | 50: broccoli 69 | 51: carrot 70 | 52: hot dog 71 | 53: pizza 72 | 54: donut 73 | 55: cake 74 | 56: chair 75 | 57: couch 76 | 58: potted plant 77 | 59: bed 78 | 60: dining table 79 | 61: toilet 80 | 62: tv 81 | 63: laptop 82 | 64: mouse 83 | 65: remote 84 | 66: keyboard 85 | 67: cell phone 86 | 68: microwave 87 | 69: oven 88 | 70: toaster 89 | 71: sink 90 | 72: refrigerator 91 | 73: book 92 | 74: clock 93 | 75: vase 94 | 76: scissors 95 | 77: teddy bear 96 | 78: hair drier 97 | 79: toothbrush 98 | 99 | 100 | # Download script/URL (optional) 101 | download: https://ultralytics.com/assets/coco8-seg.zip 102 | -------------------------------------------------------------------------------- /ultralytics/datasets/coco8.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # COCO8 dataset (first 8 images from COCO train2017) by Ultralytics 3 | # Example usage: yolo train data=coco8.yaml 4 | # parent 5 | # ├── ultralytics 6 | # └── datasets 7 | # └── coco8 ← downloads here (1 MB) 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ../datasets/coco8 # dataset root dir 12 | train: images/train # train images (relative to 'path') 4 images 13 | val: images/val # val images (relative to 'path') 4 images 14 | test: # test images (optional) 15 | 16 | # Classes 17 | names: 18 | 0: person 19 | 1: bicycle 20 | 2: car 21 | 3: motorcycle 22 | 4: airplane 23 | 5: bus 24 | 6: train 25 | 7: truck 26 | 8: boat 27 | 9: traffic light 28 | 10: fire hydrant 29 | 11: stop sign 30 | 12: parking meter 31 | 13: bench 32 | 14: bird 33 | 15: cat 34 | 16: dog 35 | 17: horse 36 | 18: sheep 37 | 19: cow 38 | 20: elephant 39 | 21: bear 40 | 22: zebra 41 | 23: giraffe 42 | 24: backpack 43 | 25: umbrella 44 | 26: handbag 45 | 27: tie 46 | 28: suitcase 47 | 29: frisbee 48 | 30: skis 49 | 31: snowboard 50 | 32: sports ball 51 | 33: kite 52 | 34: baseball bat 53 | 35: baseball glove 54 | 36: skateboard 55 | 37: surfboard 56 | 38: tennis racket 57 | 39: bottle 58 | 40: wine glass 59 | 41: cup 60 | 42: fork 61 | 43: knife 62 | 44: spoon 63 | 45: bowl 64 | 46: banana 65 | 47: apple 66 | 48: sandwich 67 | 49: orange 68 | 50: broccoli 69 | 51: carrot 70 | 52: hot dog 71 | 53: pizza 72 | 54: donut 73 | 55: cake 74 | 56: chair 75 | 57: couch 76 | 58: potted plant 77 | 59: bed 78 | 60: dining table 79 | 61: toilet 80 | 62: tv 81 | 63: laptop 82 | 64: mouse 83 | 65: remote 84 | 66: keyboard 85 | 67: cell phone 86 | 68: microwave 87 | 69: oven 88 | 70: toaster 89 | 71: sink 90 | 72: refrigerator 91 | 73: book 92 | 74: clock 93 | 75: vase 94 | 76: scissors 95 | 77: teddy bear 96 | 78: hair drier 97 | 79: toothbrush 98 | 99 | 100 | # Download script/URL (optional) 101 | download: https://ultralytics.com/assets/coco8.zip 102 | -------------------------------------------------------------------------------- /ultralytics/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import requests 4 | 5 | from ultralytics.hub.auth import Auth 6 | from ultralytics.hub.utils import PREFIX 7 | from ultralytics.yolo.data.utils import HUBDatasetStats 8 | from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save 9 | 10 | 11 | def login(api_key=''): 12 | """ 13 | Log in to the Ultralytics HUB API using the provided API key. 14 | 15 | Args: 16 | api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id 17 | 18 | Example: 19 | from ultralytics import hub 20 | hub.login('API_KEY') 21 | """ 22 | Auth(api_key, verbose=True) 23 | 24 | 25 | def logout(): 26 | """ 27 | Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'. 28 | 29 | Example: 30 | from ultralytics import hub 31 | hub.logout() 32 | """ 33 | SETTINGS['api_key'] = '' 34 | yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS) 35 | LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.") 36 | 37 | 38 | def start(key=''): 39 | """ 40 | Start training models with Ultralytics HUB (DEPRECATED). 41 | 42 | Args: 43 | key (str, optional): A string containing either the API key and model ID combination (apikey_modelid), 44 | or the full model URL (https://hub.ultralytics.com/models/apikey_modelid). 45 | """ 46 | api_key, model_id = key.split('_') 47 | LOGGER.warning(f""" 48 | WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is: 49 | 50 | from ultralytics import YOLO, hub 51 | 52 | hub.login('{api_key}') 53 | model = YOLO('https://hub.ultralytics.com/models/{model_id}') 54 | model.train()""") 55 | 56 | 57 | def reset_model(model_id=''): 58 | """Reset a trained model to an untrained state.""" 59 | r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id}) 60 | if r.status_code == 200: 61 | LOGGER.info(f'{PREFIX}Model reset successfully') 62 | return 63 | LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}') 64 | 65 | 66 | def export_fmts_hub(): 67 | """Returns a list of HUB-supported export formats.""" 68 | from ultralytics.yolo.engine.exporter import export_formats 69 | return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml'] 70 | 71 | 72 | def export_model(model_id='', format='torchscript'): 73 | """Export a model to all formats.""" 74 | assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" 75 | r = requests.post(f'https://api.ultralytics.com/v1/models/{model_id}/export', 76 | json={'format': format}, 77 | headers={'x-api-key': Auth().api_key}) 78 | assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}' 79 | LOGGER.info(f'{PREFIX}{format} export started ✅') 80 | 81 | 82 | def get_export(model_id='', format='torchscript'): 83 | """Get an exported model dictionary with download URL.""" 84 | assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" 85 | r = requests.post('https://api.ultralytics.com/get-export', 86 | json={ 87 | 'apiKey': Auth().api_key, 88 | 'modelId': model_id, 89 | 'format': format}) 90 | assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' 91 | return r.json() 92 | 93 | 94 | def check_dataset(path='', task='detect'): 95 | """ 96 | Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is 97 | uploaded to the HUB. Usage examples are given below. 98 | 99 | Args: 100 | path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''. 101 | task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'. 102 | 103 | Example: 104 | ```python 105 | from ultralytics.hub import check_dataset 106 | 107 | check_dataset('path/to/coco8.zip', task='detect') # detect dataset 108 | check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset 109 | check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset 110 | ``` 111 | """ 112 | HUBDatasetStats(path=path, task=task).get_json() 113 | LOGGER.info('Checks completed correctly ✅. Upload this dataset to https://hub.ultralytics.com/datasets/.') 114 | 115 | 116 | if __name__ == '__main__': 117 | start() 118 | -------------------------------------------------------------------------------- /ultralytics/models/README.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration 4 | files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted 5 | and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image 6 | segmentation tasks. 7 | 8 | These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like 9 | instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, 10 | from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this 11 | directory provides a great starting point for your custom model development needs. 12 | 13 | To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've 14 | selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full 15 | details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free 16 | to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now! 17 | 18 | ### Usage 19 | 20 | Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command: 21 | 22 | ```bash 23 | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 24 | ``` 25 | 26 | They may also be used directly in a Python environment, and accepts the same 27 | [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: 28 | 29 | ```python 30 | from ultralytics import YOLO 31 | 32 | model = YOLO("model.yaml") # build a YOLOv8n model from scratch 33 | # YOLO("model.pt") use pre-trained model if available 34 | model.info() # display model information 35 | model.train(data="coco128.yaml", epochs=100) # train the model 36 | ``` 37 | 38 | ## Pre-trained Model Architectures 39 | 40 | Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information 41 | and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available. 42 | 43 | ## Contributing New Models 44 | 45 | If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing). 46 | -------------------------------------------------------------------------------- /ultralytics/models/rt-detr/rtdetr-l.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | l: [1.00, 1.00, 1024] 9 | 10 | backbone: 11 | # [from, repeats, module, args] 12 | - [-1, 1, HGStem, [32, 48]] # 0-P2/4 13 | - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 14 | 15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 16 | - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 17 | 18 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 19 | - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut 20 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] 21 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 22 | 23 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 24 | - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 25 | 26 | head: 27 | - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 28 | - [-1, 1, AIFI, [1024, 8]] 29 | - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 30 | 31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 32 | - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 33 | - [[-2, -1], 1, Concat, [1]] 34 | - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 35 | - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 36 | 37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 38 | - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 39 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4 40 | - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 41 | 42 | - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 43 | - [[-1, 17], 1, Concat, [1]] # cat Y4 44 | - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 45 | 46 | - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 47 | - [[-1, 12], 1, Concat, [1]] # cat Y5 48 | - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 49 | 50 | - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 51 | -------------------------------------------------------------------------------- /ultralytics/models/rt-detr/rtdetr-x.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | x: [1.00, 1.00, 2048] 9 | 10 | backbone: 11 | # [from, repeats, module, args] 12 | - [-1, 1, HGStem, [32, 64]] # 0-P2/4 13 | - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 14 | 15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 16 | - [-1, 6, HGBlock, [128, 512, 3]] 17 | - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 18 | 19 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 20 | - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut 21 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 22 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 23 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 24 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 25 | 26 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 27 | - [-1, 6, HGBlock, [512, 2048, 5, True, False]] 28 | - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 29 | 30 | head: 31 | - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 32 | - [-1, 1, AIFI, [2048, 8]] 33 | - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 34 | 35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 36 | - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 37 | - [[-2, -1], 1, Concat, [1]] 38 | - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 39 | - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 40 | 41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 42 | - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 43 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4 44 | - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 45 | 46 | - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 47 | - [[-1, 21], 1, Concat, [1]] # cat Y4 48 | - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 49 | 50 | - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 51 | - [[-1, 16], 1, Concat, [1]] # cat Y5 52 | - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 53 | 54 | - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 55 | -------------------------------------------------------------------------------- /ultralytics/models/v3/yolov3-spp.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # darknet53 backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [32, 3, 1]], # 0 13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 14 | [-1, 1, Bottleneck, [64]], 15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 16 | [-1, 2, Bottleneck, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 18 | [-1, 8, Bottleneck, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 20 | [-1, 8, Bottleneck, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 22 | [-1, 4, Bottleneck, [1024]], # 10 23 | ] 24 | 25 | # YOLOv3-SPP head 26 | head: 27 | [[-1, 1, Bottleneck, [1024, False]], 28 | [-1, 1, SPP, [512, [5, 9, 13]]], 29 | [-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [512, 1, 1]], 31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Bottleneck, [512, False]], 37 | [-1, 1, Bottleneck, [512, False]], 38 | [-1, 1, Conv, [256, 1, 1]], 39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) 40 | 41 | [-2, 1, Conv, [128, 1, 1]], 42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3 44 | [-1, 1, Bottleneck, [256, False]], 45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) 46 | 47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) 48 | ] 49 | -------------------------------------------------------------------------------- /ultralytics/models/v3/yolov3-tiny.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # YOLOv3-tiny backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [16, 3, 1]], # 0 13 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2 14 | [-1, 1, Conv, [32, 3, 1]], 15 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4 16 | [-1, 1, Conv, [64, 3, 1]], 17 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8 18 | [-1, 1, Conv, [128, 3, 1]], 19 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16 20 | [-1, 1, Conv, [256, 3, 1]], 21 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32 22 | [-1, 1, Conv, [512, 3, 1]], 23 | [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11 24 | [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12 25 | ] 26 | 27 | # YOLOv3-tiny head 28 | head: 29 | [[-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [256, 1, 1]], 31 | [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [128, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium) 37 | 38 | [[19, 15], 1, Detect, [nc]], # Detect(P4, P5) 39 | ] 40 | -------------------------------------------------------------------------------- /ultralytics/models/v3/yolov3.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # darknet53 backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [32, 3, 1]], # 0 13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 14 | [-1, 1, Bottleneck, [64]], 15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 16 | [-1, 2, Bottleneck, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 18 | [-1, 8, Bottleneck, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 20 | [-1, 8, Bottleneck, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 22 | [-1, 4, Bottleneck, [1024]], # 10 23 | ] 24 | 25 | # YOLOv3 head 26 | head: 27 | [[-1, 1, Bottleneck, [1024, False]], 28 | [-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [512, 1, 1]], 31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Bottleneck, [512, False]], 37 | [-1, 1, Bottleneck, [512, False]], 38 | [-1, 1, Conv, [256, 1, 1]], 39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) 40 | 41 | [-2, 1, Conv, [128, 1, 1]], 42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3 44 | [-1, 1, Bottleneck, [256, False]], 45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) 46 | 47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) 48 | ] 49 | -------------------------------------------------------------------------------- /ultralytics/models/v5/yolov5-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.33, 1.25, 1024] 13 | 14 | # YOLOv5 v6.0 backbone 15 | backbone: 16 | # [from, number, module, args] 17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 19 | [-1, 3, C3, [128]], 20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 21 | [-1, 6, C3, [256]], 22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 23 | [-1, 9, C3, [512]], 24 | [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 25 | [-1, 3, C3, [768]], 26 | [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 27 | [-1, 3, C3, [1024]], 28 | [-1, 1, SPPF, [1024, 5]], # 11 29 | ] 30 | 31 | # YOLOv5 v6.0 head 32 | head: 33 | [[-1, 1, Conv, [768, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P5 36 | [-1, 3, C3, [768, False]], # 15 37 | 38 | [-1, 1, Conv, [512, 1, 1]], 39 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 40 | [[-1, 6], 1, Concat, [1]], # cat backbone P4 41 | [-1, 3, C3, [512, False]], # 19 42 | 43 | [-1, 1, Conv, [256, 1, 1]], 44 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 45 | [[-1, 4], 1, Concat, [1]], # cat backbone P3 46 | [-1, 3, C3, [256, False]], # 23 (P3/8-small) 47 | 48 | [-1, 1, Conv, [256, 3, 2]], 49 | [[-1, 20], 1, Concat, [1]], # cat head P4 50 | [-1, 3, C3, [512, False]], # 26 (P4/16-medium) 51 | 52 | [-1, 1, Conv, [512, 3, 2]], 53 | [[-1, 16], 1, Concat, [1]], # cat head P5 54 | [-1, 3, C3, [768, False]], # 29 (P5/32-large) 55 | 56 | [-1, 1, Conv, [768, 3, 2]], 57 | [[-1, 12], 1, Concat, [1]], # cat head P6 58 | [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) 59 | 60 | [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) 61 | ] 62 | -------------------------------------------------------------------------------- /ultralytics/models/v5/yolov5.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.33, 1.25, 1024] 13 | 14 | # YOLOv5 v6.0 backbone 15 | backbone: 16 | # [from, number, module, args] 17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 19 | [-1, 3, C3, [128]], 20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 21 | [-1, 6, C3, [256]], 22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 23 | [-1, 9, C3, [512]], 24 | [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 25 | [-1, 3, C3, [1024]], 26 | [-1, 1, SPPF, [1024, 5]], # 9 27 | ] 28 | 29 | # YOLOv5 v6.0 head 30 | head: 31 | [[-1, 1, Conv, [512, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 6], 1, Concat, [1]], # cat backbone P4 34 | [-1, 3, C3, [512, False]], # 13 35 | 36 | [-1, 1, Conv, [256, 1, 1]], 37 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 38 | [[-1, 4], 1, Concat, [1]], # cat backbone P3 39 | [-1, 3, C3, [256, False]], # 17 (P3/8-small) 40 | 41 | [-1, 1, Conv, [256, 3, 2]], 42 | [[-1, 14], 1, Concat, [1]], # cat head P4 43 | [-1, 3, C3, [512, False]], # 20 (P4/16-medium) 44 | 45 | [-1, 1, Conv, [512, 3, 2]], 46 | [[-1, 10], 1, Concat, [1]], # cat head P5 47 | [-1, 3, C3, [1024, False]], # 23 (P5/32-large) 48 | 49 | [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) 50 | ] 51 | -------------------------------------------------------------------------------- /ultralytics/models/v6/yolov6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | activation: nn.ReLU() # (optional) model default activation function 7 | scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv6-3.0s backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 6, Conv, [128, 3, 1]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 12, Conv, [256, 3, 1]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 18, Conv, [512, 3, 1]] 25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 26 | - [-1, 6, Conv, [1024, 3, 1]] 27 | - [-1, 1, SPPF, [1024, 5]] # 9 28 | 29 | # YOLOv6-3.0s head 30 | head: 31 | - [-1, 1, Conv, [256, 1, 1]] 32 | - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] 33 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 34 | - [-1, 1, Conv, [256, 3, 1]] 35 | - [-1, 9, Conv, [256, 3, 1]] # 14 36 | 37 | - [-1, 1, Conv, [128, 1, 1]] 38 | - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] 39 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 40 | - [-1, 1, Conv, [128, 3, 1]] 41 | - [-1, 9, Conv, [128, 3, 1]] # 19 42 | 43 | - [-1, 1, Conv, [128, 3, 2]] 44 | - [[-1, 15], 1, Concat, [1]] # cat head P4 45 | - [-1, 1, Conv, [256, 3, 1]] 46 | - [-1, 9, Conv, [256, 3, 1]] # 23 47 | 48 | - [-1, 1, Conv, [256, 3, 2]] 49 | - [[-1, 10], 1, Concat, [1]] # cat head P5 50 | - [-1, 1, Conv, [512, 3, 1]] 51 | - [-1, 9, Conv, [512, 3, 1]] # 27 52 | 53 | - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) 54 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-cls.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify 3 | 4 | # Parameters 5 | nc: 1000 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.00, 1.25, 1024] 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | 27 | # YOLOv8.0n head 28 | head: 29 | - [-1, 1, Classify, [nc]] # Classify 30 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-p2.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0 backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0-p2 head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 39 | - [[-1, 2], 1, Concat, [1]] # cat backbone P2 40 | - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) 41 | 42 | - [-1, 1, Conv, [128, 3, 2]] 43 | - [[-1, 15], 1, Concat, [1]] # cat head P3 44 | - [-1, 3, C2f, [256]] # 21 (P3/8-small) 45 | 46 | - [-1, 1, Conv, [256, 3, 2]] 47 | - [[-1, 12], 1, Concat, [1]] # cat head P4 48 | - [-1, 3, C2f, [512]] # 24 (P4/16-medium) 49 | 50 | - [-1, 1, Conv, [512, 3, 2]] 51 | - [[-1, 9], 1, Concat, [1]] # cat head P5 52 | - [-1, 3, C2f, [1024]] # 27 (P5/32-large) 53 | 54 | - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) 55 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0x6 backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [768, True]] 26 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 27 | - [-1, 3, C2f, [1024, True]] 28 | - [-1, 1, SPPF, [1024, 5]] # 11 29 | 30 | # YOLOv8.0x6 head 31 | head: 32 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 33 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5 34 | - [-1, 3, C2, [768, False]] # 14 35 | 36 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 37 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 38 | - [-1, 3, C2, [512, False]] # 17 39 | 40 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 41 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 42 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small) 43 | 44 | - [-1, 1, Conv, [256, 3, 2]] 45 | - [[-1, 17], 1, Concat, [1]] # cat head P4 46 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) 47 | 48 | - [-1, 1, Conv, [512, 3, 2]] 49 | - [[-1, 14], 1, Concat, [1]] # cat head P5 50 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large) 51 | 52 | - [-1, 1, Conv, [768, 3, 2]] 53 | - [[-1, 11], 1, Concat, [1]] # cat head P6 54 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) 55 | 56 | - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) 57 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-pose-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose 3 | 4 | # Parameters 5 | nc: 1 # number of classes 6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv8.0x6 backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 3, C2f, [128, True]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 6, C2f, [256, True]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 6, C2f, [512, True]] 25 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 26 | - [-1, 3, C2f, [768, True]] 27 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 28 | - [-1, 3, C2f, [1024, True]] 29 | - [-1, 1, SPPF, [1024, 5]] # 11 30 | 31 | # YOLOv8.0x6 head 32 | head: 33 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 34 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5 35 | - [-1, 3, C2, [768, False]] # 14 36 | 37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 38 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 39 | - [-1, 3, C2, [512, False]] # 17 40 | 41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 42 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 43 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small) 44 | 45 | - [-1, 1, Conv, [256, 3, 2]] 46 | - [[-1, 17], 1, Concat, [1]] # cat head P4 47 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) 48 | 49 | - [-1, 1, Conv, [512, 3, 2]] 50 | - [[-1, 14], 1, Concat, [1]] # cat head P5 51 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large) 52 | 53 | - [-1, 1, Conv, [768, 3, 2]] 54 | - [[-1, 11], 1, Concat, [1]] # cat head P6 55 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) 56 | 57 | - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) 58 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-pose.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose 3 | 4 | # Parameters 5 | nc: 1 # number of classes 6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv8.0n backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 3, C2f, [128, True]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 6, C2f, [256, True]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 6, C2f, [512, True]] 25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 26 | - [-1, 3, C2f, [1024, True]] 27 | - [-1, 1, SPPF, [1024, 5]] # 9 28 | 29 | # YOLOv8.0n head 30 | head: 31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 32 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 33 | - [-1, 3, C2f, [512]] # 12 34 | 35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 36 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 37 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 38 | 39 | - [-1, 1, Conv, [256, 3, 2]] 40 | - [[-1, 12], 1, Concat, [1]] # cat head P4 41 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 42 | 43 | - [-1, 1, Conv, [512, 3, 2]] 44 | - [[-1, 9], 1, Concat, [1]] # cat head P5 45 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 46 | 47 | - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) 48 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-rtdetr.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs 9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs 10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs 11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs 12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8-seg.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /ultralytics/models/v8/yolov8.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs 9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs 10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs 11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs 12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /ultralytics/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, 4 | attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, 5 | yaml_model_load) 6 | 7 | __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task', 8 | 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel', 9 | 'BaseModel') 10 | -------------------------------------------------------------------------------- /ultralytics/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Ultralytics modules. Visualize with: 4 | 5 | from ultralytics.nn.modules import * 6 | import torch 7 | import os 8 | 9 | x = torch.ones(1, 128, 40, 40) 10 | m = Conv(128, 128) 11 | f = f'{m._get_name()}.onnx' 12 | torch.onnx.export(m, x, f) 13 | os.system(f'onnxsim {f} {f} && open {f}') 14 | """ 15 | 16 | from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, 17 | HGBlock, HGStem, Proto, RepC3) 18 | from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, 19 | GhostConv, LightConv, RepConv, SpatialAttention) 20 | from .head import Classify, Detect, Pose, RTDETRDecoder, Segment 21 | from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, 22 | MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) 23 | 24 | __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 25 | 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 26 | 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 27 | 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 28 | 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', 29 | 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') 30 | -------------------------------------------------------------------------------- /ultralytics/nn/modules/utils.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Module utils 4 | """ 5 | 6 | import copy 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.init import uniform_ 14 | 15 | __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' 16 | 17 | 18 | def _get_clones(module, n): 19 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) 20 | 21 | 22 | def bias_init_with_prob(prior_prob=0.01): 23 | """initialize conv/fc bias value according to a given probability value.""" 24 | return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init 25 | 26 | 27 | def linear_init_(module): 28 | bound = 1 / math.sqrt(module.weight.shape[0]) 29 | uniform_(module.weight, -bound, bound) 30 | if hasattr(module, 'bias') and module.bias is not None: 31 | uniform_(module.bias, -bound, bound) 32 | 33 | 34 | def inverse_sigmoid(x, eps=1e-5): 35 | x = x.clamp(min=0, max=1) 36 | x1 = x.clamp(min=eps) 37 | x2 = (1 - x).clamp(min=eps) 38 | return torch.log(x1 / x2) 39 | 40 | 41 | def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, 42 | sampling_locations: torch.Tensor, 43 | attention_weights: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Multi-scale deformable attention. 46 | https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py 47 | """ 48 | 49 | bs, _, num_heads, embed_dims = value.shape 50 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape 51 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 52 | sampling_grids = 2 * sampling_locations - 1 53 | sampling_value_list = [] 54 | for level, (H_, W_) in enumerate(value_spatial_shapes): 55 | # bs, H_*W_, num_heads, embed_dims -> 56 | # bs, H_*W_, num_heads*embed_dims -> 57 | # bs, num_heads*embed_dims, H_*W_ -> 58 | # bs*num_heads, embed_dims, H_, W_ 59 | value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) 60 | # bs, num_queries, num_heads, num_points, 2 -> 61 | # bs, num_heads, num_queries, num_points, 2 -> 62 | # bs*num_heads, num_queries, num_points, 2 63 | sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) 64 | # bs*num_heads, embed_dims, num_queries, num_points 65 | sampling_value_l_ = F.grid_sample(value_l_, 66 | sampling_grid_l_, 67 | mode='bilinear', 68 | padding_mode='zeros', 69 | align_corners=False) 70 | sampling_value_list.append(sampling_value_l_) 71 | # (bs, num_queries, num_heads, num_levels, num_points) -> 72 | # (bs, num_heads, num_queries, num_levels, num_points) -> 73 | # (bs, num_heads, 1, num_queries, num_levels*num_points) 74 | attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, 75 | num_levels * num_points) 76 | output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( 77 | bs, num_heads * embed_dims, num_queries)) 78 | return output.transpose(1, 2).contiguous() 79 | -------------------------------------------------------------------------------- /ultralytics/tracker/README.md: -------------------------------------------------------------------------------- 1 | # Tracker 2 | 3 | ## Supported Trackers 4 | 5 | - [x] ByteTracker 6 | - [x] BoT-SORT 7 | 8 | ## Usage 9 | 10 | ### python interface: 11 | 12 | You can use the Python interface to track objects using the YOLO model. 13 | 14 | ```python 15 | from ultralytics import YOLO 16 | 17 | model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt 18 | model.track( 19 | source="video/streams", 20 | stream=True, 21 | tracker="botsort.yaml", # or 'bytetrack.yaml' 22 | show=True, 23 | ) 24 | ``` 25 | 26 | You can get the IDs of the tracked objects using the following code: 27 | 28 | ```python 29 | from ultralytics import YOLO 30 | 31 | model = YOLO("yolov8n.pt") 32 | 33 | for result in model.track(source="video.mp4"): 34 | print( 35 | result.boxes.id.cpu().numpy().astype(int) 36 | ) # this will print the IDs of the tracked objects in the frame 37 | ``` 38 | 39 | If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking. 40 | 41 | ```python 42 | import cv2 43 | from ultralytics import YOLO 44 | 45 | cap = cv2.VideoCapture("video.mp4") 46 | model = YOLO("yolov8n.pt") 47 | while True: 48 | ret, frame = cap.read() 49 | if not ret: 50 | break 51 | results = model.track(frame, persist=True) 52 | boxes = results[0].boxes.xyxy.cpu().numpy().astype(int) 53 | ids = results[0].boxes.id.cpu().numpy().astype(int) 54 | for box, id in zip(boxes, ids): 55 | cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) 56 | cv2.putText( 57 | frame, 58 | f"Id {id}", 59 | (box[0], box[1]), 60 | cv2.FONT_HERSHEY_SIMPLEX, 61 | 1, 62 | (0, 0, 255), 63 | 2, 64 | ) 65 | cv2.imshow("frame", frame) 66 | if cv2.waitKey(1) & 0xFF == ord("q"): 67 | break 68 | ``` 69 | 70 | ## Change tracker parameters 71 | 72 | You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/tracker/cfg folder. 73 | 74 | ## Command Line Interface (CLI) 75 | 76 | You can also use the command line interface to track objects using the YOLO model. 77 | 78 | ```bash 79 | yolo detect track source=... tracker=... 80 | yolo segment track source=... tracker=... 81 | yolo pose track source=... tracker=... 82 | ``` 83 | 84 | By default, trackers will use the configuration in `ultralytics/tracker/cfg`. 85 | We also support using a modified tracker config file. Please refer to the tracker config files 86 | in `ultralytics/tracker/cfg`.
87 | -------------------------------------------------------------------------------- /ultralytics/tracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .track import register_tracker 4 | from .trackers import BOTSORT, BYTETracker 5 | 6 | __all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import 7 | -------------------------------------------------------------------------------- /ultralytics/tracker/cfg/botsort.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT 3 | 4 | tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] 5 | track_high_thresh: 0.5 # threshold for the first association 6 | track_low_thresh: 0.1 # threshold for the second association 7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks 8 | track_buffer: 30 # buffer to calculate the time when to remove tracks 9 | match_thresh: 0.8 # threshold for matching tracks 10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) 11 | # mot20: False # for tracker evaluation(not used for now) 12 | 13 | # BoT-SORT settings 14 | cmc_method: sparseOptFlow # method of global motion compensation 15 | # ReID model related thresh (not supported yet) 16 | proximity_thresh: 0.5 17 | appearance_thresh: 0.25 18 | with_reid: False 19 | -------------------------------------------------------------------------------- /ultralytics/tracker/cfg/bytetrack.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack 3 | 4 | tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] 5 | track_high_thresh: 0.5 # threshold for the first association 6 | track_low_thresh: 0.1 # threshold for the second association 7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks 8 | track_buffer: 30 # buffer to calculate the time when to remove tracks 9 | match_thresh: 0.8 # threshold for matching tracks 10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) 11 | # mot20: False # for tracker evaluation(not used for now) 12 | -------------------------------------------------------------------------------- /ultralytics/tracker/track.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from functools import partial 4 | 5 | import torch 6 | 7 | from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load 8 | from ultralytics.yolo.utils.checks import check_yaml 9 | 10 | from .trackers import BOTSORT, BYTETracker 11 | 12 | TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT} 13 | 14 | 15 | def on_predict_start(predictor, persist=False): 16 | """ 17 | Initialize trackers for object tracking during prediction. 18 | 19 | Args: 20 | predictor (object): The predictor object to initialize trackers for. 21 | persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. 22 | 23 | Raises: 24 | AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. 25 | """ 26 | if hasattr(predictor, 'trackers') and persist: 27 | return 28 | tracker = check_yaml(predictor.args.tracker) 29 | cfg = IterableSimpleNamespace(**yaml_load(tracker)) 30 | assert cfg.tracker_type in ['bytetrack', 'botsort'], \ 31 | f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'" 32 | trackers = [] 33 | for _ in range(predictor.dataset.bs): 34 | tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) 35 | trackers.append(tracker) 36 | predictor.trackers = trackers 37 | 38 | 39 | def on_predict_postprocess_end(predictor): 40 | """Postprocess detected boxes and update with object tracking.""" 41 | bs = predictor.dataset.bs 42 | im0s = predictor.batch[1] 43 | for i in range(bs): 44 | det = predictor.results[i].boxes.cpu().numpy() 45 | if len(det) == 0: 46 | continue 47 | tracks = predictor.trackers[i].update(det, im0s[i]) 48 | if len(tracks) == 0: 49 | continue 50 | idx = tracks[:, -1].astype(int) 51 | predictor.results[i] = predictor.results[i][idx] 52 | predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1])) 53 | 54 | 55 | def register_tracker(model, persist): 56 | """ 57 | Register tracking callbacks to the model for object tracking during prediction. 58 | 59 | Args: 60 | model (object): The model object to register tracking callbacks for. 61 | persist (bool): Whether to persist the trackers if they already exist. 62 | 63 | """ 64 | model.add_callback('on_predict_start', partial(on_predict_start, persist=persist)) 65 | model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end) 66 | -------------------------------------------------------------------------------- /ultralytics/tracker/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .bot_sort import BOTSORT 4 | from .byte_tracker import BYTETracker 5 | 6 | __all__ = 'BOTSORT', 'BYTETracker' # allow simpler import 7 | -------------------------------------------------------------------------------- /ultralytics/tracker/trackers/basetrack.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | 7 | 8 | class TrackState: 9 | """Enumeration of possible object tracking states.""" 10 | 11 | New = 0 12 | Tracked = 1 13 | Lost = 2 14 | Removed = 3 15 | 16 | 17 | class BaseTrack: 18 | """Base class for object tracking, handling basic track attributes and operations.""" 19 | 20 | _count = 0 21 | 22 | track_id = 0 23 | is_activated = False 24 | state = TrackState.New 25 | 26 | history = OrderedDict() 27 | features = [] 28 | curr_feature = None 29 | score = 0 30 | start_frame = 0 31 | frame_id = 0 32 | time_since_update = 0 33 | 34 | # Multi-camera 35 | location = (np.inf, np.inf) 36 | 37 | @property 38 | def end_frame(self): 39 | """Return the last frame ID of the track.""" 40 | return self.frame_id 41 | 42 | @staticmethod 43 | def next_id(): 44 | """Increment and return the global track ID counter.""" 45 | BaseTrack._count += 1 46 | return BaseTrack._count 47 | 48 | def activate(self, *args): 49 | """Activate the track with the provided arguments.""" 50 | raise NotImplementedError 51 | 52 | def predict(self): 53 | """Predict the next state of the track.""" 54 | raise NotImplementedError 55 | 56 | def update(self, *args, **kwargs): 57 | """Update the track with new observations.""" 58 | raise NotImplementedError 59 | 60 | def mark_lost(self): 61 | """Mark the track as lost.""" 62 | self.state = TrackState.Lost 63 | 64 | def mark_removed(self): 65 | """Mark the track as removed.""" 66 | self.state = TrackState.Removed 67 | 68 | @staticmethod 69 | def reset_id(): 70 | """Reset the global track ID counter.""" 71 | BaseTrack._count = 0 72 | -------------------------------------------------------------------------------- /ultralytics/tracker/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/tracker/utils/__init__.py -------------------------------------------------------------------------------- /ultralytics/vit/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .rtdetr import RTDETR 4 | from .sam import SAM 5 | 6 | __all__ = 'RTDETR', 'SAM' # allow simpler import 7 | -------------------------------------------------------------------------------- /ultralytics/vit/rtdetr/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import RTDETR 4 | from .predict import RTDETRPredictor 5 | from .val import RTDETRValidator 6 | 7 | __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR' 8 | -------------------------------------------------------------------------------- /ultralytics/vit/rtdetr/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.data.augment import LetterBox 6 | from ultralytics.yolo.engine.predictor import BasePredictor 7 | from ultralytics.yolo.engine.results import Results 8 | from ultralytics.yolo.utils import ops 9 | 10 | 11 | class RTDETRPredictor(BasePredictor): 12 | 13 | def postprocess(self, preds, img, orig_imgs): 14 | """Postprocess predictions and returns a list of Results objects.""" 15 | bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc) 16 | bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) 17 | results = [] 18 | for i, bbox in enumerate(bboxes): # (300, 4) 19 | bbox = ops.xywh2xyxy(bbox) 20 | score, cls = scores[i].max(-1, keepdim=True) # (300, 1) 21 | idx = score.squeeze(-1) > self.args.conf # (300, ) 22 | if self.args.classes is not None: 23 | idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx 24 | pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter 25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 26 | oh, ow = orig_img.shape[:2] 27 | if not isinstance(orig_imgs, torch.Tensor): 28 | pred[..., [0, 2]] *= ow 29 | pred[..., [1, 3]] *= oh 30 | path = self.batch[0] 31 | img_path = path[i] if isinstance(path, list) else path 32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 33 | return results 34 | 35 | def pre_transform(self, im): 36 | """Pre-transform input image before inference. 37 | 38 | Args: 39 | im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. 40 | 41 | Return: A list of transformed imgs. 42 | """ 43 | # The size must be square(640) and scaleFilled. 44 | return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im] 45 | -------------------------------------------------------------------------------- /ultralytics/vit/rtdetr/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | import torch 6 | 7 | from ultralytics.nn.tasks import RTDETRDetectionModel 8 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr 9 | from ultralytics.yolo.v8.detect import DetectionTrainer 10 | 11 | from .val import RTDETRDataset, RTDETRValidator 12 | 13 | 14 | class RTDETRTrainer(DetectionTrainer): 15 | 16 | def get_model(self, cfg=None, weights=None, verbose=True): 17 | """Return a YOLO detection model.""" 18 | model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) 19 | if weights: 20 | model.load(weights) 21 | return model 22 | 23 | def build_dataset(self, img_path, mode='val', batch=None): 24 | """Build RTDETR Dataset 25 | 26 | Args: 27 | img_path (str): Path to the folder containing images. 28 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. 29 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None. 30 | """ 31 | return RTDETRDataset( 32 | img_path=img_path, 33 | imgsz=self.args.imgsz, 34 | batch_size=batch, 35 | augment=mode == 'train', # no augmentation 36 | hyp=self.args, 37 | rect=False, # no rect 38 | cache=self.args.cache or None, 39 | prefix=colorstr(f'{mode}: '), 40 | data=self.data) 41 | 42 | def get_validator(self): 43 | """Returns a DetectionValidator for RTDETR model validation.""" 44 | self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' 45 | return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 46 | 47 | def preprocess_batch(self, batch): 48 | """Preprocesses a batch of images by scaling and converting to float.""" 49 | batch = super().preprocess_batch(batch) 50 | bs = len(batch['img']) 51 | batch_idx = batch['batch_idx'] 52 | gt_bbox, gt_class = [], [] 53 | for i in range(bs): 54 | gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) 55 | gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) 56 | return batch 57 | 58 | 59 | def train(cfg=DEFAULT_CFG, use_python=False): 60 | """Train and optimize RTDETR model given training data and device.""" 61 | model = 'rtdetr-l.yaml' 62 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") 63 | device = cfg.device if cfg.device is not None else '' 64 | 65 | # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True 66 | # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching 67 | args = dict(model=model, 68 | data=data, 69 | device=device, 70 | imgsz=640, 71 | exist_ok=True, 72 | batch=4, 73 | deterministic=False, 74 | amp=False) 75 | trainer = RTDETRTrainer(overrides=args) 76 | trainer.train() 77 | 78 | 79 | if __name__ == '__main__': 80 | train() 81 | -------------------------------------------------------------------------------- /ultralytics/vit/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .build import build_sam # noqa 4 | from .model import SAM # noqa 5 | from .modules.prompt_predictor import PromptPredictor # noqa 6 | -------------------------------------------------------------------------------- /ultralytics/vit/sam/autosize.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from copy import deepcopy 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | import torch 14 | from torch.nn import functional as F 15 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 16 | 17 | 18 | class ResizeLongestSide: 19 | """ 20 | Resizes images to the longest side 'target_length', as well as provides 21 | methods for resizing coordinates and boxes. Provides methods for 22 | transforming both numpy array and batched torch tensors. 23 | """ 24 | 25 | def __init__(self, target_length: int) -> None: 26 | self.target_length = target_length 27 | 28 | def apply_image(self, image: np.ndarray) -> np.ndarray: 29 | """ 30 | Expects a numpy array with shape HxWxC in uint8 format. 31 | """ 32 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 33 | return np.array(resize(to_pil_image(image), target_size)) 34 | 35 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 36 | """ 37 | Expects a numpy array of length 2 in the final dimension. Requires the 38 | original image size in (H, W) format. 39 | """ 40 | old_h, old_w = original_size 41 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate(image, target_size, mode='bilinear', align_corners=False, antialias=True) 64 | 65 | def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: 66 | """ 67 | Expects a torch tensor with length 2 in the last dimension. Requires the 68 | original image size in (H, W) format. 69 | """ 70 | old_h, old_w = original_size 71 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 72 | coords = deepcopy(coords).to(torch.float) 73 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 74 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 75 | return coords 76 | 77 | def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: 78 | """ 79 | Expects a torch tensor with shape Bx4. Requires the original image 80 | size in (H, W) format. 81 | """ 82 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 83 | return boxes.reshape(-1, 4) 84 | 85 | @staticmethod 86 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 87 | """ 88 | Compute the output size given input size and target long side length. 89 | """ 90 | scale = long_side_length * 1.0 / max(oldh, oldw) 91 | newh, neww = oldh * scale, oldw * scale 92 | neww = int(neww + 0.5) 93 | newh = int(newh + 0.5) 94 | return (newh, neww) 95 | -------------------------------------------------------------------------------- /ultralytics/vit/sam/build.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from functools import partial 10 | 11 | import torch 12 | 13 | from ...yolo.utils.downloads import attempt_download_asset 14 | from .modules.decoders import MaskDecoder 15 | from .modules.encoders import ImageEncoderViT, PromptEncoder 16 | from .modules.sam import Sam 17 | from .modules.transformer import TwoWayTransformer 18 | 19 | 20 | def build_sam_vit_h(checkpoint=None): 21 | """Build and return a Segment Anything Model (SAM) h-size model.""" 22 | return _build_sam( 23 | encoder_embed_dim=1280, 24 | encoder_depth=32, 25 | encoder_num_heads=16, 26 | encoder_global_attn_indexes=[7, 15, 23, 31], 27 | checkpoint=checkpoint, 28 | ) 29 | 30 | 31 | def build_sam_vit_l(checkpoint=None): 32 | """Build and return a Segment Anything Model (SAM) l-size model.""" 33 | return _build_sam( 34 | encoder_embed_dim=1024, 35 | encoder_depth=24, 36 | encoder_num_heads=16, 37 | encoder_global_attn_indexes=[5, 11, 17, 23], 38 | checkpoint=checkpoint, 39 | ) 40 | 41 | 42 | def build_sam_vit_b(checkpoint=None): 43 | """Build and return a Segment Anything Model (SAM) b-size model.""" 44 | return _build_sam( 45 | encoder_embed_dim=768, 46 | encoder_depth=12, 47 | encoder_num_heads=12, 48 | encoder_global_attn_indexes=[2, 5, 8, 11], 49 | checkpoint=checkpoint, 50 | ) 51 | 52 | 53 | def _build_sam( 54 | encoder_embed_dim, 55 | encoder_depth, 56 | encoder_num_heads, 57 | encoder_global_attn_indexes, 58 | checkpoint=None, 59 | ): 60 | """Builds the selected SAM model architecture.""" 61 | prompt_embed_dim = 256 62 | image_size = 1024 63 | vit_patch_size = 16 64 | image_embedding_size = image_size // vit_patch_size 65 | sam = Sam( 66 | image_encoder=ImageEncoderViT( 67 | depth=encoder_depth, 68 | embed_dim=encoder_embed_dim, 69 | img_size=image_size, 70 | mlp_ratio=4, 71 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 72 | num_heads=encoder_num_heads, 73 | patch_size=vit_patch_size, 74 | qkv_bias=True, 75 | use_rel_pos=True, 76 | global_attn_indexes=encoder_global_attn_indexes, 77 | window_size=14, 78 | out_chans=prompt_embed_dim, 79 | ), 80 | prompt_encoder=PromptEncoder( 81 | embed_dim=prompt_embed_dim, 82 | image_embedding_size=(image_embedding_size, image_embedding_size), 83 | input_image_size=(image_size, image_size), 84 | mask_in_chans=16, 85 | ), 86 | mask_decoder=MaskDecoder( 87 | num_multimask_outputs=3, 88 | transformer=TwoWayTransformer( 89 | depth=2, 90 | embedding_dim=prompt_embed_dim, 91 | mlp_dim=2048, 92 | num_heads=8, 93 | ), 94 | transformer_dim=prompt_embed_dim, 95 | iou_head_depth=3, 96 | iou_head_hidden_dim=256, 97 | ), 98 | pixel_mean=[123.675, 116.28, 103.53], 99 | pixel_std=[58.395, 57.12, 57.375], 100 | ) 101 | sam.eval() 102 | if checkpoint is not None: 103 | attempt_download_asset(checkpoint) 104 | with open(checkpoint, 'rb') as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | 109 | 110 | sam_model_map = { 111 | # "default": build_sam_vit_h, 112 | 'sam_h.pt': build_sam_vit_h, 113 | 'sam_l.pt': build_sam_vit_l, 114 | 'sam_b.pt': build_sam_vit_b, } 115 | 116 | 117 | def build_sam(ckpt='sam_b.pt'): 118 | """Build a SAM model specified by ckpt.""" 119 | model_builder = None 120 | for k in sam_model_map.keys(): 121 | if ckpt.endswith(k): 122 | model_builder = sam_model_map.get(k) 123 | 124 | if not model_builder: 125 | raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}') 126 | 127 | return model_builder(ckpt) 128 | -------------------------------------------------------------------------------- /ultralytics/vit/sam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | SAM model interface 4 | """ 5 | 6 | from ultralytics.yolo.cfg import get_cfg 7 | 8 | from ...yolo.utils.torch_utils import model_info 9 | from .build import build_sam 10 | from .predict import Predictor 11 | 12 | 13 | class SAM: 14 | 15 | def __init__(self, model='sam_b.pt') -> None: 16 | if model and not model.endswith('.pt') and not model.endswith('.pth'): 17 | # Should raise AssertionError instead? 18 | raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') 19 | self.model = build_sam(model) 20 | self.task = 'segment' # required 21 | self.predictor = None # reuse predictor 22 | 23 | def predict(self, source, stream=False, **kwargs): 24 | """Predicts and returns segmentation masks for given image or video source.""" 25 | overrides = dict(conf=0.25, task='segment', mode='predict') 26 | overrides.update(kwargs) # prefer kwargs 27 | if not self.predictor: 28 | self.predictor = Predictor(overrides=overrides) 29 | self.predictor.setup_model(model=self.model) 30 | else: # only update args if predictor is already setup 31 | self.predictor.args = get_cfg(self.predictor.args, overrides) 32 | return self.predictor(source, stream=stream) 33 | 34 | def train(self, **kwargs): 35 | """Function trains models but raises an error as SAM models do not support training.""" 36 | raise NotImplementedError("SAM models don't support training") 37 | 38 | def val(self, **kwargs): 39 | """Run validation given dataset.""" 40 | raise NotImplementedError("SAM models don't support validation") 41 | 42 | def __call__(self, source=None, stream=False, **kwargs): 43 | """Calls the 'predict' function with given arguments to perform object detection.""" 44 | return self.predict(source, stream, **kwargs) 45 | 46 | def __getattr__(self, attr): 47 | """Raises error if object has no requested attribute.""" 48 | name = self.__class__.__name__ 49 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 50 | 51 | def info(self, detailed=False, verbose=True): 52 | """ 53 | Logs model info. 54 | 55 | Args: 56 | detailed (bool): Show detailed information about model. 57 | verbose (bool): Controls verbosity. 58 | """ 59 | return model_info(self.model, detailed=detailed, verbose=verbose) 60 | -------------------------------------------------------------------------------- /ultralytics/vit/sam/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /ultralytics/vit/sam/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ultralytics.yolo.engine.predictor import BasePredictor 7 | from ultralytics.yolo.engine.results import Results 8 | from ultralytics.yolo.utils.torch_utils import select_device 9 | 10 | from .modules.mask_generator import SamAutomaticMaskGenerator 11 | 12 | 13 | class Predictor(BasePredictor): 14 | 15 | def preprocess(self, im): 16 | """Prepares input image for inference.""" 17 | # TODO: Only support bs=1 for now 18 | # im = ResizeLongestSide(1024).apply_image(im[0]) 19 | # im = torch.as_tensor(im, device=self.device) 20 | # im = im.permute(2, 0, 1).contiguous()[None, :, :, :] 21 | return im[0] 22 | 23 | def setup_model(self, model): 24 | """Set up YOLO model with specified thresholds and device.""" 25 | device = select_device(self.args.device) 26 | model.eval() 27 | self.model = SamAutomaticMaskGenerator(model.to(device), 28 | pred_iou_thresh=self.args.conf, 29 | box_nms_thresh=self.args.iou) 30 | self.device = device 31 | # TODO: Temporary settings for compatibility 32 | self.model.pt = False 33 | self.model.triton = False 34 | self.model.stride = 32 35 | self.model.fp16 = False 36 | self.done_warmup = True 37 | 38 | def postprocess(self, preds, path, orig_imgs): 39 | """Postprocesses inference output predictions to create detection masks for objects.""" 40 | names = dict(enumerate(list(range(len(preds))))) 41 | results = [] 42 | # TODO 43 | for i, pred in enumerate([preds]): 44 | masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0)) 45 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 46 | path = self.batch[0] 47 | img_path = path[i] if isinstance(path, list) else path 48 | results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks)) 49 | return results 50 | 51 | # def __call__(self, source=None, model=None, stream=False): 52 | # frame = cv2.imread(source) 53 | # preds = self.model.generate(frame) 54 | # return self.postprocess(preds, source, frame) 55 | -------------------------------------------------------------------------------- /ultralytics/vit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /ultralytics/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from . import v8 4 | 5 | __all__ = 'v8', # tuple or list 6 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .base import BaseDataset 4 | from .build import build_dataloader, build_yolo_dataset, load_inference_source 5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset 6 | from .dataset_wrappers import MixAndRectDataset 7 | 8 | __all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset', 9 | 'build_yolo_dataset', 'build_dataloader', 'load_inference_source') 10 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/annotator.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from ultralytics import YOLO 4 | from ultralytics.vit.sam import PromptPredictor, build_sam 5 | from ultralytics.yolo.utils.torch_utils import select_device 6 | 7 | 8 | def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): 9 | """ 10 | Automatically annotates images using a YOLO object detection model and a SAM segmentation model. 11 | Args: 12 | data (str): Path to a folder containing images to be annotated. 13 | det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. 14 | sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. 15 | device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available). 16 | output_dir (str | None | optional): Directory to save the annotated results. 17 | Defaults to a 'labels' folder in the same directory as 'data'. 18 | """ 19 | device = select_device(device) 20 | det_model = YOLO(det_model) 21 | sam_model = build_sam(sam_model) 22 | det_model.to(device) 23 | sam_model.to(device) 24 | 25 | if not output_dir: 26 | output_dir = Path(str(data)).parent / 'labels' 27 | Path(output_dir).mkdir(exist_ok=True, parents=True) 28 | 29 | prompt_predictor = PromptPredictor(sam_model) 30 | det_results = det_model(data, stream=True) 31 | 32 | for result in det_results: 33 | boxes = result.boxes.xyxy # Boxes object for bbox outputs 34 | class_ids = result.boxes.cls.int().tolist() # noqa 35 | if len(class_ids): 36 | prompt_predictor.set_image(result.orig_img) 37 | masks, _, _ = prompt_predictor.predict_torch( 38 | point_coords=None, 39 | point_labels=None, 40 | boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]), 41 | multimask_output=False, 42 | ) 43 | 44 | result.update(masks=masks.squeeze(1)) 45 | segments = result.masks.xyn # noqa 46 | 47 | with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f: 48 | for i in range(len(segments)): 49 | s = segments[i] 50 | if len(s) == 0: 51 | continue 52 | segment = map(str, segments[i].reshape(-1).tolist()) 53 | f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') 54 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/ultralytics/yolo/data/dataloaders/__init__.py -------------------------------------------------------------------------------- /ultralytics/yolo/data/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import collections 4 | from copy import deepcopy 5 | 6 | from .augment import LetterBox 7 | 8 | 9 | class MixAndRectDataset: 10 | """ 11 | A dataset class that applies mosaic and mixup transformations as well as rectangular training. 12 | 13 | Attributes: 14 | dataset: The base dataset. 15 | imgsz: The size of the images in the dataset. 16 | """ 17 | 18 | def __init__(self, dataset): 19 | """ 20 | Args: 21 | dataset (BaseDataset): The base dataset to apply transformations to. 22 | """ 23 | self.dataset = dataset 24 | self.imgsz = dataset.imgsz 25 | 26 | def __len__(self): 27 | """Returns the number of items in the dataset.""" 28 | return len(self.dataset) 29 | 30 | def __getitem__(self, index): 31 | """ 32 | Applies mosaic, mixup and rectangular training transformations to an item in the dataset. 33 | 34 | Args: 35 | index (int): Index of the item in the dataset. 36 | 37 | Returns: 38 | (dict): A dictionary containing the transformed item data. 39 | """ 40 | labels = deepcopy(self.dataset[index]) 41 | for transform in self.dataset.transforms.tolist(): 42 | # Mosaic and mixup 43 | if hasattr(transform, 'get_indexes'): 44 | indexes = transform.get_indexes(self.dataset) 45 | if not isinstance(indexes, collections.abc.Sequence): 46 | indexes = [indexes] 47 | labels['mix_labels'] = [deepcopy(self.dataset[index]) for index in indexes] 48 | if self.dataset.rect and isinstance(transform, LetterBox): 49 | transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]] 50 | labels = transform(labels) 51 | if 'mix_labels' in labels: 52 | labels.pop('mix_labels') 53 | return labels 54 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/scripts/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ultralytics YOLO 🚀, AGPL-3.0 license 3 | # Download latest models from https://github.com/ultralytics/assets/releases 4 | # Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh 5 | # parent 6 | # └── weights 7 | # ├── yolov8n.pt ← downloads here 8 | # ├── yolov8s.pt 9 | # └── ... 10 | 11 | python - < None: 29 | # Load or create new NAS model 30 | import super_gradients 31 | 32 | self.predictor = None 33 | suffix = Path(model).suffix 34 | if suffix == '.pt': 35 | self._load(model) 36 | elif suffix == '': 37 | self.model = super_gradients.training.models.get(model, pretrained_weights='coco') 38 | self.task = 'detect' 39 | self.model.args = DEFAULT_CFG_DICT # attach args to model 40 | 41 | # Standardize model 42 | self.model.fuse = lambda verbose=True: self.model 43 | self.model.stride = torch.tensor([32]) 44 | self.model.names = dict(enumerate(self.model._class_names)) 45 | self.model.is_fused = lambda: False # for info() 46 | self.model.yaml = {} # for info() 47 | self.model.pt_path = model # for export() 48 | self.model.task = 'detect' # for export() 49 | self.info() 50 | 51 | @smart_inference_mode() 52 | def _load(self, weights: str): 53 | self.model = torch.load(weights) 54 | 55 | @smart_inference_mode() 56 | def predict(self, source=None, stream=False, **kwargs): 57 | """ 58 | Perform prediction using the YOLO model. 59 | 60 | Args: 61 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on. 62 | Accepts all source types accepted by the YOLO model. 63 | stream (bool): Whether to stream the predictions or not. Defaults to False. 64 | **kwargs : Additional keyword arguments passed to the predictor. 65 | Check the 'configuration' section in the documentation for all available options. 66 | 67 | Returns: 68 | (List[ultralytics.yolo.engine.results.Results]): The prediction results. 69 | """ 70 | if source is None: 71 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' 72 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") 73 | overrides = dict(conf=0.25, task='detect', mode='predict') 74 | overrides.update(kwargs) # prefer kwargs 75 | if not self.predictor: 76 | self.predictor = NASPredictor(overrides=overrides) 77 | self.predictor.setup_model(model=self.model) 78 | else: # only update args if predictor is already setup 79 | self.predictor.args = get_cfg(self.predictor.args, overrides) 80 | return self.predictor(source, stream=stream) 81 | 82 | def train(self, **kwargs): 83 | """Function trains models but raises an error as NAS models do not support training.""" 84 | raise NotImplementedError("NAS models don't support training") 85 | 86 | def val(self, **kwargs): 87 | """Run validation given dataset.""" 88 | overrides = dict(task='detect', mode='val') 89 | overrides.update(kwargs) # prefer kwargs 90 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 91 | args.imgsz = check_imgsz(args.imgsz, max_dim=1) 92 | validator = NASValidator(args=args) 93 | validator(model=self.model) 94 | self.metrics = validator.metrics 95 | return validator.metrics 96 | 97 | @smart_inference_mode() 98 | def export(self, **kwargs): 99 | """ 100 | Export model. 101 | 102 | Args: 103 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs 104 | """ 105 | overrides = dict(task='detect') 106 | overrides.update(kwargs) 107 | overrides['mode'] = 'export' 108 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 109 | args.task = self.task 110 | if args.imgsz == DEFAULT_CFG.imgsz: 111 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed 112 | if args.batch == DEFAULT_CFG.batch: 113 | args.batch = 1 # default to 1 if not modified 114 | return Exporter(overrides=args)(model=self.model) 115 | 116 | def info(self, detailed=False, verbose=True): 117 | """ 118 | Logs model info. 119 | 120 | Args: 121 | detailed (bool): Show detailed information about model. 122 | verbose (bool): Controls verbosity. 123 | """ 124 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 125 | 126 | def __call__(self, source=None, stream=False, **kwargs): 127 | """Calls the 'predict' function with given arguments to perform object detection.""" 128 | return self.predict(source, stream, **kwargs) 129 | 130 | def __getattr__(self, attr): 131 | """Raises error if object has no requested attribute.""" 132 | name = self.__class__.__name__ 133 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 134 | -------------------------------------------------------------------------------- /ultralytics/yolo/nas/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import ops 8 | from ultralytics.yolo.utils.ops import xyxy2xywh 9 | 10 | 11 | class NASPredictor(BasePredictor): 12 | 13 | def postprocess(self, preds_in, img, orig_imgs): 14 | """Postprocesses predictions and returns a list of Results objects.""" 15 | 16 | # Cat boxes and class scores 17 | boxes = xyxy2xywh(preds_in[0][0]) 18 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) 19 | 20 | preds = ops.non_max_suppression(preds, 21 | self.args.conf, 22 | self.args.iou, 23 | agnostic=self.args.agnostic_nms, 24 | max_det=self.args.max_det, 25 | classes=self.args.classes) 26 | 27 | results = [] 28 | for i, pred in enumerate(preds): 29 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 30 | if not isinstance(orig_imgs, torch.Tensor): 31 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 32 | path = self.batch[0] 33 | img_path = path[i] if isinstance(path, list) else path 34 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 35 | return results 36 | -------------------------------------------------------------------------------- /ultralytics/yolo/nas/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.utils import ops 6 | from ultralytics.yolo.utils.ops import xyxy2xywh 7 | from ultralytics.yolo.v8.detect import DetectionValidator 8 | 9 | __all__ = ['NASValidator'] 10 | 11 | 12 | class NASValidator(DetectionValidator): 13 | 14 | def postprocess(self, preds_in): 15 | """Apply Non-maximum suppression to prediction outputs.""" 16 | boxes = xyxy2xywh(preds_in[0][0]) 17 | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) 18 | return ops.non_max_suppression(preds, 19 | self.args.conf, 20 | self.args.iou, 21 | labels=self.lb, 22 | multi_label=False, 23 | agnostic=self.args.single_cls, 24 | max_det=self.args.max_det, 25 | max_time_img=0.5) 26 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/autobatch.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch. 4 | """ 5 | 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr 12 | from ultralytics.yolo.utils.torch_utils import profile 13 | 14 | 15 | def check_train_batch_size(model, imgsz=640, amp=True): 16 | """ 17 | Check YOLO training batch size using the autobatch() function. 18 | 19 | Args: 20 | model (torch.nn.Module): YOLO model to check batch size for. 21 | imgsz (int): Image size used for training. 22 | amp (bool): If True, use automatic mixed precision (AMP) for training. 23 | 24 | Returns: 25 | (int): Optimal batch size computed using the autobatch() function. 26 | """ 27 | 28 | with torch.cuda.amp.autocast(amp): 29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size 30 | 31 | 32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=DEFAULT_CFG.batch): 33 | """ 34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. 35 | 36 | Args: 37 | model (torch.nn.module): YOLO model to compute batch size for. 38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. 39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67. 40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. 41 | 42 | Returns: 43 | (int): The optimal batch size. 44 | """ 45 | 46 | # Check device 47 | prefix = colorstr('AutoBatch: ') 48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}') 49 | device = next(model.parameters()).device # get model device 50 | if device.type == 'cpu': 51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') 52 | return batch_size 53 | if torch.backends.cudnn.benchmark: 54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') 55 | return batch_size 56 | 57 | # Inspect CUDA memory 58 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 59 | d = str(device).upper() # 'CUDA:0' 60 | properties = torch.cuda.get_device_properties(device) # device properties 61 | t = properties.total_memory / gb # GiB total 62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved 63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated 64 | f = t - (r + a) # GiB free 65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') 66 | 67 | # Profile batch sizes 68 | batch_sizes = [1, 2, 4, 8, 16] 69 | try: 70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] 71 | results = profile(img, model, n=3, device=device) 72 | 73 | # Fit a solution 74 | y = [x[2] for x in results if x] # memory [2] 75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit 76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) 77 | if None in results: # some sizes failed 78 | i = results.index(None) # first fail index 79 | if b >= batch_sizes[i]: # y intercept above failure point 80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point 81 | if b < 1 or b > 1024: # b outside of safe range 82 | b = batch_size 83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') 84 | 85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted 86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') 87 | return b 88 | except Exception as e: 89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') 90 | return batch_size 91 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .base import add_integration_callbacks, default_callbacks, get_default_callbacks 4 | 5 | __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks' 6 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/dvc.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | import os 3 | 4 | import pkg_resources as pkg 5 | 6 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 7 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 8 | 9 | try: 10 | from importlib.metadata import version 11 | 12 | import dvclive 13 | 14 | assert not TESTS_RUNNING # do not log pytest 15 | 16 | ver = version('dvclive') 17 | if pkg.parse_version(ver) < pkg.parse_version('2.11.0'): 18 | LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).') 19 | dvclive = None # noqa: F811 20 | except (ImportError, AssertionError, TypeError): 21 | dvclive = None 22 | 23 | # DVCLive logger instance 24 | live = None 25 | _processed_plots = {} 26 | 27 | # `on_fit_epoch_end` is called on final validation (probably need to be fixed) 28 | # for now this is the way we distinguish final evaluation of the best model vs 29 | # last epoch validation 30 | _training_epoch = False 31 | 32 | 33 | def _logger_disabled(): 34 | return os.getenv('ULTRALYTICS_DVC_DISABLED', 'false').lower() == 'true' 35 | 36 | 37 | def _log_images(image_path, prefix=''): 38 | if live: 39 | live.log_image(os.path.join(prefix, image_path.name), image_path) 40 | 41 | 42 | def _log_plots(plots, prefix=''): 43 | for name, params in plots.items(): 44 | timestamp = params['timestamp'] 45 | if _processed_plots.get(name) != timestamp: 46 | _log_images(name, prefix) 47 | _processed_plots[name] = timestamp 48 | 49 | 50 | def _log_confusion_matrix(validator): 51 | targets = [] 52 | preds = [] 53 | matrix = validator.confusion_matrix.matrix 54 | names = list(validator.names.values()) 55 | if validator.confusion_matrix.task == 'detect': 56 | names += ['background'] 57 | 58 | for ti, pred in enumerate(matrix.T.astype(int)): 59 | for pi, num in enumerate(pred): 60 | targets.extend([names[ti]] * num) 61 | preds.extend([names[pi]] * num) 62 | 63 | live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True) 64 | 65 | 66 | def on_pretrain_routine_start(trainer): 67 | try: 68 | global live 69 | if not _logger_disabled(): 70 | live = dvclive.Live(save_dvc_exp=True) 71 | LOGGER.info( 72 | 'DVCLive is detected and auto logging is enabled (can be disabled with `ULTRALYTICS_DVC_DISABLED=true`).' 73 | ) 74 | else: 75 | LOGGER.debug('DVCLive is detected and auto logging is disabled via `ULTRALYTICS_DVC_DISABLED`.') 76 | live = None 77 | except Exception as e: 78 | LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}') 79 | 80 | 81 | def on_pretrain_routine_end(trainer): 82 | _log_plots(trainer.plots, 'train') 83 | 84 | 85 | def on_train_start(trainer): 86 | if live: 87 | live.log_params(trainer.args) 88 | 89 | 90 | def on_train_epoch_start(trainer): 91 | global _training_epoch 92 | _training_epoch = True 93 | 94 | 95 | def on_fit_epoch_end(trainer): 96 | global _training_epoch 97 | if live and _training_epoch: 98 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} 99 | for metric, value in all_metrics.items(): 100 | live.log_metric(metric, value) 101 | 102 | if trainer.epoch == 0: 103 | for metric, value in model_info_for_loggers(trainer).items(): 104 | live.log_metric(metric, value, plot=False) 105 | 106 | _log_plots(trainer.plots, 'train') 107 | _log_plots(trainer.validator.plots, 'val') 108 | 109 | live.next_step() 110 | _training_epoch = False 111 | 112 | 113 | def on_train_end(trainer): 114 | if live: 115 | # At the end log the best metrics. It runs validator on the best model internally. 116 | all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} 117 | for metric, value in all_metrics.items(): 118 | live.log_metric(metric, value, plot=False) 119 | 120 | _log_plots(trainer.plots, 'eval') 121 | _log_plots(trainer.validator.plots, 'eval') 122 | _log_confusion_matrix(trainer.validator) 123 | 124 | if trainer.best.exists(): 125 | live.log_artifact(trainer.best, copy=True) 126 | 127 | live.end() 128 | 129 | 130 | callbacks = { 131 | 'on_pretrain_routine_start': on_pretrain_routine_start, 132 | 'on_pretrain_routine_end': on_pretrain_routine_end, 133 | 'on_train_start': on_train_start, 134 | 'on_train_epoch_start': on_train_epoch_start, 135 | 'on_fit_epoch_end': on_fit_epoch_end, 136 | 'on_train_end': on_train_end} if dvclive else {} 137 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/hub.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import json 4 | from time import time 5 | 6 | from ultralytics.hub.utils import PREFIX, events 7 | from ultralytics.yolo.utils import LOGGER 8 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 9 | 10 | 11 | def on_pretrain_routine_end(trainer): 12 | """Logs info before starting timer for upload rate limit.""" 13 | session = getattr(trainer, 'hub_session', None) 14 | if session: 15 | # Start timer for upload rate limit 16 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 17 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit 18 | 19 | 20 | def on_fit_epoch_end(trainer): 21 | """Uploads training progress metrics at the end of each epoch.""" 22 | session = getattr(trainer, 'hub_session', None) 23 | if session: 24 | # Upload metrics after val end 25 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} 26 | if trainer.epoch == 0: 27 | all_plots = {**all_plots, **model_info_for_loggers(trainer)} 28 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots) 29 | if time() - session.timers['metrics'] > session.rate_limits['metrics']: 30 | session.upload_metrics() 31 | session.timers['metrics'] = time() # reset timer 32 | session.metrics_queue = {} # reset queue 33 | 34 | 35 | def on_model_save(trainer): 36 | """Saves checkpoints to Ultralytics HUB with rate limiting.""" 37 | session = getattr(trainer, 'hub_session', None) 38 | if session: 39 | # Upload checkpoints with rate limiting 40 | is_best = trainer.best_fitness == trainer.fitness 41 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: 42 | LOGGER.info(f'{PREFIX}Uploading checkpoint https://hub.ultralytics.com/models/{session.model_id}') 43 | session.upload_model(trainer.epoch, trainer.last, is_best) 44 | session.timers['ckpt'] = time() # reset timer 45 | 46 | 47 | def on_train_end(trainer): 48 | """Upload final model and metrics to Ultralytics HUB at the end of training.""" 49 | session = getattr(trainer, 'hub_session', None) 50 | if session: 51 | # Upload final model and metrics with exponential standoff 52 | LOGGER.info(f'{PREFIX}Syncing final model...') 53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) 54 | session.alive = False # stop heartbeats 55 | LOGGER.info(f'{PREFIX}Done ✅\n' 56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 57 | 58 | 59 | def on_train_start(trainer): 60 | """Run events on train start.""" 61 | events(trainer.args) 62 | 63 | 64 | def on_val_start(validator): 65 | """Runs events on validation start.""" 66 | events(validator.args) 67 | 68 | 69 | def on_predict_start(predictor): 70 | """Run events on predict start.""" 71 | events(predictor.args) 72 | 73 | 74 | def on_export_start(exporter): 75 | """Run events on export start.""" 76 | events(exporter.args) 77 | 78 | 79 | callbacks = { 80 | 'on_pretrain_routine_end': on_pretrain_routine_end, 81 | 'on_fit_epoch_end': on_fit_epoch_end, 82 | 'on_model_save': on_model_save, 83 | 'on_train_end': on_train_end, 84 | 'on_train_start': on_train_start, 85 | 'on_val_start': on_val_start, 86 | 'on_predict_start': on_predict_start, 87 | 'on_export_start': on_export_start} 88 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/mlflow.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | from pathlib import Path 6 | 7 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr 8 | 9 | try: 10 | import mlflow 11 | 12 | assert not TESTS_RUNNING # do not log pytest 13 | assert hasattr(mlflow, '__version__') # verify package is not directory 14 | except (ImportError, AssertionError): 15 | mlflow = None 16 | 17 | 18 | def on_pretrain_routine_end(trainer): 19 | """Logs training parameters to MLflow.""" 20 | global mlflow, run, run_id, experiment_name 21 | 22 | if os.environ.get('MLFLOW_TRACKING_URI') is None: 23 | mlflow = None 24 | 25 | if mlflow: 26 | mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000" 27 | mlflow.set_tracking_uri(mlflow_location) 28 | 29 | experiment_name = os.environ.get('MLFLOW_EXPERIMENT') or trainer.args.project or '/Shared/YOLOv8' 30 | experiment = mlflow.get_experiment_by_name(experiment_name) 31 | if experiment is None: 32 | mlflow.create_experiment(experiment_name) 33 | mlflow.set_experiment(experiment_name) 34 | 35 | prefix = colorstr('MLFlow: ') 36 | try: 37 | run, active_run = mlflow, mlflow.active_run() 38 | if not active_run: 39 | active_run = mlflow.start_run(experiment_id=experiment.experiment_id) 40 | run_id = active_run.info.run_id 41 | LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}') 42 | run.log_params(vars(trainer.model.args)) 43 | except Exception as err: 44 | LOGGER.error(f'{prefix}Failing init - {repr(err)}') 45 | LOGGER.warning(f'{prefix}Continuing without Mlflow') 46 | 47 | 48 | def on_fit_epoch_end(trainer): 49 | """Logs training metrics to Mlflow.""" 50 | if mlflow: 51 | metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()} 52 | run.log_metrics(metrics=metrics_dict, step=trainer.epoch) 53 | 54 | 55 | def on_train_end(trainer): 56 | """Called at end of train loop to log model artifact info.""" 57 | if mlflow: 58 | root_dir = Path(__file__).resolve().parents[3] 59 | run.log_artifact(trainer.last) 60 | run.log_artifact(trainer.best) 61 | run.pyfunc.log_model(artifact_path=experiment_name, 62 | code_path=[str(root_dir)], 63 | artifacts={'model_path': str(trainer.save_dir)}, 64 | python_model=run.pyfunc.PythonModel()) 65 | 66 | 67 | callbacks = { 68 | 'on_pretrain_routine_end': on_pretrain_routine_end, 69 | 'on_fit_epoch_end': on_fit_epoch_end, 70 | 'on_train_end': on_train_end} if mlflow else {} 71 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/neptune.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import matplotlib.image as mpimg 4 | import matplotlib.pyplot as plt 5 | 6 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 7 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 8 | 9 | try: 10 | import neptune 11 | from neptune.types import File 12 | 13 | assert not TESTS_RUNNING # do not log pytest 14 | assert hasattr(neptune, '__version__') 15 | except (ImportError, AssertionError): 16 | neptune = None 17 | 18 | run = None # NeptuneAI experiment logger instance 19 | 20 | 21 | def _log_scalars(scalars, step=0): 22 | """Log scalars to the NeptuneAI experiment logger.""" 23 | if run: 24 | for k, v in scalars.items(): 25 | run[k].append(value=v, step=step) 26 | 27 | 28 | def _log_images(imgs_dict, group=''): 29 | """Log scalars to the NeptuneAI experiment logger.""" 30 | if run: 31 | for k, v in imgs_dict.items(): 32 | run[f'{group}/{k}'].upload(File(v)) 33 | 34 | 35 | def _log_plot(title, plot_path): 36 | """Log plots to the NeptuneAI experiment logger.""" 37 | """ 38 | Log image as plot in the plot section of NeptuneAI 39 | 40 | arguments: 41 | title (str) Title of the plot 42 | plot_path (PosixPath or str) Path to the saved image file 43 | """ 44 | img = mpimg.imread(plot_path) 45 | fig = plt.figure() 46 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks 47 | ax.imshow(img) 48 | run[f'Plots/{title}'].upload(fig) 49 | 50 | 51 | def on_pretrain_routine_start(trainer): 52 | """Callback function called before the training routine starts.""" 53 | try: 54 | global run 55 | run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8']) 56 | run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()} 57 | except Exception as e: 58 | LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}') 59 | 60 | 61 | def on_train_epoch_end(trainer): 62 | """Callback function called at end of each training epoch.""" 63 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 64 | _log_scalars(trainer.lr, trainer.epoch + 1) 65 | if trainer.epoch == 1: 66 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic') 67 | 68 | 69 | def on_fit_epoch_end(trainer): 70 | """Callback function called at end of each fit (train+val) epoch.""" 71 | if run and trainer.epoch == 0: 72 | run['Configuration/Model'] = model_info_for_loggers(trainer) 73 | _log_scalars(trainer.metrics, trainer.epoch + 1) 74 | 75 | 76 | def on_val_end(validator): 77 | """Callback function called at end of each validation.""" 78 | if run: 79 | # Log val_labels and val_pred 80 | _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation') 81 | 82 | 83 | def on_train_end(trainer): 84 | """Callback function called at end of training.""" 85 | if run: 86 | # Log final results, CM matrix + PR plots 87 | files = [ 88 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', 89 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] 90 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter 91 | for f in files: 92 | _log_plot(title=f.stem, plot_path=f) 93 | # Log the final model 94 | run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str( 95 | trainer.best))) 96 | 97 | 98 | callbacks = { 99 | 'on_pretrain_routine_start': on_pretrain_routine_start, 100 | 'on_train_epoch_end': on_train_epoch_end, 101 | 'on_fit_epoch_end': on_fit_epoch_end, 102 | 'on_val_end': on_val_end, 103 | 'on_train_end': on_train_end} if neptune else {} 104 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/raytune.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | try: 4 | import ray 5 | from ray import tune 6 | from ray.air import session 7 | except (ImportError, AssertionError): 8 | tune = None 9 | 10 | 11 | def on_fit_epoch_end(trainer): 12 | """Sends training metrics to Ray Tune at end of each epoch.""" 13 | if ray.tune.is_session_enabled(): 14 | metrics = trainer.metrics 15 | metrics['epoch'] = trainer.epoch 16 | session.report(metrics) 17 | 18 | 19 | callbacks = { 20 | 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} 21 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr 4 | 5 | try: 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | assert not TESTS_RUNNING # do not log pytest 9 | except (ImportError, AssertionError): 10 | SummaryWriter = None 11 | 12 | writer = None # TensorBoard SummaryWriter instance 13 | 14 | 15 | def _log_scalars(scalars, step=0): 16 | """Logs scalar values to TensorBoard.""" 17 | if writer: 18 | for k, v in scalars.items(): 19 | writer.add_scalar(k, v, step) 20 | 21 | 22 | def on_pretrain_routine_start(trainer): 23 | """Initialize TensorBoard logging with SummaryWriter.""" 24 | if SummaryWriter: 25 | try: 26 | global writer 27 | writer = SummaryWriter(str(trainer.save_dir)) 28 | prefix = colorstr('TensorBoard: ') 29 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") 30 | except Exception as e: 31 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') 32 | 33 | 34 | def on_batch_end(trainer): 35 | """Logs scalar statistics at the end of a training batch.""" 36 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 37 | 38 | 39 | def on_fit_epoch_end(trainer): 40 | """Logs epoch metrics at end of training epoch.""" 41 | _log_scalars(trainer.metrics, trainer.epoch + 1) 42 | 43 | 44 | callbacks = { 45 | 'on_pretrain_routine_start': on_pretrain_routine_start, 46 | 'on_fit_epoch_end': on_fit_epoch_end, 47 | 'on_batch_end': on_batch_end} 48 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/wb.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | from ultralytics.yolo.utils import TESTS_RUNNING 3 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 4 | 5 | try: 6 | import wandb as wb 7 | 8 | assert hasattr(wb, '__version__') 9 | assert not TESTS_RUNNING # do not log pytest 10 | except (ImportError, AssertionError): 11 | wb = None 12 | 13 | _processed_plots = {} 14 | 15 | 16 | def _log_plots(plots, step): 17 | for name, params in plots.items(): 18 | timestamp = params['timestamp'] 19 | if _processed_plots.get(name, None) != timestamp: 20 | wb.run.log({name.stem: wb.Image(str(name))}, step=step) 21 | _processed_plots[name] = timestamp 22 | 23 | 24 | def on_pretrain_routine_start(trainer): 25 | """Initiate and start project if module is present.""" 26 | wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args)) 27 | 28 | 29 | def on_fit_epoch_end(trainer): 30 | """Logs training metrics and model information at the end of an epoch.""" 31 | wb.run.log(trainer.metrics, step=trainer.epoch + 1) 32 | _log_plots(trainer.plots, step=trainer.epoch + 1) 33 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1) 34 | if trainer.epoch == 0: 35 | wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) 36 | 37 | 38 | def on_train_epoch_end(trainer): 39 | """Log metrics and save images at the end of each training epoch.""" 40 | wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) 41 | wb.run.log(trainer.lr, step=trainer.epoch + 1) 42 | if trainer.epoch == 1: 43 | _log_plots(trainer.plots, step=trainer.epoch + 1) 44 | 45 | 46 | def on_train_end(trainer): 47 | """Save the best model as an artifact at end of training.""" 48 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1) 49 | _log_plots(trainer.plots, step=trainer.epoch + 1) 50 | art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') 51 | if trainer.best.exists(): 52 | art.add_file(trainer.best) 53 | wb.run.log_artifact(art) 54 | 55 | 56 | callbacks = { 57 | 'on_pretrain_routine_start': on_pretrain_routine_start, 58 | 'on_train_epoch_end': on_train_epoch_end, 59 | 'on_fit_epoch_end': on_fit_epoch_end, 60 | 'on_train_end': on_train_end} if wb else {} 61 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | import shutil 6 | import socket 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | from . import USER_CONFIG_DIR 12 | from .torch_utils import TORCH_1_9 13 | 14 | 15 | def find_free_network_port() -> int: 16 | """Finds a free port on localhost. 17 | 18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the 19 | `MASTER_PORT` environment variable. 20 | """ 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 22 | s.bind(('127.0.0.1', 0)) 23 | return s.getsockname()[1] # port 24 | 25 | 26 | def generate_ddp_file(trainer): 27 | """Generates a DDP file and returns its file name.""" 28 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) 29 | 30 | content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__": 31 | from {module} import {name} 32 | from ultralytics.yolo.utils import DEFAULT_CFG_DICT 33 | 34 | cfg = DEFAULT_CFG_DICT.copy() 35 | cfg.update(save_dir='') # handle the extra key 'save_dir' 36 | trainer = {name}(cfg=cfg, overrides=overrides) 37 | trainer.train()''' 38 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) 39 | with tempfile.NamedTemporaryFile(prefix='_temp_', 40 | suffix=f'{id(trainer)}.py', 41 | mode='w+', 42 | encoding='utf-8', 43 | dir=USER_CONFIG_DIR / 'DDP', 44 | delete=False) as file: 45 | file.write(content) 46 | return file.name 47 | 48 | 49 | def generate_ddp_command(world_size, trainer): 50 | """Generates and returns command for distributed training.""" 51 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 52 | if not trainer.resume: 53 | shutil.rmtree(trainer.save_dir) # remove the save_dir 54 | file = str(Path(sys.argv[0]).resolve()) 55 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters 56 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI 57 | file = generate_ddp_file(trainer) 58 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' 59 | port = find_free_network_port() 60 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] 61 | return cmd, file 62 | 63 | 64 | def ddp_cleanup(trainer, file): 65 | """Delete temp file if created.""" 66 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file 67 | os.remove(file) 68 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/errors.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.utils import emojis 4 | 5 | 6 | class HUBModelError(Exception): 7 | 8 | def __init__(self, message='Model not found. Please check model URL and try again.'): 9 | """Create an exception for when a model is not found.""" 10 | super().__init__(emojis(message)) 11 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/files.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import contextlib 4 | import glob 5 | import os 6 | import shutil 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | 11 | class WorkingDirectory(contextlib.ContextDecorator): 12 | """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.""" 13 | 14 | def __init__(self, new_dir): 15 | """Sets the working directory to 'new_dir' upon instantiation.""" 16 | self.dir = new_dir # new dir 17 | self.cwd = Path.cwd().resolve() # current dir 18 | 19 | def __enter__(self): 20 | """Changes the current directory to the specified directory.""" 21 | os.chdir(self.dir) 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | """Restore the current working directory on context exit.""" 25 | os.chdir(self.cwd) 26 | 27 | 28 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 29 | """ 30 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 31 | 32 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to 33 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the 34 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a 35 | directory if it does not already exist. 36 | 37 | Args: 38 | path (str, pathlib.Path): Path to increment. 39 | exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False. 40 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''. 41 | mkdir (bool, optional): Create a directory if it does not exist. Defaults to False. 42 | 43 | Returns: 44 | (pathlib.Path): Incremented path. 45 | """ 46 | path = Path(path) # os-agnostic 47 | if path.exists() and not exist_ok: 48 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') 49 | 50 | # Method 1 51 | for n in range(2, 9999): 52 | p = f'{path}{sep}{n}{suffix}' # increment path 53 | if not os.path.exists(p): # 54 | break 55 | path = Path(p) 56 | 57 | if mkdir: 58 | path.mkdir(parents=True, exist_ok=True) # make directory 59 | 60 | return path 61 | 62 | 63 | def file_age(path=__file__): 64 | """Return days since last file update.""" 65 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta 66 | return dt.days # + dt.seconds / 86400 # fractional days 67 | 68 | 69 | def file_date(path=__file__): 70 | """Return human-readable file modification date, i.e. '2021-3-26'.""" 71 | t = datetime.fromtimestamp(Path(path).stat().st_mtime) 72 | return f'{t.year}-{t.month}-{t.day}' 73 | 74 | 75 | def file_size(path): 76 | """Return file/dir size (MB).""" 77 | if isinstance(path, (str, Path)): 78 | mb = 1 << 20 # bytes to MiB (1024 ** 2) 79 | path = Path(path) 80 | if path.is_file(): 81 | return path.stat().st_size / mb 82 | elif path.is_dir(): 83 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb 84 | return 0.0 85 | 86 | 87 | def get_latest_run(search_dir='.'): 88 | """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" 89 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) 90 | return max(last_list, key=os.path.getctime) if last_list else '' 91 | 92 | 93 | def make_dirs(dir='new_dir/'): 94 | # Create folders 95 | dir = Path(dir) 96 | if dir.exists(): 97 | shutil.rmtree(dir) # delete dir 98 | for p in dir, dir / 'labels', dir / 'images': 99 | p.mkdir(parents=True, exist_ok=True) # make dir 100 | return dir 101 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/patches.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Monkey patches to update/extend functionality of existing functions 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | 12 | # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------ 13 | _imshow = cv2.imshow # copy to avoid recursion errors 14 | 15 | 16 | def imread(filename, flags=cv2.IMREAD_COLOR): 17 | return cv2.imdecode(np.fromfile(filename, np.uint8), flags) 18 | 19 | 20 | def imwrite(filename, img): 21 | try: 22 | cv2.imencode(Path(filename).suffix, img)[1].tofile(filename) 23 | return True 24 | except Exception: 25 | return False 26 | 27 | 28 | def imshow(path, im): 29 | _imshow(path.encode('unicode_escape').decode(), im) 30 | 31 | 32 | # PyTorch functions ---------------------------------------------------------------------------------------------------- 33 | _torch_save = torch.save # copy to avoid recursion errors 34 | 35 | 36 | def torch_save(*args, **kwargs): 37 | # Use dill (if exists) to serialize the lambda functions where pickle does not do this 38 | try: 39 | import dill as pickle 40 | except ImportError: 41 | import pickle 42 | 43 | if 'pickle_module' not in kwargs: 44 | kwargs['pickle_module'] = pickle 45 | return _torch_save(*args, **kwargs) 46 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/tuner.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.utils import LOGGER 4 | 5 | try: 6 | from ray import tune 7 | from ray.air import RunConfig, session # noqa 8 | from ray.air.integrations.wandb import WandbLoggerCallback # noqa 9 | from ray.tune.schedulers import ASHAScheduler # noqa 10 | from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa 11 | 12 | except ImportError: 13 | LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`") 14 | tune = None 15 | 16 | default_space = { 17 | # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), 18 | 'lr0': tune.uniform(1e-5, 1e-1), 19 | 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) 20 | 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 21 | 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 22 | 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) 23 | 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum 24 | 'box': tune.uniform(0.02, 0.2), # box loss gain 25 | 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) 26 | 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) 27 | 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) 28 | 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) 29 | 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) 30 | 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) 31 | 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) 32 | 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) 33 | 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 34 | 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) 35 | 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) 36 | 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) 37 | 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) 38 | 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) 39 | 40 | task_metric_map = { 41 | 'detect': 'metrics/mAP50-95(B)', 42 | 'segment': 'metrics/mAP50-95(M)', 43 | 'classify': 'metrics/accuracy_top1', 44 | 'pose': 'metrics/mAP50-95(P)'} 45 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.v8 import classify, detect, pose, segment 4 | 5 | __all__ = 'classify', 'segment', 'detect', 'pose' 6 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict 4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train 5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val 6 | 7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT 8 | 9 | 10 | class ClassificationPredictor(BasePredictor): 11 | 12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 13 | super().__init__(cfg, overrides, _callbacks) 14 | self.args.task = 'classify' 15 | 16 | def preprocess(self, img): 17 | """Converts input image to model-compatible data type.""" 18 | if not isinstance(img, torch.Tensor): 19 | img = torch.stack([self.transforms(im) for im in img], dim=0) 20 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 21 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 22 | 23 | def postprocess(self, preds, img, orig_imgs): 24 | """Postprocesses predictions to return Results objects.""" 25 | results = [] 26 | for i, pred in enumerate(preds): 27 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 28 | path = self.batch[0] 29 | img_path = path[i] if isinstance(path, list) else path 30 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) 31 | 32 | return results 33 | 34 | 35 | def predict(cfg=DEFAULT_CFG, use_python=False): 36 | """Run YOLO model predictions on input images/videos.""" 37 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 38 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 39 | else 'https://ultralytics.com/images/bus.jpg' 40 | 41 | args = dict(model=model, source=source) 42 | if use_python: 43 | from ultralytics import YOLO 44 | YOLO(model)(**args) 45 | else: 46 | predictor = ClassificationPredictor(overrides=args) 47 | predictor.predict_cli() 48 | 49 | 50 | if __name__ == '__main__': 51 | predict() 52 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.data import ClassificationDataset, build_dataloader 6 | from ultralytics.yolo.engine.validator import BaseValidator 7 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER 8 | from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix 9 | from ultralytics.yolo.utils.plotting import plot_images 10 | 11 | 12 | class ClassificationValidator(BaseValidator): 13 | 14 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): 15 | """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar.""" 16 | super().__init__(dataloader, save_dir, pbar, args, _callbacks) 17 | self.args.task = 'classify' 18 | self.metrics = ClassifyMetrics() 19 | 20 | def get_desc(self): 21 | """Returns a formatted string summarizing classification metrics.""" 22 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') 23 | 24 | def init_metrics(self, model): 25 | """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" 26 | self.names = model.names 27 | self.nc = len(model.names) 28 | self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify') 29 | self.pred = [] 30 | self.targets = [] 31 | 32 | def preprocess(self, batch): 33 | """Preprocesses input batch and returns it.""" 34 | batch['img'] = batch['img'].to(self.device, non_blocking=True) 35 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() 36 | batch['cls'] = batch['cls'].to(self.device) 37 | return batch 38 | 39 | def update_metrics(self, preds, batch): 40 | """Updates running metrics with model predictions and batch targets.""" 41 | n5 = min(len(self.model.names), 5) 42 | self.pred.append(preds.argsort(1, descending=True)[:, :n5]) 43 | self.targets.append(batch['cls']) 44 | 45 | def finalize_metrics(self, *args, **kwargs): 46 | """Finalizes metrics of the model such as confusion_matrix and speed.""" 47 | self.confusion_matrix.process_cls_preds(self.pred, self.targets) 48 | if self.args.plots: 49 | for normalize in True, False: 50 | self.confusion_matrix.plot(save_dir=self.save_dir, 51 | names=self.names.values(), 52 | normalize=normalize, 53 | on_plot=self.on_plot) 54 | self.metrics.speed = self.speed 55 | self.metrics.confusion_matrix = self.confusion_matrix 56 | 57 | def get_stats(self): 58 | """Returns a dictionary of metrics obtained by processing targets and predictions.""" 59 | self.metrics.process(self.targets, self.pred) 60 | return self.metrics.results_dict 61 | 62 | def build_dataset(self, img_path): 63 | return ClassificationDataset(root=img_path, args=self.args, augment=False) 64 | 65 | def get_dataloader(self, dataset_path, batch_size): 66 | """Builds and returns a data loader for classification tasks with given parameters.""" 67 | dataset = self.build_dataset(dataset_path) 68 | return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) 69 | 70 | def print_results(self): 71 | """Prints evaluation metrics for YOLO object detection model.""" 72 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format 73 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) 74 | 75 | def plot_val_samples(self, batch, ni): 76 | """Plot validation image samples.""" 77 | plot_images(images=batch['img'], 78 | batch_idx=torch.arange(len(batch['img'])), 79 | cls=batch['cls'].squeeze(-1), 80 | fname=self.save_dir / f'val_batch{ni}_labels.jpg', 81 | names=self.names, 82 | on_plot=self.on_plot) 83 | 84 | def plot_predictions(self, batch, preds, ni): 85 | """Plots predicted bounding boxes on input images and saves the result.""" 86 | plot_images(batch['img'], 87 | batch_idx=torch.arange(len(batch['img'])), 88 | cls=torch.argmax(preds, dim=1), 89 | fname=self.save_dir / f'val_batch{ni}_pred.jpg', 90 | names=self.names, 91 | on_plot=self.on_plot) # pred 92 | 93 | 94 | def val(cfg=DEFAULT_CFG, use_python=False): 95 | """Validate YOLO model using custom data.""" 96 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 97 | data = cfg.data or 'mnist160' 98 | 99 | args = dict(model=model, data=data) 100 | if use_python: 101 | from ultralytics import YOLO 102 | YOLO(model).val(**args) 103 | else: 104 | validator = ClassificationValidator(args=args) 105 | validator(model=args['model']) 106 | 107 | 108 | if __name__ == '__main__': 109 | val() 110 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/detect/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import DetectionPredictor, predict 4 | from .train import DetectionTrainer, train 5 | from .val import DetectionValidator, val 6 | 7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/detect/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 8 | 9 | 10 | class DetectionPredictor(BasePredictor): 11 | 12 | def postprocess(self, preds, img, orig_imgs): 13 | """Postprocesses predictions and returns a list of Results objects.""" 14 | preds = ops.non_max_suppression(preds, 15 | self.args.conf, 16 | self.args.iou, 17 | agnostic=self.args.agnostic_nms, 18 | max_det=self.args.max_det, 19 | classes=self.args.classes) 20 | 21 | results = [] 22 | for i, pred in enumerate(preds): 23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 24 | if not isinstance(orig_imgs, torch.Tensor): 25 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 26 | path = self.batch[0] 27 | img_path = path[i] if isinstance(path, list) else path 28 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 29 | return results 30 | 31 | 32 | def predict(cfg=DEFAULT_CFG, use_python=False): 33 | """Runs YOLO model inference on input image(s).""" 34 | model = cfg.model or 'yolov8n.pt' 35 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 36 | else 'https://ultralytics.com/images/bus.jpg' 37 | 38 | args = dict(model=model, source=source) 39 | if use_python: 40 | from ultralytics import YOLO 41 | YOLO(model)(**args) 42 | else: 43 | predictor = DetectionPredictor(overrides=args) 44 | predictor.predict_cli() 45 | 46 | 47 | if __name__ == '__main__': 48 | predict() 49 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/pose/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import PosePredictor, predict 4 | from .train import PoseTrainer, train 5 | from .val import PoseValidator, val 6 | 7 | __all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/pose/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.engine.results import Results 4 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 6 | 7 | 8 | class PosePredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'pose' 13 | 14 | def postprocess(self, preds, img, orig_imgs): 15 | """Return detection results for a given input image or list of images.""" 16 | preds = ops.non_max_suppression(preds, 17 | self.args.conf, 18 | self.args.iou, 19 | agnostic=self.args.agnostic_nms, 20 | max_det=self.args.max_det, 21 | classes=self.args.classes, 22 | nc=len(self.model.names)) 23 | 24 | results = [] 25 | for i, pred in enumerate(preds): 26 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 27 | shape = orig_img.shape 28 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() 29 | pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:] 30 | pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape) 31 | path = self.batch[0] 32 | img_path = path[i] if isinstance(path, list) else path 33 | results.append( 34 | Results(orig_img=orig_img, 35 | path=img_path, 36 | names=self.model.names, 37 | boxes=pred[:, :6], 38 | keypoints=pred_kpts)) 39 | return results 40 | 41 | 42 | def predict(cfg=DEFAULT_CFG, use_python=False): 43 | """Runs YOLO to predict objects in an image or video.""" 44 | model = cfg.model or 'yolov8n-pose.pt' 45 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 46 | else 'https://ultralytics.com/images/bus.jpg' 47 | 48 | args = dict(model=model, source=source) 49 | if use_python: 50 | from ultralytics import YOLO 51 | YOLO(model)(**args) 52 | else: 53 | predictor = PosePredictor(overrides=args) 54 | predictor.predict_cli() 55 | 56 | 57 | if __name__ == '__main__': 58 | predict() 59 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/pose/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | from ultralytics.nn.tasks import PoseModel 6 | from ultralytics.yolo import v8 7 | from ultralytics.yolo.utils import DEFAULT_CFG 8 | from ultralytics.yolo.utils.plotting import plot_images, plot_results 9 | 10 | 11 | # BaseTrainer python usage 12 | class PoseTrainer(v8.detect.DetectionTrainer): 13 | 14 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 15 | """Initialize a PoseTrainer object with specified configurations and overrides.""" 16 | if overrides is None: 17 | overrides = {} 18 | overrides['task'] = 'pose' 19 | super().__init__(cfg, overrides, _callbacks) 20 | 21 | def get_model(self, cfg=None, weights=None, verbose=True): 22 | """Get pose estimation model with specified configuration and weights.""" 23 | model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) 24 | if weights: 25 | model.load(weights) 26 | 27 | return model 28 | 29 | def set_model_attributes(self): 30 | """Sets keypoints shape attribute of PoseModel.""" 31 | super().set_model_attributes() 32 | self.model.kpt_shape = self.data['kpt_shape'] 33 | 34 | def get_validator(self): 35 | """Returns an instance of the PoseValidator class for validation.""" 36 | self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' 37 | return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 38 | 39 | def plot_training_samples(self, batch, ni): 40 | """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" 41 | images = batch['img'] 42 | kpts = batch['keypoints'] 43 | cls = batch['cls'].squeeze(-1) 44 | bboxes = batch['bboxes'] 45 | paths = batch['im_file'] 46 | batch_idx = batch['batch_idx'] 47 | plot_images(images, 48 | batch_idx, 49 | cls, 50 | bboxes, 51 | kpts=kpts, 52 | paths=paths, 53 | fname=self.save_dir / f'train_batch{ni}.jpg', 54 | on_plot=self.on_plot) 55 | 56 | def plot_metrics(self): 57 | """Plots training/val metrics.""" 58 | plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png 59 | 60 | 61 | def train(cfg=DEFAULT_CFG, use_python=False): 62 | """Train the YOLO model on the given data and device.""" 63 | model = cfg.model or 'yolov8n-pose.yaml' 64 | data = cfg.data or 'coco8-pose.yaml' 65 | device = cfg.device if cfg.device is not None else '' 66 | 67 | args = dict(model=model, data=data, device=device) 68 | if use_python: 69 | from ultralytics import YOLO 70 | YOLO(model).train(**args) 71 | else: 72 | trainer = PoseTrainer(overrides=args) 73 | trainer.train() 74 | 75 | 76 | if __name__ == '__main__': 77 | train() 78 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/segment/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import SegmentationPredictor, predict 4 | from .train import SegmentationTrainer, train 5 | from .val import SegmentationValidator, val 6 | 7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/segment/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.results import Results 6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 7 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 8 | 9 | 10 | class SegmentationPredictor(DetectionPredictor): 11 | 12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 13 | super().__init__(cfg, overrides, _callbacks) 14 | self.args.task = 'segment' 15 | 16 | def postprocess(self, preds, img, orig_imgs): 17 | """TODO: filter by classes.""" 18 | p = ops.non_max_suppression(preds[0], 19 | self.args.conf, 20 | self.args.iou, 21 | agnostic=self.args.agnostic_nms, 22 | max_det=self.args.max_det, 23 | nc=len(self.model.names), 24 | classes=self.args.classes) 25 | results = [] 26 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 27 | for i, pred in enumerate(p): 28 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 29 | path = self.batch[0] 30 | img_path = path[i] if isinstance(path, list) else path 31 | if not len(pred): # save empty boxes 32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 33 | continue 34 | if self.args.retina_masks: 35 | if not isinstance(orig_imgs, torch.Tensor): 36 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 37 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 38 | else: 39 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 40 | if not isinstance(orig_imgs, torch.Tensor): 41 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 42 | results.append( 43 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 44 | return results 45 | 46 | 47 | def predict(cfg=DEFAULT_CFG, use_python=False): 48 | """Runs YOLO object detection on an image or video source.""" 49 | model = cfg.model or 'yolov8n-seg.pt' 50 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 51 | else 'https://ultralytics.com/images/bus.jpg' 52 | 53 | args = dict(model=model, source=source) 54 | if use_python: 55 | from ultralytics import YOLO 56 | YOLO(model)(**args) 57 | else: 58 | predictor = SegmentationPredictor(overrides=args) 59 | predictor.predict_cli() 60 | 61 | 62 | if __name__ == '__main__': 63 | predict() 64 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/segment/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | from copy import copy 3 | 4 | from ultralytics.nn.tasks import SegmentationModel 5 | from ultralytics.yolo import v8 6 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK 7 | from ultralytics.yolo.utils.plotting import plot_images, plot_results 8 | 9 | 10 | # BaseTrainer python usage 11 | class SegmentationTrainer(v8.detect.DetectionTrainer): 12 | 13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 14 | """Initialize a SegmentationTrainer object with given arguments.""" 15 | if overrides is None: 16 | overrides = {} 17 | overrides['task'] = 'segment' 18 | super().__init__(cfg, overrides, _callbacks) 19 | 20 | def get_model(self, cfg=None, weights=None, verbose=True): 21 | """Return SegmentationModel initialized with specified config and weights.""" 22 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) 23 | if weights: 24 | model.load(weights) 25 | 26 | return model 27 | 28 | def get_validator(self): 29 | """Return an instance of SegmentationValidator for validation of YOLO model.""" 30 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' 31 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 32 | 33 | def plot_training_samples(self, batch, ni): 34 | """Creates a plot of training sample images with labels and box coordinates.""" 35 | plot_images(batch['img'], 36 | batch['batch_idx'], 37 | batch['cls'].squeeze(-1), 38 | batch['bboxes'], 39 | batch['masks'], 40 | paths=batch['im_file'], 41 | fname=self.save_dir / f'train_batch{ni}.jpg', 42 | on_plot=self.on_plot) 43 | 44 | def plot_metrics(self): 45 | """Plots training/val metrics.""" 46 | plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png 47 | 48 | 49 | def train(cfg=DEFAULT_CFG, use_python=False): 50 | """Train a YOLO segmentation model based on passed arguments.""" 51 | model = cfg.model or 'yolov8n-seg.pt' 52 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist") 53 | device = cfg.device if cfg.device is not None else '' 54 | 55 | args = dict(model=model, data=data, device=device) 56 | if use_python: 57 | from ultralytics import YOLO 58 | YOLO(model).train(**args) 59 | else: 60 | trainer = SegmentationTrainer(overrides=args) 61 | trainer.train() 62 | 63 | 64 | if __name__ == '__main__': 65 | train() 66 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASIA-IVA-Lab/FastSAM/b4ed20c2fed75eadc5aa7d8b09fedd137b873b52/utils/__init__.py --------------------------------------------------------------------------------