├── .dockerignore
├── .gitattributes
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── docker-compose.yml
├── docs
├── Square2.png
├── box2.png
├── box3.png
├── client_server.png
├── demo1.gif
├── demo2.gif
├── logo.png
├── screenshot.png
└── square.jpg
├── main.py
├── ml
├── detector.py
├── opencv
│ ├── haar_cascades.py
│ └── opencv_detector.py
├── tensorflow
│ ├── ssd.py
│ └── tf_detector.py
└── torch
│ ├── torch_detector.py
│ └── yolo.py
├── requirements.txt
├── server
├── AiServer.py
├── MssClient.py
├── QtServer.py
├── http_server.py
└── local_starter.py
├── tests
├── noCI
│ ├── test_main.py
│ ├── test_server_mssclient.py
│ ├── test_server_qtserver.py
│ └── test_utils_screen_overlay_handler.py
├── test_ml_detector.py
├── test_ml_opencv.py
├── test_ml_opencv_haar.py
├── test_ml_tensorflow.py
├── test_ml_torch.py
├── test_ml_torch_yolo.py
├── test_server_ai.py
├── test_server_http.py
├── test_utils_shared_variables.py
├── test_utils_threadpool.py
└── test_utils_tracking.py
└── utils
├── ThreadPool.py
├── screen_overlay_handler.py
├── shared_variables.py
└── tracking.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | ./docs
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *old_remove
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # Environments
92 | .env
93 | .venv
94 | env/
95 | venv/
96 | ENV/
97 | env.bak/
98 | venv.bak/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 |
115 | # Pyre type checker
116 | .pyre/
117 |
118 | # Ignore ml models
119 | *.h5
120 | *.pt
121 | *.log
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | #
2 | # The purpose if this dockerfile is to build main project for CI/CD
3 | #
4 | FROM python:3.8
5 |
6 | RUN mkdir /app
7 | WORKDIR /app
8 |
9 | RUN apt-get update && \
10 | apt-get install -y libgl1-mesa-glx && \
11 | apt-get clean && \
12 | rm -rf /var/lib/apt/lists/*
13 |
14 | COPY ./requirements.txt /requirements.txt
15 | RUN pip install -r /requirements.txt
16 |
17 | #RUN yolo settings sync=false
18 |
19 | COPY . .
20 |
21 | CMD ["python3", "/app/main.py"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Grebtsew
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 
8 | 
9 | 
10 |
11 | 
12 |
13 | `NOTE:` I use low framerate and low precision object detection to save some
14 | power on my laptop, feel free to change models, detection and remove loop delays.
15 |
16 | # About
17 | RSOD (Realtime Screen Object Detection) is a program I created to visualize
18 | realtime object detections on screen capture. To capture pc screen I use the `mss`
19 | library. Detections are visualized through `PyQt5` as overlays. Feel free to try it out
20 | and change setting variables in the `main.py` file. Possible usages for this program can
21 | be to track and discover objects on screen to assist users or perhaps to enhanced
22 | the general user experience.
23 |
24 | # Install
25 | This program is written for **python3**. The required python3 libraries are the following:
26 |
27 | * `PyQt5` (used to overlay screen to show detections)
28 | * `Tensorflow` (object detection)
29 | * `mss` (capture screen stream)
30 | * `Image`
31 | * `opencv-python`
32 | * `pyfiglet`
33 | * `numpy`
34 |
35 | If you have `Python3 pip` installed you can install required packages by running:
36 |
37 | ```bash
38 | pip install -r requirements.txt
39 | ```
40 |
41 | After installing pip packages.
42 | Run this to stop ultralytics library from sending data to google analytics:
43 | ```bash
44 | yolo settings sync=false
45 | ```
46 |
47 |
48 | # Run
49 | This tutorial takes you through the execution of this program.
50 |
51 | 1. Install requirements above
52 | 2. Clone this repo or download it:
53 | ```git
54 | git clone https://github.com/grebtsew/Realtime-Screen-Object-Detection.git
55 | ```
56 | 3. Run python script `main.py`
57 |
58 | # Change Settings
59 |
60 | # Client Server
61 | A client/server solution has been added to further increase performability and accelerate detections.
62 | Test the implementation by running the `starter.py` file. The structure is explained below.
63 |
64 |
65 | For demo purposes the starter starts the three seperate processes `MssClient`, `QtServer` and `TfServer`.
66 | The `MssClient` reads screen images and send them to the `TfServer`. The `TfServer` will perform
67 | object detections on the received imaged and send detection boxes to the `QtServer`. The `QtServer`
68 | will display the incoming boxes from the `TfServer`. This will allow for seperation of the solution
69 | and also acceleration of the detections. I would recommend placing the `TfServer` on a pc with high performance
70 | ai hardware and let another pc only focus on displaying the detections.
71 |
72 | 
73 |
74 | # Use detection request Mode
75 | In this mode the program will wait for requests from other clients and show the requested objects.
76 | Example client message:
77 | ```
78 | # Data to add persons to detection
79 | data = {
80 | 'value': "person",
81 | 'type': 'add',
82 | 'api_crypt':"password-1"
83 | }
84 |
85 | # Data to remove dogs from detection
86 | data = {
87 | 'value': "dog",
88 | 'type': 'remove',
89 | 'api_crypt':"password-1"
90 | }
91 | ```
92 | # Testing
93 |
94 | This project contains unit tests using pytest and build tests using docker.
95 | The tests are used in CI github actions. To run tests locally run the commands:
96 | ```bash
97 |
98 | # Unit tests
99 | pytest
100 |
101 | # Build test, to see that all versions still exists
102 | docker-compose up -d
103 | # This will generally complain about not accessing QT and MSS. (GUI and screen grab)
104 | ```
105 |
106 | # Deprecated demo
107 | 
108 |
109 |
110 | # Screenshot
111 | See screenshot of program execution below:
112 | 
113 |
114 | # License
115 | This repository uses [MIT](LICENSE) license.
116 |
117 |
118 | COPYRIGHT @ Grebtsew 2019
119 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | #
2 | # The purpose of this file is to illustrate how all applications can be executed.
3 | # However, it is only the AiServer container that will run.
4 | # All other containers are ment to run locally as they need local resources to work.
5 | # Such as screengrab and gui handling.
6 | # Thus, this docker-compose are mainly for testing container builds and illustrating usage.
7 | #
8 |
9 | version: '3.8'
10 |
11 | services:
12 |
13 | qtserver:
14 | image: grebtsew/qtserver
15 | container_name: qtserver
16 | build:
17 | context: .
18 | entrypoint: python3 ./server/QtServer.py
19 | ports:
20 | - "8081:8081"
21 | networks:
22 | - rsod-net
23 |
24 | mssclient:
25 | image: grebtsew/mssclient
26 | container_name: mssclient
27 | build:
28 | context: .
29 | entrypoint: python3 ./server/MssClient.py
30 | networks:
31 | - rsod-net
32 |
33 | aiserver:
34 | image: grebtsew/aiserver
35 | container_name: aiserver
36 | build:
37 | context: .
38 | entrypoint: python3 ./server/AiServer.py
39 | ports:
40 | - "8585:8585"
41 | networks:
42 | - rsod-net
43 |
44 | mainapp:
45 | image: grebtsew/main
46 | container_name: main
47 | build:
48 | context: .
49 | ports:
50 | - "5000:5000"
51 | networks:
52 | - rsod-net
53 |
54 |
55 | networks:
56 | rsod-net:
57 | driver: bridge
58 |
--------------------------------------------------------------------------------
/docs/Square2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/Square2.png
--------------------------------------------------------------------------------
/docs/box2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/box2.png
--------------------------------------------------------------------------------
/docs/box3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/box3.png
--------------------------------------------------------------------------------
/docs/client_server.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/client_server.png
--------------------------------------------------------------------------------
/docs/demo1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/demo1.gif
--------------------------------------------------------------------------------
/docs/demo2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/demo2.gif
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/logo.png
--------------------------------------------------------------------------------
/docs/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/screenshot.png
--------------------------------------------------------------------------------
/docs/square.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/docs/square.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''
2 | MAIN
3 | COPYRIGHT @ Grebtsew 2019
4 |
5 | This is main function, used to start instances of the full program
6 | '''
7 |
8 | from PyQt5.QtCore import *
9 | from PyQt5.QtWidgets import QApplication, QMainWindow
10 |
11 | from utils.shared_variables import Shared_Variables
12 | from utils import screen_overlay_handler
13 | from utils.ThreadPool import *
14 | from ml import detector
15 |
16 | import time
17 |
18 | from pyfiglet import Figlet
19 |
20 | import logging
21 | logging.basicConfig(
22 | level=logging.DEBUG, # Set the logging threshold to DEBUG (or another level)
23 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
24 | filename='last-run.log', # Log messages to a file (optional)
25 | filemode='w' # Append mode for the log file (optional)
26 | )
27 | console_handler = logging.StreamHandler() # Use the default stream (sys.stdout)
28 |
29 | # Create a formatter for the console handler (optional)
30 | console_formatter = logging.Formatter('%(levelname)s - %(message)s')
31 | console_handler.setFormatter(console_formatter)
32 |
33 | # Add the console handler to the root logger
34 | root_logger = logging.getLogger()
35 | root_logger.addHandler(console_handler)
36 |
37 | # Create a logger for your module
38 | logger = logging.getLogger('realtime-screen-object-detection.rsod')
39 |
40 | # Change these variables if you want!
41 | MAX_BOX_AREA = 100000000 # pixels^2
42 | PRECISION = 0.7 # 60 % detection treshhold
43 | MAX_DETECTION = 5
44 | MAX_TRACKING_MISSES = 30
45 | WIDTH = 1920
46 | HEIGHT = 1080
47 | SHOW_ONLY = ["person","Person"] # Start Empty, receive items to show
48 | OFFSET = (0,0)
49 | DETECTION_SIZE = 480
50 | DETECTION_DURATION = 2
51 | RESET_SHOW_ONLY_ON_START=False
52 | HTTP_SERVER = False
53 |
54 | class MainGUI(QMainWindow):
55 |
56 | def initiate_shared_variables(self):
57 | self.shared_variables = Shared_Variables()
58 | self.shared_variables.MAX_BOX_AREA = MAX_BOX_AREA
59 | self.shared_variables.PRECISION = PRECISION
60 | self.shared_variables.MAX_DETECTION = MAX_DETECTION
61 | self.shared_variables.WIDTH = WIDTH
62 | self.shared_variables.HEIGHT = HEIGHT
63 | self.shared_variables.SHOW_ONLY = SHOW_ONLY
64 | self.shared_variables.list = []
65 | self.shared_variables.OFFSET = OFFSET
66 | self.shared_variables.DETECTION_SIZE = DETECTION_SIZE
67 | self.shared_variables.DETECTION_DURATION = DETECTION_DURATION
68 | self.shared_variables.MAX_TRACKING_MISSES = MAX_TRACKING_MISSES
69 | self.shared_variables.HTTP_SERVER = HTTP_SERVER
70 |
71 | if RESET_SHOW_ONLY_ON_START:
72 | self.shared_variables.SHOW_ONLY = []
73 |
74 | # Start webserver
75 | if HTTP_SERVER:
76 | from server.http_server import HTTPserver
77 | HTTPserver(shared_variables=self.shared_variables).start()
78 |
79 | def __init__(self):
80 | super(MainGUI, self).__init__()
81 |
82 | self.initiate_shared_variables()
83 |
84 | # Create detection and load model
85 | self.detection_model = self.shared_variables.model(shared_variables = self.shared_variables)
86 | self.detection_model.download_model()
87 | self.detection_model.load_model()
88 |
89 | self.threadpool = QThreadPool()
90 |
91 | logging.info("Multithreading with maximum %d threads" % self.threadpool.maxThreadCount())
92 |
93 |
94 | self.timer = QTimer()
95 | self.timer.setInterval(10)
96 | self.timer.timeout.connect(self.print_output)
97 | self.timer.start()
98 |
99 |
100 | # Start Detection thread
101 | self.start_worker()
102 |
103 | def execute_this_fn(self, progress_callback):
104 | while True:
105 |
106 | if len(self.shared_variables.SHOW_ONLY) == 0:
107 | # how often we should detect stuff
108 | pass
109 | else:
110 | logging.debug("Trigger Detection...")
111 | if self.shared_variables.OutputFrame is not None: # wait for the first frame
112 | progress_callback.emit(self.detection_model.predict()) # detect and emits boxes!
113 |
114 | time.sleep(self.shared_variables.DETECTION_DURATION)
115 |
116 | def create_tracking_boxes(self, boxes):
117 | if len(boxes)> 0:
118 | logging.debug(f"got detection now create trackerbox: {boxes}")
119 |
120 | for box in boxes:
121 | if len(self.shared_variables.list) < MAX_DETECTION:
122 | self.shared_variables.list.append(screen_overlay_handler.TrackingBox(len(self.shared_variables.list), self.shared_variables, box[0],box[1],box[2]))
123 |
124 | def print_output(self):
125 | remove = []
126 | index = 0
127 | for box in self.shared_variables.list:
128 | if box.done:
129 | box.finish(self)
130 | remove.insert(0,index)
131 | index += 1
132 |
133 | for i in remove:
134 | del self.shared_variables.list[i]
135 | #logging.debug(self.shared_variables.list)
136 |
137 |
138 | def thread_complete(self):
139 | #logging.debug("Thread closed")
140 | pass
141 |
142 | def start_worker(self):
143 | # Pass the function to execute
144 | worker = Worker(self.execute_this_fn) # Any other args, kwargs are passed to the run function
145 | worker.signals.progress.connect(self.create_tracking_boxes)
146 | worker.signals.result.connect(self.print_output)
147 | worker.signals.finished.connect(self.thread_complete)
148 | # Execute
149 | self.threadpool.start(worker)
150 |
151 | # Main start here
152 | if __name__ == "__main__":
153 | f = Figlet(font='slant')
154 | logging.info (f.renderText('Realtime Screen stream with Ai detection Overlay'))
155 | logging.info("This program starts several threads that stream pc screen and" +
156 | "run object detection on it and show detections with PyQt5 overlay.")
157 |
158 | logging.info("Starting Program...")
159 | logging.info("All threads started, will take a few seconds to load model, enjoy!")
160 |
161 | logging.info("")
162 | logging.info("----- Settings -----")
163 | logging.info("Max box size : "+ str(MAX_BOX_AREA))
164 | logging.info("Detection precision treshhold : " + str(100*PRECISION)+"%")
165 | logging.info("Max amount of detection : "+ str(MAX_DETECTION))
166 | logging.info("Max amount of tracking misses : "+ str(MAX_TRACKING_MISSES))
167 | logging.info("Do detections every : "+str(DETECTION_DURATION) + " second")
168 | logging.info("Rescale image detection size : " +str(DETECTION_SIZE))
169 | logging.info("Classifications : " + str(SHOW_ONLY) + " * if empty all detections are allowed.")
170 | logging.info("Screen size : " + str(WIDTH) +"x"+str(HEIGHT))
171 | logging.info("Screen offset : "+str(OFFSET))
172 | logging.info("Activate HTTPserver : " + str(HTTP_SERVER))
173 | logging.info("")
174 |
175 | logging.info("")
176 | logging.info("----- Usage -----")
177 | logging.info("Exit by typing : 'ctrl+c'")
178 | logging.info("")
179 |
180 | logging.info("")
181 | logging.info("Realtime-Screen-stream-with-Ai-detection-Overlay Copyright (C) 2019 Daniel Westberg")
182 | logging.info("This program comes with ABSOLUTELY NO WARRANTY;")
183 | logging.info("This is free software, and you are welcome to redistribute it under certain conditions;")
184 | logging.info("")
185 |
186 | app = QApplication([])
187 |
188 | MainGUI()
189 |
190 | app.exec_()
191 |
--------------------------------------------------------------------------------
/ml/detector.py:
--------------------------------------------------------------------------------
1 | """
2 | Abstract Object Detector Class
3 | """
4 | import math
5 |
6 | from utils import screen_overlay_handler
7 | from abc import ABC, abstractmethod
8 |
9 | class Detector(ABC):
10 |
11 | def __init__(self, shared_variables) -> None:
12 | super().__init__(shared_variables)
13 |
14 | self.shared_variables = shared_variables
15 | self.id = id
16 |
17 | def distance_between_boxes(self, box1, box2):
18 | return int(abs(math.hypot(box2[0]-box1[0], box2[1]-box1[1])))
19 |
20 | def detection_exist(self, tracking_list, box):
21 | for tracking_box in tracking_list:
22 | if(self.distance_between_boxes(tracking_box.get_box(), box)) < 100:
23 | return True
24 | return False
25 |
26 | def create_new_tracking_box(self, scores,c, shared_variables, box):
27 | shared_variables.trackingboxes.append(screen_overlay_handler.TrackingBox(scores, c,shared_variables, box))
28 |
29 | @abstractmethod
30 | def download_model(self, url):
31 | pass
32 |
33 | @abstractmethod
34 | def load_model(self, model_path):
35 | pass
36 |
37 | @abstractmethod
38 | def predict(self, image):
39 | pass
40 |
--------------------------------------------------------------------------------
/ml/opencv/haar_cascades.py:
--------------------------------------------------------------------------------
1 | from ml.opencv import opencv_detector as od
2 |
3 |
4 | class HAAR_CASCADES(od.OPENCVDetector):
5 | def __init__(self, model_path, label_map_path):
6 | super().__init__()
7 | self.model_path = model_path
8 | self.label_map_path = label_map_path
9 |
10 | def setup(self):
11 | super().setup()
12 | # Add SSD specific setup here
13 |
14 | def load(self):
15 | super().load(self.model_path)
16 | # SSD-specific loading logic
17 |
18 | def predict(self, image):
19 | # SSD specific prediction logic
20 | prediction = super().predict(image)
21 | # Post-process SSD predictions
22 | # Implement your SSD post-processing here
23 | return prediction
24 |
--------------------------------------------------------------------------------
/ml/opencv/opencv_detector.py:
--------------------------------------------------------------------------------
1 | """
2 | Class for specific usage of tensorflow 2.X
3 | """
4 | from ml import detector as d
5 |
6 | import cv2
7 |
8 | class OPENCVDetector(d.Detector):
9 | def __init__(self):
10 | self.model = None
11 |
12 | def download_model(self):
13 | # Initialize TensorFlow specific setup here
14 | #tf.compat.v1.enable_eager_execution()
15 | pass
16 |
17 | def load_model(self, model_path):
18 | #self.model = tf.keras.models.load_model(model_path)
19 | pass
20 |
21 | def predict(self, image):
22 | # Implement TensorFlow specific prediction logic here
23 | prediction = self.model.predict(image)
24 | return prediction
25 |
--------------------------------------------------------------------------------
/ml/tensorflow/ssd.py:
--------------------------------------------------------------------------------
1 | from ml.tensorflow import tf_detector as tf
2 |
3 | class SSD(tf.TFDetector):
4 | def __init__(self, model_path, label_map_path):
5 | super().__init__()
6 | self.model_path = model_path
7 | self.label_map_path = label_map_path
8 |
9 | def setup(self):
10 | super().setup()
11 | # Add SSD specific setup here
12 |
13 | def load(self):
14 | super().load(self.model_path)
15 | # SSD-specific loading logic
16 |
17 | def predict(self, image):
18 | # SSD specific prediction logic
19 | prediction = super().predict(image)
20 | # Post-process SSD predictions
21 | # Implement your SSD post-processing here
22 | return prediction
23 |
--------------------------------------------------------------------------------
/ml/tensorflow/tf_detector.py:
--------------------------------------------------------------------------------
1 | """
2 | Class for specific usage of tensorflow 2.X
3 | """
4 | from ml import detector as d
5 |
6 | import tensorflow as tf
7 | import keras
8 |
9 | class TFDetector(d.Detector):
10 | def __init__(self):
11 | self.model = None
12 |
13 | def download_model(self):
14 | # Initialize TensorFlow specific setup here
15 | tf.compat.v1.enable_eager_execution()
16 |
17 | def load_model(self, model_path):
18 | self.model = tf.keras.models.load_model(model_path)
19 |
20 | def predict(self, image):
21 | # Implement TensorFlow specific prediction logic here
22 | prediction = self.model.predict(image)
23 | return prediction
24 |
--------------------------------------------------------------------------------
/ml/torch/torch_detector.py:
--------------------------------------------------------------------------------
1 | """
2 | Class for specific usage of tensorflow 2.X
3 | """
4 | from ml import detector as d
5 |
6 | import torch
7 |
8 | class TorchDetector(d.Detector):
9 | def __init__(self, shared_variables):
10 | self.model = None
11 |
12 | def download_model(self):
13 | # Initialize TensorFlow specific setup here
14 | pass
15 |
16 | def load_model(self, model_path):
17 | pass
18 |
19 | def predict(self, image):
20 | # Implement TensorFlow specific prediction logic here
21 | pass
--------------------------------------------------------------------------------
/ml/torch/yolo.py:
--------------------------------------------------------------------------------
1 | #import torch
2 | #from pathlib import Path
3 | #import os
4 | from ml.torch import torch_detector as td
5 | from ultralytics import YOLO as y
6 | import logging
7 |
8 | class YOLO(td.TorchDetector):
9 | def __init__(self, shared_variables):
10 |
11 | super().__init__(shared_variables)
12 | self.shared_variables = shared_variables
13 |
14 | def download_model(self):
15 | #...
16 | pass
17 |
18 | def load_model(self):
19 | # Load model
20 | self.model = y('yolov8n.pt')
21 | # YOLOv5-specific loading logic
22 |
23 | # TODO: inherit this
24 | self.shared_variables.detection_ready=True
25 |
26 |
27 |
28 | def predict(self):
29 | image = self.shared_variables.OutputFrame
30 |
31 | results = self.model.predict(image)
32 |
33 | # Access detected objects and their attributes
34 | detected_objects = []
35 |
36 |
37 |
38 | for obj in results:
39 | classes = obj.names
40 | for i in range(len(obj.boxes.cls.tolist())):
41 | _box = obj.boxes.xywhn.tolist()[i]
42 |
43 | box = (_box[0],_box[1],_box[2],_box[3])
44 |
45 | score = obj.boxes.conf.tolist()[i]
46 | classification = obj.boxes.cls.tolist()[i]
47 |
48 | __box = obj.boxes.xywh.tolist()[i]
49 |
50 | # Apply filters
51 |
52 |
53 | if len(self.shared_variables.SHOW_ONLY )> 0:
54 | if score >= self.shared_variables.PRECISION:
55 | if classes[classification] in self.shared_variables.SHOW_ONLY:
56 | if __box[2]*__box[3] <= self.shared_variables.MAX_BOX_AREA:
57 | detected_objects.append((score,classes[classification], box))
58 |
59 | else:
60 | if score > self.shared_variables.PRECISION :
61 | if __box[2]*__box[3] <= self.shared_variables.MAX_BOX_AREA:
62 | detected_objects.append((score,classes[classification], box))
63 |
64 | return detected_objects
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #ML
2 | torch==2.0.1
3 | keras==2.13.1
4 | tensorflow==2.13.0
5 | ultralytics==8.0.171
6 |
7 | # Core
8 | opencv-contrib-python==4.8.0.76
9 | pandas==2.0.3
10 | pyqt5==5.15.9
11 | numpy==1.22.2
12 | Image==1.5.33
13 | opencv-python==4.8.0.76
14 | pyfiglet==0.8.post1
15 | mss==9.0.1
16 | pillow==10.0.0
17 |
18 | # Format and testing
19 | black==23.7.0
20 | pytest==7.4.0
21 |
22 |
--------------------------------------------------------------------------------
/server/AiServer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from threading import Thread, Event
3 | import socket
4 | import cv2
5 | import pickle
6 | import struct
7 | import logging
8 | from ml.torch import yolo
9 |
10 | """
11 | COPYRIGHT @ Grebtsew 2019
12 |
13 | AiServer receives a couple of connections, reads images from incoming streams
14 | and send detections to the QtServer
15 | """
16 |
17 | QtServer_address= [["127.0.0.1",8081]]
18 |
19 | class AiServer(Thread):
20 | HOST = '127.0.0.1' # Standard loopback interface address (localhost)
21 | PORT = 8585 # Port to listen on (non-privileged ports are > 1023)
22 |
23 | def __init__(self):
24 | super(AiServer, self).__init__()
25 | logging.info("Tensorflow Server started at ", self.HOST, self.PORT)
26 | # Start detections
27 | self.ai_thread = yolo.YOLO()
28 | # Setup output socket
29 | logging.debug("Ai Server try connecting to Qt Server ", QtServer_address[0][0],QtServer_address[0][1])
30 | self.outSocket = socket.socket()
31 | self.outSocket.connect((QtServer_address[0][0],QtServer_address[0][1]))
32 | logging.debug("SUCCESS : Ai Server successfully connected to Qt Server!", )
33 |
34 | def handle_connection(self, conn):
35 | with conn:
36 | data = b""
37 | payload_size = struct.calcsize(">L")
38 |
39 | while True:
40 | # Recieve image package size
41 | while len(data) < payload_size:
42 | logging.debug("Recv: {}".format(len(data)))
43 | data += conn.recv(4096)
44 |
45 | packed_msg_size = data[:payload_size]
46 | data = data[payload_size:]
47 | msg_size = struct.unpack(">L", packed_msg_size)[0]
48 |
49 | logging.debug("msg_size: {}".format(msg_size))
50 |
51 | # Recieve image
52 | while len(data) < msg_size:
53 | data += conn.recv(4096)
54 | frame_data = data[:msg_size]
55 | data = data[msg_size:]
56 |
57 | # Decode image
58 | frame=pickle.loads(frame_data, fix_imports=True, encoding="bytes")
59 | frame = cv2.imdecode(frame, cv2.IMREAD_COLOR)
60 |
61 | # do detetions
62 | # TODO:
63 | #self.ai_thread.frame = frame
64 | #self.ai_thread.run_async()
65 | #detect_res = self.ai_thread.get_result()
66 |
67 | # send detection result to QtServer
68 | if detect_res is not None:
69 | self.send(detect_res)
70 |
71 |
72 | def send(self, data):
73 | self.outSocket.sendall(pickle.dumps(data))
74 |
75 | def run(self):
76 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as inSocket:
77 | inSocket.bind((self.HOST, self.PORT))
78 | inSocket.listen()
79 | while True:
80 | conn, addr = inSocket.accept()
81 | Thread(target=self.handle_connection, args=(conn,)).start()
82 |
83 |
84 | if __name__ == '__main__':
85 | aiserver = AiServer().start()
86 |
--------------------------------------------------------------------------------
/server/MssClient.py:
--------------------------------------------------------------------------------
1 | import threading
2 | from QtServer import QtServer
3 | import socket
4 | from mss import mss
5 | from PIL import Image
6 | import cv2
7 | import numpy as np
8 | import pickle
9 | import struct
10 | import time
11 | import logging
12 |
13 | """
14 | COPYRIGHT @ Grebtsew 2019
15 |
16 | This Client Creates connections to servers in server_list and send images
17 | to each server then recieve result and display images
18 | """
19 |
20 | server_list= [["127.0.0.1",8585]]
21 | QtServer_address= [["127.0.0.1",8081]]
22 | image_size_treshhold = 720
23 | screensize = (1920, 1080)
24 |
25 | class MssClient(threading.Thread):
26 | """
27 | This client sends images for detection server
28 | """
29 |
30 | def __init__(self, address,port):
31 | super(MssClient,self).__init__()
32 | self.address = address
33 | self.port = port
34 | self.s = socket.socket()
35 | #self.demo_start_tfserver()
36 | logging.info("MSS Client trying to connect to Tensorflow Server ", self.address, self.port)
37 | self.s.connect((self.address,self.port))
38 | logging.info("SUCCESS : Mss Client successfully connected to Tensorflow Server!")
39 | self.encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 90]
40 |
41 | def demo_start_tfserver(self):
42 | from TfServer import TfServer
43 | TfServer().start()
44 |
45 | def send(self):
46 | result, frame = cv2.imencode('.jpg', self.image, self.encode_param)
47 |
48 | data = pickle.dumps(frame, 0)
49 | size = len(data)
50 | logging.debug("Sending Image of size ", size)
51 | self.s.sendall(struct.pack(">L", size) + data)
52 |
53 | def run(self):
54 | self.send()
55 |
56 | def downscale(image):
57 | height, width, channel = image.shape
58 |
59 | if height > image_size_treshhold:
60 | scale = height/image_size_treshhold
61 |
62 | image = cv2.resize(image, (int(width/scale), int(height/scale)))
63 |
64 | return image, scale
65 |
66 |
67 |
68 | if __name__ == '__main__':
69 |
70 | image = None
71 |
72 | # Start Create connections
73 | clientlist = []
74 | for server in server_list:
75 | clientlist.append(MssClient(server[0],server[1]))
76 |
77 | # create mss
78 | sct = mss()
79 | monitor = {'top': 0, 'left': 0, 'width': screensize[0], 'height': screensize[1]}
80 |
81 | # start loop
82 | while True:
83 | # Recieve Image
84 | image = Image.frombytes('RGB', (screensize[0], screensize[1]), sct.grab(monitor).rgb)
85 | image = np.array(image)
86 | # Rescale Image
87 | image, scale = downscale(image)
88 |
89 | # async send image to all servers
90 | for server in clientlist:
91 | server.image = image
92 | server.send()
93 |
--------------------------------------------------------------------------------
/server/QtServer.py:
--------------------------------------------------------------------------------
1 | import threading
2 |
3 | import sys
4 | sys.path.insert(0,'..')
5 |
6 | import os
7 |
8 | import logging
9 | from utils import label_map_util
10 | from utils.screen_overlay_handler import *
11 | import socket
12 | import pickle
13 | import numpy as np
14 | from PyQt5.QtCore import *
15 | from PyQt5.QtWidgets import QApplication
16 | import sys
17 | import utils.screen_overlay_handler
18 |
19 | """
20 | COPYRIGHT @ Grebtsew 2019
21 |
22 |
23 | QtServer recieves detection boxes and visualize them.
24 | """
25 |
26 | MAX_DETECTION = 5
27 | MAX_BOX_AREA = 1000000 # pixels^2
28 | PRECISION = 0.6 # 60 % detection treshhold
29 | MAX_DETECTION = 5
30 | WIDTH = 1920
31 | HEIGTH = 1080
32 | SHOW_ONLY = ["person"]
33 | BOX_VIS_TIME = 0.2 # in seconds
34 |
35 | # Dont change these
36 | list = []
37 | queue = []
38 | class QtServer(threading.Thread):
39 | """
40 | This server recieves boxes and shows them in pyqt5
41 | """
42 |
43 | def __init__(self, address, port):
44 | super(QtServer,self).__init__()
45 | self.address = address
46 | self.port = port
47 | self.categorylist = self.load_tf_categories()
48 |
49 | def load_tf_categories(self):
50 | self.NUM_CLASSES = 90
51 | CWD_PATH = os.path.dirname(os.getcwd())
52 | self.PATH_TO_LABELS = os.path.join(CWD_PATH,'object_detection', 'data', 'mscoco_label_map.pbtxt')
53 | self.label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
54 | self.categories = label_map_util.convert_label_map_to_categories(self.label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True)
55 | self.category_index = label_map_util.create_category_index(self.categories)
56 | return self.categories
57 |
58 | def handle_connection(self, conn):
59 | with conn:
60 | while True:
61 | data = conn.recv(50000) # approx larger than the incoming tf result
62 | if not data:
63 | break
64 | else:
65 | try:
66 | dict = pickle.loads(data)
67 | except Exception:
68 | continue # If for some reason not entire package is recieved!
69 |
70 | boxes = np.squeeze(dict[0])
71 | scores = np.squeeze(dict[1])
72 | classification = np.squeeze(dict[2])
73 | amount = np.squeeze(dict[3])
74 |
75 | # loop through all detections
76 | for i in range(0,len(boxes)):
77 | # Calculate rescale rectangle
78 | x = int(WIDTH*boxes[i][1])
79 | y = int(HEIGTH*boxes[i][0])
80 | w = int(WIDTH*(boxes[i][3]-boxes[i][1]))
81 | h = int(HEIGTH*(boxes[i][2]-boxes[i][0]))
82 | c = ""
83 |
84 | # Check category in bounds
85 | if len(self.categorylist) >= classification[i]:
86 | c = str(self.categorylist[int(classification[i]-1)]['name'])
87 |
88 | if len(SHOW_ONLY) > 0: # dont show wrong items
89 | if not SHOW_ONLY.__contains__(c):
90 | continue
91 | if scores[i] > PRECISION: # precision treshold
92 | if w*h < MAX_BOX_AREA : # max box size check
93 | queue.append((scores[i], c,x,y,w,h)) # save all vis data in queue
94 |
95 |
96 | def run(self):
97 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
98 | s.bind((self.address, self.port))
99 | s.listen()
100 | logging.info("Qt Server started at ", self.address, self.port )
101 | while True:
102 | conn, addr = s.accept()
103 | threading.Thread(target=self.handle_connection, args=(conn,)).start()
104 |
105 | def show_rect(scores, c,x,y,w,h):
106 | list.append((screen_overlay_handler.create_box_with_image_score_classification("../images/square2.png",scores,c,x,y,w,h), time.time()))
107 |
108 | def remove_old_detections():
109 | for box in list:
110 | if time.time() - box[1] > BOX_VIS_TIME:
111 | list.remove(box)
112 |
113 | def paint_rects():
114 | if len(queue) > 0:
115 | for box in queue:
116 | show_rect(box[0],box[1],box[2],box[3],box[4],box[5])
117 | queue.clear()
118 | else:
119 | time.sleep(0.1)
120 | remove_old_detections()
121 |
122 | if __name__ == '__main__':
123 | app = QApplication(sys.argv) # create window handler
124 | QtServer_address= [["127.0.0.1",8081]]
125 | qtserver = QtServer(QtServer_address[0][0],QtServer_address[0][1])
126 | qtserver.start()
127 |
128 | while True: # Paint all incoming boxes on MAIN thread (required!)
129 | paint_rects()
130 |
--------------------------------------------------------------------------------
/server/http_server.py:
--------------------------------------------------------------------------------
1 | """
2 | This server will receive post requests containing an object that need to be detected on screen.
3 | This server will be used in a special mode of the program.
4 | """
5 | """
6 | This server will receive HTTP post requests and send the data to flask
7 | """
8 | from threading import Thread
9 | import http.server
10 | import json
11 | from functools import partial
12 | from http.server import BaseHTTPRequestHandler, HTTPServer
13 | import logging
14 |
15 | class S(BaseHTTPRequestHandler):
16 |
17 | def __init__(self,shared_variables, *args, **kwargs):
18 | self.CRYPT = "password-1"
19 | self.shared_variables = shared_variables
20 | super().__init__(*args, **kwargs)
21 |
22 | def _set_response(self):
23 | self.send_response(200, "ok")
24 | self.send_header('Access-Control-Allow-Origin', '*')
25 | self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, HEAD, GET')
26 | self.send_header("Access-Control-Allow-Headers", "X-Requested-With")
27 | self.send_header("Access-Control-Allow-Headers", "Content-Type")
28 | self.send_header('Content-type', 'application/json')
29 | self.end_headers()
30 |
31 | def do_HEAD(self):
32 | self._set_response()
33 |
34 | def do_OPTIONS(self):
35 | self._set_response()
36 |
37 | def clear_timer():
38 | self.shared_variables.SHOW_ONLY = []
39 |
40 | def do_POST(self):
41 | logging.debug(self.client_address,self.headers)
42 |
43 | if self.headers['Content-Length']:
44 |
45 | content_length = int(self.headers['Content-Length']) # <--- Gets the size of data
46 | post_data = self.rfile.read(content_length) # <--- Gets the data itself
47 | # decode incoming data // see if password is correct here!
48 | try:
49 | data = json.loads(post_data)
50 | logging.debug(data)
51 | """
52 | Data struct:
53 | data = {
54 | 'value': "person",
55 | 'type': 'add',
56 | 'api_crypt':"password-1"
57 | }
58 | """
59 | if data['api_crypt'] :
60 | if data['api_crypt'] == self.CRYPT:
61 | if data["value"]:
62 | if data["type"] == "remove":
63 | logging.debug("Removed Class : "+ str(data["value"]))
64 | for box in self.shared_variables.list:
65 | if box.classification == data["value"]:
66 | box.remove() # stop running boxes!
67 | if data["value"] in self.shared_variables.SHOW_ONLY:
68 | self.shared_variables.SHOW_ONLY.remove(data["value"])
69 | else:
70 | if not data["value"] in self.shared_variables.SHOW_ONLY:
71 | logging.debug("Added Class : "+ str(data["value"]))
72 | self.shared_variables.SHOW_ONLY.append(data["value"])
73 | except Exception as e:
74 | logging.error("ERROR: "+str(e))
75 | self._set_response()
76 |
77 | class HTTPserver(Thread):
78 |
79 | def __init__(self, shared_variables):
80 | super().__init__()
81 | self.shared_variables = shared_variables
82 | def run(self):
83 | server_address = ("127.0.0.1",5000)
84 | httpd = HTTPServer(server_address, partial(S, self.shared_variables))
85 | try:
86 | logging.info("HTTP Server Started!")
87 | httpd.serve_forever()
88 | except KeyboardInterrupt:
89 | pass
90 | httpd.server_close()
91 |
92 | def safe(json, value):
93 | try:
94 | return json[value]
95 | except Exception:
96 | return
97 |
--------------------------------------------------------------------------------
/server/local_starter.py:
--------------------------------------------------------------------------------
1 | import time
2 | import threading
3 |
4 | def start_server(name):
5 | import subprocess
6 | subprocess.run(["python", name+".py"])
7 |
8 | if __name__ == '__main__':
9 | """
10 | COPYRIGHT @ Grebtsew 2019
11 |
12 | This file starts the demo mode where all servers and clients run on local pc!
13 | """
14 |
15 | threading.Thread(target=start_server, args=("QtServer",)).start()
16 | time.sleep(2)
17 | threading.Thread(target=start_server, args=("TfServer",)).start()
18 | time.sleep(5)
19 | threading.Thread(target=start_server, args=("MssClient",)).start()
20 |
--------------------------------------------------------------------------------
/tests/noCI/test_main.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_main.py
--------------------------------------------------------------------------------
/tests/noCI/test_server_mssclient.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_server_mssclient.py
--------------------------------------------------------------------------------
/tests/noCI/test_server_qtserver.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_server_qtserver.py
--------------------------------------------------------------------------------
/tests/noCI/test_utils_screen_overlay_handler.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/noCI/test_utils_screen_overlay_handler.py
--------------------------------------------------------------------------------
/tests/test_ml_detector.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_detector.py
--------------------------------------------------------------------------------
/tests/test_ml_opencv.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_opencv.py
--------------------------------------------------------------------------------
/tests/test_ml_opencv_haar.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_opencv_haar.py
--------------------------------------------------------------------------------
/tests/test_ml_tensorflow.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_tensorflow.py
--------------------------------------------------------------------------------
/tests/test_ml_torch.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_torch.py
--------------------------------------------------------------------------------
/tests/test_ml_torch_yolo.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_ml_torch_yolo.py
--------------------------------------------------------------------------------
/tests/test_server_ai.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_server_ai.py
--------------------------------------------------------------------------------
/tests/test_server_http.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_server_http.py
--------------------------------------------------------------------------------
/tests/test_utils_shared_variables.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_utils_shared_variables.py
--------------------------------------------------------------------------------
/tests/test_utils_threadpool.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_utils_threadpool.py
--------------------------------------------------------------------------------
/tests/test_utils_tracking.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grebtsew/Realtime-Screen-Object-Detection/fe3f1992226615a67863eca08bd461857f40a443/tests/test_utils_tracking.py
--------------------------------------------------------------------------------
/utils/ThreadPool.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtGui import *
2 | from PyQt5.QtWidgets import *
3 | from PyQt5.QtCore import *
4 |
5 | import traceback
6 | import sys
7 |
8 | class WorkerSignals(QObject):
9 | '''
10 | Defines the signals available from a running worker thread.
11 |
12 | Supported signals are:
13 |
14 | finished
15 | No data
16 |
17 | error
18 | `tuple` (exctype, value, traceback.format_exc() )
19 |
20 | result
21 | `object` data returned from processing, anything
22 |
23 | progress
24 | `int` indicating % progress
25 |
26 | '''
27 | finished = pyqtSignal()
28 | error = pyqtSignal(tuple)
29 | result = pyqtSignal(object)
30 | progress = pyqtSignal(list)
31 |
32 |
33 | class Worker(QRunnable):
34 | '''
35 | Worker thread
36 |
37 | Inherits from QRunnable to handler worker thread setup, signals and wrap-up.
38 |
39 | :param callback: The function callback to run on this worker thread. Supplied args and
40 | kwargs will be passed through to the runner.
41 | :type callback: function
42 | :param args: Arguments to pass to the callback function
43 | :param kwargs: Keywords to pass to the callback function
44 |
45 | '''
46 |
47 | def __init__(self, fn, *args, **kwargs):
48 | super(Worker, self).__init__()
49 |
50 | # Store constructor arguments (re-used for processing)
51 | self.fn = fn
52 | self.args = args
53 | self.kwargs = kwargs
54 | self.signals = WorkerSignals()
55 |
56 | # Add the callback to our kwargs
57 | self.kwargs['progress_callback'] = self.signals.progress
58 |
59 | @pyqtSlot()
60 | def run(self):
61 | '''
62 | Initialise the runner function with passed args, kwargs.
63 | '''
64 |
65 | # Retrieve args/kwargs here; and fire processing using them
66 | try:
67 | result = self.fn(*self.args, **self.kwargs)
68 | except:
69 | traceback.print_exc()
70 | exctype, value = sys.exc_info()[:2]
71 | self.signals.error.emit((exctype, value, traceback.format_exc()))
72 | else:
73 | self.signals.result.emit(result) # Return the result of the processing
74 | finally:
75 | self.signals.finished.emit() # Done
76 |
--------------------------------------------------------------------------------
/utils/screen_overlay_handler.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtGui import *
2 | from PyQt5.QtWidgets import *
3 | from PyQt5.QtCore import *
4 | from utils.ThreadPool import *
5 |
6 | from utils.tracking import Tracking
7 |
8 | import logging
9 | import threading
10 | import sys
11 | import time
12 |
13 | '''
14 | COPYRIGHT @ Grebtsew 2019
15 |
16 | This file contains functions for showing overlay detections
17 | '''
18 |
19 |
20 | class TrackingBox(QSplashScreen):
21 | splash_pix = None
22 | done = False
23 |
24 | def __init__(self, id, shared_variables, score, classification, box, *args, **kwargs):
25 | super(TrackingBox, self).__init__(*args, **kwargs)
26 | self.classification = classification
27 | self.shared_variables = shared_variables
28 | self.counter = 0
29 | # TODO: this might be a little haxy
30 | self.x = int(box[0]*(self.shared_variables.WIDTH/self.shared_variables.DETECTION_SCALE)-(box[2]*(self.shared_variables.WIDTH/self.shared_variables.DETECTION_SCALE))/2)
31 | self.y = int(box[1]*(self.shared_variables.HEIGHT/self.shared_variables.DETECTION_SCALE)-(box[3]*(self.shared_variables.HEIGHT/self.shared_variables.DETECTION_SCALE))/2)
32 | self.width = int(box[2]*(self.shared_variables.WIDTH/self.shared_variables.DETECTION_SCALE))
33 | self.height = int(box[3]*(self.shared_variables.HEIGHT/self.shared_variables.DETECTION_SCALE))
34 | self.id = id
35 | self.splash_pix = QPixmap('./docs/box2.png')
36 | self.splash_pix = self.splash_pix.scaled(round(self.width*self.shared_variables.DETECTION_SCALE),round(self.height*self.shared_variables.DETECTION_SCALE));
37 | self.setPixmap(self.splash_pix)
38 |
39 | self.setWindowFlag(Qt.WindowStaysOnTopHint)
40 | self.setAttribute(Qt.WA_TranslucentBackground)
41 | self.setAttribute(Qt.WA_NoSystemBackground)
42 |
43 | label = QLabel( self )
44 | label.setWordWrap( True )
45 | label.move(30,30)
46 | label.setStyleSheet(" color: rgb(0, 100, 200); font-size: 15pt; ")
47 |
48 | label.setText( str(int(100*score))+"%" + " " + classification );
49 | self.move(self.x,self.y)
50 | self.show()
51 |
52 | self.tracking = Tracking( (self.x,self.y,self.width,self.height),self.shared_variables)
53 |
54 | self.threadpool = QThreadPool()
55 |
56 | logging.debug(f"New Box Created at {str(self.x)} {str(self.y)} Size {str(self.width)} {str(self.height)}")
57 |
58 | self.start_worker()
59 |
60 | def progress_fn(self, n):
61 | logging.debug("%d%% done" % n)
62 | pass
63 |
64 | def remove(self):
65 | self.shared_variables.list.remove(self)
66 | self.done = True
67 | self.threadpool.cancel
68 |
69 | def execute_this_fn(self, progress_callback):
70 |
71 | if(not self.tracking.running):
72 | if not self.done: # Remove ourself from gui list
73 | self.shared_variables.list.remove(self)
74 | self.done = True
75 | self.threadpool.cancel
76 | else:
77 | self.tracking.run()
78 |
79 | return "Done."
80 |
81 |
82 | def print_output(self, s):
83 | #logging.debug(str(self.id))
84 | self.hide()
85 | self.repaint_size(round(self.tracking.box[2]*self.shared_variables.DETECTION_SCALE), round(self.tracking.box[3]*self.shared_variables.DETECTION_SCALE))
86 | self.move(round(self.tracking.box[0]*self.shared_variables.DETECTION_SCALE), round(self.tracking.box[1]*self.shared_variables.DETECTION_SCALE))
87 | self.show()
88 |
89 | def thread_complete(self):
90 | #logging.debug("THREAD COMPLETE!")
91 | self.start_worker()
92 |
93 | def start_worker(self):
94 | # Pass the function to execute
95 | worker = Worker(self.execute_this_fn) # Any other args, kwargs are passed to the run function
96 | worker.signals.result.connect(self.print_output)
97 | worker.signals.finished.connect(self.thread_complete)
98 | worker.signals.progress.connect(self.progress_fn)
99 |
100 | # Execute
101 | self.threadpool.start(worker)
102 |
103 | def repaint_size(self, width, height):
104 | #splash_pix = QPixmap('../images/box2.png')
105 | self.splash_pix = self.splash_pix.scaled(width,height);
106 | self.setPixmap(self.splash_pix)
107 |
108 |
109 | def get_box(self):
110 | return self.tracking.box
111 |
112 |
113 | def create_box(x,y,width,height):
114 | '''
115 | Show an overlaybox without label
116 | @Param x box left
117 | @Param y box up
118 | @Param width box width
119 | @Param height box height
120 | @Return overlay instance
121 | '''
122 |
123 | splash_pix = QPixmap('images/square2.png')
124 | splash_pix = splash_pix.scaled(width,height);
125 |
126 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint)
127 | splash.setWindowOpacity(0.2)
128 |
129 | splash.setAttribute(Qt.WA_NoSystemBackground)
130 | splash.move(x,y)
131 |
132 | splash.show()
133 | return splash
134 |
135 |
136 | def create_box_with_score_classification(score, classification, x,y,width,height):
137 | '''
138 | Show an overlaybox with label
139 | @Param score float
140 | @Param classification string
141 | @Param x box left
142 | @Param y box up
143 | @Param width box width
144 | @Param height box height
145 | @Return overlay instance
146 | '''
147 | splash_pix = QPixmap('images/square2.png')
148 | splash_pix = splash_pix.scaled(width,height);
149 |
150 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint)
151 | splash.setWindowOpacity(0.2)
152 |
153 | label = QLabel( splash );
154 | label.setWordWrap( True );
155 | label.setText( str(int(100*score))+"%" + " " + classification );
156 |
157 | splash.setAttribute(Qt.WA_NoSystemBackground)
158 | splash.move(x,y)
159 |
160 | splash.show()
161 | return splash
162 |
163 | def create_box_with_image_score_classification(image_path, score, classification, x,y,width,height):
164 | '''
165 | Show an overlaybox with label
166 | @Param score float
167 | @Param classification string
168 | @Param x box left
169 | @Param y box up
170 | @Param width box width
171 | @Param height box height
172 | @Return overlay instance
173 | '''
174 | splash_pix = QPixmap(image_path)
175 | splash_pix = splash_pix.scaled(width,height);
176 |
177 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint)
178 | splash.setWindowOpacity(0.2)
179 |
180 | label = QLabel( splash );
181 | label.setWordWrap( True );
182 | label.setText( str(int(100*score))+"%" + " " + classification );
183 |
184 | splash.setAttribute(Qt.WA_NoSystemBackground)
185 | splash.move(x,y)
186 | splash.show()
187 | return splash
188 |
189 | #TODO
190 | def create_fancy_box(score, classification, x,y,width,height):
191 | # Create fancier box without image
192 | splash_pix = QPixmap(width, height)
193 |
194 |
195 | painter = QPainter(splash_pix)
196 | painter.setPen(QPen(Qt.blue, 10, Qt.SolidLine))
197 | path = QPainterPath()
198 |
199 | path.addRoundedRect(QRectF(0+10,0+10,width-20,height-20), 30, 30);
200 | painter.drawPath(path);
201 | painter.end()
202 |
203 |
204 | splash = QSplashScreen(splash_pix, Qt.WindowStaysOnTopHint)
205 | splash.setWindowOpacity(1)
206 | splash.setAttribute(Qt.WA_TranslucentBackground)
207 |
208 |
209 | label = QLabel( splash );
210 | label.setWordWrap( True );
211 | label.move(30,30)
212 | label.setStyleSheet(" color: rgb(0, 100, 200); font-size: 30pt; ")
213 | label.setText( str(int(100*score))+"%" + " " + classification );
214 |
215 | splash.setAttribute(Qt.WA_NoSystemBackground)
216 | splash.move(x,y)
217 | splash.show()
218 | return splash
219 |
--------------------------------------------------------------------------------
/utils/shared_variables.py:
--------------------------------------------------------------------------------
1 | '''
2 | COPYRIGHT @ Grebtsew 2019
3 |
4 | This file contains a shared variables class and a mss screen stream capture class
5 | '''
6 | # Shared variables between threads
7 | from mss import mss
8 | from PIL import Image
9 | from threading import Thread
10 |
11 | import numpy as np
12 | import cv2
13 | import time
14 | import logging
15 |
16 | from ml.torch.yolo import YOLO
17 | from ml.tensorflow.ssd import SSD
18 | from ml.opencv.haar_cascades import HAAR_CASCADES
19 |
20 | # Global shared variables
21 | # an instace of this class share variables between system threads
22 | class Shared_Variables():
23 |
24 | # Select model to use here!
25 | model = YOLO # YOLO | SSD | HAAR_CASCADES
26 |
27 | trackingboxes = []
28 | _initialized = 0
29 | OFFSET = (0,0)
30 | HTTP_SERVER = False
31 | WIDTH, HEIGHT = 1920, 1080
32 | detection_ready = False
33 | category_index = None
34 | OutputFrame = None
35 | frame = None
36 | boxes = None
37 | categorylist = []
38 | category_max = None
39 | stream_running = True
40 | detection_running = True
41 | list = []
42 | DETECTION_SIZE = 640
43 | DETECTION_SCALE = 0
44 |
45 | def __init__(self):
46 | Thread.__init__(self)
47 | self._initialized = 1
48 | Screen_Streamer(shared_variables=self).start()
49 |
50 |
51 | class Screen_Streamer(Thread):
52 | def __init__(self, shared_variables = None ):
53 | Thread.__init__(self)
54 | self.shared_variables = shared_variables
55 |
56 |
57 | def downscale(self, image):
58 | image_size_treshhold = self.shared_variables.DETECTION_SIZE
59 | height, width, channel = image.shape
60 |
61 | if height > image_size_treshhold:
62 | scale = height/image_size_treshhold
63 |
64 | image = cv2.resize(image, (int(width/scale), int(height/scale)))
65 |
66 | return image, scale
67 |
68 |
69 | def run(self):
70 | sct = mss()
71 | monitor = {'top': self.shared_variables.OFFSET[0], 'left': self.shared_variables.OFFSET[1], 'width': self.shared_variables.WIDTH, 'height': self.shared_variables.HEIGHT}
72 | logging.info(f"MSS started with monitor : {monitor}")
73 |
74 | while self.shared_variables.stream_running:
75 | if self.shared_variables.detection_ready:
76 | img = Image.frombytes('RGB', (self.shared_variables.WIDTH, self.shared_variables.HEIGHT), sct.grab(monitor).rgb)
77 | #cv2.imshow('test', np.array(img))
78 | #cv2.waitKey(1)
79 | #self.shared_variables.frame = np.array(img)
80 | self.shared_variables.OutputFrame, self.shared_variables.DETECTION_SCALE = self.downscale(np.array(img))
81 | if cv2.waitKey(25) & 0xFF == ord('q'):
82 | cv2.destroyAllWindows()
83 | break
84 | else:
85 | time.sleep(0.1)
86 |
--------------------------------------------------------------------------------
/utils/tracking.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtGui import *
2 | from PyQt5.QtWidgets import *
3 | from PyQt5.QtCore import *
4 |
5 | import numpy as np
6 | import logging
7 | import time
8 | import traceback, sys
9 | # Tracking thread
10 |
11 | # imports
12 | from math import hypot
13 | import math
14 | import cv2
15 | import sys
16 | import threading
17 | import datetime
18 |
19 |
20 | class Tracking():
21 | tracker_test = None
22 | tracker = None
23 | frame = None
24 | running = True
25 |
26 | fail_counter = 0
27 |
28 | start_time = None
29 | end_time = None
30 | first_time = True
31 | first = True
32 | # Initiate thread
33 | # parameters name , shared_variables reference
34 | #
35 | def __init__(self, box, shared_variables):
36 | self.box = box
37 | self.shared_variables = shared_variables
38 |
39 | self.kalman = cv2.KalmanFilter(4, 2, 0)
40 | self.kalman.measurementMatrix = np.array([[1,0,0,0],
41 | [0,1,0,0]],np.float32)
42 |
43 | self.kalman.transitionMatrix = np.array([[1,0,1,0],
44 | [0,1,0,1],
45 | [0,0,1,0],
46 | [0,0,0,1]],np.float32)
47 |
48 | self.kalman.processNoiseCov = np.array([[1,0,0,0],
49 | [0,1,0,0],
50 | [0,0,1,0],
51 | [0,0,0,1]],np.float32) * 0.03
52 |
53 |
54 | # Run
55 | # Thread run function
56 | #
57 | def run(self):
58 | self.frame = self.shared_variables.OutputFrame
59 |
60 | if self.frame is not None:
61 | if self.first_time:
62 | self.update_custom_tracker()
63 | self.first_time = False
64 | self.object_custom_tracking()
65 |
66 |
67 |
68 | # Create_custom_tracker
69 | #
70 | # Create custom tracker, can chage tracking method here
71 | # will need cv2 and cv2-contrib to work!
72 | #
73 | def create_custom_tracker(self):
74 | #higher object tracking accuracy and can tolerate slower FPS throughput
75 | #self.tracker = cv2.TrackerCSRT_create()
76 | #faster FPS throughput but can handle slightly lower object tracking accuracy
77 | self.tracker = cv2.TrackerKCF_create()
78 | #MOSSE when you need pure speed
79 | #self.tracker = cv2.TrackerMOSSE_create()
80 |
81 | # Update_custom_tracker
82 | #
83 | # Set and reset custom tracker
84 | #
85 | def update_custom_tracker(self):
86 | self.create_custom_tracker()
87 | print(self.frame.shape, self.box)
88 | self.tracker_test = self.tracker.init( self.frame, self.box)
89 |
90 | # def distance_between_boxes(self, box1, box2):
91 | # return int(abs(math.hypot(box2[0]-box1[0], box2[1]-box1[1])))
92 |
93 | def get_box(self):
94 | return self.box
95 |
96 | # Object_Custom_tracking
97 | #
98 | # This function uses the OpenCV tracking form uncommented in update_custom_tracking
99 | #
100 | def object_custom_tracking(self):
101 |
102 | # Calculate
103 | self.tracker_test, box = self.tracker.update(self.frame)
104 | # Update tracker box
105 | #logging.debug(self.tracker_test, box, len(self.shared_variables.list))
106 |
107 | if self.tracker_test:
108 | #cv2.waitKey(1)
109 | #cv2.imshow("test", self.frame)
110 |
111 | if self.first:
112 | A = self.kalman.statePost
113 | A[0:4] = np.array([[np.float32(box[0])], [np.float32(box[1])],[0],[0]])
114 | # A[4:8] = 0.0
115 | self.kalman.statePost = A
116 | self.kalman.statePre = A
117 | self.first = False
118 |
119 | current_measurement = np.array([[np.float32(box[0])], [np.float32(box[1])]])
120 | self.kalman.correct(current_measurement)
121 | prediction = self.kalman.predict()
122 | #logging.debug(int(prediction[0]), int(prediction[1]))
123 | self.box = [int(prediction[0]), int(prediction[1]), box[2], box[3]]
124 | self.fail_counter = 0
125 |
126 | else:
127 | self.fail_counter+=1
128 | if(self.fail_counter > self.shared_variables.MAX_TRACKING_MISSES): # missed fifteen frames
129 | self.running = False
130 |
131 | class MultiTracking():
132 | tracker_test = None
133 | tracker = None
134 | frame = None
135 | running = True
136 |
137 | fail_counter = 0
138 |
139 | def __init__(self, shared_variables):
140 | self.box = box
141 | self.shared_variables = shared_variables
142 |
143 | def run(self):
144 | self.frame = self.shared_variables.OutputFrame
145 |
146 | if self.frame is not None:
147 | if self.first_time:
148 | self.update_custom_tracker()
149 | self.first_time = False
150 | self.object_custom_tracking()
151 |
152 | def create_custom_tracker(self):
153 | self.Tracker = cv2.MultiTracker_create()
154 |
155 | def update_custom_tracker(self):
156 | self.create_custom_tracker()
157 |
158 | #self.tracker_test = self.tracker.init( self.frame, self.box)
159 |
160 | def get_box(self):
161 | return self.box
162 |
163 | def add_tracker(self, frame, box):
164 | trackerType = "CSRT"
165 |
166 | # Initialize MultiTracker
167 | for bbox in bboxes:
168 | self.tracker.add(createTrackerByName(trackerType), frame, box)
169 |
170 | def object_custom_tracking(self):
171 | # Calculate
172 | self.tracker_test, box = self.tracker.update(self.frame)
173 | # Update tracker box
174 | #logging.debug(self.tracker_test, box)
175 | if self.tracker_test:
176 | cv2.waitKey(1)
177 | cv2.imshow("test", self.frame)
178 | self.box = box
179 | self.fail_counter = 0
180 |
181 |
182 | else:
183 | self.fail_counter+=1
184 | if(self.fail_counter > 2): # missed five frames
185 | self.running = False
186 |
--------------------------------------------------------------------------------