├── README.md ├── config.yaml ├── main.py ├── requirements.txt ├── shot_detector.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Basketball shot detection 2 | A project with YOLOv8 to detect and analyze basketball shots in real-time. The algorithm tracks the ball's motion, applies data-cleaning techniques, and predicts its trajectory using linear regression to register successful shots when intersecting with the hoop. It enhances playing experience and offers game analytics. 3 | 4 | https://github.com/nitinhemaraj/AI-basketball-shot-detection/assets/90787449/2c42261d-0b8b-4a95-a4c1-e67875397c21 5 | 6 | ## Introduction 7 | This project combines the power of Machine Learning and Computer Vision for the purpose of detecting and analyzing basketball shots in real-time! Built upon the latest YOLOv8 (You Only Look Once) machine learning model and the OpenCV library, the program can process video streams from various sources, such as live webcam feed or pre-recorded videos, providing a tool that can be used for an immersive playing experience and enhanced game analytics. 8 | 9 | ## Model Training 10 | The training process utilizes the ultralytics YOLO implementation and a custom dataset specified in the 'config.yaml' file. The model undergoes a set number of training epochs, with the resulting weights of the best-performing model saved for subsequent usage in shot detection. Although this model worked for my usage, a different dataset or training method might work better for your specific project. 11 | 12 | ## Algorithm 13 | The core of this project is an algorithm that uses the trained YOLOv8 model to detect basketballs and hoops in each frame. It then analyzes the motion and position of the basketball relative to the hoop to determine if a shot has been made. 14 | 15 | To enhance the accuracy of the shot detection, the algorithm not only tracks the ball's position over time but also applies data-cleaning techniques to both the ball and hoop positions. The algorithm is designed to filter out inaccurate data points, remove points beyond a certain frame limit and prevent jumping from one object to another to maintain the accuracy of the detection. 16 | 17 | A linear regression is used to predict the ball's trajectory based on its positions. If the projected trajectory intersects with the hoop, the algorithm registers it as a successful shot. 18 | 19 | ## How to Use This Code 20 | 1. Clone this repository to your local machine. 21 | 2. Download the dataset specified in 'config.yaml' and adjust the paths in the configuration file to match your local setup. 22 | 3. Follow the instructions in 'main.py' to train the model and prepare for shot detection. 23 | 4. Run 'shot_detector.py' through your webcam or a phone for real-time shot detection. Or input a video for shot detection analysis. 24 | 5. Please ensure you have the required Python packages installed, including OpenCV, numpy, and ultralytics' YOLO. 25 | 26 | ### Score Detection Accuracy: 95% 27 | ### Shot Detection Accuracy: 97% 28 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | train: ...\train\images 2 | val: ...\valid\images 3 | test: ...\test\images 4 | 5 | # replace '...' with your local absolute paths 6 | # download the dataset used here: https://universe.roboflow.com/034-ganesh-kumar-m-v-cs-r2lwe/basketball-lhqoe/dataset/1 7 | 8 | nc: 2 9 | names: ['Basketball', 'Basketball Hoop'] 10 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO 2 | 3 | if __name__ == "__main__": 4 | 5 | # 0. There are a lot of requirements to create/use OpenCV projects. I used the following video to get started 6 | # https://youtu.be/WgPbbWmnXJ8 7 | 8 | # 1. Download the dataset specified in config.yaml 9 | # 2. Put the downloaded folders into your project and change absolute paths in config.yaml 10 | # 3. Change the following text to the correct relative paths for your project 11 | # 4. Run this 12 | # 5. Go to runs/detect/train/weights and use best.pt as the model while running shot_detector.py 13 | 14 | # Load a model 15 | model = YOLO('Yolo-Weights/yolov8n.pt') 16 | 17 | # Train the model 18 | results = model.train(data='config.yaml', epochs=100, imgsz=640) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Object Detection Requirements 2 | 3 | cvzone==1.5.6 4 | ultralytics==8.0.26 5 | hydra-core>=1.2.0 6 | matplotlib>=3.2.2 7 | numpy>=1.18.5 8 | opencv-python==4.5.4.60 9 | Pillow>=7.1.2 10 | PyYAML>=5.3.1 11 | requests>=2.23.0 12 | scipy>=1.4.1 13 | torch>=1.7.0 14 | torchvision>=0.8.1 15 | tqdm>=4.64.0 16 | filterpy==1.4.5 17 | scikit-image==0.19.3 18 | lap==0.4.0 -------------------------------------------------------------------------------- /shot_detector.py: -------------------------------------------------------------------------------- 1 | # Avi Shah - Basketball Shot Detector/Tracker - July 2023 2 | 3 | from ultralytics import YOLO 4 | import cv2 5 | import cvzone 6 | import math 7 | import numpy as np 8 | from utils import score, detect_down, detect_up, in_hoop_region, clean_hoop_pos, clean_ball_pos 9 | 10 | 11 | class ShotDetector: 12 | def __init__(self): 13 | # Load the YOLO model created from main.py - change text to your relative path 14 | self.model = YOLO("best.pt") 15 | self.class_names = ['Basketball', 'Basketball Hoop'] 16 | 17 | # Uncomment line below to use webcam (I streamed to my iPhone using Iriun Webcam) 18 | # self.cap = cv2.VideoCapture(0) 19 | 20 | # Use video - replace text with your video path 21 | self.cap = cv2.VideoCapture("video_test_5.mp4") 22 | 23 | self.ball_pos = [] # array of tuples ((x_pos, y_pos), frame count, width, height, conf) 24 | self.hoop_pos = [] # array of tuples ((x_pos, y_pos), frame count, width, height, conf) 25 | 26 | self.frame_count = 0 27 | self.frame = None 28 | 29 | self.makes = 0 30 | self.attempts = 0 31 | 32 | # Used to detect shots (upper and lower region) 33 | self.up = False 34 | self.down = False 35 | self.up_frame = 0 36 | self.down_frame = 0 37 | 38 | # Used for green and red colors after make/miss 39 | self.fade_frames = 20 40 | self.fade_counter = 0 41 | self.overlay_color = (0, 0, 0) 42 | 43 | self.run() 44 | 45 | def run(self): 46 | while True: 47 | ret, self.frame = self.cap.read() 48 | 49 | if not ret: 50 | # End of the video or an error occurred 51 | break 52 | 53 | results = self.model(self.frame, stream=True) 54 | 55 | for r in results: 56 | boxes = r.boxes 57 | for box in boxes: 58 | # Bounding box 59 | x1, y1, x2, y2 = box.xyxy[0] 60 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 61 | w, h = x2 - x1, y2 - y1 62 | 63 | # Confidence 64 | conf = math.ceil((box.conf[0] * 100)) / 100 65 | 66 | # Class Name 67 | cls = int(box.cls[0]) 68 | current_class = self.class_names[cls] 69 | 70 | center = (int(x1 + w / 2), int(y1 + h / 2)) 71 | 72 | # Only create ball points if high confidence or near hoop 73 | if (conf > .3 or (in_hoop_region(center, self.hoop_pos) and conf > 0.15)) and current_class == "Basketball": 74 | self.ball_pos.append((center, self.frame_count, w, h, conf)) 75 | cvzone.cornerRect(self.frame, (x1, y1, w, h)) 76 | 77 | # Create hoop points if high confidence 78 | if conf > .5 and current_class == "Basketball Hoop": 79 | self.hoop_pos.append((center, self.frame_count, w, h, conf)) 80 | cvzone.cornerRect(self.frame, (x1, y1, w, h)) 81 | 82 | self.clean_motion() 83 | self.shot_detection() 84 | self.display_score() 85 | self.frame_count += 1 86 | 87 | cv2.imshow('Frame', self.frame) 88 | 89 | # Close if 'q' is clicked 90 | if cv2.waitKey(1) & 0xFF == ord('q'): # higher waitKey slows video down, use 1 for webcam 91 | break 92 | 93 | self.cap.release() 94 | cv2.destroyAllWindows() 95 | 96 | def clean_motion(self): 97 | # Clean and display ball motion 98 | self.ball_pos = clean_ball_pos(self.ball_pos, self.frame_count) 99 | for i in range(0, len(self.ball_pos)): 100 | cv2.circle(self.frame, self.ball_pos[i][0], 2, (0, 0, 255), 2) 101 | 102 | # Clean hoop motion and display current hoop center 103 | if len(self.hoop_pos) > 1: 104 | self.hoop_pos = clean_hoop_pos(self.hoop_pos) 105 | cv2.circle(self.frame, self.hoop_pos[-1][0], 2, (128, 128, 0), 2) 106 | 107 | def shot_detection(self): 108 | if len(self.hoop_pos) > 0 and len(self.ball_pos) > 0: 109 | # Detecting when ball is in 'up' and 'down' area - ball can only be in 'down' area after it is in 'up' 110 | if not self.up: 111 | self.up = detect_up(self.ball_pos, self.hoop_pos) 112 | if self.up: 113 | self.up_frame = self.ball_pos[-1][1] 114 | 115 | if self.up and not self.down: 116 | self.down = detect_down(self.ball_pos, self.hoop_pos) 117 | if self.down: 118 | self.down_frame = self.ball_pos[-1][1] 119 | 120 | # If ball goes from 'up' area to 'down' area in that order, increase attempt and reset 121 | if self.frame_count % 10 == 0: 122 | if self.up and self.down and self.up_frame < self.down_frame: 123 | self.attempts += 1 124 | self.up = False 125 | self.down = False 126 | 127 | # If it is a make, put a green overlay 128 | if score(self.ball_pos, self.hoop_pos): 129 | self.makes += 1 130 | self.overlay_color = (0, 255, 0) 131 | self.fade_counter = self.fade_frames 132 | 133 | # If it is a miss, put a red overlay 134 | else: 135 | self.overlay_color = (0, 0, 255) 136 | self.fade_counter = self.fade_frames 137 | 138 | def display_score(self): 139 | # Add text 140 | text = str(self.makes) + " / " + str(self.attempts) 141 | cv2.putText(self.frame, text, (50, 125), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 255, 255), 6) 142 | cv2.putText(self.frame, text, (50, 125), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 0), 3) 143 | 144 | # Gradually fade out color after shot 145 | if self.fade_counter > 0: 146 | alpha = 0.2 * (self.fade_counter / self.fade_frames) 147 | self.frame = cv2.addWeighted(self.frame, 1 - alpha, np.full_like(self.frame, self.overlay_color), alpha, 0) 148 | self.fade_counter -= 1 149 | 150 | 151 | if __name__ == "__main__": 152 | ShotDetector() 153 | 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | def score(ball_pos, hoop_pos): 6 | x = [] 7 | y = [] 8 | rim_height = hoop_pos[-1][0][1] - 0.5 * hoop_pos[-1][3] 9 | 10 | # Get first point above rim and first point below rim 11 | for i in reversed(range(len(ball_pos))): 12 | if ball_pos[i][0][1] < rim_height: 13 | x.append(ball_pos[i][0][0]) 14 | y.append(ball_pos[i][0][1]) 15 | x.append(ball_pos[i+1][0][0]) 16 | y.append(ball_pos[i+1][0][1]) 17 | break 18 | 19 | # Create line from two points 20 | if len(x) > 1: 21 | m, b = np.polyfit(x, y, 1) 22 | print(x, y) 23 | # Checks if projected line fits between the ends of the rim {x = (y-b)/m} 24 | predicted_x = ((hoop_pos[-1][0][1] - 0.5*hoop_pos[-1][3]) - b)/m 25 | rim_x1 = hoop_pos[-1][0][0] - 0.4 * hoop_pos[-1][2] 26 | rim_x2 = hoop_pos[-1][0][0] + 0.4 * hoop_pos[-1][2] 27 | if rim_x1 < predicted_x < rim_x2: 28 | return True 29 | 30 | 31 | # Detects if the ball is below the net - used to detect shot attempts 32 | def detect_down(ball_pos, hoop_pos): 33 | y = hoop_pos[-1][0][1] + 0.5 * hoop_pos[-1][3] 34 | if ball_pos[-1][0][1] > y: 35 | return True 36 | return False 37 | 38 | 39 | # Detects if the ball is around the backboard - used to detect shot attempts 40 | def detect_up(ball_pos, hoop_pos): 41 | x1 = hoop_pos[-1][0][0] - 4 * hoop_pos[-1][2] 42 | x2 = hoop_pos[-1][0][0] + 4 * hoop_pos[-1][2] 43 | y1 = hoop_pos[-1][0][1] - 2 * hoop_pos[-1][3] 44 | y2 = hoop_pos[-1][0][1] 45 | 46 | if x1 < ball_pos[-1][0][0] < x2 and y1 < ball_pos[-1][0][1] < y2 - 0.5 * hoop_pos[-1][3]: 47 | return True 48 | return False 49 | 50 | 51 | # Checks if center point is near the hoop 52 | def in_hoop_region(center, hoop_pos): 53 | if len(hoop_pos) < 1: 54 | return False 55 | x = center[0] 56 | y = center[1] 57 | 58 | x1 = hoop_pos[-1][0][0] - 1 * hoop_pos[-1][2] 59 | x2 = hoop_pos[-1][0][0] + 1 * hoop_pos[-1][2] 60 | y1 = hoop_pos[-1][0][1] - 1 * hoop_pos[-1][3] 61 | y2 = hoop_pos[-1][0][1] + 0.5 * hoop_pos[-1][3] 62 | 63 | if x1 < x < x2 and y1 < y < y2: 64 | return True 65 | return False 66 | 67 | 68 | # Removes inaccurate data points 69 | def clean_ball_pos(ball_pos, frame_count): 70 | # Removes inaccurate ball size to prevent jumping to wrong ball 71 | if len(ball_pos) > 1: 72 | # Width and Height 73 | w1 = ball_pos[-2][2] 74 | h1 = ball_pos[-2][3] 75 | w2 = ball_pos[-1][2] 76 | h2 = ball_pos[-1][3] 77 | 78 | # X and Y coordinates 79 | x1 = ball_pos[-2][0][0] 80 | y1 = ball_pos[-2][0][1] 81 | x2 = ball_pos[-1][0][0] 82 | y2 = ball_pos[-1][0][1] 83 | 84 | # Frame count 85 | f1 = ball_pos[-2][1] 86 | f2 = ball_pos[-1][1] 87 | f_dif = f2 - f1 88 | 89 | dist = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) 90 | 91 | max_dist = 4 * math.sqrt((w1) ** 2 + (h1) ** 2) 92 | 93 | # Ball should not move a 4x its diameter within 5 frames 94 | if (dist > max_dist) and (f_dif < 5): 95 | ball_pos.pop() 96 | 97 | # Ball should be relatively square 98 | elif (w2*1.4 < h2) or (h2*1.4 < w2): 99 | ball_pos.pop() 100 | 101 | # Remove points older than 30 frames 102 | if len(ball_pos) > 0: 103 | if frame_count - ball_pos[0][1] > 30: 104 | ball_pos.pop(0) 105 | 106 | return ball_pos 107 | 108 | 109 | def clean_hoop_pos(hoop_pos): 110 | # Prevents jumping from one hoop to another 111 | if len(hoop_pos) > 1: 112 | x1 = hoop_pos[-2][0][0] 113 | y1 = hoop_pos[-2][0][1] 114 | x2 = hoop_pos[-1][0][0] 115 | y2 = hoop_pos[-1][0][1] 116 | 117 | w1 = hoop_pos[-2][2] 118 | h1 = hoop_pos[-2][3] 119 | w2 = hoop_pos[-1][2] 120 | h2 = hoop_pos[-1][3] 121 | 122 | f1 = hoop_pos[-2][1] 123 | f2 = hoop_pos[-1][1] 124 | 125 | f_dif = f2-f1 126 | 127 | dist = math.sqrt((x2-x1)**2 + (y2-y1)**2) 128 | 129 | max_dist = 0.5 * math.sqrt(w1 ** 2 + h1 ** 2) 130 | 131 | # Hoop should not move 0.5x its diameter within 5 frames 132 | if dist > max_dist and f_dif < 5: 133 | hoop_pos.pop() 134 | 135 | # Hoop should be relatively square 136 | if (w2*1.3 < h2) or (h2*1.3 < w2): 137 | hoop_pos.pop() 138 | 139 | # Remove old points 140 | if len(hoop_pos) > 25: 141 | hoop_pos.pop(0) 142 | 143 | return hoop_pos 144 | --------------------------------------------------------------------------------