├── .gitattributes ├── .gitignore ├── 170740.mp4 ├── 377.png ├── API.py ├── Model ├── Boat-detect-medium.pt └── Boat-detect-nano.pt ├── Output ├── output.avi └── output.jpg ├── README.md ├── Ship.py ├── detect.py ├── main.py ├── requirements.txt ├── tracker ├── basetrack.py ├── byte_tracker.py ├── kalman_filter.py └── matching.py └── tracking.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.mp4 filter=lfs diff=lfs merge=lfs -text 2 | *.avi filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .vscode/ 162 | data.yaml -------------------------------------------------------------------------------- /170740.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a848c76f6d23301ee5b3d8ecf14e2b29a972d44189037373a706edcdb9eddde3 3 | size 13956150 4 | -------------------------------------------------------------------------------- /377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TriNguyen317/Ship-detection-and-tracking-Yolov8/1a521b23052f80559aa762efd5db3e90dbf8365b/377.png -------------------------------------------------------------------------------- /API.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, File, UploadFile, Form, Response 2 | import cv2 3 | from ultralytics import YOLO 4 | from ultralytics import YOLO 5 | from PIL import Image 6 | import io 7 | import numpy as np 8 | 9 | from main import Draw 10 | app = FastAPI() 11 | 12 | model = YOLO("./Model/Boat-detect-medium.pt") 13 | 14 | # def detect(model,img, conf=0.3, iou_thresh=0.45): 15 | # result = model(img, iou=iou_thresh) 16 | 17 | # boxes = result[0].boxes # Boxes object for bbox outputs 18 | # conf_detect=boxes.conf.cpu().numpy() 19 | # box_detect=boxes.xyxy.cpu().numpy() 20 | # idx=np.where(conf_detect>conf) 21 | # return box_detect[idx], conf_detect[idx] 22 | 23 | # def Draw(model,img): 24 | # img=np.array(img) 25 | # boxes, conf =detect(model,img) 26 | # for num, i in enumerate(boxes): 27 | # img=cv2.rectangle(img,(int(i[0]),int(i[1])),(int(i[2]),int(i[3])),(255, 0, 0),2) 28 | # cv2.putText(img, str(conf[num]), (int(i[0]),int(i[1]-3)),cv2.FONT_HERSHEY_SIMPLEX, 0.75 ,(255, 0, 0) ,2,cv2.LINE_AA) 29 | 30 | # #cv2.imwrite("68-detect.png",img) 31 | # cv2.waitKey(0) 32 | # cv2.destroyAllWindows() 33 | # return img 34 | 35 | 36 | 37 | @app.post("/uploadfile/") 38 | async def create_upload_file(file: UploadFile = File(...), 39 | Path_model: str = Form(default="./Model/Boat-detect-medium.pt"), 40 | imgsz: int = Form(default=640), 41 | conf: float = Form(default=0.5), 42 | iou: float = Form(default=0.45)): 43 | 44 | data=await file.read() 45 | img= Image.open(io.BytesIO(data)).convert("RGB") 46 | img=np.array(img) 47 | detect_img=Draw(model,img) 48 | detect_img=cv2.cvtColor(detect_img,cv2.COLOR_RGB2BGR) 49 | _,detect_img=cv2.imencode(".png",detect_img) 50 | 51 | response= Response(content=detect_img.tobytes(), media_type="image/png") 52 | 53 | return response -------------------------------------------------------------------------------- /Model/Boat-detect-medium.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TriNguyen317/Ship-detection-and-tracking-Yolov8/1a521b23052f80559aa762efd5db3e90dbf8365b/Model/Boat-detect-medium.pt -------------------------------------------------------------------------------- /Model/Boat-detect-nano.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TriNguyen317/Ship-detection-and-tracking-Yolov8/1a521b23052f80559aa762efd5db3e90dbf8365b/Model/Boat-detect-nano.pt -------------------------------------------------------------------------------- /Output/output.avi: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8bc610d4611dc5c47b470269d34131f522a11ca3b96f79c04764942667d554a7 3 | size 112832914 4 | -------------------------------------------------------------------------------- /Output/output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TriNguyen317/Ship-detection-and-tracking-Yolov8/1a521b23052f80559aa762efd5db3e90dbf8365b/Output/output.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ship Detection Project 2 | 3 | This project focuses on ship detection in images and videos using computer vision techniques and the YOLO (You Only Look Once) algorithm implemented with the Ultralytics library. It provides an API for users to upload images and receive the detected ship images as a response. 4 | 5 | ## Features 6 | 7 | - Ship detection in images and videos. 8 | - FastAPI-based API for easy integration and usage. 9 | - Utilizes the YOLO algorithm for accurate ship detection. 10 | - Supports image and video inputs. 11 | - Provides bounding box visualization of detected ships. 12 | 13 | ## Table of Contents 14 | 15 | - [Introduction](#introduction) 16 | - [Installation](#installation) 17 | - [API Documentation](#API-Documentation) 18 | - [Script Documentation](#Script-Documentation) 19 | - [Examples](#examples) 20 | 21 | 22 | ## Installation 23 | 24 | 1. Clone the repository: 25 | 26 | ```bash 27 | git clone https://github.com/TriNguyen317/Ship-detection-and-tracking.git 28 | 29 | ``` 30 | 31 | 2. Install the required dependencies: 32 | 33 | 3. Download the YOLO model weights and place them in the appropriate directory: 34 | 35 | - You can download the model weights file or use existing weights in the `Model` directory. 36 | 37 | ## API Documentation 38 | 39 | 1. Start the API server: 40 | 41 | ```bash 42 | uvicorn API:app --host 0.0.0.0 --port 8000 43 | 44 | ``` 45 | 46 | 2. Access the API at `http://localhost:8000/docs#/default/create_upload_file_uploadfile__post` and click `Try it out` button upload an image file containing ships. 47 | 48 | The API will process the input file, perform ship detection, and return the image or video with bounding boxes indicating the detected ships. 49 | 50 | The API provides the following endpoint: 51 | 52 | ### Upload File \[/uploadfile/\] 53 | 54 | - Description: Uploads an image or video file for ship detection. 55 | - Parameters: 56 | - `file` (file): The image or video file to be uploaded. 57 | - `Path_model` (string, optional): The path to the YOLO model weights file. Default: `./Model/Boat-detect-medium.pt`. 58 | - `imgsz` (integer, optional): The image size for processing. Default: `640`. 59 | - `conf` (float, optional): Confidence threshold for ship detection. Default: `0.6`. 60 | - `iou` (float, optional): IOU (Intersection over Union) threshold for ship detection. Default: `0.45`. 61 | - Response: 62 | 63 | ## Script Documentation 64 | 65 | ### Command-line Arguments 66 | The project supports the following command-line arguments: 67 | 68 | - `-imgsz`: Size of the image (default: 640) 69 | - `-input`: Path to the input file (default: "170740.mp4") 70 | - `-output`: Path for the output file (default: "track") 71 | - `-model`: Path to the model file (default: "./Model/Boat-detect-medium.pt") 72 | - `-conf`: Score confidence threshold (default: 0.6) 73 | - `-iou_threshold`: IOU threshold (default: 0.5) 74 | - `-video`: Flag indicating if the input is a video (default: False) 75 | - `-detect`: Activate the detection task (default: True) 76 | - `-tracking`: Activate the tracking task (default: False) 77 | - `-track_buffer`: Buffer to calculate when to remove tracks (default: 30) 78 | - `-match_thresh`: Matching threshold for tracking (default: 0.5) 79 | - `-time-check-state`: Time to reset state (default: 1.5) 80 | - `-train`: Task is training (default: False) 81 | - `-epoch`: Num epochs (default: 50) 82 | 83 | Explain each argument in detail, including its purpose, default value, and any constraints or limitations. 84 | 85 | ### Config data 86 | ``` 87 | - Data 88 | - train 89 | - images 90 | - label 91 | - valid 92 | - images 93 | - label 94 | - test 95 | - images 96 | - label 97 | ``` 98 | 99 | - Link data: https://drive.google.com/file/d/1c46R47X17maEfEUvCb8T6snlHiNLsGzx/view?usp=sharing 100 | 101 | ### Config data.yaml file 102 | ``` 103 | train: Path to train data 104 | val: Path to valid data 105 | test: Path to test data 106 | 107 | nc: num of class 108 | names: Array of name each class 109 | ``` 110 | 111 | ### Train 112 | 113 | ```bash 114 | python main.py -train -epoch 50 115 | ``` 116 | 117 | ### Detection and Tracking 118 | 119 | ```bash 120 | python main.py -detect -input 377.png -output output 121 | ``` 122 | 123 | ```bash 124 | python main.py -tracking -video -input 10.mp4 -output output 125 | ``` 126 | 127 | ## Acknowledgments 128 | 129 | - Providing the YOLO implementation. 130 | - Task track and detect ship with YOLOv8 131 | - The FastAPI framework for creating the API server. 132 | - Any other resources or references that have been used in this project. 133 | 134 | ## Contact 135 | 136 | For any inquiries or questions, please contact: 137 | 138 | - Project Maintainer: Nguyen Dinh Tri (dinhtrikt11102002@gmail.com) 139 | - Project Homepage: [https://github.com/TriNguyen317/Ship-detection-and-tracking] 140 | 141 | Feel free to reach out with any feedback or suggestions! -------------------------------------------------------------------------------- /Ship.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import math 4 | # Class save Ship object 5 | 6 | class Ship(): 7 | def __init__(self, tlwh, score, track_id): 8 | self.is_move = True 9 | self.bbox = tlwh 10 | self.pre_bbox = [0, 0, 0, 0] 11 | self.timestart = 0 12 | self.track_id = track_id 13 | self.frame_start = 0 14 | self.is_activate = False 15 | self.score = score 16 | self.off_frame = 0 17 | 18 | def change_state_move(self): 19 | self.is_move = not self.is_move 20 | self.timestart = time.time() 21 | 22 | def activate(self, frame_id): 23 | self.is_activate = True 24 | self.frame_start = frame_id 25 | self.off_frame = 0 26 | 27 | def update(self, bbox, score): 28 | self.bbox = bbox 29 | self.score = score 30 | 31 | def deactivate(self): 32 | self.is_activate = False 33 | self.off_frame += 1 34 | 35 | def update_bbox(self, bbox): 36 | self.pre_bbox = bbox 37 | 38 | def __repr__(self): 39 | return 'Ship_{}_{}'.format(self.track_id, self.is_move) 40 | 41 | # Class manage ship object 42 | 43 | 44 | class Ship_manager(): 45 | def __init__(self, time_remove=30): 46 | self.list_ship = [] 47 | self.time_remove = time_remove 48 | 49 | # Update frame to list 50 | def update(self, ArrayStrack, frame_id): 51 | old_ids = np.array([i.track_id for i in self.list_ship]) 52 | new_ids = np.array([i.track_id for i in ArrayStrack]) 53 | 54 | for strack in ArrayStrack: 55 | if strack.track_id in old_ids: 56 | # update 57 | exist = np.where(old_ids == strack.track_id) 58 | self.list_ship[exist[0][0]].activate(frame_id) 59 | self.list_ship[exist[0][0]].update(strack.tlwh, strack.score) 60 | else: 61 | # ship_new = Ship() 62 | ship = Ship(strack.tlwh, strack.score, strack.track_id) 63 | ship.activate(frame_id) 64 | self.list_ship.append(ship) 65 | 66 | for pos, ship in enumerate(self.list_ship): 67 | # exist=np.where(new_ids==ship.track_id) 68 | if ship.track_id not in new_ids: 69 | ship.deactivate() 70 | if ship.off_frame > self.time_remove: 71 | self.list_ship = remove(self.list_ship, pos) 72 | # Check state of object, if the difference of both iou and center of bbox is too large, change state to MOVING, 73 | # otherwise STOP. 74 | 75 | def check_state(self): 76 | for ship in self.list_ship: 77 | iou = get_iou(ship.bbox, ship.pre_bbox) 78 | dis_center = get_discenter(ship.bbox, ship.pre_bbox) 79 | if iou > 0.75: 80 | if ship.is_move == True: 81 | ship.change_state_move() 82 | else: 83 | if dis_center > 15: 84 | if ship.is_move == False: 85 | ship.change_state_move() 86 | 87 | # Update pre_bbox 88 | def update_bbox(self, ArrayStrack): 89 | old_ids = np.array([i.track_id for i in self.list_ship]) 90 | for strack in ArrayStrack: 91 | if strack.track_id in old_ids: 92 | exist = np.where(old_ids == strack.track_id) 93 | self.list_ship[exist[0][0]].update_bbox(strack.tlwh) 94 | 95 | # Get distance between 2 box center 96 | 97 | 98 | def get_discenter(box1, box2): 99 | x1 = box1[0]+box1[2]/2 100 | y1 = box1[1]+box1[3]/2 101 | x2 = box2[0]+box2[2]/2 102 | y2 = box2[1]+box2[3]/2 103 | dis = math.sqrt((x2-x1)**2+(y2-y1)**2) 104 | return dis 105 | 106 | # Get IOU between 2 box 107 | 108 | def get_iou(box1, box2, epsilon=1e-5): 109 | x1 = max(box1[0], box2[0]) 110 | y1 = max(box1[1], box2[1]) 111 | x2 = min(box1[0]+box1[2], box2[0]+box2[2]) 112 | y2 = min(box1[1]+box1[3], box2[1]+box2[3]) 113 | width = (x2 - x1) 114 | height = (y2 - y1) 115 | if (width < 0) or (height < 0): 116 | return 0.0 117 | area_overlap = width * height 118 | area_a = box1[2] * box1[3] 119 | area_b = box2[2] * box2[3] 120 | area_combined = area_a + area_b - area_overlap 121 | iou = area_overlap / (area_combined+epsilon) 122 | return iou 123 | 124 | # Remove object at position pos 125 | 126 | def remove(ships, pos): 127 | array = [] 128 | for i in range(len(ships)): 129 | if i != pos: 130 | array.append(ships[i]) 131 | return array 132 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | # Task detect image 5 | 6 | def detectImg(model,img, conf=0.5, iou_thresh=0.45): 7 | result = model(img, iou=iou_thresh) 8 | 9 | boxes = result[0].boxes # Boxes object for bbox outputs 10 | conf_detect = boxes.conf.cpu().numpy() 11 | box_detect = boxes.xyxy.cpu().numpy() 12 | idx = np.where(conf_detect > conf) 13 | return box_detect[idx], np.round_(conf_detect[idx], decimals=3) 14 | 15 | 16 | # Draw 17 | def Draw(model,img): 18 | boxes, conf = detectImg(model,img) 19 | for num, i in enumerate(boxes): 20 | img = cv2.rectangle(img, (int(i[0]), int(i[1])), 21 | (int(i[2]), int(i[3])), (255, 0, 0), 2) 22 | cv2.putText(img, str(conf[num]), (int(i[0]), 23 | int(i[1]-3)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 0, 0), 2, cv2.LINE_AA) 24 | 25 | cv2.waitKey(0) 26 | cv2.destroyAllWindows() 27 | return img 28 | 29 | 30 | # Task detect video 31 | def detectVideo(args, model): 32 | ''' 33 | Args: 34 | imgsz (int): Input of size image. Defaut: 640 35 | input (str): Path of input data. Defaut: 337.png 36 | output (str): Path of output data. Defaut: output 37 | model (str): Path of model. Default: ./Model/Boat-detect-medium.pt 38 | conf (float): Score confidence. Default: 0.6 39 | iou_threshold (float): IOU threshold. Default: 0.5 40 | video (bool): Input is video. Default: False 41 | detect (bool): Task is detection. Default: False 42 | tracking (bool): Task is tracking. Default: False 43 | track_buffer (int): buffer to calculate the time when to remove tracks. Default: 30 44 | match_thresh (float): Matching threshold for tracking in bytetrack. Default: 0.5 45 | time_check_state (float): Time to update state of ship (second). Default: 1.5 46 | 47 | ''' 48 | video = cv2.VideoCapture(args.input) 49 | if (video.isOpened() == False): 50 | print("Error reading video file") 51 | frame_width = int(video.get(3)) 52 | frame_height = int(video.get(4)) 53 | size = (frame_width, frame_height) 54 | print(size) 55 | result = cv2.VideoWriter("./Data/Output/"+args.output+".avi", 56 | cv2.VideoWriter_fourcc(*'MJPG'), 57 | 30, size) 58 | while (True): 59 | ret, frame = video.read() 60 | if ret == True: 61 | frame = Draw(model,frame) 62 | result.write(frame) 63 | if cv2.waitKey(1) & 0xFF == ord('s'): 64 | break 65 | else: 66 | break 67 | video.release() 68 | result.release() 69 | cv2.destroyAllWindows() 70 | print("The video was successfully detected") 71 | print("The video was successfully saved") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import cv2 5 | from ultralytics import YOLO 6 | from ultralytics import YOLO 7 | from detect import * 8 | from tracking import Tracking 9 | 10 | device = ( 11 | "cuda" 12 | if torch.cuda.is_available() 13 | else "mps" 14 | if torch.backends.mps.is_available() 15 | else "cpu" 16 | ) 17 | print(f"Using {device} device") 18 | 19 | 20 | if __name__ == "__main__": 21 | ''' 22 | Args: 23 | imgsz (int): Input of size image. Defaut: 640 24 | input (str): Path of input data. Defaut: 337.png 25 | output (str): Path of output data. Defaut: output 26 | model (str): Path of model. Default: ./Model/Boat-detect-medium.pt 27 | conf (float): Score confidence. Default: 0.6 28 | iou_threshold (float): IOU threshold. Default: 0.5 29 | video (bool): Input is video. Default: False 30 | detect (bool): Task is detection. Default: False 31 | tracking (bool): Task is tracking. Default: False 32 | track_buffer (int): buffer to calculate the time when to remove tracks. Default: 30 33 | match_thresh (float): Matching threshold for tracking in bytetrack. Default: 0.5 34 | time_check_state (float): Time to update state of ship (second). Default: 1.5 35 | train (bool): Task is training. Default: False 36 | epoch (int): Num of epoch. Default: 50 37 | ''' 38 | parser = argparse.ArgumentParser(prog='Boat-detect', 39 | epilog='Text at the bottom of help') 40 | parser.add_argument("-imgsz", type=int, 41 | default=640, help="Size img") 42 | parser.add_argument("-input", type=str, 43 | default="337.png", help="Path of input data") 44 | parser.add_argument("-output", type=str, 45 | default="output", help="Path of output data") 46 | parser.add_argument("-model", type=str, 47 | default="./Model/Boat-detect-medium.pt", help="Path of model") 48 | parser.add_argument("-conf", type=float, 49 | default=0.6, help="Score confidence") 50 | parser.add_argument("-iou_threshold", type=float, 51 | default=0.5, help="IOU threshold") 52 | parser.add_argument("-video", type=bool, action=argparse.BooleanOptionalAction, 53 | default=False, help="Confirm input is a Video") 54 | parser.add_argument("-detect", type=bool, action=argparse.BooleanOptionalAction, 55 | default=True, help="Activate task detection") 56 | parser.add_argument("-tracking", type=bool, action=argparse.BooleanOptionalAction, 57 | default=False, help="Activate task tracking") 58 | parser.add_argument("-track_buffer", type=float, 59 | default=30, help="buffer to calculate the time when to remove tracks") 60 | parser.add_argument("-match_thresh", type=float, 61 | default=0.5, help="Matching threshold for tracking in bytetrack") 62 | parser.add_argument("-time_check_state", type=float, 63 | default=1.5, help="Time to update state of ship") 64 | parser.add_argument("-train", type=bool, action=argparse.BooleanOptionalAction, 65 | default=False, help="Task is training model") 66 | parser.add_argument("-epoch", type=int, 67 | default=50, help="Num epochs") 68 | 69 | args = parser.parse_args() 70 | use_gpu = torch.cuda.is_available() 71 | if use_gpu: 72 | torch.cuda.empty_cache() 73 | # Read model from file .pt 74 | model = YOLO(args.model) 75 | if args.train: 76 | model.train(data="data.yaml", epochs=args.epoch, imgsz=args.imgsz, single_cls=True) 77 | else: 78 | # Detect and tracking on video 79 | if args.video == True: 80 | if args.tracking == True: 81 | Tracking(args, model) 82 | else: 83 | detectVideo(args, model) 84 | # Detect on image 85 | else: 86 | img = cv2.imread(args.input) 87 | detect_img = Draw(model, img) 88 | cv2.imwrite("./Output/"+args.output+".jpg", detect_img) 89 | print("The image was successfully detected") 90 | print("The image was successfully saved") 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TriNguyen317/Ship-detection-and-tracking-Yolov8/1a521b23052f80559aa762efd5db3e90dbf8365b/requirements.txt -------------------------------------------------------------------------------- /tracker/basetrack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | 4 | 5 | class TrackState(object): 6 | New = 0 7 | Tracked = 1 8 | Lost = 2 9 | Removed = 3 10 | 11 | 12 | class BaseTrack(object): 13 | _count = 0 14 | 15 | track_id = 0 16 | is_activated = False 17 | state = TrackState.New 18 | 19 | history = OrderedDict() 20 | features = [] 21 | curr_feature = None 22 | score = 0 23 | start_frame = 0 24 | frame_id = 0 25 | time_since_update = 0 26 | 27 | # multi-camera 28 | location = (np.inf, np.inf) 29 | 30 | @property 31 | def end_frame(self): 32 | return self.frame_id 33 | 34 | @staticmethod 35 | def next_id(): 36 | BaseTrack._count += 1 37 | return BaseTrack._count 38 | 39 | def activate(self, *args): 40 | raise NotImplementedError 41 | 42 | def predict(self): 43 | raise NotImplementedError 44 | 45 | def update(self, *args, **kwargs): 46 | raise NotImplementedError 47 | 48 | def mark_lost(self): 49 | self.state = TrackState.Lost 50 | 51 | def mark_removed(self): 52 | self.state = TrackState.Removed 53 | -------------------------------------------------------------------------------- /tracker/byte_tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import os 4 | import os.path as osp 5 | import copy 6 | import torch 7 | import torch.nn.functional as F 8 | import time 9 | 10 | from tracker.kalman_filter import KalmanFilter 11 | import tracker.matching as matching 12 | from tracker.basetrack import BaseTrack, TrackState 13 | 14 | class STrack(BaseTrack): 15 | shared_kalman = KalmanFilter() 16 | def __init__(self, tlwh, score): 17 | 18 | # wait activate 19 | self._tlwh = np.asarray(tlwh, dtype=np.float64) 20 | self.kalman_filter = None 21 | self.mean, self.covariance = None, None 22 | self.is_activated = False 23 | 24 | self.score = score 25 | self.tracklet_len = 0 26 | 27 | def predict(self): 28 | mean_state = self.mean.copy() 29 | if self.state != TrackState.Tracked: 30 | mean_state[7] = 0 31 | self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) 32 | 33 | @staticmethod 34 | 35 | def multi_predict(stracks): 36 | if len(stracks) > 0: 37 | multi_mean = np.asarray([st.mean.copy() for st in stracks]) 38 | multi_covariance = np.asarray([st.covariance for st in stracks]) 39 | for i, st in enumerate(stracks): 40 | if st.state != TrackState.Tracked: 41 | multi_mean[i][7] = 0 42 | multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) 43 | for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): 44 | stracks[i].mean = mean 45 | stracks[i].covariance = cov 46 | 47 | def activate(self, kalman_filter, frame_id): 48 | """Start a new tracklet""" 49 | self.kalman_filter = kalman_filter 50 | self.track_id = self.next_id() 51 | self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) 52 | 53 | self.tracklet_len = 0 54 | self.state = TrackState.Tracked 55 | if frame_id == 1: 56 | self.is_activated = True 57 | # self.is_activated = True 58 | self.frame_id = frame_id 59 | self.start_frame = frame_id 60 | 61 | def re_activate(self, new_track, frame_id, new_id=False): 62 | self.mean, self.covariance = self.kalman_filter.update( 63 | self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) 64 | ) 65 | self.tracklet_len = 0 66 | self.state = TrackState.Tracked 67 | self.is_activated = True 68 | self.frame_id = frame_id 69 | if new_id: 70 | self.track_id = self.next_id() 71 | self.score = new_track.score 72 | 73 | def update(self, new_track, frame_id): 74 | """ 75 | Update a matched track 76 | :type new_track: STrack 77 | :type frame_id: int 78 | :type update_feature: bool 79 | :return: 80 | """ 81 | self.frame_id = frame_id 82 | self.tracklet_len += 1 83 | 84 | new_tlwh = new_track.tlwh 85 | self.mean, self.covariance = self.kalman_filter.update( 86 | self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) 87 | self.state = TrackState.Tracked 88 | self.is_activated = True 89 | 90 | self.score = new_track.score 91 | 92 | @property 93 | # @jit(nopython=True) 94 | def tlwh(self): 95 | """Get current position in bounding box format `(top left x, top left y, 96 | width, height)`. 97 | """ 98 | if self.mean is None: 99 | return self._tlwh.copy() 100 | ret = self.mean[:4].copy() 101 | ret[2] *= ret[3] 102 | ret[:2] -= ret[2:] / 2 103 | return ret 104 | 105 | @property 106 | # @jit(nopython=True) 107 | def tlbr(self): 108 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., 109 | `(top left, bottom right)`. 110 | """ 111 | ret = self.tlwh.copy() 112 | ret[2:] += ret[:2] 113 | return ret 114 | 115 | @staticmethod 116 | # @jit(nopython=True) 117 | def tlwh_to_xyah(tlwh): 118 | """Convert bounding box to format `(center x, center y, aspect ratio, 119 | height)`, where the aspect ratio is `width / height`. 120 | """ 121 | ret = np.asarray(tlwh).copy() 122 | ret[:2] += ret[2:] / 2 123 | ret[2] /= ret[3] 124 | return ret 125 | 126 | def to_xyah(self): 127 | return self.tlwh_to_xyah(self.tlwh) 128 | 129 | @staticmethod 130 | # @jit(nopython=True) 131 | def tlbr_to_tlwh(tlbr): 132 | ret = np.asarray(tlbr).copy() 133 | ret[2:] -= ret[:2] 134 | return ret 135 | 136 | @staticmethod 137 | # @jit(nopython=True) 138 | def tlwh_to_tlbr(tlwh): 139 | ret = np.asarray(tlwh).copy() 140 | ret[2:] += ret[:2] 141 | return ret 142 | 143 | def __repr__(self): 144 | return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) 145 | 146 | 147 | class BYTETracker(object): 148 | def __init__(self, args, frame_rate=30): 149 | self.tracked_stracks = [] # type: list[STrack] 150 | self.lost_stracks = [] # type: list[STrack] 151 | self.removed_stracks = [] # type: list[STrack] 152 | 153 | self.frame_id = 0 154 | self.args = args 155 | self.det_thresh = args.conf 156 | # self.det_thresh = args.conf + 0.1 157 | self.buffer_size = int(frame_rate / 30.0 * args.track_buffer) 158 | self.max_time_lost = self.buffer_size 159 | self.kalman_filter = KalmanFilter() 160 | 161 | def update(self, scores, bboxes, img_info, img_size): 162 | self.frame_id += 1 163 | activated_starcks = [] 164 | refind_stracks = [] 165 | lost_stracks = [] 166 | removed_stracks = [] 167 | 168 | # if output_results.shape[1] == 5: 169 | # scores = output_results[:, 4] 170 | # bboxes = output_results[:, :4] 171 | # else: 172 | # output_results = output_results.cpu().numpy() 173 | # scores = output_results[:, 4] * output_results[:, 5] 174 | # bboxes = output_results[:, :4] # x1y1x2y2 175 | img_h, img_w = img_info[0], img_info[1] 176 | scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w)) 177 | bboxes /= scale 178 | 179 | remain_inds = scores > self.args.conf 180 | inds_low = scores > 0.1 181 | inds_high = scores < self.args.conf 182 | 183 | inds_second = np.logical_and(inds_low, inds_high) 184 | dets_second = bboxes[inds_second] 185 | dets = bboxes[remain_inds] 186 | scores_keep = scores[remain_inds] 187 | scores_second = scores[inds_second] 188 | 189 | if len(dets) > 0: 190 | '''Detections''' 191 | detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for 192 | (tlbr, s) in zip(dets, scores_keep)] 193 | else: 194 | detections = [] 195 | 196 | ''' Add newly detected tracklets to tracked_stracks''' 197 | unconfirmed = [] 198 | tracked_stracks = [] # type: list[STrack] 199 | for track in self.tracked_stracks: 200 | if not track.is_activated: 201 | unconfirmed.append(track) 202 | else: 203 | tracked_stracks.append(track) 204 | 205 | ''' Step 2: First association, with high score detection boxes''' 206 | strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) 207 | # Predict the current location with KF 208 | STrack.multi_predict(strack_pool) 209 | dists = matching.iou_distance(strack_pool, detections) 210 | # if not self.args.mot20: 211 | dists = matching.fuse_score(dists, detections) 212 | matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh) 213 | 214 | for itracked, idet in matches: 215 | track = strack_pool[itracked] 216 | det = detections[idet] 217 | if track.state == TrackState.Tracked: 218 | track.update(detections[idet], self.frame_id) 219 | activated_starcks.append(track) 220 | else: 221 | track.re_activate(det, self.frame_id, new_id=False) 222 | refind_stracks.append(track) 223 | ''' Step 3: Second association, with low score detection boxes''' 224 | # association the untrack to the low score detections 225 | if len(dets_second) > 0: 226 | '''Detections''' 227 | detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for 228 | (tlbr, s) in zip(dets_second, scores_second)] 229 | else: 230 | detections_second = [] 231 | r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] 232 | dists = matching.iou_distance(r_tracked_stracks, detections_second) 233 | 234 | 235 | matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) 236 | for itracked, idet in matches: 237 | track = r_tracked_stracks[itracked] 238 | det = detections_second[idet] 239 | if track.state == TrackState.Tracked: 240 | track.update(det, self.frame_id) 241 | activated_starcks.append(track) 242 | else: 243 | track.re_activate(det, self.frame_id, new_id=False) 244 | refind_stracks.append(track) 245 | 246 | for it in u_track: 247 | track = r_tracked_stracks[it] 248 | if not track.state == TrackState.Lost: 249 | track.mark_lost() 250 | lost_stracks.append(track) 251 | 252 | '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' 253 | detections = [detections[i] for i in u_detection] 254 | dists = matching.iou_distance(unconfirmed, detections) 255 | # if not self.args.mot20: 256 | dists = matching.fuse_score(dists, detections) 257 | matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.5) 258 | for itracked, idet in matches: 259 | unconfirmed[itracked].update(detections[idet], self.frame_id) 260 | activated_starcks.append(unconfirmed[itracked]) 261 | for it in u_unconfirmed: 262 | track = unconfirmed[it] 263 | track.mark_removed() 264 | removed_stracks.append(track) 265 | 266 | """ Step 4: Init new stracks""" 267 | for inew in u_detection: 268 | track = detections[inew] 269 | if track.score < self.det_thresh: 270 | continue 271 | track.activate(self.kalman_filter, self.frame_id) 272 | activated_starcks.append(track) 273 | """ Step 5: Update state""" 274 | for track in self.lost_stracks: 275 | if self.frame_id - track.end_frame > self.max_time_lost: 276 | track.mark_removed() 277 | removed_stracks.append(track) 278 | 279 | # print('Ramained match {} s'.format(t4-t3)) 280 | 281 | self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] 282 | self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) 283 | self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) 284 | self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) 285 | self.lost_stracks.extend(lost_stracks) 286 | self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) 287 | self.removed_stracks.extend(removed_stracks) 288 | self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) 289 | # get scores of lost tracks 290 | output_stracks = [track for track in self.tracked_stracks]# if track.is_activated] 291 | 292 | 293 | 294 | return output_stracks 295 | 296 | 297 | def joint_stracks(tlista, tlistb): 298 | exists = {} 299 | res = [] 300 | for t in tlista: 301 | exists[t.track_id] = 1 302 | res.append(t) 303 | for t in tlistb: 304 | tid = t.track_id 305 | if not exists.get(tid, 0): 306 | exists[tid] = 1 307 | res.append(t) 308 | return res 309 | 310 | 311 | def sub_stracks(tlista, tlistb): 312 | stracks = {} 313 | for t in tlista: 314 | stracks[t.track_id] = t 315 | for t in tlistb: 316 | tid = t.track_id 317 | if stracks.get(tid, 0): 318 | del stracks[tid] 319 | return list(stracks.values()) 320 | 321 | 322 | def remove_duplicate_stracks(stracksa, stracksb): 323 | pdist = matching.iou_distance(stracksa, stracksb) 324 | pairs = np.where(pdist < 0.15) 325 | dupa, dupb = list(), list() 326 | for p, q in zip(*pairs): 327 | timep = stracksa[p].frame_id - stracksa[p].start_frame 328 | timeq = stracksb[q].frame_id - stracksb[q].start_frame 329 | if timep > timeq: 330 | dupb.append(q) 331 | else: 332 | dupa.append(p) 333 | resa = [t for i, t in enumerate(stracksa) if not i in dupa] 334 | resb = [t for i, t in enumerate(stracksb) if not i in dupb] 335 | return resa, resb 336 | -------------------------------------------------------------------------------- /tracker/kalman_filter.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import scipy.linalg 4 | 5 | 6 | """ 7 | Table for the 0.95 quantile of the chi-square distribution with N degrees of 8 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv 9 | function and used as Mahalanobis gating threshold. 10 | """ 11 | chi2inv95 = { 12 | 1: 3.8415, 13 | 2: 5.9915, 14 | 3: 7.8147, 15 | 4: 9.4877, 16 | 5: 11.070, 17 | 6: 12.592, 18 | 7: 14.067, 19 | 8: 15.507, 20 | 9: 16.919} 21 | 22 | 23 | class KalmanFilter(object): 24 | """ 25 | A simple Kalman filter for tracking bounding boxes in image space. 26 | 27 | The 8-dimensional state space 28 | 29 | x, y, a, h, vx, vy, va, vh 30 | 31 | contains the bounding box center position (x, y), aspect ratio a, height h, 32 | and their respective velocities. 33 | 34 | Object motion follows a constant velocity model. The bounding box location 35 | (x, y, a, h) is taken as direct observation of the state space (linear 36 | observation model). 37 | 38 | """ 39 | 40 | def __init__(self): 41 | ndim, dt = 4, 1. 42 | 43 | # Create Kalman filter model matrices. 44 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 45 | for i in range(ndim): 46 | self._motion_mat[i, ndim + i] = dt 47 | self._update_mat = np.eye(ndim, 2 * ndim) 48 | 49 | # Motion and observation uncertainty are chosen relative to the current 50 | # state estimate. These weights control the amount of uncertainty in 51 | # the model. This is a bit hacky. 52 | self._std_weight_position = 1. / 20 53 | self._std_weight_velocity = 1. / 160 54 | 55 | def initiate(self, measurement): 56 | """Create track from unassociated measurement. 57 | 58 | Parameters 59 | ---------- 60 | measurement : ndarray 61 | Bounding box coordinates (x, y, a, h) with center position (x, y), 62 | aspect ratio a, and height h. 63 | 64 | Returns 65 | ------- 66 | (ndarray, ndarray) 67 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 68 | dimensional) of the new track. Unobserved velocities are initialized 69 | to 0 mean. 70 | 71 | """ 72 | mean_pos = measurement 73 | mean_vel = np.zeros_like(mean_pos) 74 | mean = np.r_[mean_pos, mean_vel] 75 | 76 | std = [ 77 | 2 * self._std_weight_position * measurement[3], 78 | 2 * self._std_weight_position * measurement[3], 79 | 1e-2, 80 | 2 * self._std_weight_position * measurement[3], 81 | 10 * self._std_weight_velocity * measurement[3], 82 | 10 * self._std_weight_velocity * measurement[3], 83 | 1e-5, 84 | 10 * self._std_weight_velocity * measurement[3]] 85 | covariance = np.diag(np.square(std)) 86 | return mean, covariance 87 | 88 | def predict(self, mean, covariance): 89 | """Run Kalman filter prediction step. 90 | 91 | Parameters 92 | ---------- 93 | mean : ndarray 94 | The 8 dimensional mean vector of the object state at the previous 95 | time step. 96 | covariance : ndarray 97 | The 8x8 dimensional covariance matrix of the object state at the 98 | previous time step. 99 | 100 | Returns 101 | ------- 102 | (ndarray, ndarray) 103 | Returns the mean vector and covariance matrix of the predicted 104 | state. Unobserved velocities are initialized to 0 mean. 105 | 106 | """ 107 | std_pos = [ 108 | self._std_weight_position * mean[3], 109 | self._std_weight_position * mean[3], 110 | 1e-2, 111 | self._std_weight_position * mean[3]] 112 | std_vel = [ 113 | self._std_weight_velocity * mean[3], 114 | self._std_weight_velocity * mean[3], 115 | 1e-5, 116 | self._std_weight_velocity * mean[3]] 117 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 118 | 119 | #mean = np.dot(self._motion_mat, mean) 120 | mean = np.dot(mean, self._motion_mat.T) 121 | covariance = np.linalg.multi_dot(( 122 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 123 | 124 | return mean, covariance 125 | 126 | def project(self, mean, covariance): 127 | """Project state distribution to measurement space. 128 | 129 | Parameters 130 | ---------- 131 | mean : ndarray 132 | The state's mean vector (8 dimensional array). 133 | covariance : ndarray 134 | The state's covariance matrix (8x8 dimensional). 135 | 136 | Returns 137 | ------- 138 | (ndarray, ndarray) 139 | Returns the projected mean and covariance matrix of the given state 140 | estimate. 141 | 142 | """ 143 | std = [ 144 | self._std_weight_position * mean[3], 145 | self._std_weight_position * mean[3], 146 | 1e-1, 147 | self._std_weight_position * mean[3]] 148 | innovation_cov = np.diag(np.square(std)) 149 | 150 | mean = np.dot(self._update_mat, mean) 151 | covariance = np.linalg.multi_dot(( 152 | self._update_mat, covariance, self._update_mat.T)) 153 | return mean, covariance + innovation_cov 154 | 155 | def multi_predict(self, mean, covariance): 156 | """Run Kalman filter prediction step (Vectorized version). 157 | Parameters 158 | ---------- 159 | mean : ndarray 160 | The Nx8 dimensional mean matrix of the object states at the previous 161 | time step. 162 | covariance : ndarray 163 | The Nx8x8 dimensional covariance matrics of the object states at the 164 | previous time step. 165 | Returns 166 | ------- 167 | (ndarray, ndarray) 168 | Returns the mean vector and covariance matrix of the predicted 169 | state. Unobserved velocities are initialized to 0 mean. 170 | """ 171 | std_pos = [ 172 | self._std_weight_position * mean[:, 3], 173 | self._std_weight_position * mean[:, 3], 174 | 1e-2 * np.ones_like(mean[:, 3]), 175 | self._std_weight_position * mean[:, 3]] 176 | std_vel = [ 177 | self._std_weight_velocity * mean[:, 3], 178 | self._std_weight_velocity * mean[:, 3], 179 | 1e-5 * np.ones_like(mean[:, 3]), 180 | self._std_weight_velocity * mean[:, 3]] 181 | sqr = np.square(np.r_[std_pos, std_vel]).T 182 | 183 | motion_cov = [] 184 | for i in range(len(mean)): 185 | motion_cov.append(np.diag(sqr[i])) 186 | motion_cov = np.asarray(motion_cov) 187 | 188 | mean = np.dot(mean, self._motion_mat.T) 189 | left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) 190 | covariance = np.dot(left, self._motion_mat.T) + motion_cov 191 | 192 | return mean, covariance 193 | 194 | def update(self, mean, covariance, measurement): 195 | """Run Kalman filter correction step. 196 | 197 | Parameters 198 | ---------- 199 | mean : ndarray 200 | The predicted state's mean vector (8 dimensional). 201 | covariance : ndarray 202 | The state's covariance matrix (8x8 dimensional). 203 | measurement : ndarray 204 | The 4 dimensional measurement vector (x, y, a, h), where (x, y) 205 | is the center position, a the aspect ratio, and h the height of the 206 | bounding box. 207 | 208 | Returns 209 | ------- 210 | (ndarray, ndarray) 211 | Returns the measurement-corrected state distribution. 212 | 213 | """ 214 | projected_mean, projected_cov = self.project(mean, covariance) 215 | 216 | chol_factor, lower = scipy.linalg.cho_factor( 217 | projected_cov, lower=True, check_finite=False) 218 | kalman_gain = scipy.linalg.cho_solve( 219 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, 220 | check_finite=False).T 221 | innovation = measurement - projected_mean 222 | 223 | new_mean = mean + np.dot(innovation, kalman_gain.T) 224 | new_covariance = covariance - np.linalg.multi_dot(( 225 | kalman_gain, projected_cov, kalman_gain.T)) 226 | return new_mean, new_covariance 227 | 228 | def gating_distance(self, mean, covariance, measurements, 229 | only_position=False, metric='maha'): 230 | """Compute gating distance between state distribution and measurements. 231 | A suitable distance threshold can be obtained from `chi2inv95`. If 232 | `only_position` is False, the chi-square distribution has 4 degrees of 233 | freedom, otherwise 2. 234 | Parameters 235 | ---------- 236 | mean : ndarray 237 | Mean vector over the state distribution (8 dimensional). 238 | covariance : ndarray 239 | Covariance of the state distribution (8x8 dimensional). 240 | measurements : ndarray 241 | An Nx4 dimensional matrix of N measurements, each in 242 | format (x, y, a, h) where (x, y) is the bounding box center 243 | position, a the aspect ratio, and h the height. 244 | only_position : Optional[bool] 245 | If True, distance computation is done with respect to the bounding 246 | box center position only. 247 | Returns 248 | ------- 249 | ndarray 250 | Returns an array of length N, where the i-th element contains the 251 | squared Mahalanobis distance between (mean, covariance) and 252 | `measurements[i]`. 253 | """ 254 | mean, covariance = self.project(mean, covariance) 255 | if only_position: 256 | mean, covariance = mean[:2], covariance[:2, :2] 257 | measurements = measurements[:, :2] 258 | 259 | d = measurements - mean 260 | if metric == 'gaussian': 261 | return np.sum(d * d, axis=1) 262 | elif metric == 'maha': 263 | cholesky_factor = np.linalg.cholesky(covariance) 264 | z = scipy.linalg.solve_triangular( 265 | cholesky_factor, d.T, lower=True, check_finite=False, 266 | overwrite_b=True) 267 | squared_maha = np.sum(z * z, axis=0) 268 | return squared_maha 269 | else: 270 | raise ValueError('invalid distance metric') -------------------------------------------------------------------------------- /tracker/matching.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy 4 | import lap 5 | from scipy.spatial.distance import cdist 6 | 7 | from cython_bbox import bbox_overlaps as bbox_ious 8 | from tracker.kalman_filter import KalmanFilter 9 | import time 10 | 11 | def merge_matches(m1, m2, shape): 12 | O,P,Q = shape 13 | m1 = np.asarray(m1) 14 | m2 = np.asarray(m2) 15 | 16 | M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) 17 | M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) 18 | 19 | mask = M1*M2 20 | match = mask.nonzero() 21 | match = list(zip(match[0], match[1])) 22 | unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) 23 | unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) 24 | 25 | return match, unmatched_O, unmatched_Q 26 | 27 | 28 | def _indices_to_matches(cost_matrix, indices, thresh): 29 | matched_cost = cost_matrix[tuple(zip(*indices))] 30 | matched_mask = (matched_cost <= thresh) 31 | 32 | matches = indices[matched_mask] 33 | unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) 34 | unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) 35 | 36 | return matches, unmatched_a, unmatched_b 37 | 38 | 39 | def linear_assignment(cost_matrix, thresh): 40 | if cost_matrix.size == 0: 41 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) 42 | matches, unmatched_a, unmatched_b = [], [], [] 43 | cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) 44 | for ix, mx in enumerate(x): 45 | if mx >= 0: 46 | matches.append([ix, mx]) 47 | unmatched_a = np.where(x < 0)[0] 48 | unmatched_b = np.where(y < 0)[0] 49 | matches = np.asarray(matches) 50 | return matches, unmatched_a, unmatched_b 51 | 52 | 53 | def ious(atlbrs, btlbrs): 54 | """ 55 | Compute cost based on IoU 56 | :type atlbrs: list[tlbr] | np.ndarray 57 | :type atlbrs: list[tlbr] | np.ndarray 58 | 59 | :rtype ious np.ndarray 60 | """ 61 | ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float64) 62 | if ious.size == 0: 63 | return ious 64 | 65 | ious = bbox_ious( 66 | np.ascontiguousarray(atlbrs, dtype=np.float64), 67 | np.ascontiguousarray(btlbrs, dtype=np.float64) 68 | ) 69 | 70 | return ious 71 | 72 | 73 | def iou_distance(atracks, btracks): 74 | """ 75 | Compute cost based on IoU 76 | :type atracks: list[STrack] 77 | :type btracks: list[STrack] 78 | 79 | :rtype cost_matrix np.ndarray 80 | """ 81 | 82 | if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): 83 | atlbrs = atracks 84 | btlbrs = btracks 85 | else: 86 | atlbrs = [track.tlbr for track in atracks] 87 | btlbrs = [track.tlbr for track in btracks] 88 | _ious = ious(atlbrs, btlbrs) 89 | cost_matrix = 1 - _ious 90 | 91 | return cost_matrix 92 | 93 | def v_iou_distance(atracks, btracks): 94 | """ 95 | Compute cost based on IoU 96 | :type atracks: list[STrack] 97 | :type btracks: list[STrack] 98 | 99 | :rtype cost_matrix np.ndarray 100 | """ 101 | 102 | if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): 103 | atlbrs = atracks 104 | btlbrs = btracks 105 | else: 106 | atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks] 107 | btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks] 108 | _ious = ious(atlbrs, btlbrs) 109 | cost_matrix = 1 - _ious 110 | 111 | return cost_matrix 112 | 113 | def embedding_distance(tracks, detections, metric='cosine'): 114 | """ 115 | :param tracks: list[STrack] 116 | :param detections: list[BaseTrack] 117 | :param metric: 118 | :return: cost_matrix np.ndarray 119 | """ 120 | 121 | cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float64) 122 | if cost_matrix.size == 0: 123 | return cost_matrix 124 | det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float64) 125 | #for i, track in enumerate(tracks): 126 | #cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) 127 | track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float64) 128 | cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features 129 | return cost_matrix 130 | 131 | 132 | def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): 133 | if cost_matrix.size == 0: 134 | return cost_matrix 135 | gating_dim = 2 if only_position else 4 136 | gating_threshold = KalmanFilter.chi2inv95[gating_dim] 137 | measurements = np.asarray([det.to_xyah() for det in detections]) 138 | for row, track in enumerate(tracks): 139 | gating_distance = kf.gating_distance( 140 | track.mean, track.covariance, measurements, only_position) 141 | cost_matrix[row, gating_distance > gating_threshold] = np.inf 142 | return cost_matrix 143 | 144 | 145 | def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): 146 | if cost_matrix.size == 0: 147 | return cost_matrix 148 | gating_dim = 2 if only_position else 4 149 | gating_threshold = KalmanFilter.chi2inv95[gating_dim] 150 | measurements = np.asarray([det.to_xyah() for det in detections]) 151 | for row, track in enumerate(tracks): 152 | gating_distance = kf.gating_distance( 153 | track.mean, track.covariance, measurements, only_position, metric='maha') 154 | cost_matrix[row, gating_distance > gating_threshold] = np.inf 155 | cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance 156 | return cost_matrix 157 | 158 | 159 | def fuse_iou(cost_matrix, tracks, detections): 160 | if cost_matrix.size == 0: 161 | return cost_matrix 162 | reid_sim = 1 - cost_matrix 163 | iou_dist = iou_distance(tracks, detections) 164 | iou_sim = 1 - iou_dist 165 | fuse_sim = reid_sim * (1 + iou_sim) / 2 166 | det_scores = np.array([det.score for det in detections]) 167 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) 168 | #fuse_sim = fuse_sim * (1 + det_scores) / 2 169 | fuse_cost = 1 - fuse_sim 170 | return fuse_cost 171 | 172 | 173 | def fuse_score(cost_matrix, detections): 174 | if cost_matrix.size == 0: 175 | return cost_matrix 176 | iou_sim = 1 - cost_matrix 177 | det_scores = np.array([det.score for det in detections]) 178 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) 179 | fuse_sim = iou_sim * det_scores 180 | fuse_cost = 1 - fuse_sim 181 | return fuse_cost -------------------------------------------------------------------------------- /tracking.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | from tracker.byte_tracker import BYTETracker 4 | from Ship import Ship_manager 5 | from detect import detectImg 6 | import numpy as np 7 | 8 | 9 | # Task tracking 10 | def Tracking(args, model): 11 | ''' 12 | Args: 13 | imgsz (int): Input of size image. Defaut: 640 14 | input (str): Path of input data. Defaut: 337.png 15 | output (str): Path of output data. Defaut: output 16 | model (str): Path of model. Default: ./Model/Boat-detect-medium.pt 17 | conf (float): Score confidence. Default: 0.6 18 | iou_threshold (float): IOU threshold. Default: 0.5 19 | video (bool): Input is video. Default: False 20 | detect (bool): Task is detection. Default: False 21 | tracking (bool): Task is tracking. Default: False 22 | track_buffer (int): buffer to calculate the time when to remove tracks. Default: 30 23 | match_thresh (float): Matching threshold for tracking in bytetrack. Default: 0.5 24 | time_check_state (float): Time to update state of ship (second). Default: 1.5 25 | 26 | ''' 27 | #Read video 28 | tracker = BYTETracker(args) 29 | ArrayShip = Ship_manager(args.track_buffer) 30 | video = cv2.VideoCapture(args.input) 31 | FPS = 30 32 | if (video.isOpened() == False): 33 | print("Error reading video file") 34 | frame_width = int(video.get(3)) 35 | frame_height = int(video.get(4)) 36 | size = (frame_width, frame_height) 37 | print(size) 38 | result = cv2.VideoWriter("./Data/Output/"+args.output+".avi", 39 | cv2.VideoWriter_fourcc(*'MJPG'), 40 | FPS, size) 41 | since = time.time() 42 | frame_count = 0 43 | 44 | while (True): 45 | ret, frame = video.read() 46 | if ret == True: 47 | # Return bbox and conf from model 48 | boxes, conf = detectImg(model,frame, args.conf, args.iou_threshold) 49 | online_targets = tracker.update(conf, boxes, size, size) 50 | 51 | ArrayShip.update(online_targets, frame_count) 52 | if frame_count % (FPS*1.5) == 0: 53 | ArrayShip.check_state() 54 | ArrayShip.update_bbox(online_targets) 55 | frame_count += 1 56 | 57 | for i in ArrayShip.list_ship: 58 | if i.is_activate == True: 59 | x = int(i.bbox[0]+i.bbox[2]) 60 | y = int(i.bbox[1]+i.bbox[3]) 61 | state = "MOVING" if i.is_move == True else "STOP" 62 | frame = cv2.rectangle( 63 | frame, (int(i.bbox[0]), int(i.bbox[1])), (x, y), (255, 0, 0), 3) 64 | cv2.putText(frame, 'conf= '+str(i.score), (int(i.bbox[0]), int(i.bbox[1]-3)), 65 | cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 0), 2, cv2.LINE_AA) 66 | cv2.putText(frame, 'id= '+str(i.track_id), (int(i.bbox[0]), y+3), 67 | cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 0), 2, cv2.LINE_AA) 68 | if i.is_move == True: 69 | state += ": 0 s" 70 | else: 71 | elapsed = time.time() - since 72 | process_fps = frame_count / elapsed 73 | ratio = process_fps/FPS 74 | timeend = time.time() 75 | state += ": " + \ 76 | str(np.round_((timeend-i.timestart) 77 | * ratio, decimals=2))+" s" 78 | cv2.putText(frame, 'time: '+str(state), (x, y+3), 79 | cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 0), 2, cv2.LINE_AA) 80 | result.write(frame) 81 | cv2.imshow("a", frame) 82 | if cv2.waitKey(1) & 0xFF == ord("c"): 83 | break 84 | else: 85 | break 86 | video.release() 87 | result.release() 88 | cv2.destroyAllWindows() 89 | print("The video was successfully tracked") 90 | print("The video was successfully saved") --------------------------------------------------------------------------------