├── .dockerignore ├── .gitattributes ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yml ├── docs ├── Square2.png ├── box2.png ├── box3.png ├── client_server.png ├── demo1.gif ├── demo2.gif ├── logo.png ├── screenshot.png └── square.jpg ├── main.py ├── ml ├── detector.py ├── opencv │ ├── haar_cascades.py │ └── opencv_detector.py ├── tensorflow │ ├── ssd.py │ └── tf_detector.py └── torch │ ├── torch_detector.py │ └── yolo.py ├── requirements.txt ├── server ├── AiServer.py ├── MssClient.py ├── QtServer.py ├── http_server.py └── local_starter.py ├── tests ├── noCI │ ├── test_main.py │ ├── test_server_mssclient.py │ ├── test_server_qtserver.py │ └── test_utils_screen_overlay_handler.py ├── test_ml_detector.py ├── test_ml_opencv.py ├── test_ml_opencv_haar.py ├── test_ml_tensorflow.py ├── test_ml_torch.py ├── test_ml_torch_yolo.py ├── test_server_ai.py ├── test_server_http.py ├── test_utils_shared_variables.py ├── test_utils_threadpool.py └── test_utils_tracking.py └── utils ├── ThreadPool.py ├── screen_overlay_handler.py ├── shared_variables.py └── tracking.py /.dockerignore: -------------------------------------------------------------------------------- 1 | ./docs -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *old_remove 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # Ignore ml models 119 | *.h5 120 | *.pt 121 | *.log -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 2 | # The purpose if this dockerfile is to build main project for CI/CD 3 | # 4 | FROM python:3.8 5 | 6 | RUN mkdir /app 7 | WORKDIR /app 8 | 9 | RUN apt-get update && \ 10 | apt-get install -y libgl1-mesa-glx && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists/* 13 | 14 | COPY ./requirements.txt /requirements.txt 15 | RUN pip install -r /requirements.txt 16 | 17 | #RUN yolo settings sync=false 18 | 19 | COPY . . 20 | 21 | CMD ["python3", "/app/main.py"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Grebtsew 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 | 5 |

6 | 7 | ![license](https://img.shields.io/github/license/grebtsew/Realtime-Screen-Object-Detection) 8 | ![size](https://img.shields.io/github/repo-size/grebtsew/Realtime-Screen-Object-Detection) 9 | ![commit](https://img.shields.io/github/last-commit/grebtsew/Realtime-Screen-Object-Detection) 10 | 11 | ![demo](docs/demo2.gif) 12 | 13 | `NOTE:` I use low framerate and low precision object detection to save some 14 | power on my laptop, feel free to change models, detection and remove loop delays. 15 | 16 | # About 17 | RSOD (Realtime Screen Object Detection) is a program I created to visualize 18 | realtime object detections on screen capture. To capture pc screen I use the `mss` 19 | library. Detections are visualized through `PyQt5` as overlays. Feel free to try it out 20 | and change setting variables in the `main.py` file. Possible usages for this program can 21 | be to track and discover objects on screen to assist users or perhaps to enhanced 22 | the general user experience. 23 | 24 | # Install 25 | This program is written for **python3**. The required python3 libraries are the following: 26 | 27 | * `PyQt5` (used to overlay screen to show detections) 28 | * `Tensorflow` (object detection) 29 | * `mss` (capture screen stream) 30 | * `Image` 31 | * `opencv-python` 32 | * `pyfiglet` 33 | * `numpy` 34 | 35 | If you have `Python3 pip` installed you can install required packages by running: 36 | 37 | ```bash 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | After installing pip packages. 42 | Run this to stop ultralytics library from sending data to google analytics: 43 | ```bash 44 | yolo settings sync=false 45 | ``` 46 | 47 | 48 | # Run 49 | This tutorial takes you through the execution of this program. 50 | 51 | 1. Install requirements above 52 | 2. Clone this repo or download it: 53 | ```git 54 | git clone https://github.com/grebtsew/Realtime-Screen-Object-Detection.git 55 | ``` 56 | 3. Run python script `main.py` 57 | 58 | # Change Settings 59 | 60 | # Client Server 61 | A client/server solution has been added to further increase performability and accelerate detections. 62 | Test the implementation by running the `starter.py` file. The structure is explained below. 63 | 64 | 65 | For demo purposes the starter starts the three seperate processes `MssClient`, `QtServer` and `TfServer`. 66 | The `MssClient` reads screen images and send them to the `TfServer`. The `TfServer` will perform 67 | object detections on the received imaged and send detection boxes to the `QtServer`. The `QtServer` 68 | will display the incoming boxes from the `TfServer`. This will allow for seperation of the solution 69 | and also acceleration of the detections. I would recommend placing the `TfServer` on a pc with high performance 70 | ai hardware and let another pc only focus on displaying the detections. 71 | 72 | ![demo](docs/client_server.png) 73 | 74 | # Use detection request Mode 75 | In this mode the program will wait for requests from other clients and show the requested objects. 76 | Example client message: 77 | ``` 78 | # Data to add persons to detection 79 | data = { 80 | 'value': "person", 81 | 'type': 'add', 82 | 'api_crypt':"password-1" 83 | } 84 | 85 | # Data to remove dogs from detection 86 | data = { 87 | 'value': "dog", 88 | 'type': 'remove', 89 | 'api_crypt':"password-1" 90 | } 91 | ``` 92 | # Testing 93 | 94 | This project contains unit tests using pytest and build tests using docker. 95 | The tests are used in CI github actions. To run tests locally run the commands: 96 | ```bash 97 | 98 | # Unit tests 99 | pytest 100 | 101 | # Build test, to see that all versions still exists 102 | docker-compose up -d 103 | # This will generally complain about not accessing QT and MSS. (GUI and screen grab) 104 | ``` 105 | 106 | # Deprecated demo 107 | ![demo](docs/demo1.gif) 108 | 109 | 110 | # Screenshot 111 | See screenshot of program execution below: 112 | ![screenshot](docs/screenshot.png) 113 | 114 | # License 115 | This repository uses [MIT](LICENSE) license. 116 | 117 | 118 | COPYRIGHT @ Grebtsew 2019 119 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | # 2 | # The purpose of this file is to illustrate how all applications can be executed. 3 | # However, it is only the AiServer container that will run. 4 | # All other containers are ment to run locally as they need local resources to work. 5 | # Such as screengrab and gui handling. 6 | # Thus, this docker-compose are mainly for testing container builds and illustrating usage. 7 | # 8 | 9 | version: '3.8' 10 | 11 | services: 12 | 13 | qtserver: 14 | image: grebtsew/qtserver 15 | container_name: qtserver 16 | build: 17 | context: . 18 | entrypoint: python3 ./server/QtServer.py 19 | ports: 20 | - "8081:8081" 21 | networks: 22 | - rsod-net 23 | 24 | mssclient: 25 | image: grebtsew/mssclient 26 | container_name: mssclient 27 | build: 28 | context: . 29 | entrypoint: python3 ./server/MssClient.py 30 | networks: 31 | - rsod-net 32 | 33 | aiserver: 34 | image: grebtsew/aiserver 35 | container_name: aiserver 36 | build: 37 | context: . 38 | entrypoint: python3 ./server/AiServer.py 39 | ports: 40 | - "8585:8585" 41 | networks: 42 | - rsod-net 43 | 44 | mainapp: 45 | image: grebtsew/main 46 | container_name: main 47 | build: 48 | context: . 49 | ports: 50 | - "5000:5000" 51 | networks: 52 | - rsod-net 53 | 54 | 55 | networks: 56 | rsod-net: 57 | driver: bridge 58 | -------------------------------------------------------------------------------- /docs/Square2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/Square2.png -------------------------------------------------------------------------------- /docs/box2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/box2.png -------------------------------------------------------------------------------- /docs/box3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/box3.png -------------------------------------------------------------------------------- /docs/client_server.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/client_server.png -------------------------------------------------------------------------------- /docs/demo1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/demo1.gif -------------------------------------------------------------------------------- /docs/demo2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/demo2.gif -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/logo.png -------------------------------------------------------------------------------- /docs/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/screenshot.png -------------------------------------------------------------------------------- /docs/square.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/square.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | MAIN 3 | COPYRIGHT @ Grebtsew 2019 4 | 5 | This is main function, used to start instances of the full program 6 | ''' 7 | 8 | from PyQt5.QtCore import * 9 | from PyQt5.QtWidgets import QApplication, QMainWindow 10 | 11 | from utils.shared_variables import Shared_Variables 12 | from utils import screen_overlay_handler 13 | from utils.ThreadPool import * 14 | from ml import detector 15 | 16 | import time 17 | 18 | from pyfiglet import Figlet 19 | 20 | import logging 21 | logging.basicConfig( 22 | level=logging.DEBUG, # Set the logging threshold to DEBUG (or another level) 23 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 24 | filename='last-run.log', # Log messages to a file (optional) 25 | filemode='w' # Append mode for the log file (optional) 26 | ) 27 | console_handler = logging.StreamHandler() # Use the default stream (sys.stdout) 28 | 29 | # Create a formatter for the console handler (optional) 30 | console_formatter = logging.Formatter('%(levelname)s - %(message)s') 31 | console_handler.setFormatter(console_formatter) 32 | 33 | # Add the console handler to the root logger 34 | root_logger = logging.getLogger() 35 | root_logger.addHandler(console_handler) 36 | 37 | # Create a logger for your module 38 | logger = logging.getLogger('realtime-screen-object-detection.rsod') 39 | 40 | # Change these variables if you want! 41 | MAX_BOX_AREA = 100000000 # pixels^2 42 | PRECISION = 0.7 # 60 % detection treshhold 43 | MAX_DETECTION = 5 44 | MAX_TRACKING_MISSES = 30 45 | WIDTH = 1920 46 | HEIGHT = 1080 47 | SHOW_ONLY = ["person","Person"] # Start Empty, receive items to show 48 | OFFSET = (0,0) 49 | DETECTION_SIZE = 480 50 | DETECTION_DURATION = 2 51 | RESET_SHOW_ONLY_ON_START=False 52 | HTTP_SERVER = False 53 | 54 | class MainGUI(QMainWindow): 55 | 56 | def initiate_shared_variables(self): 57 | self.shared_variables = Shared_Variables() 58 | self.shared_variables.MAX_BOX_AREA = MAX_BOX_AREA 59 | self.shared_variables.PRECISION = PRECISION 60 | self.shared_variables.MAX_DETECTION = MAX_DETECTION 61 | self.shared_variables.WIDTH = WIDTH 62 | self.shared_variables.HEIGHT = HEIGHT 63 | self.shared_variables.SHOW_ONLY = SHOW_ONLY 64 | self.shared_variables.list = [] 65 | self.shared_variables.OFFSET = OFFSET 66 | self.shared_variables.DETECTION_SIZE = DETECTION_SIZE 67 | self.shared_variables.DETECTION_DURATION = DETECTION_DURATION 68 | self.shared_variables.MAX_TRACKING_MISSES = MAX_TRACKING_MISSES 69 | self.shared_variables.HTTP_SERVER = HTTP_SERVER 70 | 71 | if RESET_SHOW_ONLY_ON_START: 72 | self.shared_variables.SHOW_ONLY = [] 73 | 74 | # Start webserver 75 | if HTTP_SERVER: 76 | from server.http_server import HTTPserver 77 | HTTPserver(shared_variables=self.shared_variables).start() 78 | 79 | def __init__(self): 80 | super(MainGUI, self).__init__() 81 | 82 | self.initiate_shared_variables() 83 | 84 | # Create detection and load model 85 | self.detection_model = self.shared_variables.model(shared_variables = self.shared_variables) 86 | self.detection_model.download_model() 87 | self.detection_model.load_model() 88 | 89 | self.threadpool = QThreadPool() 90 | 91 | logging.info("Multithreading with maximum %d threads" % self.threadpool.maxThreadCount()) 92 | 93 | 94 | self.timer = QTimer() 95 | self.timer.setInterval(10) 96 | self.timer.timeout.connect(self.print_output) 97 | self.timer.start() 98 | 99 | 100 | # Start Detection thread 101 | self.start_worker() 102 | 103 | def execute_this_fn(self, progress_callback): 104 | while True: 105 | 106 | if len(self.shared_variables.SHOW_ONLY) == 0: 107 | # how often we should detect stuff 108 | pass 109 | else: 110 | logging.debug("Trigger Detection...") 111 | if self.shared_variables.OutputFrame is not None: # wait for the first frame 112 | progress_callback.emit(self.detection_model.predict()) # detect and emits boxes! 113 | 114 | time.sleep(self.shared_variables.DETECTION_DURATION) 115 | 116 | def create_tracking_boxes(self, boxes): 117 | if len(boxes)> 0: 118 | logging.debug(f"got detection now create trackerbox: {boxes}") 119 | 120 | for box in boxes: 121 | if len(self.shared_variables.list) < MAX_DETECTION: 122 | self.shared_variables.list.append(screen_overlay_handler.TrackingBox(len(self.shared_variables.list), self.shared_variables, box[0],box[1],box[2])) 123 | 124 | def print_output(self): 125 | remove = [] 126 | index = 0 127 | for box in self.shared_variables.list: 128 | if box.done: 129 | box.finish(self) 130 | remove.insert(0,index) 131 | index += 1 132 | 133 | for i in remove: 134 | del self.shared_variables.list[i] 135 | #logging.debug(self.shared_variables.list) 136 | 137 | 138 | def thread_complete(self): 139 | #logging.debug("Thread closed") 140 | pass 141 | 142 | def start_worker(self): 143 | # Pass the function to execute 144 | worker = Worker(self.execute_this_fn) # Any other args, kwargs are passed to the run function 145 | worker.signals.progress.connect(self.create_tracking_boxes) 146 | worker.signals.result.connect(self.print_output) 147 | worker.signals.finished.connect(self.thread_complete) 148 | # Execute 149 | self.threadpool.start(worker) 150 | 151 | # Main start here 152 | if __name__ == "__main__": 153 | f = Figlet(font='slant') 154 | logging.info (f.renderText('Realtime Screen stream with Ai detection Overlay')) 155 | logging.info("This program starts several threads that stream pc screen and" + 156 | "run object detection on it and show detections with PyQt5 overlay.") 157 | 158 | logging.info("Starting Program...") 159 | logging.info("All threads started, will take a few seconds to load model, enjoy!") 160 | 161 | logging.info("") 162 | logging.info("----- Settings -----") 163 | logging.info("Max box size : "+ str(MAX_BOX_AREA)) 164 | logging.info("Detection precision treshhold : " + str(100*PRECISION)+"%") 165 | logging.info("Max amount of detection : "+ str(MAX_DETECTION)) 166 | logging.info("Max amount of tracking misses : "+ str(MAX_TRACKING_MISSES)) 167 | logging.info("Do detections every : "+str(DETECTION_DURATION) + " second") 168 | logging.info("Rescale image detection size : " +str(DETECTION_SIZE)) 169 | logging.info("Classifications : " + str(SHOW_ONLY) + " * if empty all detections are allowed.") 170 | logging.info("Screen size : " + str(WIDTH) +"x"+str(HEIGHT)) 171 | logging.info("Screen offset : "+str(OFFSET)) 172 | logging.info("Activate HTTPserver : " + str(HTTP_SERVER)) 173 | logging.info("") 174 | 175 | logging.info("") 176 | logging.info("----- Usage -----") 177 | logging.info("Exit by typing : 'ctrl+c'") 178 | logging.info("") 179 | 180 | logging.info("") 181 | logging.info("Realtime-Screen-stream-with-Ai-detection-Overlay Copyright (C) 2019 Daniel Westberg") 182 | logging.info("This program comes with ABSOLUTELY NO WARRANTY;") 183 | logging.info("This is free software, and you are welcome to redistribute it under certain conditions;") 184 | logging.info("") 185 | 186 | app = QApplication([]) 187 | 188 | MainGUI() 189 | 190 | app.exec_() 191 | -------------------------------------------------------------------------------- /ml/detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Abstract Object Detector Class 3 | """ 4 | import math 5 | 6 | from utils import screen_overlay_handler 7 | from abc import ABC, abstractmethod 8 | 9 | class Detector(ABC): 10 | 11 | def __init__(self, shared_variables) -> None: 12 | super().__init__(shared_variables) 13 | 14 | self.shared_variables = shared_variables 15 | self.id = id 16 | 17 | def distance_between_boxes(self, box1, box2): 18 | return int(abs(math.hypot(box2[0]-box1[0], box2[1]-box1[1]))) 19 | 20 | def detection_exist(self, tracking_list, box): 21 | for tracking_box in tracking_list: 22 | if(self.distance_between_boxes(tracking_box.get_box(), box)) < 100: 23 | return True 24 | return False 25 | 26 | def create_new_tracking_box(self, scores,c, shared_variables, box): 27 | shared_variables.trackingboxes.append(screen_overlay_handler.TrackingBox(scores, c,shared_variables, box)) 28 | 29 | @abstractmethod 30 | def download_model(self, url): 31 | pass 32 | 33 | @abstractmethod 34 | def load_model(self, model_path): 35 | pass 36 | 37 | @abstractmethod 38 | def predict(self, image): 39 | pass 40 | -------------------------------------------------------------------------------- /ml/opencv/haar_cascades.py: -------------------------------------------------------------------------------- 1 | from ml.opencv import opencv_detector as od 2 | 3 | 4 | class HAAR_CASCADES(od.OPENCVDetector): 5 | def __init__(self, model_path, label_map_path): 6 | super().__init__() 7 | self.model_path = model_path 8 | self.label_map_path = label_map_path 9 | 10 | def setup(self): 11 | super().setup() 12 | # Add SSD specific setup here 13 | 14 | def load(self): 15 | super().load(self.model_path) 16 | # SSD-specific loading logic 17 | 18 | def predict(self, image): 19 | # SSD specific prediction logic 20 | prediction = super().predict(image) 21 | # Post-process SSD predictions 22 | # Implement your SSD post-processing here 23 | return prediction 24 | -------------------------------------------------------------------------------- /ml/opencv/opencv_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for specific usage of tensorflow 2.X 3 | """ 4 | from ml import detector as d 5 | 6 | import cv2 7 | 8 | class OPENCVDetector(d.Detector): 9 | def __init__(self): 10 | self.model = None 11 | 12 | def download_model(self): 13 | # Initialize TensorFlow specific setup here 14 | #tf.compat.v1.enable_eager_execution() 15 | pass 16 | 17 | def load_model(self, model_path): 18 | #self.model = tf.keras.models.load_model(model_path) 19 | pass 20 | 21 | def predict(self, image): 22 | # Implement TensorFlow specific prediction logic here 23 | prediction = self.model.predict(image) 24 | return prediction 25 | -------------------------------------------------------------------------------- /ml/tensorflow/ssd.py: -------------------------------------------------------------------------------- 1 | from ml.tensorflow import tf_detector as tf 2 | 3 | class SSD(tf.TFDetector): 4 | def __init__(self, model_path, label_map_path): 5 | super().__init__() 6 | self.model_path = model_path 7 | self.label_map_path = label_map_path 8 | 9 | def setup(self): 10 | super().setup() 11 | # Add SSD specific setup here 12 | 13 | def load(self): 14 | super().load(self.model_path) 15 | # SSD-specific loading logic 16 | 17 | def predict(self, image): 18 | # SSD specific prediction logic 19 | prediction = super().predict(image) 20 | # Post-process SSD predictions 21 | # Implement your SSD post-processing here 22 | return prediction 23 | -------------------------------------------------------------------------------- /ml/tensorflow/tf_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for specific usage of tensorflow 2.X 3 | """ 4 | from ml import detector as d 5 | 6 | import tensorflow as tf 7 | import keras 8 | 9 | class TFDetector(d.Detector): 10 | def __init__(self): 11 | self.model = None 12 | 13 | def download_model(self): 14 | # Initialize TensorFlow specific setup here 15 | tf.compat.v1.enable_eager_execution() 16 | 17 | def load_model(self, model_path): 18 | self.model = tf.keras.models.load_model(model_path) 19 | 20 | def predict(self, image): 21 | # Implement TensorFlow specific prediction logic here 22 | prediction = self.model.predict(image) 23 | return prediction 24 | -------------------------------------------------------------------------------- /ml/torch/torch_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for specific usage of tensorflow 2.X 3 | """ 4 | from ml import detector as d 5 | 6 | import torch 7 | 8 | class TorchDetector(d.Detector): 9 | def __init__(self, shared_variables): 10 | self.model = None 11 | 12 | def download_model(self): 13 | # Initialize TensorFlow specific setup here 14 | pass 15 | 16 | def load_model(self, model_path): 17 | pass 18 | 19 | def predict(self, image): 20 | # Implement TensorFlow specific prediction logic here 21 | pass -------------------------------------------------------------------------------- /ml/torch/yolo.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | #from pathlib import Path 3 | #import os 4 | from ml.torch import torch_detector as td 5 | from ultralytics import YOLO as y 6 | import logging 7 | 8 | class YOLO(td.TorchDetector): 9 | def __init__(self, shared_variables): 10 | 11 | super().__init__(shared_variables) 12 | self.shared_variables = shared_variables 13 | 14 | def download_model(self): 15 | #... 16 | pass 17 | 18 | def load_model(self): 19 | # Load model 20 | self.model = y('yolov8n.pt') 21 | # YOLOv5-specific loading logic 22 | 23 | # TODO: inherit this 24 | self.shared_variables.detection_ready=True 25 | 26 | 27 | 28 | def predict(self): 29 | image = self.shared_variables.OutputFrame 30 | 31 | results = self.model.predict(image) 32 | 33 | # Access detected objects and their attributes 34 | detected_objects = [] 35 | 36 | 37 | 38 | for obj in results: 39 | classes = obj.names 40 | for i in range(len(obj.boxes.cls.tolist())): 41 | _box = obj.boxes.xywhn.tolist()[i] 42 | 43 | box = (_box[0],_box[1],_box[2],_box[3]) 44 | 45 | score = obj.boxes.conf.tolist()[i] 46 | classification = obj.boxes.cls.tolist()[i] 47 | 48 | __box = obj.boxes.xywh.tolist()[i] 49 | 50 | # Apply filters 51 | 52 | 53 | if len(self.shared_variables.SHOW_ONLY )> 0: 54 | if score >= self.shared_variables.PRECISION: 55 | if classes[classification] in self.shared_variables.SHOW_ONLY: 56 | if __box[2]*__box[3] <= self.shared_variables.MAX_BOX_AREA: 57 | detected_objects.append((score,classes[classification], box)) 58 | 59 | else: 60 | if score > self.shared_variables.PRECISION : 61 | if __box[2]*__box[3] <= self.shared_variables.MAX_BOX_AREA: 62 | detected_objects.append((score,classes[classification], box)) 63 | 64 | return detected_objects 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #ML 2 | torch==2.0.1 3 | keras==2.13.1 4 | tensorflow==2.13.0 5 | ultralytics==8.0.171 6 | 7 | # Core 8 | opencv-contrib-python==4.8.0.76 9 | pandas==2.0.3 10 | pyqt5==5.15.9 11 | numpy==1.22.2 12 | Image==1.5.33 13 | opencv-python==4.8.0.76 14 | pyfiglet==0.8.post1 15 | mss==9.0.1 16 | pillow==10.0.0 17 | 18 | # Format and testing 19 | black==23.7.0 20 | pytest==7.4.0 21 | 22 | -------------------------------------------------------------------------------- /server/AiServer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from threading import Thread, Event 3 | import socket 4 | import cv2 5 | import pickle 6 | import struct 7 | import logging 8 | from ml.torch import yolo 9 | 10 | """ 11 | COPYRIGHT @ Grebtsew 2019 12 | 13 | AiServer receives a couple of connections, reads images from incoming streams 14 | and send detections to the QtServer 15 | """ 16 | 17 | QtServer_address= [["127.0.0.1",8081]] 18 | 19 | class AiServer(Thread): 20 | HOST = '127.0.0.1' # Standard loopback interface address (localhost) 21 | PORT = 8585 # Port to listen on (non-privileged ports are > 1023) 22 | 23 | def __init__(self): 24 | super(AiServer, self).__init__() 25 | logging.info("Tensorflow Server started at ", self.HOST, self.PORT) 26 | # Start detections 27 | self.ai_thread = yolo.YOLO() 28 | # Setup output socket 29 | logging.debug("Ai Server try connecting to Qt Server ", QtServer_address[0][0],QtServer_address[0][1]) 30 | self.outSocket = socket.socket() 31 | self.outSocket.connect((QtServer_address[0][0],QtServer_address[0][1])) 32 | logging.debug("SUCCESS : Ai Server successfully connected to Qt Server!", ) 33 | 34 | def handle_connection(self, conn): 35 | with conn: 36 | data = b"" 37 | payload_size = struct.calcsize(">L") 38 | 39 | while True: 40 | # Recieve image package size 41 | while len(data) < payload_size: 42 | logging.debug("Recv: {}".format(len(data))) 43 | data += conn.recv(4096) 44 | 45 | packed_msg_size = data[:payload_size] 46 | data = data[payload_size:] 47 | msg_size = struct.unpack(">L", packed_msg_size)[0] 48 | 49 | logging.debug("msg_size: {}".format(msg_size)) 50 | 51 | # Recieve image 52 | while len(data) < msg_size: 53 | data += conn.recv(4096) 54 | frame_data = data[:msg_size] 55 | data = data[msg_size:] 56 | 57 | # Decode image 58 | frame=pickle.loads(frame_data, fix_imports=True, encoding="bytes") 59 | frame = cv2.imdecode(frame, cv2.IMREAD_COLOR) 60 | 61 | # do detetions 62 | # TODO: 63 | #self.ai_thread.frame = frame 64 | #self.ai_thread.run_async() 65 | #detect_res = self.ai_thread.get_result() 66 | 67 | # send detection result to QtServer 68 | if detect_res is not None: 69 | self.send(detect_res) 70 | 71 | 72 | def send(self, data): 73 | self.outSocket.sendall(pickle.dumps(data)) 74 | 75 | def run(self): 76 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as inSocket: 77 | inSocket.bind((self.HOST, self.PORT)) 78 | inSocket.listen() 79 | while True: 80 | conn, addr = inSocket.accept() 81 | Thread(target=self.handle_connection, args=(conn,)).start() 82 | 83 | 84 | if __name__ == '__main__': 85 | aiserver = AiServer().start() 86 | -------------------------------------------------------------------------------- /server/MssClient.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from QtServer import QtServer 3 | import socket 4 | from mss import mss 5 | from PIL import Image 6 | import cv2 7 | import numpy as np 8 | import pickle 9 | import struct 10 | import time 11 | import logging 12 | 13 | """ 14 | COPYRIGHT @ Grebtsew 2019 15 | 16 | This Client Creates connections to servers in server_list and send images 17 | to each server then recieve result and display images 18 | """ 19 | 20 | server_list= [["127.0.0.1",8585]] 21 | QtServer_address= [["127.0.0.1",8081]] 22 | image_size_treshhold = 720 23 | screensize = (1920, 1080) 24 | 25 | class MssClient(threading.Thread): 26 | """ 27 | This client sends images for detection server 28 | """ 29 | 30 | def __init__(self, address,port): 31 | super(MssClient,self).__init__() 32 | self.address = address 33 | self.port = port 34 | self.s = socket.socket() 35 | #self.demo_start_tfserver() 36 | logging.info("MSS Client trying to connect to Tensorflow Server ", self.address, self.port) 37 | self.s.connect((self.address,self.port)) 38 | logging.info("SUCCESS : Mss Client successfully connected to Tensorflow Server!") 39 | self.encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 90] 40 | 41 | def demo_start_tfserver(self): 42 | from TfServer import TfServer 43 | TfServer().start() 44 | 45 | def send(self): 46 | result, frame = cv2.imencode('.jpg', self.image, self.encode_param) 47 | 48 | data = pickle.dumps(frame, 0) 49 | size = len(data) 50 | logging.debug("Sending Image of size ", size) 51 | self.s.sendall(struct.pack(">L", size) + data) 52 | 53 | def run(self): 54 | self.send() 55 | 56 | def downscale(image): 57 | height, width, channel = image.shape 58 | 59 | if height > image_size_treshhold: 60 | scale = height/image_size_treshhold 61 | 62 | image = cv2.resize(image, (int(width/scale), int(height/scale))) 63 | 64 | return image, scale 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | 70 | image = None 71 | 72 | # Start Create connections 73 | clientlist = [] 74 | for server in server_list: 75 | clientlist.append(MssClient(server[0],server[1])) 76 | 77 | # create mss 78 | sct = mss() 79 | monitor = {'top': 0, 'left': 0, 'width': screensize[0], 'height': screensize[1]} 80 | 81 | # start loop 82 | while True: 83 | # Recieve Image 84 | image = Image.frombytes('RGB', (screensize[0], screensize[1]), sct.grab(monitor).rgb) 85 | image = np.array(image) 86 | # Rescale Image 87 | image, scale = downscale(image) 88 | 89 | # async send image to all servers 90 | for server in clientlist: 91 | server.image = image 92 | server.send() 93 | -------------------------------------------------------------------------------- /server/QtServer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import sys 4 | sys.path.insert(0,'..') 5 | 6 | import os 7 | 8 | import logging 9 | from utils import label_map_util 10 | from utils.screen_overlay_handler import * 11 | import socket 12 | import pickle 13 | import numpy as np 14 | from PyQt5.QtCore import * 15 | from PyQt5.QtWidgets import QApplication 16 | import sys 17 | import utils.screen_overlay_handler 18 | 19 | """ 20 | COPYRIGHT @ Grebtsew 2019 21 | 22 | 23 | QtServer recieves detection boxes and visualize them. 24 | """ 25 | 26 | MAX_DETECTION = 5 27 | MAX_BOX_AREA = 1000000 # pixels^2 28 | PRECISION = 0.6 # 60 % detection treshhold 29 | MAX_DETECTION = 5 30 | WIDTH = 1920 31 | HEIGTH = 1080 32 | SHOW_ONLY = ["person"] 33 | BOX_VIS_TIME = 0.2 # in seconds 34 | 35 | # Dont change these 36 | list = [] 37 | queue = [] 38 | class QtServer(threading.Thread): 39 | """ 40 | This server recieves boxes and shows them in pyqt5 41 | """ 42 | 43 | def __init__(self, address, port): 44 | super(QtServer,self).__init__() 45 | self.address = address 46 | self.port = port 47 | self.categorylist = self.load_tf_categories() 48 | 49 | def load_tf_categories(self): 50 | self.NUM_CLASSES = 90 51 | CWD_PATH = os.path.dirname(os.getcwd()) 52 | self.PATH_TO_LABELS = os.path.join(CWD_PATH,'object_detection', 'data', 'mscoco_label_map.pbtxt') 53 | self.label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS) 54 | self.categories = label_map_util.convert_label_map_to_categories(self.label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True) 55 | self.category_index = label_map_util.create_category_index(self.categories) 56 | return self.categories 57 | 58 | def handle_connection(self, conn): 59 | with conn: 60 | while True: 61 | data = conn.recv(50000) # approx larger than the incoming tf result 62 | if not data: 63 | break 64 | else: 65 | try: 66 | dict = pickle.loads(data) 67 | except Exception: 68 | continue # If for some reason not entire package is recieved! 69 | 70 | boxes = np.squeeze(dict[0]) 71 | scores = np.squeeze(dict[1]) 72 | classification = np.squeeze(dict[2]) 73 | amount = np.squeeze(dict[3]) 74 | 75 | # loop through all detections 76 | for i in range(0,len(boxes)): 77 | # Calculate rescale rectangle 78 | x = int(WIDTH*boxes[i][1]) 79 | y = int(HEIGTH*boxes[i][0]) 80 | w = int(WIDTH*(boxes[i][3]-boxes[i][1])) 81 | h = int(HEIGTH*(boxes[i][2]-boxes[i][0])) 82 | c = "" 83 | 84 | # Check category in bounds 85 | if len(self.categorylist) >= classification[i]: 86 | c = str(self.categorylist[int(classification[i]-1)]['name']) 87 | 88 | if len(SHOW_ONLY) > 0: # dont show wrong items 89 | if not SHOW_ONLY.__contains__(c): 90 | continue 91 | if scores[i] > PRECISION: # precision treshold 92 | if w*h < MAX_BOX_AREA : # max box size check 93 | queue.append((scores[i], c,x,y,w,h)) # save all vis data in queue 94 | 95 | 96 | def run(self): 97 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 98 | s.bind((self.address, self.port)) 99 | s.listen() 100 | logging.info("Qt Server started at ", self.address, self.port ) 101 | while True: 102 | conn, addr = s.accept() 103 | threading.Thread(target=self.handle_connection, args=(conn,)).start() 104 | 105 | def show_rect(scores, c,x,y,w,h): 106 | list.append((screen_overlay_handler.create_box_with_image_score_classification("../images/square2.png",scores,c,x,y,w,h), time.time())) 107 | 108 | def remove_old_detections(): 109 | for box in list: 110 | if time.time() - box[1] > BOX_VIS_TIME: 111 | list.remove(box) 112 | 113 | def paint_rects(): 114 | if len(queue) > 0: 115 | for box in queue: 116 | show_rect(box[0],box[1],box[2],box[3],box[4],box[5]) 117 | queue.clear() 118 | else: 119 | time.sleep(0.1) 120 | remove_old_detections() 121 | 122 | if __name__ == '__main__': 123 | app = QApplication(sys.argv) # create window handler 124 | QtServer_address= [["127.0.0.1",8081]] 125 | qtserver = QtServer(QtServer_address[0][0],QtServer_address[0][1]) 126 | qtserver.start() 127 | 128 | while True: # Paint all incoming boxes on MAIN thread (required!) 129 | paint_rects() 130 | -------------------------------------------------------------------------------- /server/http_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | This server will receive post requests containing an object that need to be detected on screen. 3 | This server will be used in a special mode of the program. 4 | """ 5 | """ 6 | This server will receive HTTP post requests and send the data to flask 7 | """ 8 | from threading import Thread 9 | import http.server 10 | import json 11 | from functools import partial 12 | from http.server import BaseHTTPRequestHandler, HTTPServer 13 | import logging 14 | 15 | class S(BaseHTTPRequestHandler): 16 | 17 | def __init__(self,shared_variables, *args, **kwargs): 18 | self.CRYPT = "password-1" 19 | self.shared_variables = shared_variables 20 | super().__init__(*args, **kwargs) 21 | 22 | def _set_response(self): 23 | self.send_response(200, "ok") 24 | self.send_header('Access-Control-Allow-Origin', '*') 25 | self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, HEAD, GET') 26 | self.send_header("Access-Control-Allow-Headers", "X-Requested-With") 27 | self.send_header("Access-Control-Allow-Headers", "Content-Type") 28 | self.send_header('Content-type', 'application/json') 29 | self.end_headers() 30 | 31 | def do_HEAD(self): 32 | self._set_response() 33 | 34 | def do_OPTIONS(self): 35 | self._set_response() 36 | 37 | def clear_timer(): 38 | self.shared_variables.SHOW_ONLY = [] 39 | 40 | def do_POST(self): 41 | logging.debug(self.client_address,self.headers) 42 | 43 | if self.headers['Content-Length']: 44 | 45 | content_length = int(self.headers['Content-Length']) # <--- Gets the size of data 46 | post_data = self.rfile.read(content_length) # <--- Gets the data itself 47 | # decode incoming data // see if password is correct here! 48 | try: 49 | data = json.loads(post_data) 50 | logging.debug(data) 51 | """ 52 | Data struct: 53 | data = { 54 | 'value': "person", 55 | 'type': 'add', 56 | 'api_crypt':"password-1" 57 | } 58 | """ 59 | if data['api_crypt'] : 60 | if data['api_crypt'] == self.CRYPT: 61 | if data["value"]: 62 | if data["type"] == "remove": 63 | logging.debug("Removed Class : "+ str(data["value"])) 64 | for box in self.shared_variables.list: 65 | if box.classification == data["value"]: 66 | box.remove() # stop running boxes! 67 | if data["value"] in self.shared_variables.SHOW_ONLY: 68 | self.shared_variables.SHOW_ONLY.remove(data["value"]) 69 | else: 70 | if not data["value"] in self.shared_variables.SHOW_ONLY: 71 | logging.debug("Added Class : "+ str(data["value"])) 72 | self.shared_variables.SHOW_ONLY.append(data["value"]) 73 | except Exception as e: 74 | logging.error("ERROR: "+str(e)) 75 | self._set_response() 76 | 77 | class HTTPserver(Thread): 78 | 79 | def __init__(self, shared_variables): 80 | super().__init__() 81 | self.shared_variables = shared_variables 82 | def run(self): 83 | server_address = ("127.0.0.1",5000) 84 | httpd = HTTPServer(server_address, partial(S, self.shared_variables)) 85 | try: 86 | logging.info("HTTP Server Started!") 87 | httpd.serve_forever() 88 | except KeyboardInterrupt: 89 | pass 90 | httpd.server_close() 91 | 92 | def safe(json, value): 93 | try: 94 | return json[value] 95 | except Exception: 96 | return 97 | -------------------------------------------------------------------------------- /server/local_starter.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | 4 | def start_server(name): 5 | import subprocess 6 | subprocess.run(["python", name+".py"]) 7 | 8 | if __name__ == '__main__': 9 | """ 10 | COPYRIGHT @ Grebtsew 2019 11 | 12 | This file starts the demo mode where all servers and clients run on local pc! 13 | """ 14 | 15 | threading.Thread(target=start_server, args=("QtServer",)).start() 16 | time.sleep(2) 17 | threading.Thread(target=start_server, args=("TfServer",)).start() 18 | time.sleep(5) 19 | threading.Thread(target=start_server, args=("MssClient",)).start() 20 | -------------------------------------------------------------------------------- /tests/noCI/test_main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_main.py -------------------------------------------------------------------------------- /tests/noCI/test_server_mssclient.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_server_mssclient.py -------------------------------------------------------------------------------- /tests/noCI/test_server_qtserver.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_server_qtserver.py -------------------------------------------------------------------------------- /tests/noCI/test_utils_screen_overlay_handler.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_utils_screen_overlay_handler.py -------------------------------------------------------------------------------- /tests/test_ml_detector.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_detector.py -------------------------------------------------------------------------------- /tests/test_ml_opencv.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_opencv.py -------------------------------------------------------------------------------- /tests/test_ml_opencv_haar.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_opencv_haar.py -------------------------------------------------------------------------------- /tests/test_ml_tensorflow.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_tensorflow.py -------------------------------------------------------------------------------- /tests/test_ml_torch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_torch.py -------------------------------------------------------------------------------- /tests/test_ml_torch_yolo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_torch_yolo.py -------------------------------------------------------------------------------- /tests/test_server_ai.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_server_ai.py -------------------------------------------------------------------------------- /tests/test_server_http.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_server_http.py -------------------------------------------------------------------------------- /tests/test_utils_shared_variables.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_utils_shared_variables.py -------------------------------------------------------------------------------- /tests/test_utils_threadpool.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_utils_threadpool.py -------------------------------------------------------------------------------- /tests/test_utils_tracking.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_utils_tracking.py -------------------------------------------------------------------------------- /utils/ThreadPool.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import * 2 | from PyQt5.QtWidgets import * 3 | from PyQt5.QtCore import * 4 | 5 | import traceback 6 | import sys 7 | 8 | class WorkerSignals(QObject): 9 | ''' 10 | Defines the signals available from a running worker thread. 11 | 12 | Supported signals are: 13 | 14 | finished 15 | No data 16 | 17 | error 18 | `tuple` (exctype, value, traceback.format_exc() ) 19 | 20 | result 21 | `object` data returned from processing, anything 22 | 23 | progress 24 | `int` indicating % progress 25 | 26 | ''' 27 | finished = pyqtSignal() 28 | error = pyqtSignal(tuple) 29 | result = pyqtSignal(object) 30 | progress = pyqtSignal(list) 31 | 32 | 33 | class Worker(QRunnable): 34 | ''' 35 | Worker thread 36 | 37 | Inherits from QRunnable to handler worker thread setup, signals and wrap-up. 38 | 39 | :param callback: The function callback to run on this worker thread. Supplied args and 40 | kwargs will be passed through to the runner. 41 | :type callback: function 42 | :param args: Arguments to pass to the callback function 43 | :param kwargs: Keywords to pass to the callback function 44 | 45 | ''' 46 | 47 | def __init__(self, fn, *args, **kwargs): 48 | super(Worker, self).__init__() 49 | 50 | # Store constructor arguments (re-used for processing) 51 | self.fn = fn 52 | self.args = args 53 | self.kwargs = kwargs 54 | self.signals = WorkerSignals() 55 | 56 | # Add the callback to our kwargs 57 | self.kwargs['progress_callback'] = self.signals.progress 58 | 59 | @pyqtSlot() 60 | def run(self): 61 | ''' 62 | Initialise the runner function with passed args, kwargs. 63 | ''' 64 | 65 | # Retrieve args/kwargs here; and fire processing using them 66 | try: 67 | result = self.fn(*self.args, **self.kwargs) 68 | except: 69 | traceback.print_exc() 70 | exctype, value = sys.exc_info()[:2] 71 | self.signals.error.emit((exctype, value, traceback.format_exc())) 72 | else: 73 | self.signals.result.emit(result) # Return the result of the processing 74 | finally: 75 | self.signals.finished.emit() # Done 76 | -------------------------------------------------------------------------------- /utils/screen_overlay_handler.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import * 2 | from PyQt5.QtWidgets import * 3 | from PyQt5.QtCore import * 4 | from utils.ThreadPool import * 5 | 6 | from utils.tracking import Tracking 7 | 8 | import logging 9 | import threading 10 | import sys 11 | import time 12 | 13 | ''' 14 | COPYRIGHT @ Grebtsew 2019 15 | 16 | This file contains functions for showing overlay detections 17 | ''' 18 | 19 | 20 | class TrackingBox(QSplashScreen): 21 | splash_pix = None 22 | done = False 23 | 24 | def __init__(self, id, shared_variables, score, classification, box, *args, **kwargs): 25 | super(TrackingBox, self).__init__(*args, **kwargs) 26 | self.classification = classification 27 | self.shared_variables = shared_variables 28 | self.counter = 0 29 | # TODO: this might be a little haxy 30 | self.x = int(box[0]*(self.shared_variables.WIDTH/self.shared_variables.DETECTION_SCALE)-(box[2]*(self.shared_variables.WIDTH/self.shared_variables.DETECTION_SCALE))/2) 31 | self.y = int(box[1]*(self.shared_variables.HEIGHT/self.shared_variables.DETECTION_SCALE)-(box[3]*(self.shared_variables.HEIGHT/self.shared_variables.DETECTION_SCALE))/2) 32 | self.width = int(box[2]*(self.shared_variables.WIDTH/self.shared_variables.DETECTION_SCALE)) 33 | self.height = int(box[3]*(self.shared_variables.HEIGHT/self.shared_variables.DETECTION_SCALE)) 34 | self.id = id 35 | self.splash_pix = QPixmap('./docs/box2.png') 36 | self.splash_pix = self.splash_pix.scaled(round(self.width*self.shared_variables.DETECTION_SCALE),round(self.height*self.shared_variables.DETECTION_SCALE)); 37 | self.setPixmap(self.splash_pix) 38 | 39 | self.setWindowFlag(Qt.WindowStaysOnTopHint) 40 | self.setAttribute(Qt.WA_TranslucentBackground) 41 | self.setAttribute(Qt.WA_NoSystemBackground) 42 | 43 | label = QLabel( self ) 44 | label.setWordWrap( True ) 45 | label.move(30,30) 46 | label.setStyleSheet(" color: rgb(0, 100, 200); font-size: 15pt; ") 47 | 48 | label.setText( str(int(100*score))+"%" + " " + classification ); 49 | self.move(self.x,self.y) 50 | self.show() 51 | 52 | self.tracking = Tracking( (self.x,self.y,self.width,self.height),self.shared_variables) 53 | 54 | self.threadpool = QThreadPool() 55 | 56 | logging.debug(f"New Box Created at {str(self.x)} {str(self.y)} Size {str(self.width)} {str(self.height)}") 57 | 58 | self.start_worker() 59 | 60 | def progress_fn(self, n): 61 | logging.debug("%d%% done" % n) 62 | pass 63 | 64 | def remove(self): 65 | self.shared_variables.list.remove(self) 66 | self.done = True 67 | self.threadpool.cancel 68 | 69 | def execute_this_fn(self, progress_callback): 70 | 71 | if(not self.tracking.running): 72 | if not self.done: # Remove ourself from gui list 73 | self.shared_variables.list.remove(self) 74 | self.done = True 75 | self.threadpool.cancel 76 | else: 77 | self.tracking.run() 78 | 79 | return "Done." 80 | 81 | 82 | def print_output(self, s): 83 | #logging.debug(str(self.id)) 84 | self.hide() 85 | self.repaint_size(round(self.tracking.box[2]*self.shared_variables.DETECTION_SCALE), round(self.tracking.box[3]*self.shared_variables.DETECTION_SCALE)) 86 | self.move(round(self.tracking.box[0]*self.shared_variables.DETECTION_SCALE), round(self.tracking.box[1]*self.shared_variables.DETECTION_SCALE)) 87 | self.show() 88 | 89 | def thread_complete(self): 90 | #logging.debug("THREAD COMPLETE!") 91 | self.start_worker() 92 | 93 | def start_worker(self): 94 | # Pass the function to execute 95 | worker = Worker(self.execute_this_fn) # Any other args, kwargs are passed to the run function 96 | worker.signals.result.connect(self.print_output) 97 | worker.signals.finished.connect(self.thread_complete) 98 | worker.signals.progress.connect(self.progress_fn) 99 | 100 | # Execute 101 | self.threadpool.start(worker) 102 | 103 | def repaint_size(self, width, height): 104 | #splash_pix = QPixmap('../images/box2.png') 105 | self.splash_pix = self.splash_pix.scaled(width,height); 106 | self.setPixmap(self.splash_pix) 107 | 108 | 109 | def get_box(self): 110 | return self.tracking.box 111 | 112 | 113 | def create_box(x,y,width,height): 114 | ''' 115 | Show an overlaybox without label 116 | @Param x box left 117 | @Param y box up 118 | @Param width box width 119 | @Param height box height 120 | @Return overlay instance 121 | ''' 122 | 123 | splash_pix = QPixmap('images/square2.png') 124 | splash_pix = splash_pix.scaled(width,height); 125 | 126 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint) 127 | splash.setWindowOpacity(0.2) 128 | 129 | splash.setAttribute(Qt.WA_NoSystemBackground) 130 | splash.move(x,y) 131 | 132 | splash.show() 133 | return splash 134 | 135 | 136 | def create_box_with_score_classification(score, classification, x,y,width,height): 137 | ''' 138 | Show an overlaybox with label 139 | @Param score float 140 | @Param classification string 141 | @Param x box left 142 | @Param y box up 143 | @Param width box width 144 | @Param height box height 145 | @Return overlay instance 146 | ''' 147 | splash_pix = QPixmap('images/square2.png') 148 | splash_pix = splash_pix.scaled(width,height); 149 | 150 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint) 151 | splash.setWindowOpacity(0.2) 152 | 153 | label = QLabel( splash ); 154 | label.setWordWrap( True ); 155 | label.setText( str(int(100*score))+"%" + " " + classification ); 156 | 157 | splash.setAttribute(Qt.WA_NoSystemBackground) 158 | splash.move(x,y) 159 | 160 | splash.show() 161 | return splash 162 | 163 | def create_box_with_image_score_classification(image_path, score, classification, x,y,width,height): 164 | ''' 165 | Show an overlaybox with label 166 | @Param score float 167 | @Param classification string 168 | @Param x box left 169 | @Param y box up 170 | @Param width box width 171 | @Param height box height 172 | @Return overlay instance 173 | ''' 174 | splash_pix = QPixmap(image_path) 175 | splash_pix = splash_pix.scaled(width,height); 176 | 177 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint) 178 | splash.setWindowOpacity(0.2) 179 | 180 | label = QLabel( splash ); 181 | label.setWordWrap( True ); 182 | label.setText( str(int(100*score))+"%" + " " + classification ); 183 | 184 | splash.setAttribute(Qt.WA_NoSystemBackground) 185 | splash.move(x,y) 186 | splash.show() 187 | return splash 188 | 189 | #TODO 190 | def create_fancy_box(score, classification, x,y,width,height): 191 | # Create fancier box without image 192 | splash_pix = QPixmap(width, height) 193 | 194 | 195 | painter = QPainter(splash_pix) 196 | painter.setPen(QPen(Qt.blue, 10, Qt.SolidLine)) 197 | path = QPainterPath() 198 | 199 | path.addRoundedRect(QRectF(0+10,0+10,width-20,height-20), 30, 30); 200 | painter.drawPath(path); 201 | painter.end() 202 | 203 | 204 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint) 205 | splash.setWindowOpacity(1) 206 | splash.setAttribute(Qt.WA_TranslucentBackground) 207 | 208 | 209 | label = QLabel( splash ); 210 | label.setWordWrap( True ); 211 | label.move(30,30) 212 | label.setStyleSheet(" color: rgb(0, 100, 200); font-size: 30pt; ") 213 | label.setText( str(int(100*score))+"%" + " " + classification ); 214 | 215 | splash.setAttribute(Qt.WA_NoSystemBackground) 216 | splash.move(x,y) 217 | splash.show() 218 | return splash 219 | -------------------------------------------------------------------------------- /utils/shared_variables.py: -------------------------------------------------------------------------------- 1 | ''' 2 | COPYRIGHT @ Grebtsew 2019 3 | 4 | This file contains a shared variables class and a mss screen stream capture class 5 | ''' 6 | # Shared variables between threads 7 | from mss import mss 8 | from PIL import Image 9 | from threading import Thread 10 | 11 | import numpy as np 12 | import cv2 13 | import time 14 | import logging 15 | 16 | from ml.torch.yolo import YOLO 17 | from ml.tensorflow.ssd import SSD 18 | from ml.opencv.haar_cascades import HAAR_CASCADES 19 | 20 | # Global shared variables 21 | # an instace of this class share variables between system threads 22 | class Shared_Variables(): 23 | 24 | # Select model to use here! 25 | model = YOLO # YOLO | SSD | HAAR_CASCADES 26 | 27 | trackingboxes = [] 28 | _initialized = 0 29 | OFFSET = (0,0) 30 | HTTP_SERVER = False 31 | WIDTH, HEIGHT = 1920, 1080 32 | detection_ready = False 33 | category_index = None 34 | OutputFrame = None 35 | frame = None 36 | boxes = None 37 | categorylist = [] 38 | category_max = None 39 | stream_running = True 40 | detection_running = True 41 | list = [] 42 | DETECTION_SIZE = 640 43 | DETECTION_SCALE = 0 44 | 45 | def __init__(self): 46 | Thread.__init__(self) 47 | self._initialized = 1 48 | Screen_Streamer(shared_variables=self).start() 49 | 50 | 51 | class Screen_Streamer(Thread): 52 | def __init__(self, shared_variables = None ): 53 | Thread.__init__(self) 54 | self.shared_variables = shared_variables 55 | 56 | 57 | def downscale(self, image): 58 | image_size_treshhold = self.shared_variables.DETECTION_SIZE 59 | height, width, channel = image.shape 60 | 61 | if height > image_size_treshhold: 62 | scale = height/image_size_treshhold 63 | 64 | image = cv2.resize(image, (int(width/scale), int(height/scale))) 65 | 66 | return image, scale 67 | 68 | 69 | def run(self): 70 | sct = mss() 71 | monitor = {'top': self.shared_variables.OFFSET[0], 'left': self.shared_variables.OFFSET[1], 'width': self.shared_variables.WIDTH, 'height': self.shared_variables.HEIGHT} 72 | logging.info(f"MSS started with monitor : {monitor}") 73 | 74 | while self.shared_variables.stream_running: 75 | if self.shared_variables.detection_ready: 76 | img = Image.frombytes('RGB', (self.shared_variables.WIDTH, self.shared_variables.HEIGHT), sct.grab(monitor).rgb) 77 | #cv2.imshow('test', np.array(img)) 78 | #cv2.waitKey(1) 79 | #self.shared_variables.frame = np.array(img) 80 | self.shared_variables.OutputFrame, self.shared_variables.DETECTION_SCALE = self.downscale(np.array(img)) 81 | if cv2.waitKey(25) & 0xFF == ord('q'): 82 | cv2.destroyAllWindows() 83 | break 84 | else: 85 | time.sleep(0.1) 86 | -------------------------------------------------------------------------------- /utils/tracking.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import * 2 | from PyQt5.QtWidgets import * 3 | from PyQt5.QtCore import * 4 | 5 | import numpy as np 6 | import logging 7 | import time 8 | import traceback, sys 9 | # Tracking thread 10 | 11 | # imports 12 | from math import hypot 13 | import math 14 | import cv2 15 | import sys 16 | import threading 17 | import datetime 18 | 19 | 20 | class Tracking(): 21 | tracker_test = None 22 | tracker = None 23 | frame = None 24 | running = True 25 | 26 | fail_counter = 0 27 | 28 | start_time = None 29 | end_time = None 30 | first_time = True 31 | first = True 32 | # Initiate thread 33 | # parameters name , shared_variables reference 34 | # 35 | def __init__(self, box, shared_variables): 36 | self.box = box 37 | self.shared_variables = shared_variables 38 | 39 | self.kalman = cv2.KalmanFilter(4, 2, 0) 40 | self.kalman.measurementMatrix = np.array([[1,0,0,0], 41 | [0,1,0,0]],np.float32) 42 | 43 | self.kalman.transitionMatrix = np.array([[1,0,1,0], 44 | [0,1,0,1], 45 | [0,0,1,0], 46 | [0,0,0,1]],np.float32) 47 | 48 | self.kalman.processNoiseCov = np.array([[1,0,0,0], 49 | [0,1,0,0], 50 | [0,0,1,0], 51 | [0,0,0,1]],np.float32) * 0.03 52 | 53 | 54 | # Run 55 | # Thread run function 56 | # 57 | def run(self): 58 | self.frame = self.shared_variables.OutputFrame 59 | 60 | if self.frame is not None: 61 | if self.first_time: 62 | self.update_custom_tracker() 63 | self.first_time = False 64 | self.object_custom_tracking() 65 | 66 | 67 | 68 | # Create_custom_tracker 69 | # 70 | # Create custom tracker, can chage tracking method here 71 | # will need cv2 and cv2-contrib to work! 72 | # 73 | def create_custom_tracker(self): 74 | #higher object tracking accuracy and can tolerate slower FPS throughput 75 | #self.tracker = cv2.TrackerCSRT_create() 76 | #faster FPS throughput but can handle slightly lower object tracking accuracy 77 | self.tracker = cv2.TrackerKCF_create() 78 | #MOSSE when you need pure speed 79 | #self.tracker = cv2.TrackerMOSSE_create() 80 | 81 | # Update_custom_tracker 82 | # 83 | # Set and reset custom tracker 84 | # 85 | def update_custom_tracker(self): 86 | self.create_custom_tracker() 87 | print(self.frame.shape, self.box) 88 | self.tracker_test = self.tracker.init( self.frame, self.box) 89 | 90 | # def distance_between_boxes(self, box1, box2): 91 | # return int(abs(math.hypot(box2[0]-box1[0], box2[1]-box1[1]))) 92 | 93 | def get_box(self): 94 | return self.box 95 | 96 | # Object_Custom_tracking 97 | # 98 | # This function uses the OpenCV tracking form uncommented in update_custom_tracking 99 | # 100 | def object_custom_tracking(self): 101 | 102 | # Calculate 103 | self.tracker_test, box = self.tracker.update(self.frame) 104 | # Update tracker box 105 | #logging.debug(self.tracker_test, box, len(self.shared_variables.list)) 106 | 107 | if self.tracker_test: 108 | #cv2.waitKey(1) 109 | #cv2.imshow("test", self.frame) 110 | 111 | if self.first: 112 | A = self.kalman.statePost 113 | A[0:4] = np.array([[np.float32(box[0])], [np.float32(box[1])],[0],[0]]) 114 | # A[4:8] = 0.0 115 | self.kalman.statePost = A 116 | self.kalman.statePre = A 117 | self.first = False 118 | 119 | current_measurement = np.array([[np.float32(box[0])], [np.float32(box[1])]]) 120 | self.kalman.correct(current_measurement) 121 | prediction = self.kalman.predict() 122 | #logging.debug(int(prediction[0]), int(prediction[1])) 123 | self.box = [int(prediction[0]), int(prediction[1]), box[2], box[3]] 124 | self.fail_counter = 0 125 | 126 | else: 127 | self.fail_counter+=1 128 | if(self.fail_counter > self.shared_variables.MAX_TRACKING_MISSES): # missed fifteen frames 129 | self.running = False 130 | 131 | class MultiTracking(): 132 | tracker_test = None 133 | tracker = None 134 | frame = None 135 | running = True 136 | 137 | fail_counter = 0 138 | 139 | def __init__(self, shared_variables): 140 | self.box = box 141 | self.shared_variables = shared_variables 142 | 143 | def run(self): 144 | self.frame = self.shared_variables.OutputFrame 145 | 146 | if self.frame is not None: 147 | if self.first_time: 148 | self.update_custom_tracker() 149 | self.first_time = False 150 | self.object_custom_tracking() 151 | 152 | def create_custom_tracker(self): 153 | self.Tracker = cv2.MultiTracker_create() 154 | 155 | def update_custom_tracker(self): 156 | self.create_custom_tracker() 157 | 158 | #self.tracker_test = self.tracker.init( self.frame, self.box) 159 | 160 | def get_box(self): 161 | return self.box 162 | 163 | def add_tracker(self, frame, box): 164 | trackerType = "CSRT" 165 | 166 | # Initialize MultiTracker 167 | for bbox in bboxes: 168 | self.tracker.add(createTrackerByName(trackerType), frame, box) 169 | 170 | def object_custom_tracking(self): 171 | # Calculate 172 | self.tracker_test, box = self.tracker.update(self.frame) 173 | # Update tracker box 174 | #logging.debug(self.tracker_test, box) 175 | if self.tracker_test: 176 | cv2.waitKey(1) 177 | cv2.imshow("test", self.frame) 178 | self.box = box 179 | self.fail_counter = 0 180 | 181 | 182 | else: 183 | self.fail_counter+=1 184 | if(self.fail_counter > 2): # missed five frames 185 | self.running = False 186 | --------------------------------------------------------------------------------