├── .gitignore ├── LICENSE ├── README.md ├── clip_app ├── __init__.py ├── clip_app_pipeline.py ├── clip_callback.py ├── clip_hailopython.py ├── clip_pipeline.py ├── gui.py ├── logger_setup.py └── text_image_matcher.py ├── clip_application.py ├── community_projects ├── ad_genie │ ├── README.md │ ├── ad_genie.py │ ├── data_preparation.py │ ├── lables_preparation.py │ └── resources │ │ ├── ad_genie.gif │ │ └── structure.jpeg ├── baiby_monitor │ ├── README.md │ ├── archive │ │ └── hazardous_emb.json │ ├── download_resources.sh │ ├── embeddings │ │ ├── cry_detection_emb.json │ │ └── sleep_detection_emb.json │ ├── requirements.txt │ ├── send_message │ │ ├── example_usage.py │ │ └── telegram_messenger.py │ └── src │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── baiby_telegram.py │ │ ├── clip_pipeline.py │ │ ├── lullaby_callback.py │ │ ├── match_handler.py │ │ ├── play_lullaby.py │ │ └── telegram.ini.example ├── community_projects.md └── template_example │ ├── README.md │ ├── downdload_resources.sh │ ├── requirments.txt │ └── template_example.py ├── compile_postprocess.sh ├── cpp ├── TextImageMatcher.cpp ├── TextImageMatcher.hpp ├── clip.cpp ├── clip.hpp ├── clip_croppers.cpp ├── clip_croppers.hpp ├── clip_croppers_new.cpp ├── clip_matcher.cpp ├── clip_matcher.hpp └── meson.build ├── download_resources.sh ├── embeddings.json ├── example_embeddings.json ├── install.sh ├── meson.build ├── requirements.txt ├── resources ├── CLIP_UI.png ├── Hackathon-banner-2024.png ├── configs │ └── yolov5_personface.json ├── github_clip_based_classification.png └── texts_json_example.json ├── run_tests.sh ├── setup.py ├── setup_env.sh └── tests ├── test_clip_app.py ├── test_demo_clip.py └── test_resources └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled files 2 | __pycache__/ 3 | clip_app/__pycache__/ 4 | *.pyc 5 | *.pyo 6 | 7 | # Resources 8 | resources/*.so 9 | resources/*.hef 10 | resources/*.mp4 11 | 12 | # Logs and temp files 13 | *.log 14 | *.tmp 15 | 16 | # Packages and dependencies 17 | *.egg-info/ 18 | hailo_clip_venv/ 19 | 20 | # Build, cache, and coverage files 21 | build/ 22 | build.release/ 23 | dist/ 24 | *.egg 25 | *.egg-info/ 26 | .cache/ 27 | coverage/ 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hailo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](resources/github_clip_based_classification.png) 2 | # CLIP Zero Shot Inference Application 3 | 4 | This is an example application to run a CLIP inference on a video in real-time. The image embeddings are accelerated by the Hailo-8/8L AI processor, while the text embeddings run on the host. Text embeddings are sparse and should be calculated only once per text. If they do not need to be updated in real-time, they can be saved to a JSON file and loaded on the next run. 5 | 6 | Click the image below to watch the demo on YouTube. 7 | 8 | [![Watch the demo on YouTube](https://img.youtube.com/vi/XXizBHtCLew/0.jpg)](https://youtu.be/XXizBHtCLew) 9 | 10 | ## Prerequisites 11 | 12 | This application is compatible with x86 and RPi5 (8GB) systems. 13 | 14 | This example has been tested with the following Hailo TAPPAS versions: 15 | - v3.30.0 16 | - v3.31.0 17 | 18 | Please ensure that one of these versions is installed on your system. 19 | 20 | - **`hailo-tappas-core`**: TAPPAS core installation using a `.deb` file or `apt install` (Raspberry Pi platforms). 21 | - **`hailo_tappas`**: For full TAPPAS installation. See instructions in our [TAPPAS repository](https://github.com/hailo-ai/tappas). 22 | 23 | 24 | This repo uses [Hailo Apps Infra repository](https://github.com/hailo-ai/hailo-apps-infra). It will be installed automatically to your virtualenv when running the installation script. You can also clone it manually see instructions in the Hailo Apps Infra repository. 25 | 26 | 27 | ## Installation 28 | To install the application, clone the repository and run the installation script: 29 | 30 | ```bash 31 | ./install.sh 32 | ``` 33 | It will prepare a virtual environment and install the required dependencies. 34 | 35 | ## Usage 36 | 37 | To prepare the environment, run the following command: 38 | 39 | ```bash 40 | source setup_env.sh 41 | ``` 42 | 43 | Run the example: 44 | 45 | ```bash 46 | python clip_application.py --input demo 47 | ``` 48 | On the first run, CLIP will download the required models. This will happen only once. 49 | 50 | 51 | ## User Guide 52 | Watch the Hailo CLIP Zero Shot Classification Tutorial 53 | 54 | [![Tutorial: Hailo CLIP Zero Shot Classification Application](https://img.youtube.com/vi/xhXOxgEE6K4/0.jpg)](https://youtu.be/xhXOxgEE6K4) 55 | 56 | ### Arguments 57 | 58 | ```bash 59 | python clip_application.py -h 60 | ``` 61 | 62 | ### Modes 63 | 64 | - **Default mode (`--detector none`)**: Runs CLIP inference on the entire frame, which is the intended use for CLIP and provides the best results. 65 | - **Person mode (`--detector person`)**: Runs CLIP inference on detected persons. CLIP acts as a person classifier and runs every second per tracked person. This interval can be adjusted in the code. 66 | - **Face mode (`--detector face`)**: Runs CLIP inference on detected faces. This mode may not perform as well as person mode due to cropped faces being less represented in the dataset. Experiment to see if it fits your application. 67 | 68 | ### Using a Webcam as Input 69 | 70 | #### USB Camera 71 | Before running the application, ensure a camera is connected to your device. Use the `--input` flag to specify the camera device, defaulting to `/dev/video0`. 72 | You can check which USB webcam device is connected by running the following command: 73 | ```bash 74 | get-usb-camera 75 | ``` 76 | 77 | Once you identify your camera device, you can run the application as follows: 78 | ```bash 79 | python clip_application.py --input /dev/video0 80 | ``` 81 | #### rpi Camera 82 | ```bash 83 | python clip_application.py --input rpi 84 | ``` 85 | 86 | ### UI Controls 87 | 88 | ![UI Controls](resources/CLIP_UI.png) 89 | 90 | - **Threshold Slider**: Adjusts the threshold for CLIP classification. Classifications with probabilities lower than this threshold will be ignored. 91 | - **Negative Checkbox**: Marks the classification as a negative prompt. It will be included in the Softmax calculation but will not be shown in the output. 92 | - **Ensemble Checkbox**: Enables ensemble mode, where the prompt text embedding is calculated with variations to improve results. See `ensemble_template` in `text_image_matcher.py` for more details. 93 | - **Text Description**: The text prompt for CLIP classification. 94 | - **Probability Bars**: Displays the probability of various classifications in real-time. 95 | - **Load Button**: Loads the text embeddings from a JSON file specified by the `--json-path` flag. 96 | - **Save Button**: Saves the text embeddings to a JSON file specified by the `--json-path` flag. 97 | - **Track ID**: Displays the classification probabilities for a specific person in person mode. The track ID appears in the bottom left corner of the bounding box. 98 | - **Quit Button**: Exits the application. 99 | 100 | ## Tips for Good Prompt Usage 101 | 102 | - Keep in mind that the network was trained on image + caption pairs. Your text description should be somewhat similar. For example, a text description of "A photo of a cat" will give a better score than "cat". 103 | - The app has a pre-defined "prefix" of "A photo of a" which you can change in the `TextImageMatcher` class. 104 | - The pipeline output will select one of the classes as "the best one". There is no `background` class. You should define a "negative" prompt (or prompts) to be used as `background`. When set as `negative`, the class will be used in the "best match" algorithm but will not be shown in the output. 105 | - You can also use `threshold` to fine-tune detection sensitivity. However, using `negative` prompts is better for detecting specific classes. 106 | - Negative prompts should be used to "peel off" similar classifications to your target. For example, "a man with a red shirt" will have a high score for just a man or a shirt of a different color. Add negative prompts like "a man with a blue shirt" to ensure you do not get lots of false classifications. 107 | - Play around with prompts to see what works best for your application. 108 | 109 | ## Integrating Your Code 110 | 111 | You can integrate your code in the `clip_application.py` file. This file includes a user-defined `app_callback` function that is called after the CLIP inference and before the display. You can use it to add your logic to the app. The `app_callback_class` will be passed to the callback function and can be used to access the app's data. 112 | 113 | ### Online Text Embeddings 114 | 115 | - The application will run the text embeddings on the host, allowing you to change the text on the fly. This mode might not work on weak machines as it requires a host with enough memory to run the text embeddings model (on CPU). See [Offline Text Embeddings](#offline-text-embeddings) for more details. 116 | - You can set which JSON file to use for saving and loading embeddings using the `--json-path` flag. If not set, `embeddings.json` will be used. 117 | - If you wish to load/save your JSON, use the `--json-path` flag explicitly. 118 | 119 | ### Offline Text Embeddings 120 | 121 | - To run without online text embeddings, you can set the `--disable-runtime-prompts` flag. This will speed up the load time and save memory. Additionally, you can use the app without the `torch` and `torchvision` dependencies. This might be suitable for final application deployment. 122 | - You can save the embeddings to a JSON file and load them on the next run. This will not require running the text embeddings on the host. 123 | - If you need to prepare text embeddings on a weak machine, you can use the `text_image_matcher` tool. This tool will run the text embeddings on the host and save them to a JSON file without running the full pipeline. This tool assumes the first text is a 'positive' prompt and the rest are negative. 124 | 125 | #### Arguments 126 | ```bash 127 | text_image_matcher -h 128 | usage: text_image_matcher [-h] [--output OUTPUT] [--interactive] [--image-path IMAGE_PATH] [--texts-list TEXTS_LIST [TEXTS_LIST ...]] [--texts-json TEXTS_JSON] 129 | 130 | options: 131 | -h, --help show this help message and exit 132 | --output OUTPUT output file name default=text_embeddings.json 133 | --interactive input text from interactive shell 134 | --image-path IMAGE_PATH 135 | Optional, path to image file to match. Note image embeddings are not running on Hailo here. 136 | --texts-list TEXTS_LIST [TEXTS_LIST ...] 137 | A list of texts to add to the matcher, the first one will be the searched text, the others will be considered negative prompts. Example: --texts-list "cat" "dog" "yellow car" 138 | --texts-json TEXTS_JSON 139 | A json of texts to add to the matcher, the json will include 2 keys negative and positive, the values are going to be lists of texts. 140 | Example: resources/texts_json_example.json 141 | 142 | ``` 143 | 144 | ## CPP Code Compilation 145 | 146 | Some CPP code is used in this app for post-processing and cropping. This code should be compiled before running the example. It uses Hailo `pkg-config` to find the required libraries. 147 | 148 | The compilation script is `compile_postprocess.sh`. You can run it manually, but it will be executed automatically when installing the package. The post-process `.so` files will be installed under the resources directory. 149 | 150 | ## Known Issues 151 | #### Known Issue with Setuptools 152 | When running with TAPPAS docker, you might encounter this error: 153 | 154 | ```plaintext 155 | ImportError: cannot import name 'packaging' from 'pkg_resources' 156 | ``` 157 | 158 | This is a known issue with setuptools version 70.0.0 and above. To fix it, either downgrade setuptools to version 69.5.1: 159 | 160 | ```bash 161 | pip install setuptools==69.5.1 162 | ``` 163 | 164 | Or upgrade setuptools to the latest version: 165 | 166 | ```bash 167 | pip install --upgrade setuptools 168 | ``` 169 | ## Hailo Apps Infra 170 | The Hailo Apps Infra repository containes the infrastructure of hailo applications and pipelines. 171 | You can find it here see the [Hailo Apps Infra](https://github.com/hailo-ai/hailo-apps-infra). 172 | 173 | ## Contributing 174 | 175 | We welcome contributions from the community. You can contribute by: 176 | 1. Contribute to our [Community Projects](community_projects/community_projects.md). 177 | 2. Reporting issues and bugs. 178 | 3. Suggesting new features or improvements. 179 | 4. Joining the discussion on the [Hailo Community Forum](https://community.hailo.ai/). 180 | 181 | 182 | ## License 183 | 184 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 185 | 186 | ## Disclaimer 187 | 188 | This code example is provided by Hailo solely on an “AS IS” basis and “with all faults.” No responsibility or liability is accepted or shall be imposed upon Hailo regarding the accuracy, merchantability, completeness, or suitability of the code example. Hailo shall not have any liability or responsibility for errors or omissions in, or any business decisions made by you in reliance on this code example or any part of it. If an error occurs when running this example, please open a ticket in the "Issues" tab. 189 | -------------------------------------------------------------------------------- /clip_app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/clip_app/__init__.py -------------------------------------------------------------------------------- /clip_app/clip_app_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import sys 5 | import signal 6 | import importlib.util 7 | from functools import partial 8 | import gi 9 | import threading 10 | gi.require_version('Gtk', '3.0') 11 | gi.require_version('Gst', '1.0') 12 | from gi.repository import Gtk, Gst, GLib 13 | from clip_app.logger_setup import setup_logger, set_log_level 14 | from clip_app.clip_pipeline import get_pipeline 15 | from clip_app.text_image_matcher import text_image_matcher 16 | from clip_app.clip_callback import app_callback_class, dummy_callback 17 | from clip_app import gui 18 | from hailo_apps_infra.gstreamer_app import picamera_thread 19 | from hailo_apps_infra.gstreamer_helper_pipelines import get_source_type 20 | 21 | # add logging 22 | logger = setup_logger() 23 | set_log_level(logger, logging.INFO) 24 | 25 | 26 | class ClipApp(): 27 | def __init__(self, user_data, app_callback): 28 | self.args = self.parse_arguments().parse_args() 29 | 30 | self.app_callback = app_callback 31 | set_log_level(logger, logging.INFO) 32 | 33 | self.user_data = user_data 34 | self.win = AppWindow(self.args, self.user_data, self.app_callback) 35 | def run(self): 36 | self.win.connect("destroy", self.on_destroy) 37 | self.win.show_all() 38 | signal.signal(signal.SIGINT, signal.SIG_DFL) 39 | Gtk.main() 40 | 41 | def on_destroy(self, window): 42 | logger.info("Destroying window...") 43 | window.quit_button_clicked(None) 44 | 45 | def parse_arguments(self): 46 | parser = argparse.ArgumentParser(description="Hailo online CLIP app") 47 | parser.add_argument("--input", "-i", type=str, default="/dev/video0", help="Input source. Can be a file, USB (webcam), RPi camera (CSI camera module). \ 48 | For RPi camera use '-i rpi' \ 49 | For demo video use '--input demo'. \ 50 | Default is /dev/video0.") 51 | parser.add_argument("--detector", "-d", type=str, choices=["person", "face", "none"], default="none", help="Which detection pipeline to use.") 52 | parser.add_argument("--json-path", type=str, default=None, help="Path to JSON file to load and save embeddings. If not set, embeddings.json will be used.") 53 | parser.add_argument("--disable-sync", action="store_true",help="Disables display sink sync, will run as fast as possible. Relevant when using file source.") 54 | parser.add_argument("--dump-dot", action="store_true", help="Dump the pipeline graph to a dot file.") 55 | parser.add_argument("--detection-threshold", type=float, default=0.5, help="Detection threshold.") 56 | parser.add_argument("--show-fps", "-f", action="store_true", help="Print FPS on sink.") 57 | parser.add_argument("--disable-runtime-prompts", action="store_true", help="When set, app will not support runtime prompts. Default is False.") 58 | 59 | return parser 60 | 61 | class AppWindow(Gtk.Window): 62 | # Add GUI functions to the AppWindow class 63 | build_ui = gui.build_ui 64 | add_text_boxes = gui.add_text_boxes 65 | update_text_boxes = gui.update_text_boxes 66 | update_text_prefix = gui.update_text_prefix 67 | quit_button_clicked = gui.quit_button_clicked 68 | on_text_box_updated = gui.on_text_box_updated 69 | on_slider_value_changed = gui.on_slider_value_changed 70 | on_negative_check_button_toggled = gui.on_negative_check_button_toggled 71 | on_ensemble_check_button_toggled = gui.on_ensemble_check_button_toggled 72 | on_load_button_clicked = gui.on_load_button_clicked 73 | on_save_button_clicked = gui.on_save_button_clicked 74 | update_progress_bars = gui.update_progress_bars 75 | on_track_id_update = gui.on_track_id_update 76 | disable_text_boxes = gui.disable_text_boxes 77 | 78 | # Add the get_pipeline function to the AppWindow class 79 | get_pipeline = get_pipeline 80 | 81 | 82 | def __init__(self, args, user_data, app_callback): 83 | Gtk.Window.__init__(self, title="Clip App") 84 | self.set_border_width(10) 85 | self.set_default_size(1, 1) 86 | self.fullscreen_mode = False 87 | 88 | self.current_path = os.path.dirname(os.path.realpath(__file__)) 89 | # move self.current_path one directory up to get the path to the workspace 90 | self.current_path = os.path.dirname(self.current_path) 91 | os.environ["GST_DEBUG_DUMP_DOT_DIR"] = self.current_path 92 | 93 | self.tappas_postprocess_dir = os.environ.get('TAPPAS_POST_PROC_DIR', '') 94 | if self.tappas_postprocess_dir == '': 95 | logger.error("TAPPAS_POST_PROC_DIR environment variable is not set. Please set it by sourcing setup_env.sh") 96 | sys.exit(1) 97 | 98 | # Create options menu 99 | self.options_menu = args 100 | 101 | self.dump_dot = self.options_menu.dump_dot 102 | self.video_source = self.options_menu.input 103 | self.source_type = get_source_type(self.video_source) 104 | self.sync = "false" if (self.options_menu.disable_sync or self.source_type != "file") else "true" 105 | self.show_fps = self.options_menu.show_fps 106 | self.json_file = os.path.join(self.current_path, "embeddings.json") if self.options_menu.json_path is None else self.options_menu.json_path 107 | if self.options_menu.input == "demo": 108 | self.input = os.path.join(self.current_path, "resources", "clip_example.mp4") 109 | self.json_file = os.path.join(self.current_path, "example_embeddings.json") if self.options_menu.json_path is None else self.options_menu.json_path 110 | else: 111 | self.input = self.options_menu.input 112 | self.detector = self.options_menu.detector 113 | self.user_data = user_data 114 | self.app_callback = app_callback 115 | # get current path 116 | Gst.init(None) 117 | self.pipeline = self.create_pipeline() 118 | if self.input == "rpi": 119 | picam_thread = threading.Thread(target=picamera_thread, args=(self.pipeline, 1280, 720, 'RGB')) 120 | picam_thread.start() 121 | bus = self.pipeline.get_bus() 122 | bus.add_signal_watch() 123 | bus.connect("message", self.on_message) 124 | 125 | # get text_image_matcher instance 126 | self.text_image_matcher = text_image_matcher 127 | self.text_image_matcher.set_threshold(self.options_menu.detection_threshold) 128 | 129 | # build UI 130 | self.max_entries = 6 131 | self.build_ui(self.options_menu) 132 | 133 | # set runtime 134 | if self.options_menu.disable_runtime_prompts: 135 | logger.info("No text embedding runtime selected, adding new text is disabled. Loading %s", self.json_file) 136 | self.disable_text_boxes() 137 | self.on_load_button_clicked(None) 138 | else: 139 | self.text_image_matcher.init_clip() 140 | 141 | if self.text_image_matcher.model_runtime is not None: 142 | logger.info("Using %s for text embedding", self.text_image_matcher.model_runtime) 143 | self.on_load_button_clicked(None) 144 | 145 | 146 | identity = self.pipeline.get_by_name("identity_callback") 147 | if identity is None: 148 | logger.warning("identity_callback element not found, add in your pipeline where you want the callback to be called.") 149 | else: 150 | identity_pad = identity.get_static_pad("src") 151 | identity_pad.add_probe(Gst.PadProbeType.BUFFER, partial(self.app_callback, self), self.user_data) 152 | # start the pipeline 153 | self.pipeline.set_state(Gst.State.PLAYING) 154 | 155 | if self.dump_dot: 156 | GLib.timeout_add_seconds(5, self.dump_dot_file) 157 | 158 | self.update_text_boxes() 159 | 160 | # Define a timeout duration in nanoseconds (e.g., 5 seconds) 161 | timeout_ns = 5 * Gst.SECOND 162 | 163 | # Wait until state change is done or until the timeout occurs 164 | state_change_return, _state, _pending = self.pipeline.get_state(timeout_ns) 165 | 166 | if state_change_return == Gst.StateChangeReturn.SUCCESS: 167 | logger.info("Pipeline state changed to PLAYING successfully.") 168 | elif state_change_return == Gst.StateChangeReturn.ASYNC: 169 | logger.info("State change is ongoing asynchronously.") 170 | elif state_change_return == Gst.StateChangeReturn.FAILURE: 171 | logger.info("State change failed.") 172 | else: 173 | logger.warning("Unknown state change return value.") 174 | 175 | 176 | def dump_dot_file(self): 177 | logger.info("Dumping dot file...") 178 | Gst.debug_bin_to_dot_file(self.pipeline, Gst.DebugGraphDetails.ALL, "pipeline") 179 | return False 180 | 181 | 182 | def on_message(self, bus, message): 183 | t = message.type 184 | if t == Gst.MessageType.EOS: 185 | self.on_eos() 186 | elif t == Gst.MessageType.ERROR: 187 | err, debug = message.parse_error() 188 | logger.error("Error: %s %s", err, debug) 189 | self.shutdown() 190 | # print QOS messages 191 | elif t == Gst.MessageType.QOS: 192 | # print which element is reporting QOS 193 | src = message.src.get_name() 194 | logger.info("QOS from %s", src) 195 | return True 196 | 197 | 198 | def on_eos(self): 199 | logger.info("EOS received, shutting down the pipeline.") 200 | self.pipeline.set_state(Gst.State.PAUSED) 201 | GLib.usleep(100000) # 0.1 second delay 202 | 203 | self.pipeline.set_state(Gst.State.READY) 204 | GLib.usleep(100000) # 0.1 second delay 205 | 206 | self.pipeline.set_state(Gst.State.NULL) 207 | GLib.idle_add(Gtk.main_quit) 208 | 209 | def shutdown(self): 210 | logger.info("Sending EOS event to the pipeline...") 211 | self.pipeline.send_event(Gst.Event.new_eos()) 212 | 213 | def create_pipeline(self): 214 | pipeline_str = get_pipeline(self) 215 | logger.info('PIPELINE:\ngst-launch-1.0 %s', pipeline_str) 216 | try: 217 | pipeline = Gst.parse_launch(pipeline_str) 218 | except GLib.Error as e: 219 | logger.error("An error occurred while parsing the pipeline: %s", e) 220 | return pipeline 221 | 222 | 223 | if __name__ == "__main__": 224 | user_data = app_callback_class() 225 | clip = ClipApp(user_data, dummy_callback) 226 | clip.run() 227 | -------------------------------------------------------------------------------- /clip_app/clip_callback.py: -------------------------------------------------------------------------------- 1 | import gi 2 | gi.require_version('Gst', '1.0') 3 | from gi.repository import Gst 4 | 5 | class app_callback_class: 6 | def __init__(self): 7 | self.frame_count = 0 8 | self.use_frame = False 9 | self.running = True 10 | 11 | def increment(self): 12 | self.frame_count += 1 13 | 14 | def get_count(self): 15 | return self.frame_count 16 | 17 | def dummy_callback(self, pad, info, user_data): 18 | """ 19 | A minimal dummy callback function that returns immediately. 20 | 21 | Args: 22 | pad: The GStreamer pad 23 | info: The probe info 24 | user_data: User-defined data passed to the callback 25 | 26 | Returns: 27 | Gst.PadProbeReturn.OK 28 | """ 29 | return Gst.PadProbeReturn.OK 30 | 31 | -------------------------------------------------------------------------------- /clip_app/clip_hailopython.py: -------------------------------------------------------------------------------- 1 | import hailo 2 | import numpy as np 3 | # Importing VideoFrame before importing GST is must 4 | from gsthailo import VideoFrame 5 | from gi.repository import Gst 6 | from clip_app.text_image_matcher import text_image_matcher 7 | 8 | def run(video_frame: VideoFrame): 9 | top_level_matrix = video_frame.roi.get_objects_typed(hailo.HAILO_MATRIX) 10 | if len(top_level_matrix) == 0: 11 | detections = video_frame.roi.get_objects_typed(hailo.HAILO_DETECTION) 12 | else: 13 | detections = [video_frame.roi] # Use the ROI as the detection 14 | 15 | embeddings_np = None 16 | used_detection = [] 17 | track_id_focus = text_image_matcher.track_id_focus # Used to focus on a specific track_id 18 | update_tracked_probability = None 19 | for detection in detections: 20 | results = detection.get_objects_typed(hailo.HAILO_MATRIX) 21 | if len(results) == 0: 22 | # print("No matrix found in detection") 23 | continue 24 | # Convert the matrix to a NumPy array 25 | detection_embeddings = np.array(results[0].get_data()) 26 | used_detection.append(detection) 27 | if embeddings_np is None: 28 | embeddings_np = detection_embeddings[np.newaxis, :] 29 | else: 30 | embeddings_np = np.vstack((embeddings_np, detection_embeddings)) 31 | if track_id_focus is not None: 32 | track = detection.get_objects_typed(hailo.HAILO_UNIQUE_ID) 33 | if len(track) == 1: 34 | track_id = track[0].get_id() 35 | # If we have a track_id_focus, update only the tracked_probability of the focused track 36 | if track_id == track_id_focus: 37 | update_tracked_probability = len(used_detection) - 1 38 | if embeddings_np is not None: 39 | matches = text_image_matcher.match(embeddings_np, report_all=True, update_tracked_probability=update_tracked_probability) 40 | for match in matches: 41 | # (row_idx, label, confidence, entry_index) = match 42 | detection = used_detection[match.row_idx] 43 | old_classification = detection.get_objects_typed(hailo.HAILO_CLASSIFICATION) 44 | if (match.passed_threshold and not match.negative): 45 | # Add label as classification metadata 46 | classification = hailo.HailoClassification('clip', match.text, match.similarity) 47 | detection.add_object(classification) 48 | # remove old classification 49 | for old in old_classification: 50 | detection.remove_object(old) 51 | return Gst.FlowReturn.OK 52 | -------------------------------------------------------------------------------- /clip_app/clip_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | # Pipeline parameters 5 | video_width = 1280 6 | video_height = 720 7 | batch_size = 8 8 | 9 | # Check Hailo Device Type from the environment variable DEVICE_ARCHITECTURE 10 | # If the environment variable is not set, default to HAILO8L 11 | device_architecture = os.getenv("DEVICE_ARCHITECTURE") 12 | if device_architecture is None or device_architecture == "HAILO8L": 13 | device_architecture = "HAILO8L" 14 | # HEF files for H8L 15 | YOLO5_HEF_NAME = "yolov5s_personface_h8l_pi.hef" 16 | CLIP_HEF_NAME = "clip_resnet_50x4_h8l.hef" 17 | else: 18 | device_architecture = "HAILO8" 19 | # HEF files for H8 20 | YOLO5_HEF_NAME = "yolov5s_personface.hef" 21 | CLIP_HEF_NAME = "clip_resnet_50x4.hef" 22 | 23 | from hailo_apps_infra.gstreamer_helper_pipelines import ( 24 | SOURCE_PIPELINE, 25 | QUEUE, 26 | INFERENCE_PIPELINE, 27 | INFERENCE_PIPELINE_WRAPPER, 28 | TRACKER_PIPELINE, 29 | DISPLAY_PIPELINE, 30 | CROPPER_PIPELINE 31 | ) 32 | 33 | 34 | ################################################################### 35 | # NEW helper function to add in your gstreamer_helper_pipelines.py 36 | ################################################################### 37 | 38 | 39 | def get_pipeline(self): 40 | # Initialize directories and paths 41 | RESOURCES_DIR = os.path.join(self.current_path, "resources") 42 | POSTPROCESS_DIR = self.tappas_postprocess_dir 43 | hailopython_path = os.path.join(self.current_path, "clip_app/clip_hailopython.py") 44 | # personface 45 | YOLO5_POSTPROCESS_SO = os.path.join(POSTPROCESS_DIR, "libyolo_post.so") 46 | YOLO5_NETWORK_NAME = "yolov5_personface_letterbox" 47 | YOLO5_HEF_PATH = os.path.join(RESOURCES_DIR, YOLO5_HEF_NAME) 48 | YOLO5_CONFIG_PATH = os.path.join(RESOURCES_DIR, "configs/yolov5_personface.json") 49 | DETECTION_POST_PIPE = f'hailofilter so-path={YOLO5_POSTPROCESS_SO} qos=false function_name={YOLO5_NETWORK_NAME} config-path={YOLO5_CONFIG_PATH} ' 50 | hef_path = YOLO5_HEF_PATH 51 | 52 | # CLIP 53 | clip_hef_path = os.path.join(RESOURCES_DIR, CLIP_HEF_NAME) 54 | clip_postprocess_so = os.path.join(RESOURCES_DIR, "libclip_post.so") 55 | DEFAULT_CROP_SO = os.path.join(RESOURCES_DIR, "libclip_croppers.so") 56 | clip_matcher_so = os.path.join(RESOURCES_DIR, "libclip_matcher.so") 57 | clip_matcher_config = os.path.join(self.current_path, "embeddings.json") 58 | 59 | source_pipeline = SOURCE_PIPELINE( 60 | video_source=self.input, 61 | video_width=video_width, 62 | video_height=video_height, 63 | video_format='RGB', 64 | name='source' 65 | ) 66 | 67 | detection_pipeline = INFERENCE_PIPELINE( 68 | hef_path=hef_path, 69 | post_process_so=YOLO5_POSTPROCESS_SO, 70 | batch_size=batch_size, 71 | config_json=YOLO5_CONFIG_PATH, 72 | post_function_name=YOLO5_NETWORK_NAME, 73 | scheduler_priority=31, 74 | scheduler_timeout_ms=100, 75 | name='detection_inference' 76 | ) 77 | 78 | if self.detector == "none": 79 | detection_pipeline_wrapper = "" 80 | else: 81 | detection_pipeline_wrapper = INFERENCE_PIPELINE_WRAPPER(detection_pipeline) 82 | 83 | 84 | clip_pipeline = INFERENCE_PIPELINE( 85 | hef_path=clip_hef_path, 86 | post_process_so=clip_postprocess_so, 87 | batch_size=batch_size, 88 | name='clip_inference', 89 | scheduler_timeout_ms=1000, 90 | scheduler_priority=16, 91 | ) 92 | 93 | if self.detector == "person": 94 | class_id = 1 95 | crop_function_name = "person_cropper" 96 | elif self.detector == "face": 97 | class_id = 2 98 | crop_function_name = "face_cropper" 99 | else: # fast_sam 100 | class_id = 0 101 | crop_function_name = "object_cropper" 102 | 103 | tracker_pipeline = TRACKER_PIPELINE(class_id=class_id, keep_past_metadata=True) 104 | 105 | 106 | 107 | # Clip pipeline with cropper integration 108 | clip_cropper_pipeline = CROPPER_PIPELINE( 109 | inner_pipeline=clip_pipeline, 110 | so_path=DEFAULT_CROP_SO, 111 | function_name=crop_function_name, 112 | name='clip_cropper' 113 | ) 114 | 115 | # Clip pipeline with muxer integration (no cropper) 116 | clip_pipeline_wrapper = f'tee name=clip_t hailomuxer name=clip_hmux \ 117 | clip_t. ! {QUEUE(name="clip_bypass_q", max_size_buffers=20)} ! clip_hmux.sink_0 \ 118 | clip_t. ! {QUEUE(name="clip_muxer_queue")} ! videoscale n-threads=4 qos=false ! {clip_pipeline} ! clip_hmux.sink_1 \ 119 | clip_hmux. ! {QUEUE(name="clip_hmux_queue")} ' 120 | 121 | # TBD aggregator does not support ROI classification 122 | # clip_pipeline_wrapper = INFERENCE_PIPELINE_WRAPPER(clip_pipeline, name='clip') 123 | 124 | display_pipeline = DISPLAY_PIPELINE(sync=self.sync, show_fps=self.show_fps) 125 | 126 | # Text to image matcher 127 | CLIP_PYTHON_MATCHER = f'hailopython name=pyproc module={hailopython_path} qos=false ' 128 | CLIP_CPP_MATCHER = f'hailofilter so-path={clip_matcher_so} qos=false config-path={clip_matcher_config} ' 129 | 130 | clip_postprocess_pipeline = f' {CLIP_PYTHON_MATCHER} ! \ 131 | {QUEUE(name="clip_postprocess_queue")} ! \ 132 | identity name=identity_callback ' 133 | 134 | # PIPELINE 135 | if self.detector == "none": 136 | PIPELINE = f'{source_pipeline} ! \ 137 | {clip_pipeline_wrapper} ! \ 138 | {clip_postprocess_pipeline} ! \ 139 | {display_pipeline}' 140 | else: 141 | PIPELINE = f'{source_pipeline} ! \ 142 | {detection_pipeline_wrapper} ! \ 143 | {tracker_pipeline} ! \ 144 | {clip_cropper_pipeline} ! \ 145 | {clip_postprocess_pipeline} ! \ 146 | {display_pipeline}' 147 | 148 | return PIPELINE 149 | -------------------------------------------------------------------------------- /clip_app/gui.py: -------------------------------------------------------------------------------- 1 | import gi 2 | gi.require_version('Gtk', '3.0') 3 | from gi.repository import Gtk, GLib 4 | 5 | from clip_app.logger_setup import setup_logger 6 | logger = setup_logger() 7 | 8 | def build_ui(self, args): 9 | ui_vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=6) 10 | self.add(ui_vbox) 11 | self.ui_vbox = ui_vbox 12 | 13 | # Create a label and slider for threshold control 14 | slider_label = Gtk.Label("Threshold:", valign=Gtk.Align.CENTER) 15 | self.slider = Gtk.Scale.new_with_range(Gtk.Orientation.HORIZONTAL, 0.0, 1.0, 0.05) 16 | self.slider.set_value(args.detection_threshold) 17 | self.slider.connect("value-changed", self.on_slider_value_changed) 18 | 19 | # Pack the label and slider into a horizontal box and add it to the vertical box 20 | hbox = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=10) 21 | hbox.pack_start(slider_label, False, False, 0) 22 | hbox.pack_start(self.slider, True, True, 0) 23 | ui_vbox.pack_start(hbox, False, False, 0) 24 | 25 | # Text boxes to control text embeddings 26 | self.add_text_boxes() 27 | 28 | # add 2 buttons to hbox load and save 29 | hbox = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=6) 30 | self.load_button = Gtk.Button(label="Load") 31 | self.load_button.connect("clicked", self.on_load_button_clicked) 32 | hbox.pack_start(self.load_button, False, False, 0) 33 | self.save_button = Gtk.Button(label="Save") 34 | self.save_button.connect("clicked", self.on_save_button_clicked) 35 | hbox.pack_start(self.save_button, False, False, 0) 36 | # Add label and track_id input box 37 | self.track_id_label = Gtk.Label(label="Track ID") 38 | self.track_id_entry = Gtk.Entry() 39 | hbox.pack_end(self.track_id_entry, False, False, 0) 40 | hbox.pack_end(self.track_id_label, False, False, 0) 41 | self.track_id_entry.connect("activate", lambda widget: self.on_track_id_update(widget)) 42 | self.track_id_entry.connect("focus-out-event", lambda widget, event: self.on_track_id_update(widget)) 43 | 44 | ui_vbox.pack_start(hbox, False, False, 0) 45 | 46 | # Quit Button 47 | quit_button = Gtk.Button(label="Quit") 48 | quit_button.connect("clicked", self.quit_button_clicked) 49 | ui_vbox.pack_start(quit_button, False, False, 0) 50 | 51 | 52 | def add_text_boxes(self, N=6): 53 | """Adds N text boxes to the GUI and sets up callbacks for text changes.""" 54 | self.text_boxes = [] 55 | self.probability_progress_bars = [] 56 | self.negative_check_buttons = [] 57 | self.ensemble_check_buttons = [] 58 | self.text_prefix_labels = [] 59 | 60 | # Create vertical boxes for each column 61 | vbox1 = Gtk.Box(orientation=Gtk.Orientation.VERTICAL) 62 | vbox2 = Gtk.Box(orientation=Gtk.Orientation.VERTICAL) 63 | vbox3 = Gtk.Box(orientation=Gtk.Orientation.VERTICAL) 64 | vbox4 = Gtk.Box(orientation=Gtk.Orientation.VERTICAL) 65 | vbox5 = Gtk.Box(orientation=Gtk.Orientation.VERTICAL) 66 | 67 | # Adding header line to the vertical boxes 68 | vbox1.pack_start(Gtk.Label(label="Negative", width_chars=10), False, False, 0) 69 | vbox2.pack_start(Gtk.Label(label="Ensemble", width_chars=10), False, False, 0) 70 | vbox3.pack_start(Gtk.Label(label="Prefix", width_chars=10), False, False, 0) 71 | vbox4.pack_start(Gtk.Label(label="Text Description", width_chars=20), False, False, 0) 72 | vbox5.pack_start(Gtk.Label(label="Probability", width_chars=10), False, False, 0) 73 | 74 | for i in range(N): 75 | # Create and add a negative check button with a callback 76 | negative_check_button = Gtk.CheckButton() 77 | negative_check_button.connect("toggled", self.on_negative_check_button_toggled, i) 78 | negative_check_button.set_halign(Gtk.Align.CENTER) 79 | negative_check_button.set_valign(Gtk.Align.CENTER) 80 | vbox1.pack_start(negative_check_button, True, True, 0) 81 | self.negative_check_buttons.append(negative_check_button) 82 | # Create and add an ensemble check button 83 | ensemble_check_button = Gtk.CheckButton() 84 | ensemble_check_button.connect("toggled", self.on_ensemble_check_button_toggled, i) 85 | ensemble_check_button.set_halign(Gtk.Align.CENTER) 86 | ensemble_check_button.set_valign(Gtk.Align.CENTER) 87 | vbox2.pack_start(ensemble_check_button, True, True, 0) 88 | self.ensemble_check_buttons.append(ensemble_check_button) 89 | # Create and add a label 90 | label = Gtk.Label(label=f"{self.text_image_matcher.text_prefix}") 91 | vbox3.pack_start(label, True, True, 0) 92 | self.text_prefix_labels.append(label) 93 | 94 | # Create and add a text box with callbacks 95 | text_box = Gtk.Entry() 96 | text_box.set_width_chars(20) # Adjust the width to align with the "Text Description" header 97 | text_box.connect("activate", lambda widget, idx=i: self.on_text_box_updated(widget, None, idx)) 98 | text_box.connect("focus-out-event", lambda widget, event, idx=i: self.on_text_box_updated(widget, event, idx)) 99 | vbox4.pack_start(text_box, True, True, 0) 100 | self.text_boxes.append(text_box) 101 | 102 | # Create and add a progress bar with vertical alignment 103 | progress_bar = Gtk.ProgressBar() 104 | progress_bar.set_fraction(0.0) 105 | progress_bar.set_valign(Gtk.Align.CENTER) 106 | vbox5.pack_start(progress_bar, True, True, 0) 107 | self.probability_progress_bars.append(progress_bar) 108 | 109 | # Create a horizontal box to hold the vertical boxes 110 | hbox = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=6) 111 | hbox.pack_start(vbox1, False, False, 0) 112 | hbox.pack_start(vbox2, False, False, 0) 113 | hbox.pack_start(vbox3, False, False, 0) 114 | hbox.pack_start(vbox4, True, True, 0) 115 | hbox.pack_start(vbox5, True, True, 0) 116 | 117 | # Schedule the update_progress_bars method to be called every half a second 118 | GLib.timeout_add(500, self.update_progress_bars) 119 | 120 | self.ui_vbox.pack_start(hbox, False, False, 0) 121 | 122 | 123 | def update_text_boxes(self): 124 | if len(self.text_image_matcher.entries) > self.max_entries: 125 | return 126 | for i, entry in enumerate(self.text_image_matcher.entries): 127 | self.text_boxes[i].set_text(entry.text) 128 | self.negative_check_buttons[i].set_active(entry.negative) 129 | self.ensemble_check_buttons[i].set_active(entry.ensemble) 130 | 131 | 132 | def update_text_prefix(self, new_text_prefix): 133 | self.text_image_matcher.text_prefix = new_text_prefix 134 | for label in self.text_prefix_labels: 135 | label.set_text(new_text_prefix) 136 | 137 | 138 | def quit_button_clicked(self, widget): 139 | logger.info("Quit button clicked") 140 | self.shutdown() 141 | 142 | 143 | def on_text_box_updated(self, widget, event, idx): 144 | """Callback function for text box updates.""" 145 | text = widget.get_text() 146 | logger.info("Text box %s updated: %s", idx, text) 147 | self.text_image_matcher.add_text(widget.get_text(), idx) 148 | 149 | def on_track_id_update(self, widget): 150 | """Callback function for track id updates.""" 151 | track_id_focus = widget.get_text() 152 | # check if track id is a number 153 | if not track_id_focus.isdigit(): 154 | logger.warning("Track ID must be a number, got: %s", track_id_focus) 155 | widget.set_text("") 156 | self.text_image_matcher.track_id_focus = None 157 | return 158 | logger.info("Track ID updated: %s", track_id_focus) 159 | self.text_image_matcher.track_id_focus = int(track_id_focus) 160 | 161 | def on_slider_value_changed(self, widget): 162 | value = float(widget.get_value()) 163 | logger.info("Setting detection threshold to: %s", value) 164 | self.text_image_matcher.set_threshold(value) 165 | 166 | def on_negative_check_button_toggled(self, widget, idx): 167 | negative = widget.get_active() 168 | logger.info("Text box %s is set to negative: %s", idx, negative) 169 | self.text_image_matcher.entries[idx].negative = negative 170 | 171 | def on_ensemble_check_button_toggled(self, widget, idx): 172 | ensemble = widget.get_active() 173 | logger.info("Text box %s is set to ensemble: %s", idx, ensemble) 174 | # Encode text with new ensemble option 175 | self.text_image_matcher.add_text(self.text_boxes[idx].get_text(), idx, ensemble=ensemble) 176 | 177 | def on_load_button_clicked(self, widget): 178 | """Callback function for the load button.""" 179 | logger.info("Loading embeddings from %s\n", self.json_file) 180 | self.text_image_matcher.load_embeddings(self.json_file) 181 | if len(self.text_image_matcher.entries) > self.max_entries: 182 | print(f"Load more then {self.max_entries} embeddings.\nSkipping updating text boxes.") 183 | self.update_text_boxes() 184 | self.slider.set_value(self.text_image_matcher.threshold) 185 | self.update_text_prefix(self.text_image_matcher.text_prefix) 186 | 187 | def on_save_button_clicked(self, widget): 188 | """Callback function for the save button.""" 189 | logger.info("Saving embeddings to %s\n", self.json_file) 190 | self.text_image_matcher.save_embeddings(self.json_file) 191 | 192 | def update_progress_bars(self): 193 | """Updates the progress bars based on the current probability values.""" 194 | if len(self.text_image_matcher.entries) > self.max_entries: 195 | return 196 | for i, entry in enumerate(self.text_image_matcher.entries): 197 | if entry.text != "": 198 | self.probability_progress_bars[i].set_fraction(entry.tracked_probability) 199 | else: 200 | self.probability_progress_bars[i].set_fraction(0.0) 201 | return True 202 | 203 | def disable_text_boxes(self): 204 | for text_box in self.text_boxes: 205 | text_box.set_editable(False) 206 | -------------------------------------------------------------------------------- /clip_app/logger_setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def setup_logger(name=__name__, level=logging.DEBUG): 4 | # Set up a logger with the specified name and level. 5 | # Args: 6 | # name (str): The name of the logger. Defaults to the name of the current module. 7 | # level (int): The initial logging level. Defaults to logging.DEBUG. 8 | # Returns: 9 | # logger: The configured logger instance. 10 | 11 | # Get the logger with the specified name 12 | logger = logging.getLogger(name) 13 | # Set the logging level 14 | logger.setLevel(level) 15 | if not logger.handlers: 16 | # Create a console handler with the specified level 17 | console_handler = logging.StreamHandler() 18 | console_handler.setLevel(level) 19 | # Create a formatter 20 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 21 | # Add the formatter to the handler 22 | console_handler.setFormatter(formatter) 23 | # Add the handler to the logger 24 | logger.addHandler(console_handler) 25 | return logger 26 | 27 | def set_log_level(logger, level): 28 | """ 29 | Set the logging level for both the logger and its handlers. 30 | Args: 31 | logger: The logger instance whose level should be changed. 32 | level (int): The new logging level. 33 | """ 34 | logger.setLevel(level) 35 | for handler in logger.handlers: 36 | handler.setLevel(level) 37 | -------------------------------------------------------------------------------- /clip_app/text_image_matcher.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import os 4 | import logging 5 | import sys 6 | import argparse 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from clip_app.logger_setup import setup_logger, set_log_level 11 | 12 | """ 13 | This class is used to store the text embeddings and match them to image embeddings 14 | This class should be used as a singleton! 15 | An instance of this class is created in the end of this file. 16 | import text_image_matcher from this file to make sure that only one instance of the TextImageMatcher class is created. 17 | Example: from TextImageMatcher import text_image_matcher 18 | """ 19 | 20 | # Set up the logger 21 | logger = setup_logger() 22 | # Change the log level to INFO 23 | set_log_level(logger, logging.INFO) 24 | 25 | # Set up global variables. Only required imports are done in the init functions 26 | clip = None 27 | torch = None 28 | 29 | 30 | class TextEmbeddingEntry: 31 | def __init__(self, text="", embedding=None, negative=False, ensemble=False): 32 | self.text = text 33 | self.embedding = embedding if embedding is not None else np.array([]) 34 | self.negative = negative 35 | self.ensemble = ensemble 36 | self.probability = 0.0 37 | self.tracked_probability = 0.0 38 | 39 | def to_dict(self): 40 | return { 41 | "text": self.text, 42 | "embedding": self.embedding.tolist(), # Convert numpy array to list 43 | "negative": self.negative, 44 | "ensemble": self.ensemble 45 | } 46 | 47 | 48 | class Match: 49 | def __init__(self, row_idx, text, similarity, entry_index, negative, passed_threshold): 50 | self.row_idx = row_idx # row index in the image embedding 51 | self.text = text # best matching text 52 | self.similarity = similarity # similarity between the image and best text embeddings 53 | self.entry_index = entry_index # index of the entry in TextImageMatcher.entries 54 | self.negative = negative # True if the best match is a negative entry 55 | self.passed_threshold = passed_threshold # True if the similarity is above the threshold 56 | 57 | def to_dict(self): 58 | return { 59 | "row_idx": self.row_idx, 60 | "text": self.text, 61 | "similarity": self.similarity, 62 | "entry_index": self.entry_index, 63 | "negative": self.negative, 64 | "passed_threshold": self.passed_threshold 65 | } 66 | 67 | 68 | class TextImageMatcher: 69 | _instance = None 70 | 71 | def __new__(cls): 72 | if cls._instance is None: 73 | cls._instance = super(TextImageMatcher, cls).__new__(cls) 74 | return cls._instance 75 | 76 | def __init__(self, model_name="RN50x4", threshold=0.8, max_entries=6): 77 | self.model = None # model is initialized in init_clip 78 | self.preprocess = None # preprocess is initialized in init_clip 79 | self.model_runtime = None 80 | self.model_name = model_name 81 | self.threshold = threshold 82 | self.run_softmax = True 83 | self.device = "cpu" 84 | 85 | self.max_entries = max_entries 86 | self.entries = [TextEmbeddingEntry() for _ in range(max_entries)] 87 | self.user_data = None # user data can be used to store additional information 88 | self.text_prefix = "A photo of a " 89 | self.ensemble_template = [ 90 | 'a photo of a {}.', 91 | 'a photo of the {}.', 92 | 'a photo of my {}.', 93 | 'a photo of a big {}.', 94 | 'a photo of a small {}.', 95 | ] 96 | self.track_id_focus = None # Used to focus on specific track id when showing confidence 97 | 98 | def init_clip(self): 99 | """Initialize the CLIP model.""" 100 | global clip, torch 101 | import clip 102 | import torch 103 | logger.info("Loading model %s on device %s, this might take a while...", self.model_name, self.device) 104 | self.model, self.preprocess = clip.load(self.model_name, device=self.device) 105 | self.model_runtime = "clip" 106 | 107 | def set_threshold(self, new_threshold): 108 | self.threshold = new_threshold 109 | 110 | def set_text_prefix(self, new_text_prefix): 111 | self.text_prefix = new_text_prefix 112 | 113 | def set_ensemble_template(self, new_ensemble_template): 114 | self.ensemble_template = new_ensemble_template 115 | 116 | def update_text_entries(self, new_entry, index=None): 117 | if index is None: 118 | for i, entry in enumerate(self.entries): 119 | if entry.text == "": 120 | self.entries[i] = new_entry 121 | return 122 | if len(self.entries) == self.max_entries: 123 | logger.info(f"Entry list has more then {self.max_entries} entries, The gui will not show the prompts.") 124 | self.entries.append(new_entry) 125 | elif 0 <= index < len(self.entries): 126 | self.entries[index] = new_entry 127 | else: 128 | logger.error("Index out of bounds: %s", index) 129 | 130 | def add_text(self, text, index=None, negative=False, ensemble=False): 131 | if self.model_runtime is None: 132 | logger.error("No model is loaded. Please call init_clip before calling add_text.") 133 | return 134 | text_entries = [template.format(text) for template in self.ensemble_template] if ensemble else [self.text_prefix + text] 135 | logger.debug("Adding text entries: %s", text_entries) 136 | 137 | global clip, torch 138 | text_tokens = clip.tokenize(text_entries).to(self.device) 139 | with torch.no_grad(): 140 | text_features = self.model.encode_text(text_tokens) 141 | text_features /= text_features.norm(dim=-1, keepdim=True) 142 | ensemble_embedding = torch.mean(text_features, dim=0).cpu().numpy().flatten() 143 | new_entry = TextEmbeddingEntry(text, ensemble_embedding, negative, ensemble) 144 | self.update_text_entries(new_entry, index) 145 | 146 | def get_embeddings(self): 147 | """Return a list of indexes to self.entries if entry.text != "".""" 148 | return [i for i, entry in enumerate(self.entries) if entry.text != ""] 149 | 150 | def get_texts(self): 151 | """Return all entries' text (not only valid ones).""" 152 | return [entry.text for entry in self.entries] 153 | 154 | def save_embeddings(self, filename): 155 | data_to_save = { 156 | "threshold": self.threshold, 157 | "text_prefix": self.text_prefix, 158 | "ensemble_template": self.ensemble_template, 159 | "entries": [entry.to_dict() for entry in self.entries] 160 | } 161 | with open(filename, 'w', encoding='utf-8') as f: 162 | json.dump(data_to_save, f) 163 | 164 | def load_embeddings(self, filename): 165 | if not os.path.isfile(filename): 166 | with open(filename, 'w', encoding='utf-8') as f: 167 | f.write('') # Create an empty file or initialize with some data 168 | logger.info("File %s does not exist, creating it.", filename) 169 | else: 170 | try: 171 | with open(filename, 'r', encoding='utf-8') as f: 172 | data = json.load(f) 173 | self.threshold = data['threshold'] 174 | self.text_prefix = data['text_prefix'] 175 | self.ensemble_template = data['ensemble_template'] 176 | self.entries = [TextEmbeddingEntry(text=entry['text'], 177 | embedding=np.array(entry['embedding']), 178 | negative=entry['negative'], 179 | ensemble=entry['ensemble']) 180 | for entry in data['entries']] 181 | except Exception as e: 182 | logger.error("Error while loading file %s: %s. Maybe you forgot to save your embeddings?", filename, e) 183 | 184 | def get_image_embedding(self, image): 185 | if self.model_runtime is None: 186 | logger.error("No model is loaded. Please call init_clip before calling get_image_embedding.") 187 | return None 188 | image_input = self.preprocess(image).unsqueeze(0).to(self.device) 189 | with torch.no_grad(): 190 | image_embedding = self.model.encode_image(image_input) 191 | image_embedding /= image_embedding.norm(dim=-1, keepdim=True) 192 | return image_embedding.cpu().numpy().flatten() 193 | 194 | def match(self, image_embedding_np, report_all=False, update_tracked_probability=None): 195 | """ 196 | This function is used to match an image embedding to a text embedding 197 | Returns a list of tuples: (row_idx, text, similarity, entry_index) 198 | row_idx is the index of the row in the image embedding 199 | text is the best matching text 200 | similarity is the similarity between the image and text embeddings 201 | entry_index is the index of the entry in self.entries 202 | If the best match is a negative entry, or if the similarity is below the threshold, the tuple is not returned 203 | If no match is found, an empty list is returned 204 | If report_all is True, the function returns a list of all matches, 205 | including negative entries and entries below the threshold. 206 | """ 207 | if len(image_embedding_np.shape) == 1: 208 | image_embedding_np = image_embedding_np.reshape(1, -1) 209 | results = [] 210 | all_dot_products = None 211 | valid_entries = self.get_embeddings() 212 | if len(valid_entries) == 0: 213 | return [] 214 | text_embeddings_np = np.array([self.entries[i].embedding for i in valid_entries]) 215 | for row_idx, image_embedding_1d in enumerate(image_embedding_np): 216 | dot_products = np.dot(text_embeddings_np, image_embedding_1d) 217 | all_dot_products = dot_products[np.newaxis, :] if all_dot_products is None else np.vstack((all_dot_products, dot_products)) 218 | 219 | if self.run_softmax: 220 | similarities = np.exp(100 * dot_products) 221 | similarities /= np.sum(similarities) 222 | else: 223 | # These magic numbers were collected by running actual inferences and measureing statistics. 224 | # stats min: 0.27013595659637846, max: 0.4043235050452188, avg: 0.33676838831786493 225 | # map to [0,1] 226 | similarities = (dot_products - 0.27) / (0.41 - 0.27) 227 | similarities = np.clip(similarities, 0, 1) 228 | 229 | best_idx = np.argmax(similarities) 230 | best_similarity = similarities[best_idx] 231 | for i, _ in enumerate(similarities): 232 | self.entries[valid_entries[i]].probability = similarities[i] 233 | if update_tracked_probability is None or update_tracked_probability == row_idx: 234 | logger.debug("Updating tracked probability for entry %s to %s", valid_entries[i], similarities[i]) 235 | self.entries[valid_entries[i]].tracked_probability = similarities[i] 236 | new_match = Match(row_idx, 237 | self.entries[valid_entries[best_idx]].text, 238 | best_similarity, valid_entries[best_idx], 239 | self.entries[valid_entries[best_idx]].negative, 240 | best_similarity > self.threshold) 241 | if not report_all and new_match.negative: 242 | continue 243 | if report_all or new_match.passed_threshold: 244 | results.append(new_match) 245 | 246 | logger.debug("Best match output: %s", results) 247 | return results 248 | 249 | 250 | text_image_matcher = TextImageMatcher() 251 | 252 | 253 | def main(): 254 | parser = argparse.ArgumentParser() 255 | parser.add_argument("--output", type=str, default="text_embeddings.json", help="output file name default=text_embeddings.json") 256 | parser.add_argument("--interactive", action="store_true", help="input text from interactive shell") 257 | parser.add_argument("--image-path", type=str, default=None, help="Optional, path to image file to match. Note image embeddings are not running on Hailo here.") 258 | parser.add_argument('--texts-list', nargs='+', help='A list of texts to add to the matcher, the first one will be the searched text, the others will be considered negative prompts.\n Example: --texts-list "cat" "dog" "yellow car"') 259 | parser.add_argument('--texts-json', type=str, help='A json of texts to add to the matcher, the json will include 2 keys negative and positive, the values are going to be lists of texts\n Example: --texts-json resources/texts_json_example.json') 260 | args = parser.parse_args() 261 | 262 | matcher = TextImageMatcher() 263 | matcher.init_clip() 264 | texts = [] 265 | if args.interactive: 266 | while True: 267 | text = input(f'Enter text (leave empty to finish) {matcher.text_prefix}: ') 268 | if text == "": 269 | break 270 | texts.append(text) 271 | elif args.texts_list: 272 | texts = args.texts_list 273 | elif args.texts_json: 274 | with open(args.texts_json, 'r') as file: 275 | data = json.load(file) 276 | texts_positive = data['positive'] 277 | texts_negative = data['negative'] 278 | else: 279 | texts = ["birthday cake", "person", "landscape"] 280 | 281 | if not args.texts_json: 282 | texts_positive = [texts[0]] 283 | texts_negative = texts[1:] 284 | 285 | logger.info("Adding text embeddings: ") 286 | 287 | for text in texts_positive: 288 | logger.info('%s%s (%s)', matcher.text_prefix, text, "positive") 289 | for text in texts_negative: 290 | logger.info('%s%s (%s)', matcher.text_prefix, text, "negative") 291 | 292 | 293 | start_time = time.time() 294 | 295 | for text in texts_positive: 296 | matcher.add_text(text, negative=False) 297 | for text in texts_negative: 298 | matcher.add_text(text, negative=True) 299 | 300 | end_time = time.time() 301 | logger.info("Time taken to add %s text embeddings using add_text(): %.4f seconds", len(texts), end_time - start_time) 302 | 303 | matcher.save_embeddings(args.output) 304 | 305 | if args.image_path is None: 306 | logger.info("No image path provided, skipping image embedding generation") 307 | sys.exit() 308 | 309 | image = Image.open(args.image_path) 310 | image_embedding = matcher.get_image_embedding(image) 311 | 312 | start_time = time.time() 313 | result = matcher.match(image_embedding, report_all=True) 314 | end_time = time.time() 315 | 316 | if result: 317 | logger.info("Best match: %s", result[0].text) 318 | 319 | valid_entries = matcher.get_embeddings() 320 | for i in valid_entries: 321 | logger.info("Entry %s: %s similarity: %.4f", i, matcher.entries[i].text, matcher.entries[i].probability) 322 | logger.info("Time taken to run match(): %.4f seconds", end_time - start_time) 323 | 324 | if __name__ == "__main__": 325 | main() 326 | -------------------------------------------------------------------------------- /clip_application.py: -------------------------------------------------------------------------------- 1 | import gi 2 | gi.require_version('Gst', '1.0') 3 | from gi.repository import Gst 4 | import hailo 5 | from clip_app.clip_app_pipeline import ClipApp 6 | 7 | class app_callback_class: 8 | def __init__(self): 9 | self.frame_count = 0 10 | self.use_frame = False 11 | self.running = True 12 | 13 | def increment(self): 14 | self.frame_count += 1 15 | 16 | def get_count(self): 17 | return self.frame_count 18 | 19 | 20 | def app_callback(self, pad, info, user_data): 21 | """ 22 | This is the callback function that will be called when data is available 23 | from the pipeline. 24 | Processing time should be kept to a minimum in this function. 25 | If longer processing is needed, consider using a separate thread / process. 26 | """ 27 | # Get the GstBuffer from the probe info 28 | buffer = info.get_buffer() 29 | # Check if the buffer is valid 30 | if buffer is None: 31 | return Gst.PadProbeReturn.OK 32 | string_to_print = "" 33 | # Get the detections from the buffer 34 | roi = hailo.get_roi_from_buffer(buffer) 35 | detections = roi.get_objects_typed(hailo.HAILO_DETECTION) 36 | if len(detections) == 0: 37 | detections = [roi] # Use the ROI as the detection 38 | # Parse the detections 39 | for detection in detections: 40 | track = detection.get_objects_typed(hailo.HAILO_UNIQUE_ID) 41 | track_id = None 42 | label = None 43 | confidence = 0.0 44 | for track_id_obj in track: 45 | track_id = track_id_obj.get_id() 46 | if track_id is not None: 47 | string_to_print += f'Track ID: {track_id} ' 48 | classifications = detection.get_objects_typed(hailo.HAILO_CLASSIFICATION) 49 | if len(classifications) > 0: 50 | string_to_print += ' CLIP Classifications:' 51 | for classification in classifications: 52 | label = classification.get_label() 53 | confidence = classification.get_confidence() 54 | string_to_print += f'Label: {label} Confidence: {confidence:.2f} ' 55 | string_to_print += '\n' 56 | if isinstance(detection, hailo.HailoDetection): 57 | label = detection.get_label() 58 | bbox = detection.get_bbox() 59 | confidence = detection.get_confidence() 60 | string_to_print += f"Detection: {label} {confidence:.2f}\n" 61 | if string_to_print: 62 | print(string_to_print) 63 | return Gst.PadProbeReturn.OK 64 | 65 | def main(): 66 | user_data = app_callback_class() 67 | clip = ClipApp(user_data, app_callback) 68 | clip.run() 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /community_projects/ad_genie/README.md: -------------------------------------------------------------------------------- 1 | ![](../../resources/Hackathon-banner-2024.png) 2 | # AD GENIE - Personalized Advertisement 3 | ![](resources/ad_genie.gif) 4 | 5 | # Watch our Youtube video 6 | ### [![Watch our beautifully illustrated demo!](https://img.youtube.com/vi/0_v2V7lV514/0.jpg)]() 7 | 8 | ## Overview 9 | This project personalizes ads using the CLIP model. 10 | It runs on Raspberry Pi 5, AI+ HAT (with the Hailo 8 device). 11 | The project is a runtime project which receives an input from a USB camera (of a person wearing a certain outfit) and outputs images of the outfit that fits most the "style" of the person (from a certain database). 12 | 13 | This system resembles a content personalization system in public spaces, and is capable of interacting with nearby individuals and tailoring commercial content to their preferences. 14 | 15 | The system can be utilized in various public settings, such as shopping malls, street billboards, and bus station displays. In retail settings, it serves as a tool to elevate the looks of shop window displays, attracting and engaging customers to enter the store. 16 | 17 | ## Setup Instructions 18 | - Follow README setup instructions of [CLIP application example](../../README.md) 19 | 20 | ### Register and Download 21 | - Register and download zara dataset from [Zara Dataset](https://www.kaggle.com/datasets/abhinavtyagi2708/zara-dataset-men-and-women-clothing) 22 | 23 | ### Organize Dataset Structure 24 | - The downloaded dataset is a zip files, unzip it to a folder. 25 | - The dataset contains a Men directory inside another Men directory, and the same applies for Women. Remove the redundant nested directories. 26 | - Delete all categories that are not clothing, such as: 27 | - Shoes 28 | - Bags 29 | - Jewelry 30 | - Special Prices 31 | - Perfumes 32 | - Accessories 33 | - Beauty 34 | 35 | ![](resources/structure.jpeg) 36 | ### Run Data Preparation Script 37 | 38 | - Execute the data_preparation.py script to create zara.json and an images directory containing all images under the resources directory. Note that this process might take 30-40 minutes, and some downloads may fail. 39 | ```bash 40 | python data_preparation.py --data 41 | ``` 42 | ### Run Labels Preparation Script 43 | - Execute the labels_preparation.py script to generate labels.json. 44 | ```bash 45 | python lables_preparation.py --json-path 46 | ``` 47 | ### Creating Data Embeddings 48 | - Generate data embeddings using the following command (this step takes approximately 15 minutes): 49 | ```bash 50 | text_image_matcher --texts-json resources/lables.json --output resources/data_embdedding.json 51 | ``` 52 | ### Adjusting the Threshold 53 | - Open resources/data_embedding.json and locate the threshold at the beginning of the file. Change its value from 0.8 to 0.01. 54 | 55 | ## Running Example 56 | ```bash 57 | python ad_genie.py -d person -i /dev/video0 --json-path resources/data_embdedding.json 58 | ``` 59 | - Check Your USB camera port using get-usb-camera 60 | -------------------------------------------------------------------------------- /community_projects/ad_genie/ad_genie.py: -------------------------------------------------------------------------------- 1 | import gi 2 | gi.require_version('Gst', '1.0') 3 | from gi.repository import Gst 4 | import os 5 | import json 6 | import random 7 | import time 8 | import multiprocessing 9 | import hailo 10 | from PIL import Image, ImageDraw 11 | from hailo_apps_infra.gstreamer_app import app_callback_class 12 | from clip_app.clip_app_pipeline import ClipApp 13 | 14 | class user_app_callback_class(app_callback_class): 15 | """ 16 | User-defined callback class to process detections and manage display updates. 17 | """ 18 | 19 | def __init__(self): 20 | """ 21 | Initializes the user callback class, including loading resources, 22 | setting up the display, and starting the label processing process. 23 | """ 24 | super().__init__() 25 | screen_width = 1080 26 | screen_height = 1920 27 | self.display = DisplayManager(screen_width, screen_height) 28 | self.display.show() 29 | CLOTHES_JSON_PATH = "resources/zara.json" 30 | CLOTHES_FOLDER = os.path.join("static", "clothes") 31 | # Load the clothes mapping from JSON 32 | with open(CLOTHES_JSON_PATH, "r", encoding="utf-8") as f: 33 | clothes_map = json.load(f) 34 | self.clothes_map = clothes_map 35 | self.MAX_QUEUE_SIZE = 3 36 | self.labels_queue = multiprocessing.Queue(maxsize=self.MAX_QUEUE_SIZE) 37 | client_process = multiprocessing.Process(target=self.label_to_css, args=(self.labels_queue,)) 38 | client_process.start() 39 | 40 | def parse_lable(self,lable_str): 41 | """ 42 | Parses a label string to determine the corresponding clothing item file. 43 | """ 44 | 45 | if "Men" in lable_str: 46 | gender = "Men" 47 | elif "Women" in lable_str: 48 | gender = "Women" 49 | # We also assume the item is after "wearing". 50 | # For example, "a men wearing REFLECTIVE EFFECT JACKET" 51 | # -> item_str = "REFLECTIVE EFFECT JACKET" 52 | parts = lable_str.split("wearing a ") 53 | if len(parts) > 1: 54 | item_str = parts[1].strip().upper() # "REFLECTIVE EFFECT JACKET" 55 | else: 56 | item_str = None 57 | # Attempt to look up the file 58 | matched_file = None 59 | if gender and item_str: 60 | if gender in self.clothes_map.keys() and item_str in self.clothes_map[gender].keys(): 61 | matched_file = self.clothes_map[gender][item_str][0] 62 | return matched_file 63 | 64 | def update_image(self, file = None): 65 | """ 66 | Updates the display with a new image. 67 | """ 68 | if file is None: 69 | self.display.update_image(f"resources/images/{self.choose_random()}") 70 | else: 71 | self.display.update_image(f"resources/images/{file}") 72 | self.display.show() 73 | 74 | def choose_random(self): 75 | """ 76 | Chooses a random clothing image from the loaded clothes map. 77 | """ 78 | gender = random.choice(list(self.clothes_map.keys())) 79 | item_str = random.choice(list(self.clothes_map[gender].keys())) 80 | file = random.choice(list(self.clothes_map[gender][item_str])) 81 | return file 82 | 83 | def label_to_css(self, queue_in,) -> None: 84 | """ 85 | Processes labels from the queue and updates the display accordingly. 86 | """ 87 | start = time.time() 88 | while True: 89 | if not queue_in.empty(): 90 | label = queue_in.get() 91 | label = queue_in.get() 92 | now_time = time.time() 93 | if now_time - start < 2: 94 | continue 95 | start = time.time() 96 | matched_file = self.parse_lable(label) 97 | self.update_image(matched_file) 98 | 99 | def increment(self): 100 | """Increments the frame count (for potential future use).""" 101 | self.frame_count += 1 102 | 103 | def get_count(self): 104 | """ 105 | Retrieves the current frame count. 106 | """ 107 | return self.frame_count 108 | 109 | def user_app_callback(self, pad, info, user_data): 110 | """ 111 | This is the callback function that will be called when data is available 112 | from the pipeline. 113 | Processing time should be kept to a minimum in this function. 114 | If longer processing is needed, consider using a separate thread / process. 115 | """ 116 | # Get the GstBuffer from the probe info 117 | buffer = info.get_buffer() 118 | # Check if the buffer is valid 119 | if buffer is None: 120 | return Gst.PadProbeReturn.OK 121 | string_to_print = "" 122 | # Get the detections from the buffer 123 | roi = hailo.get_roi_from_buffer(buffer) 124 | detections = roi.get_objects_typed(hailo.HAILO_DETECTION) 125 | if len(detections) == 0: 126 | detections = [roi] # Use the ROI as the detection 127 | # Parse the detections 128 | for detection in detections: 129 | track = detection.get_objects_typed(hailo.HAILO_UNIQUE_ID) 130 | track_id = None 131 | label = None 132 | confidence = 0.0 133 | for track_id_obj in track: 134 | track_id = track_id_obj.get_id() 135 | if track_id is not None: 136 | string_to_print += f'Track ID: {track_id} ' 137 | classifications = detection.get_objects_typed(hailo.HAILO_CLASSIFICATION) 138 | if len(classifications) > 0: 139 | string_to_print += ' CLIP Classifications:' 140 | for classification in classifications: 141 | label = classification.get_label() 142 | user_data.labels_queue.put(label) 143 | confidence = classification.get_confidence() 144 | string_to_print += f'Label: {label} Confidence: {confidence:.2f} ' 145 | string_to_print += '\n' 146 | if isinstance(detection, hailo.HailoDetection): 147 | label = detection.get_label() 148 | bbox = detection.get_bbox() 149 | confidence = detection.get_confidence() 150 | string_to_print += f"Detection: {label} {confidence:.2f}\n" 151 | if string_to_print: 152 | print(string_to_print) 153 | return Gst.PadProbeReturn.OK 154 | 155 | class DisplayManager: 156 | """ 157 | Manages the display canvas for showing images and logos. 158 | """ 159 | def __init__(self, screen_width, screen_height): 160 | """ 161 | Initializes the display manager with screen dimensions. 162 | """ 163 | self.screen_width = screen_width 164 | self.screen_height = screen_height 165 | self.canvas = Image.new('RGB', (screen_width, screen_height), color='white') # white canvas 166 | self.image = None 167 | self.left_logo = None 168 | self.right_logo = None 169 | 170 | def update_image(self, image_path): 171 | """ 172 | Updates the canvas with a new image. 173 | """ 174 | new_image = Image.open(image_path) 175 | img_width, img_height = new_image.size 176 | left_margin = (self.screen_width - img_width) // 2 177 | top_margin = (self.screen_height - img_height) // 2 178 | self.canvas.paste(new_image, (left_margin, top_margin)) 179 | self.image = new_image 180 | 181 | def update_logos(self, left_logo_path=None, right_logo_path=None): 182 | """ 183 | Updates the canvas with logos at the bottom corners. 184 | """ 185 | draw = ImageDraw.Draw(self.canvas) # to draw white placeholders for logos 186 | 187 | # Handle the left logo 188 | if left_logo_path: 189 | left_logo = Image.open(left_logo_path).convert('RGBA') # Ensure RGBA mode for transparency 190 | left_logo = left_logo.resize((100, 100)) # Resize logo if needed 191 | self.left_logo = left_logo 192 | logo_x = 0 193 | logo_y = self.screen_height - left_logo.height 194 | # Use the alpha channel of the logo as the mask for pasting 195 | self.canvas.paste(left_logo, (logo_x, logo_y), left_logo.split()[3]) # Alpha channel as mask 196 | 197 | # Handle the right logo 198 | if right_logo_path: 199 | right_logo = Image.open(right_logo_path).convert('RGBA') # Ensure RGBA mode for transparency 200 | right_logo = right_logo.resize((100, 100)) # Resize logo if needed 201 | self.right_logo = right_logo 202 | logo_x = self.screen_width - right_logo.width 203 | logo_y = self.screen_height - right_logo.height 204 | # Use the alpha channel of the logo as the mask for pasting 205 | self.canvas.paste(right_logo, (logo_x, logo_y), right_logo.split()[3]) # Alpha channel as mask 206 | 207 | def show(self): 208 | """Displays the current canvas.""" 209 | self.canvas.show() 210 | 211 | def save(self, save_path): 212 | """ 213 | Saves the current canvas to a file. 214 | """ 215 | self.canvas.save(save_path) 216 | 217 | if __name__ == "__main__": 218 | user_data = user_app_callback_class() 219 | clip = ClipApp(user_data, user_app_callback) 220 | clip.run() 221 | 222 | -------------------------------------------------------------------------------- /community_projects/ad_genie/data_preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import ast 4 | import requests 5 | from collections import defaultdict 6 | import json 7 | import argparse 8 | 9 | 10 | def download_images(base_dir, retries=1, dest_dir="resources/images"): 11 | """ 12 | Downloads images specified in CSV files located within a directory structure. 13 | 14 | Args: 15 | base_dir (str): The base directory containing nested folders and CSV files. 16 | retries (int): Number of retry attempts for downloading an image in case of failure. 17 | dest_dir (str): Destination directory to save downloaded images. 18 | 19 | Outputs: 20 | - Downloads images to the specified destination directory. 21 | - Logs any failed downloads to a file named `failed_links.txt` in the base directory. 22 | - Saves metadata about successfully downloaded images in `resources/zara.json`. 23 | 24 | Notes: 25 | - CSV files should have an "image" column with image URLs (in a JSON-like format) and a "name" column for product names. 26 | - Folder names categorize data into keys like "Men" and "Women" in the JSON output. 27 | """ 28 | failed_links = [] 29 | data_dict = defaultdict(dict) 30 | data_dict['Men'] = defaultdict(dict) 31 | data_dict['Women'] = defaultdict(dict) 32 | for _, dirs, _ in os.walk(base_dir): 33 | for dir1 in dirs: 34 | if dir1 != dest_dir: 35 | for root, _, files in os.walk(os.path.join(base_dir, dir1)): 36 | for file in files: 37 | if file.endswith(".csv"): 38 | file_path = os.path.join(root, file) 39 | folder_name = os.path.basename(root) 40 | csv_filename = os.path.splitext(file)[0] 41 | 42 | with open(file_path, 'r', encoding='utf-8') as f: 43 | reader = csv.DictReader(f) 44 | for row_number, row in enumerate(reader): 45 | # Normalize and handle different column naming 46 | image_field = next( 47 | (field for field in row if 'image' in field.lower()), None 48 | ) 49 | product_field = next( 50 | (field for field in row if 'name' in field.lower()), None 51 | ) 52 | if not image_field: 53 | print(f"No valid image column found in {file}. Skipping...") 54 | continue 55 | 56 | try: 57 | image_data = ast.literal_eval(row[image_field]) 58 | for image_num, image_url in enumerate(image_data): 59 | image_link = list(image_url.keys())[0] 60 | # Generate the file name 61 | image_name = f"{folder_name}_{csv_filename}_{row_number}_{image_num}.jpg" 62 | # import ipdb; ipdb.set_trace() 63 | download_path = os.path.join(dest_dir, image_name) 64 | 65 | # Download the image with retry logic 66 | success = download_image_with_retry(image_link, download_path, retries) 67 | if not success: 68 | failed_links.append((image_link, download_path)) 69 | else: 70 | if row[product_field] not in data_dict[folder_name].keys(): 71 | data_dict[folder_name][row[product_field]] = [] 72 | data_dict[folder_name][row[product_field]].append(image_name) 73 | except (ValueError, KeyError) as e: 74 | print(f"Error processing row {row_number} in {file}: {e}") 75 | 76 | # Write all failed links to a file 77 | failed_links_file = os.path.join(base_dir, "failed_links.txt") 78 | with open(failed_links_file, 'w') as f: 79 | for link, path in failed_links: 80 | f.write(f"{link} -> {path}\n") 81 | print(f"Failed links logged to {failed_links_file}") 82 | with open('resources/zara.json', 'w') as file: 83 | json.dump(data_dict, file, indent=4) 84 | 85 | 86 | def download_image_with_retry(url, save_path, retries): 87 | """ 88 | Downloads an image from a URL with retry logic. 89 | """ 90 | try: 91 | response = requests.get(url, stream=True, timeout=10) 92 | response.raise_for_status() 93 | with open(save_path, 'wb') as f: 94 | for chunk in response.iter_content(1024): 95 | f.write(chunk) 96 | print(f"Downloaded: {save_path}") 97 | return True 98 | except requests.RequestException as e: 99 | print(f"Failed for {url}: {e}") 100 | return False 101 | 102 | 103 | def parse_arguments(): 104 | """ 105 | Parses command-line arguments for the script. 106 | 107 | Returns: 108 | argparse.Namespace: Parsed arguments containing: 109 | - data (str): Path to the base directory containing the dataset. 110 | """ 111 | parser = argparse.ArgumentParser(description="Data Preparation for Ad Genie") 112 | parser.add_argument("--data", "-d", type=str, default="resources/zara_dataset", help="Enter the path to the zara dataset") 113 | return parser.parse_args() 114 | 115 | # Parse arguments and run the download_images function 116 | args = parse_arguments() 117 | base_directory = args.data 118 | download_images(base_directory) 119 | -------------------------------------------------------------------------------- /community_projects/ad_genie/lables_preparation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | 5 | def parse_arguments(): 6 | """ 7 | Parse command-line arguments for the script. 8 | 9 | Returns: 10 | argparse.Namespace: Parsed arguments containing: 11 | - json_path (str): Path to the input JSON file. 12 | - output (str): Path to the output JSON file. 13 | """ 14 | parser = argparse.ArgumentParser(description="Lables Preparation for Ad Genie") 15 | parser.add_argument("--json-path", "-p", type=str, default="resources/zara.json", help="Enter the path to the zara json file") 16 | parser.add_argument("--output", "-o", type=str, default="resources/lables.json", help="Enter the path for the output labels json") 17 | return parser.parse_args() 18 | 19 | args = parse_arguments() 20 | 21 | # Open the input JSON file and load its content into the `data` variable 22 | with open(args.json_path, 'r') as file: 23 | data = json.load(file) 24 | 25 | # Initialize the labels dictionary 26 | labels = {} 27 | labels['negative'] = [] 28 | labels['positive'] = [] 29 | 30 | # Process the input data to generate positive labels 31 | for gender in data: 32 | for label in data[gender]: 33 | labels['positive'].append(f'a {os.path.basename(gender)} wearing a {label}') 34 | 35 | # Write the generated labels to the specified output JSON file 36 | with open(args.output, 'w') as file: 37 | json.dump(labels, file, indent=4) -------------------------------------------------------------------------------- /community_projects/ad_genie/resources/ad_genie.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/community_projects/ad_genie/resources/ad_genie.gif -------------------------------------------------------------------------------- /community_projects/ad_genie/resources/structure.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/community_projects/ad_genie/resources/structure.jpeg -------------------------------------------------------------------------------- /community_projects/baiby_monitor/README.md: -------------------------------------------------------------------------------- 1 | ![](../../resources/Hackathon-banner-2024.png) 2 | 3 | # BAIby Monitor 4 | 5 | The BAIby Monitor project is an open-source initiative aimed at developing a smart baby monitoring system that utilizes machine learning to detect a baby's cries and other activities, providing real-time notifications to parents or caregivers. 6 | 7 | ## Demo Video 8 | 9 | For a visual demonstration of the BAIby Monitor project, watch this video: 10 | 11 | [![BAIby Monitor Demo](https://img.youtube.com/vi/sXgL5g_A-u0/0.jpg)](https://youtu.be/sXgL5g_A-u0) 12 | 13 | 14 | ## Features 15 | 16 | - **Cry Detection**: Employs machine learning algorithms to distinguish a baby's cries from other sounds, ensuring accurate alerts. 17 | - **Real-Time Notifications**: Sends immediate alerts to connected devices when the baby cries or unusual activity is detected. 18 | - **Activity Monitoring**: Tracks the baby's movements and sounds, offering insights into sleep patterns and behavior. 19 | - **User-Friendly Interface**: Provides an intuitive dashboard for monitoring and configuring settings. 20 | 21 | ## Installation 22 | 23 | ## Setup Instructions 24 | 1. Follow README setup instructions of [CLIP application example](../../README.md) 25 | 26 | 2. **Navigate to the Project Directory**: 27 | 28 | ```bash 29 | cd hailo-CLIP/community_projects/baiby_monitor 30 | ``` 31 | 32 | 3. **Install Dependencies**: 33 | 34 | Install the required packages: 35 | 36 | ```bash 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | 4. **Set Up the Environment**: 41 | 42 | Create a `.env` file in the project directory and add necessary configuration variables as specified in `.env.example`. 43 | 44 | ** Telegram Bot Activation **: 45 | a. Find @bAIbyMonbot in the 'Telegram' App. 46 | b. Press the 'Start' Button. 47 | c. You are ready to receive messages to your Telegram. 48 | 49 | 6. **Run the Application**: 50 | 51 | ```bash 52 | python src/baiby_telegram.py 53 | ``` 54 | 55 | ## Usage 56 | 57 | - **Access the Dashboard**: Once the application is running, navigate to `http://localhost:5001` in your web browser to access the monitoring dashboard. 58 | - **Configure Settings**: Use the dashboard to adjust sensitivity levels, notification preferences, and other settings. 59 | - **Monitor Alerts**: Receive real-time notifications on the dashboard and connected devices when the system detects the baby crying or unusual activity. 60 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/download_resources.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the resource directory 4 | RESOURCE_DIR="./resources" 5 | mkdir -p "$RESOURCE_DIR" 6 | 7 | # Define download function with file existence check and retries 8 | download_model() { 9 | file_name=$(basename "$2") 10 | resource_dir="$1" 11 | 12 | if [ ! -f "$resource_dir/$file_name" ]; then 13 | echo "Downloading $file_name..." 14 | wget --tries=3 --retry-connrefused --quiet --show-progress "$2" -P "$resource_dir" || { 15 | echo "Failed to download $file_name after multiple attempts." 16 | # Instead of exit 1, log and continue 17 | echo "Download failed for $file_name. Continuing..." 18 | } 19 | else 20 | echo "File $file_name already exists. Skipping download." 21 | fi 22 | } 23 | 24 | # Define all URLs in arrays 25 | RESOURCES=( 26 | "https://hailo-csdata.s3.eu-west-2.amazonaws.com/resources/hackathon/baiby/brahms-lullaby.mp3" 27 | ) 28 | 29 | # Run downloads for each array 30 | for url in "${RESOURCES[@]}"; do 31 | download_model "$RESOURCE_DIR" "$url" & 32 | done 33 | 34 | 35 | # Wait for all background downloads to complete 36 | wait 37 | 38 | echo "All downloads completed successfully!" -------------------------------------------------------------------------------- /community_projects/baiby_monitor/requirements.txt: -------------------------------------------------------------------------------- 1 | playsound 2 | flask 3 | requests 4 | config 5 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/send_message/example_usage.py: -------------------------------------------------------------------------------- 1 | from telegram_messenger import AlertHandler, TelegramBot 2 | 3 | # Still does not work correctly. Does not send the message to the bot. 4 | 5 | # TODO: Move to VENV 6 | TOKEN = "7949633686:AAHe8DWw9vpdkGPbDOyJfGZ82bXhe1PCChI" 7 | # Initialize the bot 8 | telegram_bot = TelegramBot(TOKEN) 9 | 10 | sender = AlertHandler(telegram_bot) 11 | sender.receive_alert("This is an alert message!") -------------------------------------------------------------------------------- /community_projects/baiby_monitor/send_message/telegram_messenger.py: -------------------------------------------------------------------------------- 1 | # pip install python-telegram-bot 2 | 3 | 4 | import logging 5 | from telegram import Update 6 | from telegram.ext import Application, CommandHandler, MessageHandler, filters, CallbackContext 7 | 8 | # TODO: Move to VENV 9 | TOKEN = "7949633686:AAHe8DWw9vpdkGPbDOyJfGZ82bXhe1PCChI" 10 | WELCOME_MESSAGE = "Welcome to the bAIby_monitor_bot!" 11 | 12 | # Set up logging 13 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | class TelegramBot: 17 | def __init__(self, token: str): 18 | self.token = token 19 | self.application = Application.builder().token(self.token).build() 20 | self.last_chat_id = None 21 | 22 | # Register handlers 23 | self.application.add_handler(CommandHandler("start", self.start)) 24 | self.application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message)) 25 | self.application.add_handler(CommandHandler("send_last", self.send_message_to_last_user)) 26 | 27 | async def handle_message(self, update: Update, context: CallbackContext) -> None: 28 | chat_id = update.message.chat_id 29 | user_first_name = update.message.chat.first_name 30 | 31 | # Check if the chat_id is already stored 32 | with open("chat_ids.txt", "r") as file: 33 | chat_ids = file.read().splitlines() 34 | 35 | if str(chat_id) not in chat_ids: 36 | # Save the new chat_id 37 | with open("chat_ids.txt", "a") as file: 38 | file.write(f"{chat_id}\n") 39 | 40 | # Update the last chat ID 41 | self.last_chat_id = "983475502" 42 | 43 | # Send a welcome message to the new user 44 | await context.bot.send_message(chat_id=chat_id, text=f"Hello {user_first_name}, {WELCOME_MESSAGE}") 45 | 46 | async def start(self, update: Update, context: CallbackContext) -> None: 47 | await update.message.reply_text(WELCOME_MESSAGE) 48 | 49 | async def send_message_to_last_user(self, context: CallbackContext) -> None: 50 | if self.last_chat_id: 51 | await context.bot.send_message(chat_id=self.last_chat_id, text="This is a message to the last added user.") 52 | 53 | def run(self) -> None: 54 | self.application.run_polling() 55 | 56 | def send_alert(self, message: str) -> None: 57 | if self.last_chat_id: 58 | self.application.bot.send_message(chat_id=self.last_chat_id, text=message) 59 | 60 | class AlertHandler: 61 | def __init__(self, bot: TelegramBot): 62 | self.bot = bot 63 | 64 | def receive_alert(self, message: str) -> None: 65 | self.bot.send_alert(message) 66 | 67 | if __name__ == "__main__": 68 | # Create the chat_ids.txt file if it doesn't exist 69 | open("chat_ids.txt", "a").close() 70 | 71 | # Initialize the bot 72 | telegram_bot = TelegramBot(TOKEN) 73 | 74 | # Initialize the alert handler with the bot 75 | alert_handler = AlertHandler(telegram_bot) 76 | 77 | # Run the bot 78 | telegram_bot.run() 79 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/.gitignore: -------------------------------------------------------------------------------- 1 | resources/ -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/community_projects/baiby_monitor/src/__init__.py -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/baiby_telegram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from flask import Flask, request, jsonify 4 | import argparse 5 | import requests 6 | import os 7 | import json 8 | import configparser 9 | 10 | app = Flask(__name__) 11 | 12 | INI_PATH = os.path.join(os.path.dirname(__file__), 'telegram.ini') 13 | 14 | def read_ids_from_ini(file_path, section, option): 15 | """ 16 | Reads a list of IDs from a specified section and option in an INI file. 17 | 18 | Args: 19 | file_path (str): Path to the INI file. 20 | section (str): Section in the INI file. 21 | option (str): Option under the section. 22 | 23 | Returns: 24 | list: List of IDs as strings. 25 | """ 26 | config = configparser.ConfigParser() 27 | config.read(file_path) 28 | 29 | if config.has_section(section) and config.has_option(section, option) : 30 | if section == 'IDs': 31 | ids = config.get(section, option) 32 | return [id.strip() for id in ids.split(",")] 33 | else: 34 | return config.get(section, option) 35 | else: 36 | raise ValueError(f"Section '{section}' or option '{option}' not found in the INI file.") 37 | 38 | 39 | def get_ids_from_URL(): 40 | TOKEN = read_ids_from_ini(INI_PATH, "BOT", "token") 41 | url = f"https://api.telegram.org/bot{TOKEN}/getUpdates" 42 | resp = requests.get(url).json() 43 | 44 | print(f'Listing all Users in Telegram bot bot{TOKEN}') 45 | for i, res in enumerate(resp['result']): 46 | print('\t{i}: Name = {name}, ID = {id}'.format(i=i, 47 | name = res['message']['chat']['first_name'], 48 | id = res['message']['chat']['id'])) 49 | 50 | 51 | def send_telegram_message(message: str, debug: bool = False) -> tuple[bool, str]: 52 | """Send a Telegram message. 53 | 54 | Args: 55 | message (str): The message to be sent. 56 | debug (bool, optional): Whether to enable debug mode. Defaults to False. 57 | 58 | Returns: 59 | tuple[bool, str]: A tuple where the first element indicates the success status 60 | (True for success, False for failure), and the second element is the success message. 61 | """ 62 | 63 | chat_ids = read_ids_from_ini(INI_PATH, "IDs", "list") 64 | BOT_TOKEN = read_ids_from_ini(INI_PATH, "BOT", "token") 65 | 66 | TELEGRAM_URL = f"https://api.telegram.org/bot{BOT_TOKEN}/sendMessage" 67 | 68 | for c_id in chat_ids: 69 | payload = { 70 | "chat_id": c_id, 71 | "text": message, 72 | "parse_mode": "HTML" 73 | } 74 | 75 | try: 76 | if debug: 77 | print(f"Sending request to: {TELEGRAM_URL}") 78 | print(f"Payload: {json.dumps(payload, indent=2)}") 79 | 80 | response = requests.post(TELEGRAM_URL, json=payload) 81 | if debug: 82 | print(f"Response status: {response.status_code}") 83 | print(f"Response body: {response.text}") 84 | 85 | response.raise_for_status() 86 | if debug: 87 | print("Message sent successfully") 88 | 89 | except requests.exceptions.RequestException as e: 90 | return False, f"Failed to send message: {str(e)}\nResponse: {response.text if 'response' in locals() else 'No response'}" 91 | 92 | return True, "All Messages sent successfully" 93 | 94 | 95 | @app.route('/notify', methods=['POST']) 96 | def notify(): 97 | 98 | data = request.get_json() 99 | if not data or 'message' not in data: 100 | return jsonify({'error': 'No message provided'}), 400 101 | 102 | success, message = send_telegram_message(data['message'], debug=True) 103 | if success: 104 | return jsonify({'status': 'success', 'message': message}) 105 | return jsonify({'status': 'error', 'message': message}), 500 106 | 107 | def set_args_parser(parser): 108 | parser.description = "A python script for sending Telegram messages to bots" 109 | parser.add_argument('-c', '--chat-id', help='Optional flag to get a list of chat IDs.', action='store_true') 110 | 111 | def main(): 112 | app.run(host='0.0.0.0', port=5001, debug=True) 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | set_args_parser(parser) 117 | args = parser.parse_args() 118 | if args.chat_id: 119 | get_ids_from_URL() 120 | else: 121 | main() 122 | 123 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/clip_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | batch_size = 8 4 | video_sink = "xvimagesink" 5 | # Note: only 16:9 resolutions are supported 6 | # RES_X = 1920 7 | # RES_Y = 1080 8 | RES_X = 1280 9 | RES_Y = 720 10 | 11 | # Check Hailo Device Type from the environment variable DEVICE_ARCHITECTURE 12 | # If the environment variable is not set, default to HAILO8L 13 | device_architecture = os.getenv("DEVICE_ARCHITECTURE") 14 | if device_architecture is None or device_architecture == "HAILO8L": 15 | device_architecture = "HAILO8L" 16 | # HEF files for H8L 17 | YOLO5_HEF_NAME = "yolov5s_personface_h8l_pi.hef" 18 | CLIP_HEF_NAME = "clip_resnet_50x4_h8l.hef" 19 | else: 20 | device_architecture = "HAILO8" 21 | # HEF files for H8 22 | YOLO5_HEF_NAME = "yolov5s_personface.hef" 23 | CLIP_HEF_NAME = "clip_resnet_50x4.hef" 24 | 25 | from hailo_apps_infra.gstreamer_helper_pipelines import ( 26 | SOURCE_PIPELINE, 27 | QUEUE, 28 | INFERENCE_PIPELINE, 29 | INFERENCE_PIPELINE_WRAPPER, 30 | TRACKER_PIPELINE, 31 | DISPLAY_PIPELINE, 32 | CROPPER_PIPELINE 33 | ) 34 | 35 | 36 | ################################################################### 37 | # NEW helper function to add in your gstreamer_helper_pipelines.py 38 | ################################################################### 39 | 40 | 41 | def get_pipeline(self): 42 | # Initialize directories and paths 43 | RESOURCES_DIR = os.path.join(self.current_path, "resources") 44 | POSTPROCESS_DIR = self.tappas_postprocess_dir 45 | hailopython_path = os.path.join(self.current_path, "clip_app/clip_hailopython.py") 46 | # personface 47 | YOLO5_POSTPROCESS_SO = os.path.join(POSTPROCESS_DIR, "libyolo_post.so") 48 | YOLO5_NETWORK_NAME = "yolov5_personface_letterbox" 49 | YOLO5_HEF_PATH = os.path.join(RESOURCES_DIR, YOLO5_HEF_NAME) 50 | YOLO5_CONFIG_PATH = os.path.join(RESOURCES_DIR, "configs/yolov5_personface.json") 51 | DETECTION_POST_PIPE = f'hailofilter so-path={YOLO5_POSTPROCESS_SO} qos=false function_name={YOLO5_NETWORK_NAME} config-path={YOLO5_CONFIG_PATH} ' 52 | hef_path = YOLO5_HEF_PATH 53 | 54 | # CLIP 55 | clip_hef_path = os.path.join(RESOURCES_DIR, CLIP_HEF_NAME) 56 | clip_postprocess_so = os.path.join(RESOURCES_DIR, "libclip_post.so") 57 | DEFAULT_CROP_SO = os.path.join(RESOURCES_DIR, "libclip_croppers.so") 58 | clip_matcher_so = os.path.join(RESOURCES_DIR, "libclip_matcher.so") 59 | clip_matcher_config = os.path.join(self.current_path, "embeddings.json") 60 | 61 | source_pipeline = SOURCE_PIPELINE( 62 | video_source=self.input_uri, 63 | video_width=RES_X, 64 | video_height=RES_Y, 65 | video_format='RGB', 66 | name='source' 67 | ) 68 | 69 | detection_pipeline = INFERENCE_PIPELINE( 70 | hef_path=hef_path, 71 | post_process_so=YOLO5_POSTPROCESS_SO, 72 | batch_size=batch_size, 73 | config_json=YOLO5_CONFIG_PATH, 74 | post_function_name=YOLO5_NETWORK_NAME, 75 | scheduler_priority=31, 76 | scheduler_timeout_ms=100, 77 | name='detection_inference' 78 | ) 79 | 80 | if self.detector == "none": 81 | detection_pipeline_wrapper = "" 82 | else: 83 | detection_pipeline_wrapper = INFERENCE_PIPELINE_WRAPPER(detection_pipeline) 84 | 85 | 86 | clip_pipeline = INFERENCE_PIPELINE( 87 | hef_path=clip_hef_path, 88 | post_process_so=clip_postprocess_so, 89 | batch_size=batch_size, 90 | name='clip_inference', 91 | scheduler_timeout_ms=1000, 92 | scheduler_priority=16, 93 | ) 94 | 95 | if self.detector == "person": 96 | class_id = 1 97 | crop_function_name = "person_cropper" 98 | elif self.detector == "face": 99 | class_id = 2 100 | crop_function_name = "face_cropper" 101 | else: # fast_sam 102 | class_id = 0 103 | crop_function_name = "object_cropper" 104 | 105 | tracker_pipeline = TRACKER_PIPELINE(class_id=class_id, keep_past_metadata=True) 106 | 107 | 108 | 109 | # Clip pipeline with cropper integration 110 | clip_cropper_pipeline = CROPPER_PIPELINE( 111 | inner_pipeline=clip_pipeline, 112 | so_path=DEFAULT_CROP_SO, 113 | function_name=crop_function_name, 114 | name='clip_cropper' 115 | ) 116 | 117 | # Clip pipeline with muxer integration (no cropper) 118 | clip_pipeline_wrapper = " ! ".join([ 119 | "tee name=clip_t hailomuxer name=clip_hmux clip_t.", 120 | str(QUEUE(name="clip_bypass_q", max_size_buffers=20)), 121 | "clip_hmux.sink_0 clip_t.", 122 | str(QUEUE(name="clip_muxer_queue")), 123 | "videoscale n-threads=4 qos=false", 124 | clip_pipeline, 125 | "clip_hmux.sink_1 clip_hmux.", 126 | str(QUEUE(name="clip_hmux_queue")) 127 | ]) 128 | 129 | # TBD aggregator does not support ROI classification 130 | # clip_pipeline_wrapper = INFERENCE_PIPELINE_WRAPPER(clip_pipeline, name='clip') 131 | 132 | display_pipeline = DISPLAY_PIPELINE(video_sink=video_sink, sync=self.sync_req, show_fps=self.show_fps) 133 | 134 | # Text to image matcher 135 | CLIP_PYTHON_MATCHER = f'hailopython name=pyproc module={hailopython_path} qos=false ' 136 | CLIP_CPP_MATCHER = f'hailofilter so-path={clip_matcher_so} qos=false config-path={clip_matcher_config} ' 137 | 138 | clip_postprocess_pipeline = " ! ".join([ 139 | CLIP_PYTHON_MATCHER, 140 | str(QUEUE(name="clip_postprocess_queue")), 141 | "identity name=identity_callback" 142 | ]) 143 | 144 | # PIPELINE 145 | if self.detector == "none": 146 | PIPELINE = " ! ".join([ 147 | source_pipeline, 148 | clip_pipeline_wrapper, 149 | clip_postprocess_pipeline, 150 | display_pipeline 151 | ]) 152 | else: 153 | PIPELINE = " ! ".join([ 154 | source_pipeline, 155 | detection_pipeline_wrapper, 156 | tracker_pipeline, 157 | clip_cropper_pipeline, 158 | clip_postprocess_pipeline, 159 | display_pipeline 160 | ]) 161 | 162 | return PIPELINE 163 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/lullaby_callback.py: -------------------------------------------------------------------------------- 1 | import gi 2 | import hailo 3 | import os 4 | 5 | gi.require_version('Gst', '1.0') 6 | from gi.repository import Gst 7 | from clip_app.text_image_matcher import text_image_matcher 8 | from community_projects.baiby_monitor.src.match_handler import MatchHandler 9 | 10 | 11 | 12 | current_path = os.path.dirname(os.path.realpath(__file__)) 13 | embedding_path = os.path.join(current_path, "..", "embeddings") 14 | json_files = [os.path.join(embedding_path, f) for f in os.listdir(embedding_path) if os.path.isfile(os.path.join(embedding_path, f))] 15 | len_json_files = len(json_files) 16 | 17 | 18 | match_handler = MatchHandler() 19 | 20 | 21 | class app_callback_class: 22 | def __init__(self): 23 | self.frame_count = 0 24 | self.use_frame = False 25 | self.running = True 26 | self.text_image_matcher = text_image_matcher 27 | 28 | def increment(self): 29 | self.frame_count += 1 30 | 31 | def get_count(self): 32 | return self.frame_count 33 | 34 | 35 | def app_callback(self, pad, info, user_data): 36 | """ 37 | This is the callback function that will be called when data is available 38 | from the pipeline. 39 | Processing time should be kept to a minimum in this function. 40 | If longer processing is needed, consider using a separate thread / process. 41 | """ 42 | # Get the GstBuffer from the probe info 43 | buffer = info.get_buffer() 44 | # Check if the buffer is valid 45 | if buffer is None: 46 | return Gst.PadProbeReturn.OK 47 | string_to_print = "" 48 | # Get the detections from the buffer 49 | roi = hailo.get_roi_from_buffer(buffer) 50 | detections = roi.get_objects_typed(hailo.HAILO_DETECTION) 51 | if len(detections) == 0: 52 | detections = [roi] # Use the ROI as the detection 53 | user_data.increment() 54 | # Switch embeddings every 10 frames 55 | if user_data.get_count() % 10 == 0: 56 | # Load embeddings from the next json file 57 | json_file = json_files[user_data.get_count() // 10 % len_json_files] 58 | # TODO: add logging or remove print 59 | print(f"Loading embeddings from {json_file}") 60 | user_data.text_image_matcher.load_embeddings(json_file) 61 | # Parse the detections 62 | for detection in detections: 63 | track = detection.get_objects_typed(hailo.HAILO_UNIQUE_ID) 64 | track_id = None 65 | label = None 66 | confidence = 0.0 67 | for track_id_obj in track: 68 | track_id = track_id_obj.get_id() 69 | if track_id is not None: 70 | string_to_print += f'Track ID: {track_id} ' 71 | classifications = detection.get_objects_typed(hailo.HAILO_CLASSIFICATION) 72 | if len(classifications) > 0: 73 | string_to_print += ' CLIP Classifications:' 74 | for classification in classifications: 75 | label = classification.get_label() 76 | match_handler.handle(label) 77 | confidence = classification.get_confidence() 78 | string_to_print += f'Label: {label} Confidence: {confidence:.2f} ' 79 | string_to_print += '\n' 80 | if isinstance(detection, hailo.HailoDetection): 81 | label = detection.get_label() 82 | bbox = detection.get_bbox() 83 | confidence = detection.get_confidence() 84 | string_to_print += f"Detection: {label} {confidence:.2f}\n" 85 | # if string_to_print: 86 | # print(string_to_print) 87 | return Gst.PadProbeReturn.OK 88 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/match_handler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Callable 3 | from community_projects.baiby_monitor.src.play_lullaby import play_mp3 4 | from community_projects.baiby_monitor.src.baiby_telegram import send_telegram_message 5 | 6 | 7 | @dataclass 8 | class DetectionClass: 9 | function: Callable 10 | counter: int = field(default=0, init=False) 11 | argument: str = field(default="") 12 | is_activated: bool = field(default=False, init=False) 13 | 14 | def __post_init__(self): 15 | if self.function is None: 16 | raise ValueError("function must be provided") 17 | 18 | 19 | class MatchHandler: 20 | _instance = None 21 | 22 | BEHAVIOR_DICT = { 23 | # Cry detection 24 | "Calm baby": None, 25 | "Crying baby": DetectionClass(function=send_telegram_message, argument="Baby is crying"), 26 | 27 | # Sleep detection 28 | "awaken baby": DetectionClass(function=play_mp3), 29 | "sleeping baby": None, 30 | } 31 | 32 | def __new__(cls): 33 | if cls._instance is None: 34 | cls._instance = super(MatchHandler, cls).__new__(cls) 35 | return cls._instance 36 | 37 | def handle(self, label: str) -> None: 38 | detection_class = self.BEHAVIOR_DICT.get(label) 39 | if detection_class: 40 | detection_class.counter += 1 41 | if detection_class.counter >= 110 and detection_class.is_activated is False: 42 | print(f"\nDetected {label}\n") 43 | detection_class.function(detection_class.argument) 44 | detection_class.is_activated = True 45 | detection_class.counter = 0 46 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/play_lullaby.py: -------------------------------------------------------------------------------- 1 | # When called will activate a lullaby / song / melody 2 | 3 | import os 4 | 5 | from playsound import playsound 6 | 7 | 8 | CURRENT_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) 9 | DEFAULT_MP3_FILE = os.path.join(CURRENT_DIRECTORY, "..", "resources", "brahms-lullaby.mp3") 10 | 11 | def play_mp3(mp3_file_path: str = DEFAULT_MP3_FILE): 12 | try: 13 | # Play the MP3 file 14 | print("Playing the MP3 file...") 15 | # playssound(mp3_file_path) 16 | except Exception as e: 17 | print(f"An error occurred while playing the file: {e}") 18 | 19 | if __name__ == "__main__": 20 | play_mp3() 21 | -------------------------------------------------------------------------------- /community_projects/baiby_monitor/src/telegram.ini.example: -------------------------------------------------------------------------------- 1 | [BOT] 2 | ;if not connecting to existing Bot, go to 3 | ; https://web.telegram.org/k/#@BotFather 4 | ; type /start 5 | ; and then /setname 6 | ; and /newbot 7 | ; you will get your bot token 8 | token = abcd 9 | 10 | [IDs] 11 | ;list of chat_ID, can be obtained by joining the telegram bot @bAIbyMonbot 12 | ;from the URL https://web.telegram.org/k/#@bAIbyMonbot 13 | ;and pressing /start 14 | ;after that you can run 15 | ;python telegra.py --get-chat-ids and get a list of all IDs, add your chat id to 16 | ;this list 17 | list = 1234 18 | 19 | -------------------------------------------------------------------------------- /community_projects/community_projects.md: -------------------------------------------------------------------------------- 1 | # Community Projects 2 | 3 | **Welcome to our Community Projects**! 🎉 4 | Here, you’ll find all our awesome community examples, based on **Hailo** and **Raspberry Pi** and we’d love for you to contribute too! 🚀 5 | Check out how you can get involved and share your own creations here: 6 | [How to Add a Community Project](community_projects.md#how-to-add-a-community-project) 7 | 8 | ## Ad Genie 9 | A system that personalizes ads by matching outfits to individual styles in real time. 10 | 11 | [![Watch the demo on YouTube](https://img.youtube.com/vi/0_v2V7lV514/0.jpg)](https://youtu.be/0_v2V7lV514) 12 | 13 | For more information see [Ad Genie Example Documentation.](ad_genie/README.md) 14 | 15 | 16 | ## BAIby Monitor 17 | 18 | A smart baby monitor detecting cries and activity with real-time alerts. 19 | 20 | [![VIDEO](https://img.youtube.com/vi/sXgL5g_A-u0/0.jpg)](https://youtu.be/sXgL5g_A-u0) 21 | 22 | For more information see [BAIby Monitor Example Documentation.](baiby_monitor/README.md) 23 | 24 | 25 | ## More Projects 26 | 27 | You can explore more community projects in our hailo-rpi5-examples repository - 28 | [hailo-rpi5-examples community projects](https://github.com/hailo-ai/hailo-rpi5-examples/blob/main/community_projects/community_projects.md) 29 | 30 | # How to Add a Community Project 31 | 32 | This guide will walk you through creating and structuring your community project in the repository. 33 | 34 | --- 35 | 36 | ## 1. Run the Clip Appliacation 37 | Ensure your environment is set up properly by following the repository instructions and successfully running the clip application. This step verifies your setup. 38 | 39 | --- 40 | 41 | ## 2. Review the Clip Application Documentation 42 | 43 | Familiarize yourself with the Clip Application by reading the documentation [here](../README.md). 44 | 45 | --- 46 | 47 | ## 3. Create Your Project Directory 48 | 49 | Navigate to the `community_projects` folder and create a directory for your project: 50 | 51 | ```bash 52 | cd community_projects 53 | mkdir 54 | cd 55 | ``` 56 | 57 | --- 58 | 59 | ## 4. Copy the Template Example 60 | 61 | Copy the contents of the template example to your project directory: 62 | 63 | ```bash 64 | cp ../template_example/* . 65 | ``` 66 | 67 | --- 68 | 69 | # Using Clip Application as a Template 70 | 71 | The following sections explain how to build your project based on the existing Clip Application 72 | 73 | --- 74 | 75 | ## Modify the Callback Class 76 | 77 | 1. Update the callback class, which inherits from `app_callback_class`. 78 | 2. Add the necessary members and methods to customize your application's behavior. 79 | 80 | --- 81 | 82 | ## Define the Callback Function 83 | 84 | 1. Define a callback function. 85 | 2. This function will handle data from the pipeline and can include your custom logic. 86 | 87 | --- 88 | 89 | ## Modify the Main Function 90 | 91 | 1. Update the `main` function to initialize and run your application. 92 | 93 | This structure ensures your project can run as a standalone script. 94 | 95 | --- 96 | 97 | ## Adding New Networks and Post-Processes 98 | 99 | We are working on a "Community Model Zoo" to allow users to share models and post-processes on our servers. Meanwhile, follow these steps: 100 | 101 | 1. Save your HEF file on a file-sharing service like Google Drive. 102 | 2. Provide a `download_resources.sh` script to automate the download. 103 | 3. If you develop a new post-process: 104 | - Add the necessary code and a compilation script. 105 | - Refer to the [Hailo Apps Infra repository](https://github.com/hailo-ai/hailo-apps-infra) for guidance on creating and compiling new post-processes. 106 | 107 | --- 108 | 109 | ## Update Project Files 110 | See the template example [README.md](./template_example/README.md) for guidance on updating your project files. 111 | 112 | --- 113 | 114 | # Pull Requests (PRs) 115 | 116 | To contribute your project or improvements: 117 | 118 | 1. Submit a PR to the **`dev` branch** of the repository. 119 | 2. Your code should remain within your `community_projects/` directory. **PRs modifying core code will be rejected.** 120 | 3. If you must alter core code for your project: 121 | - Copy the relevant code into your directory and modify it as needed. 122 | - Alternatively, suggest manual edits in your instructions. 123 | - **Be aware**: This approach may cause compatibility issues in future releases due to lack of backward compatibility. Breaking your code. 124 | 125 | Suggestions for improving the core codebase are welcome. However, they must be generic, well-tested, and adaptable to multiple platforms. 126 | If you identify missing functionality in our framework, you are welcome to implement it in your project directory. Exceptional features or common functions might be integrated into our core codebase. 127 | **Important:** Code added to the core must meet these requirements: 128 | - Thoroughly tested and verified. 129 | - Generic and adaptable to multiple platforms. 130 | 131 | By adhering to these guidelines, you help maintain the repository's stability and enable better integration of your work. 132 | 133 | --- 134 | 135 | # Important Guidelines 136 | 137 | ### **Do Not Add Binary Files** 138 | - Avoid adding non-code files (e.g., images, HEFs, videos) directly to the repository. 139 | - Use a `download_resources.sh` script to fetch these files from external sources like Google Drive. 140 | - For Model Zoo HEFs, download them directly from Hailo's servers. 141 | ## Code of Conduct 142 | 143 | We are committed to fostering a welcoming and inclusive community. Please ensure that your contributions adhere to the following guidelines: 144 | 145 | - Use clean and non-offensive language. 146 | - Be respectful and considerate of others. 147 | - Avoid any form of harassment or discrimination. 148 | 149 | --- -------------------------------------------------------------------------------- /community_projects/template_example/README.md: -------------------------------------------------------------------------------- 1 | # Template Example Project 2 | Create a README.md file in your project directory. This file should include: 3 | 4 | ## Overview 5 | A summary of your project. 6 | 7 | ## Video 8 | Add a video link from YouTube with a brief description of what the video demonstrates - for example: 9 | [![Watch the demo on YouTube](https://img.youtube.com/vi/XXizBHtCLew/0.jpg)](https://youtu.be/XXizBHtCLew) 10 | 11 | ## Versions 12 | Specify the versions of Hailo examples you verified the software with. 13 | 14 | ## Setup Instructions 15 | How to install dependencies and run the project - for example: 16 | 17 | Run the following commands: 18 | ```bash 19 | pip install -r requirements.txt 20 | ./download_resources.sh 21 | ``` 22 | 23 | ## Usage 24 | Examples of how to run the script - for example: 25 | 26 | Basic usage: 27 | ```bash 28 | python template_example.py 29 | ``` 30 | 31 | A good README ensures that others can understand and use your project. 32 | If needed, you can add additional sections or links to external documentation. 33 | For example Thingiverse, Instructables, or other resources that can help users understand your project. -------------------------------------------------------------------------------- /community_projects/template_example/downdload_resources.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Instructions: 4 | # 1. This script downloads the specified file from the Hailo Model Zoo. 5 | # 2. The file will be saved into the 'resources' directory. 6 | # 3. Ensure 'wget' is installed on your system. 7 | 8 | # URL of the file to download 9 | FILE_URL="https://hailo-model-zoo.s3.eu-west-2.amazonaws.com/ModelZoo/Compiled/v2.13.0/hailo8/yolov5m_wo_spp.hef" 10 | 11 | # Create resources directory if it doesn't exist 12 | mkdir -p ./resources 13 | 14 | # Function to download the file 15 | download_file() { 16 | URL=$1 17 | FILENAME=$(basename "$URL") 18 | OUTPUT_FILE="./resources/$FILENAME" 19 | 20 | echo "Downloading: $FILENAME" 21 | 22 | # Download the file 23 | wget --quiet --show-progress --no-clobber --directory-prefix=./resources "$URL" || { 24 | echo "Error downloading: $URL" 25 | return 1 26 | } 27 | 28 | echo "Successfully downloaded: $FILENAME" 29 | } 30 | 31 | # Main logic 32 | echo "Starting download..." 33 | 34 | # Download the specified file 35 | download_file "$FILE_URL" 36 | 37 | echo "Download completed." -------------------------------------------------------------------------------- /community_projects/template_example/requirments.txt: -------------------------------------------------------------------------------- 1 | gdown -------------------------------------------------------------------------------- /community_projects/template_example/template_example.py: -------------------------------------------------------------------------------- 1 | import gi 2 | gi.require_version('Gst', '1.0') 3 | from gi.repository import Gst 4 | import hailo 5 | from clip_app.clip_app_pipeline import ClipApp 6 | from hailo_apps_infra.gstreamer_app import app_callback_class 7 | 8 | 9 | class user_app_callback_class(app_callback_class): 10 | def __init__(self): 11 | self.frame_count = 0 12 | self.use_frame = False 13 | self.running = True 14 | 15 | def increment(self): 16 | self.frame_count += 1 17 | 18 | def get_count(self): 19 | return self.frame_count 20 | 21 | 22 | def user_app_callback(self, pad, info, user_data): 23 | """ 24 | This is the callback function that will be called when data is available 25 | from the pipeline. 26 | Processing time should be kept to a minimum in this function. 27 | If longer processing is needed, consider using a separate thread / process. 28 | """ 29 | # Get the GstBuffer from the probe info 30 | buffer = info.get_buffer() 31 | # Check if the buffer is valid 32 | if buffer is None: 33 | return Gst.PadProbeReturn.OK 34 | string_to_print = "" 35 | # Get the detections from the buffer 36 | roi = hailo.get_roi_from_buffer(buffer) 37 | detections = roi.get_objects_typed(hailo.HAILO_DETECTION) 38 | if len(detections) == 0: 39 | detections = [roi] # Use the ROI as the detection 40 | # Parse the detections 41 | for detection in detections: 42 | track = detection.get_objects_typed(hailo.HAILO_UNIQUE_ID) 43 | track_id = None 44 | label = None 45 | confidence = 0.0 46 | for track_id_obj in track: 47 | track_id = track_id_obj.get_id() 48 | if track_id is not None: 49 | string_to_print += f'Track ID: {track_id} ' 50 | classifications = detection.get_objects_typed(hailo.HAILO_CLASSIFICATION) 51 | if len(classifications) > 0: 52 | string_to_print += ' CLIP Classifications:' 53 | for classification in classifications: 54 | label = classification.get_label() 55 | confidence = classification.get_confidence() 56 | string_to_print += f'Label: {label} Confidence: {confidence:.2f} ' 57 | string_to_print += '\n' 58 | if isinstance(detection, hailo.HailoDetection): 59 | label = detection.get_label() 60 | bbox = detection.get_bbox() 61 | confidence = detection.get_confidence() 62 | string_to_print += f"Detection: {label} {confidence:.2f}\n" 63 | if string_to_print: 64 | print(string_to_print) 65 | return Gst.PadProbeReturn.OK 66 | 67 | def main(): 68 | user_data = user_app_callback_class() 69 | clip = ClipApp(user_data, user_app_callback) 70 | clip.run() 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /compile_postprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the project directory name 4 | PROJECT_DIR="." 5 | 6 | # Get the build mode from the command line (default to release) 7 | if [ "$1" = "debug" ]; then 8 | BUILD_MODE="debug" 9 | else 10 | BUILD_MODE="release" 11 | fi 12 | 13 | # Create the build directory 14 | BUILD_DIR="$PROJECT_DIR/build.$BUILD_MODE" 15 | mkdir -p $BUILD_DIR 16 | cd $BUILD_DIR 17 | 18 | # Configure the project with Meson 19 | meson setup .. --buildtype=$BUILD_MODE 20 | 21 | # Compile the project 22 | ninja 23 | 24 | # Install the project (optional) 25 | ninja install 26 | -------------------------------------------------------------------------------- /cpp/TextImageMatcher.cpp: -------------------------------------------------------------------------------- 1 | #include "TextImageMatcher.hpp" 2 | // Define static members 3 | TextImageMatcher* TextImageMatcher::instance = nullptr; 4 | std::mutex TextImageMatcher::mutex; -------------------------------------------------------------------------------- /cpp/TextImageMatcher.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #ifndef TEXTIMAGEMATCHER_H 20 | #define TEXTIMAGEMATCHER_H 21 | 22 | // Usage: 23 | // To use the singleton instance of TextImageMatcher, you would call: 24 | // auto matcher = TextImageMatcher::getInstance("", 0.5f, 5) 25 | 26 | class TextEmbeddingEntry { 27 | public: 28 | std::string text; 29 | xt::xarray embedding; // Use xtensor array for embedding 30 | bool negative; 31 | bool ensemble; 32 | double probability; // Use double for probability 33 | 34 | TextEmbeddingEntry(std::string txt, const std::vector& emb, bool neg, bool ens) 35 | : text(txt), negative(neg), ensemble(ens), probability(0.0) { 36 | // Adapt the std::vector to xt::xarray 37 | embedding = xt::adapt(emb, {emb.size()}); 38 | } 39 | }; 40 | 41 | class Match { 42 | 43 | public: 44 | int row_idx; 45 | std::string text; 46 | double similarity; // Use double for similarity 47 | int entry_index; 48 | bool negative; 49 | bool passed_threshold; 50 | 51 | Match(int r_idx, std::string txt, double sim, int e_idx, bool neg, bool passed) 52 | : row_idx(r_idx), text(txt), similarity(sim), entry_index(e_idx), negative(neg), passed_threshold(passed) {} 53 | }; 54 | 55 | class TextImageMatcher { 56 | public: 57 | std::string model_name; 58 | double threshold; // Use double for threshold 59 | int max_entries; 60 | bool run_softmax = true; 61 | std::vector entries; 62 | std::string user_data = ""; 63 | std::string text_prefix = "A photo of a "; 64 | 65 | private: 66 | // Singleton instance 67 | static TextImageMatcher* instance; 68 | static std::mutex mutex; 69 | 70 | // Prevent Copy Construction and Assignment 71 | TextImageMatcher(const TextImageMatcher&) = delete; 72 | TextImageMatcher& operator=(const TextImageMatcher&) = delete; 73 | 74 | // Private Constructor 75 | TextImageMatcher(std::string m_name, double thresh, int max_ents) 76 | : model_name(m_name), threshold(thresh), max_entries(max_ents) { 77 | // Initialize entries with default TextEmbeddingEntry 78 | for (int i = 0; i < max_entries; ++i) { 79 | entries.push_back(TextEmbeddingEntry("", std::vector(), false, false)); 80 | } 81 | } 82 | std::atomic m_debug;//When set outputs all matches overrides match(report_all = false) 83 | 84 | public: 85 | // Public Method to get the singleton instance 86 | static TextImageMatcher* getInstance(std::string model_name, float threshold, int max_entries) { 87 | std::lock_guard lock(mutex); // Thread-safe in a multi-threaded environment 88 | if (instance == nullptr) { 89 | instance = new TextImageMatcher(model_name, threshold, max_entries); 90 | } 91 | return instance; 92 | } 93 | 94 | // Destructor 95 | ~TextImageMatcher() { 96 | // Cleanup code 97 | } 98 | 99 | void set_threshold(double new_threshold) { 100 | threshold = new_threshold; 101 | } 102 | 103 | void set_text_prefix(std::string new_text_prefix) { 104 | text_prefix = new_text_prefix; 105 | } 106 | 107 | std::vector get_embeddings() { 108 | std::vector valid_entries; 109 | for (size_t i = 0; i < entries.size(); i++) { 110 | if (!entries[i].text.empty()) { 111 | valid_entries.push_back(i); 112 | } 113 | } 114 | return valid_entries; 115 | } 116 | 117 | void load_embeddings(std::string filename) { 118 | if (!std::filesystem::exists(filename)) { 119 | std::ofstream file(filename); 120 | file.close(); 121 | std::cout << "File " << filename << " does not exist, creating it." << std::endl; 122 | } else { 123 | try { 124 | std::ifstream f(filename); 125 | nlohmann::json data; 126 | f >> data; 127 | 128 | threshold = data["threshold"].get(); 129 | text_prefix = data["text_prefix"].get(); 130 | 131 | entries.clear(); 132 | for (size_t i = 0; i < data["entries"].size(); i++) { 133 | std::string text = data["entries"][i]["text"]; 134 | std::vector embedding = data["entries"][i]["embedding"].get>(); 135 | bool negative = data["entries"][i]["negative"]; 136 | bool ensemble = data["entries"][i]["ensemble"]; 137 | entries.push_back(TextEmbeddingEntry(text, embedding, negative, ensemble)); 138 | } 139 | } catch (const std::exception& e) { 140 | std::cout << "Error while loading file " << filename << ": " << e.what() << ". Maybe you forgot to save your embeddings?" << std::endl; 141 | } 142 | } 143 | } 144 | void set_debug(bool debug) { 145 | m_debug.store(debug); 146 | std::cout << "Setting debug to: " << m_debug.load() << std::endl; 147 | } 148 | 149 | std::vector match(const xt::xarray& image_embedding_np, bool report_all = false) { 150 | 151 | bool report_all_debug = report_all || m_debug.load(); 152 | 153 | std::vector results; 154 | // Ensure the input is a 2D array 155 | xt::xarray image_embedding = image_embedding_np; 156 | if (image_embedding.dimension() == 1) { 157 | image_embedding = image_embedding.reshape({1, -1}); 158 | } 159 | // Getting valid entries 160 | std::vector valid_entries = get_embeddings(); 161 | if (valid_entries.empty()) { 162 | return results; // Return an empty list if no valid entries 163 | } 164 | 165 | std::vector> to_stack; 166 | to_stack.reserve(valid_entries.size()); // Reserve memory in advance 167 | 168 | xt::xarray text_embeddings_np; // Declare text_embeddings_np outside the if block 169 | 170 | if (!valid_entries.empty()) { 171 | for (size_t entry_idx : valid_entries) { 172 | if (entry_idx < entries.size()) { 173 | to_stack.push_back(entries[entry_idx].embedding); 174 | } 175 | } 176 | 177 | if (!to_stack.empty()) { 178 | // Initialize text_embeddings_np with the correct shape 179 | text_embeddings_np.resize({to_stack.size(), to_stack.front().size()}); 180 | for (size_t i = 0; i < to_stack.size(); ++i) { 181 | xt::view(text_embeddings_np, i, xt::all()) = to_stack[i]; 182 | } 183 | } 184 | } 185 | 186 | // Looping through each image embedding 187 | for (std::size_t row_idx = 0; row_idx < image_embedding.shape()[0]; ++row_idx) { 188 | auto image_embedding_1d = xt::view(image_embedding, row_idx); 189 | xt::xarray dot_products = xt::linalg::dot(text_embeddings_np, image_embedding_1d); 190 | xt::xarray similarities; 191 | 192 | if (run_softmax) { 193 | similarities = xt::exp(100 * dot_products); 194 | double sum = xt::sum(similarities)(); 195 | if (sum != 0) { 196 | similarities /= sum; 197 | } else { 198 | similarities = xt::zeros({dot_products.size()}); 199 | } 200 | } else { 201 | // These values are based on statistics collected for the RN50x4 model 202 | similarities = (dot_products - 0.27) / (0.41 - 0.27); 203 | similarities = xt::clip(similarities, 0, 1); 204 | } 205 | 206 | int best_idx = xt::argmax(similarities)(); 207 | double best_similarity = similarities[best_idx]; 208 | 209 | // Updating probabilities in entries 210 | for (size_t i = 0; i < similarities.size(); i++) { 211 | entries[valid_entries[i]].probability = similarities[i]; 212 | } 213 | 214 | // Creating a new match object 215 | Match new_match(row_idx, 216 | entries[valid_entries[best_idx]].text, 217 | best_similarity, 218 | valid_entries[best_idx], 219 | entries[valid_entries[best_idx]].negative, 220 | best_similarity > threshold); 221 | 222 | // Filtering results based on conditions 223 | if (!report_all_debug && new_match.negative) { 224 | continue; 225 | } 226 | if (report_all_debug || new_match.passed_threshold) { 227 | results.push_back(new_match); 228 | } 229 | } 230 | 231 | // print results 232 | // std::cout << "Best match output: ["; 233 | // for (const auto& match : results) { 234 | // std::cout << match.text << ", "; 235 | // } 236 | // std::cout << "]" << std::endl; 237 | 238 | return results; 239 | } 240 | }; 241 | 242 | #endif // TEXTIMAGEMATCHER_H 243 | -------------------------------------------------------------------------------- /cpp/clip.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | #include 6 | #include "common/tensors.hpp" 7 | #include "common/math.hpp" 8 | #include "clip.hpp" 9 | #include "hailo_tracker.hpp" 10 | #include "hailo_xtensor.hpp" 11 | #include "xtensor/xadapt.hpp" 12 | #include "xtensor/xarray.hpp" 13 | 14 | #define OUTPUT_LAYER_NAME "clip_resnet_50x4/conv89" 15 | 16 | 17 | ClipParams *init(std::string config_path, std::string func_name) 18 | { 19 | if (config_path == "NULL") 20 | { 21 | config_path = "hailo_tracker"; 22 | } 23 | ClipParams *params = new ClipParams(config_path); 24 | return params; 25 | } 26 | 27 | void clip(HailoROIPtr roi, std::string layer_name, std::string tracker_name) 28 | { 29 | if (!roi->has_tensors()) 30 | { 31 | return; 32 | } 33 | // Remove previous matrices 34 | roi->remove_objects_typed(HAILO_MATRIX); 35 | 36 | auto tensor = roi->get_tensor(layer_name); 37 | xt::xarray embeddings = common::get_xtensor_float(tensor); 38 | 39 | // vector normalization 40 | auto normalized_embedding = common::vector_normalization(embeddings); 41 | HailoMatrixPtr hailo_matrix = hailo_common::create_matrix_ptr(normalized_embedding); 42 | roi->add_object(hailo_matrix); 43 | } 44 | 45 | void filter(HailoROIPtr roi, void *params_void_ptr) 46 | { 47 | ClipParams *params = reinterpret_cast(params_void_ptr); 48 | clip(roi, OUTPUT_LAYER_NAME, params->tracker_name); 49 | } 50 | -------------------------------------------------------------------------------- /cpp/clip.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | #pragma once 6 | #include "hailo_objects.hpp" 7 | #include "hailo_common.hpp" 8 | 9 | __BEGIN_DECLS 10 | class ClipParams 11 | { 12 | public: 13 | std::string tracker_name; // Should have the same name as the relevant hailo_tracker 14 | ClipParams(std::string tracker_name) : tracker_name(tracker_name) {} 15 | }; 16 | ClipParams *init(std::string config_path, std::string func_name); 17 | 18 | void filter(HailoROIPtr roi, void *params_void_ptr); 19 | __END_DECLS -------------------------------------------------------------------------------- /cpp/clip_croppers.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | 6 | // This file is a modified version of tappas/core/hailo/libs/croppers/vms/vms_croppers.cpp 7 | 8 | #include 9 | #include 10 | 11 | #include "clip_croppers.hpp" 12 | 13 | #define PERSON_LABEL "person" 14 | #define FACE_LABEL "face" 15 | #define OBJECT_LABEL "object" 16 | 17 | /** 18 | * @brief Get the tracking Hailo Unique Id object from a Hailo Detection. 19 | * 20 | * @param detection HailoDetectionPtr 21 | * @return HailoUniqueIdPtr pointer to the Hailo Unique Id object 22 | */ 23 | HailoUniqueIDPtr get_tracking_id(HailoDetectionPtr detection) 24 | { 25 | for (auto obj : detection->get_objects_typed(HAILO_UNIQUE_ID)) 26 | { 27 | HailoUniqueIDPtr id = std::dynamic_pointer_cast(obj); 28 | if (id->get_mode() == TRACKING_ID) 29 | { 30 | return id; 31 | } 32 | } 33 | return nullptr; 34 | } 35 | 36 | std::map track_counter; 37 | std::deque order; 38 | long unsigned int max_tracked_objects = 100; 39 | 40 | /** 41 | * @brief Returns a boolean indicating if tracker update is required for a given detection. 42 | * It is determined by the number of frames since the last update. 43 | * How many frames to wait for an update are defined in TRACK_UPDATE. 44 | * 45 | * @param detection HailoDetectionPtr 46 | * @param use_track_update boolean can override the default behaviour, false will always require an update 47 | * @return boolean indicating if tracker update is required. 48 | */ 49 | bool track_update(HailoDetectionPtr detection, bool use_track_update, int TRACK_UPDATE=15) 50 | { 51 | auto tracking_obj = get_tracking_id(detection); 52 | if (!tracking_obj) 53 | { 54 | // No tracking object found - track update required. 55 | return false; 56 | } 57 | if (use_track_update) 58 | { 59 | int tracking_id = tracking_obj->get_id(); 60 | auto counter = track_counter.find(tracking_id); 61 | if (counter == track_counter.end()) 62 | { 63 | // Emplace new element to the track_counter map. track update required. 64 | if (track_counter.size() >= max_tracked_objects) 65 | { 66 | // Remove the oldest entry 67 | int oldest_id = order.front(); 68 | order.pop_front(); 69 | track_counter.erase(oldest_id); 70 | } 71 | track_counter.emplace(tracking_id, 0); 72 | order.push_back(tracking_id); 73 | return true; 74 | } 75 | else if (counter->second >= TRACK_UPDATE) 76 | { 77 | // Counter passed the TRACK_UPDATE limit - set existing track to 0. track update required. 78 | track_counter[tracking_id] = 0; 79 | return true; 80 | } 81 | else if (counter->second < TRACK_UPDATE) 82 | { 83 | // Counter is still below TRACK_UPDATE_LIMIT - increasing the existing value. track update should be skipped. 84 | track_counter[tracking_id] += 1; 85 | } 86 | return false; 87 | } 88 | // Use track update is false - track update required. 89 | return true; 90 | } 91 | 92 | /** 93 | * @brief Returns a vector of detections to crop and resize. 94 | * 95 | * @param image The original picture (cv::Mat). 96 | * @param roi The main ROI of this picture. 97 | * @param label The label to crop. 98 | * @param crop_every_x_frames Run crop every X frames per tracked object. 99 | * @param max_crops_per_frame Max number of objects to crop per frame. 100 | * @return std::vector vector of ROI's to crop and resize. 101 | */ 102 | 103 | std::vector object_crop(const std::shared_ptr& image, const HailoROIPtr& roi, const std::string label=PERSON_LABEL, 104 | int crop_every_x_frames=30, int max_crops_per_frame=5) 105 | { 106 | auto object_counter = 0; 107 | std::vector crop_rois; 108 | 109 | std::vector detections_ptrs = hailo_common::get_hailo_detections(roi); 110 | std::vector detections_to_crop; 111 | 112 | for (HailoDetectionPtr &detection : detections_ptrs) 113 | { 114 | auto detection_label = detection->get_label(); 115 | if (label != detection->get_label()) 116 | { 117 | // Not the label we are looking for. 118 | continue; 119 | } 120 | auto tracking_obj = get_tracking_id(detection); 121 | if (!tracking_obj) 122 | { 123 | // object is not tracked don't crop it. 124 | continue; 125 | } 126 | if (track_update(detection, true, crop_every_x_frames)) 127 | { 128 | detections_to_crop.emplace_back(detection); 129 | object_counter += 1; 130 | if (object_counter >= max_crops_per_frame) 131 | { 132 | break; 133 | } 134 | 135 | } 136 | } 137 | 138 | for (HailoDetectionPtr &detection : detections_to_crop) 139 | { 140 | crop_rois.emplace_back(detection); 141 | } 142 | return crop_rois; 143 | } 144 | 145 | std::vector face_cropper(std::shared_ptr image, HailoROIPtr roi) 146 | { 147 | return object_crop(image, roi, FACE_LABEL, 15, 8); 148 | } 149 | 150 | std::vector person_cropper(std::shared_ptr image, HailoROIPtr roi) 151 | { 152 | return object_crop(image, roi, PERSON_LABEL, 15, 8); 153 | } 154 | 155 | std::vector object_cropper(std::shared_ptr image, HailoROIPtr roi) 156 | { 157 | return object_crop(image, roi, OBJECT_LABEL, 15, 8); 158 | } 159 | -------------------------------------------------------------------------------- /cpp/clip_croppers.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | #pragma once 6 | #include 7 | #include 8 | #include "hailo_objects.hpp" 9 | #include "hailo_common.hpp" 10 | #include "hailomat.hpp" 11 | 12 | __BEGIN_DECLS 13 | std::vector person_cropper(std::shared_ptr mat, HailoROIPtr roi); 14 | std::vector face_cropper(std::shared_ptr image, HailoROIPtr roi); 15 | std::vector object_cropper(std::shared_ptr image, HailoROIPtr roi); 16 | 17 | __END_DECLS -------------------------------------------------------------------------------- /cpp/clip_croppers_new.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | // Note this implementation uses the HailoUserMetadata object to store the crop aging value. 6 | // There is and issue with the destructor of this object that is causing a segmentation fault when the object is destroyed. 7 | 8 | // This file is a modified version of tappas/core/hailo/libs/croppers/vms/vms_croppers.cpp 9 | 10 | #include 11 | #include 12 | #include "clip_croppers.hpp" 13 | 14 | #define PERSON_LABEL "person" 15 | #define FACE_LABEL "face" 16 | #define OBJECT_LABEL "object" 17 | 18 | /** 19 | * @brief Get the tracking Hailo Unique Id object from a Hailo Detection. 20 | * 21 | * @param detection HailoDetectionPtr 22 | * @return HailoUniqueIdPtr pointer to the Hailo Unique Id object 23 | */ 24 | HailoUniqueIDPtr get_tracking_id(HailoDetectionPtr detection) 25 | { 26 | for (auto obj : detection->get_objects_typed(HAILO_UNIQUE_ID)) 27 | { 28 | HailoUniqueIDPtr id = std::dynamic_pointer_cast(obj); 29 | if (id->get_mode() == TRACKING_ID) 30 | { 31 | return id; 32 | } 33 | } 34 | return nullptr; 35 | } 36 | 37 | /** 38 | * @brief reset crop aging, if not found create it. 39 | * 40 | * @param detection HailoDetectionPtr 41 | * @return none 42 | */ 43 | 44 | void reset_crop_aging(HailoDetectionPtr detection) 45 | { 46 | for (auto obj : detection->get_objects_typed(HAILO_USER_META)) 47 | { 48 | return; 49 | HailoUserMetaPtr meta = std::dynamic_pointer_cast(obj); 50 | if (meta->get_user_string() == "CROP_AGING") 51 | { 52 | meta->set_user_int(0); 53 | return; 54 | } 55 | } 56 | // if we got here, it means the crop aging meta was not found. create it. 57 | HailoUserMetaPtr meta = std::make_shared(); 58 | meta->set_user_string("CROP_AGING"); 59 | meta->set_user_int(0); 60 | detection->add_object(meta); 61 | return; 62 | } 63 | 64 | /** 65 | * @brief Get and increase the crop aging meta from a Hailo Detection. 66 | * 67 | * @param detection HailoDetectionPtr 68 | * @param increase boolean indicating if the crop aging should be increased 69 | * @return int crop aging value 70 | */ 71 | int get_crop_aging(HailoDetectionPtr detection, bool increase=false) 72 | { 73 | return 30; 74 | for (auto obj : detection->get_objects_typed(HAILO_USER_META)) 75 | { 76 | HailoUserMetaPtr meta = std::dynamic_pointer_cast(obj); 77 | if (meta->get_user_string() == "CROP_AGING") 78 | { 79 | int crop_aging = meta->get_user_int(); 80 | if (increase) 81 | { 82 | crop_aging += 1; 83 | meta->set_user_int(crop_aging); 84 | } 85 | return crop_aging; 86 | } 87 | } 88 | // if we got here, it means the crop aging meta was not found. create it. 89 | reset_crop_aging(detection); 90 | return 0; 91 | } 92 | 93 | /** 94 | * @brief Returns a vector of detections to crop and resize. 95 | * 96 | * @param image The original picture (cv::Mat). 97 | * @param roi The main ROI of this picture. 98 | * @param label The label to crop. 99 | * @param crop_every_x_frames Run crop every X frames per tracked object. 100 | * @param max_crops_per_frame Max number of objects to crop per frame. 101 | * @return std::vector vector of ROI's to crop and resize. 102 | */ 103 | 104 | std::vector object_crop(const std::shared_ptr& image, const HailoROIPtr& roi, const std::string label=PERSON_LABEL, 105 | int crop_every_x_frames=30, int max_crops_per_frame=2) 106 | { 107 | auto object_counter = 0; 108 | std::vector crop_rois; 109 | 110 | std::vector detections_ptrs = hailo_common::get_hailo_detections(roi); 111 | std::vector detections_to_crop; 112 | 113 | for (HailoDetectionPtr &detection : detections_ptrs) 114 | { 115 | auto detection_label = detection->get_label(); 116 | if (label != detection->get_label()) 117 | { 118 | // Not the label we are looking for. 119 | continue; 120 | } 121 | auto tracking_obj = get_tracking_id(detection); 122 | if (!tracking_obj) 123 | { 124 | // object is not tracked don't crop it. 125 | continue; 126 | } 127 | if (get_crop_aging(detection, true) < crop_every_x_frames) // also increase crop aging 128 | { 129 | // crop aging is below crop_every_x_frames limit. don't crop it. 130 | continue; 131 | } 132 | detections_to_crop.emplace_back(detection); 133 | } 134 | // sort detections by crop_aging desendind order. 135 | std::sort(detections_to_crop.begin(), detections_to_crop.end(), [](HailoDetectionPtr a, HailoDetectionPtr b) { 136 | return get_crop_aging(a) > get_crop_aging(b); 137 | }); 138 | 139 | for (HailoDetectionPtr &detection : detections_to_crop) 140 | { 141 | crop_rois.emplace_back(detection); 142 | // printf("cropping id %d aging %d\n", get_tracking_id(detection)->get_id(), get_crop_aging(detection)); 143 | reset_crop_aging(detection); 144 | object_counter += 1; 145 | if (object_counter >= max_crops_per_frame) 146 | { 147 | break; 148 | } 149 | } 150 | return crop_rois; 151 | } 152 | 153 | std::vector face_cropper(std::shared_ptr image, HailoROIPtr roi) 154 | { 155 | return object_crop(image, roi, FACE_LABEL, 10, 2); 156 | } 157 | 158 | std::vector person_cropper(std::shared_ptr image, HailoROIPtr roi) 159 | { 160 | return object_crop(image, roi, PERSON_LABEL, 10, 2); 161 | } 162 | 163 | std::vector object_cropper(std::shared_ptr image, HailoROIPtr roi) 164 | { 165 | return object_crop(image, roi, OBJECT_LABEL, 30, 1); 166 | } 167 | -------------------------------------------------------------------------------- /cpp/clip_matcher.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | #include 6 | #include "common/tensors.hpp" 7 | #include "common/math.hpp" 8 | #include "hailo_tracker.hpp" 9 | #include "hailo_xtensor.hpp" 10 | #include "hailo_objects.hpp" 11 | #include "hailo_common.hpp" 12 | #include "xtensor/xadapt.hpp" 13 | #include "xtensor/xarray.hpp" 14 | 15 | #include "clip_matcher.hpp" 16 | #include "TextImageMatcher.hpp" 17 | TextImageMatcher* matcher = TextImageMatcher::getInstance("", 0.8f, 6); 18 | 19 | static xt::xarray get_xtensor(HailoMatrixPtr matrix) 20 | { 21 | // Adapt a HailoTensorPtr to an xarray (quantized) 22 | xt::xarray xtensor = xt::adapt(matrix->get_data().data(), matrix->size(), xt::no_ownership(), matrix->shape()); 23 | // remove (squeeze) xtensor first dim 24 | xtensor = xt::squeeze(xtensor, 0); 25 | return xtensor; 26 | } 27 | 28 | 29 | void* init(std::string config_path, std::string func_name) 30 | { 31 | if (config_path == "NULL") 32 | { 33 | std::cout << "No default JSON provided" << std::endl; 34 | } 35 | else 36 | { 37 | matcher->load_embeddings(config_path); 38 | matcher->run_softmax = false; 39 | } 40 | return nullptr; 41 | } 42 | 43 | void update_config(std::string config_path) 44 | { 45 | if (config_path == "NULL") 46 | { 47 | std::cout << "No default JSON provided" << std::endl; 48 | } 49 | else 50 | { 51 | matcher->load_embeddings(config_path); 52 | } 53 | return; 54 | } 55 | 56 | void filter(HailoROIPtr roi) 57 | { 58 | // define 2D array for image embedding 59 | xt::xarray image_embedding; 60 | 61 | // vector to hold detections 62 | std::vector detections_ptrs; 63 | 64 | // vector to hold used detections 65 | // std::vector used_detections; 66 | std::vector used_detections; 67 | 68 | // Check if roi is used for clip 69 | auto roi_matrixs = roi->get_objects_typed(HAILO_MATRIX); 70 | if (!roi_matrixs.empty()) 71 | { 72 | HailoMatrixPtr matrix_ptr = std::dynamic_pointer_cast(roi_matrixs[0]); 73 | xt::xarray embeddings = get_xtensor(matrix_ptr); 74 | image_embedding = embeddings; 75 | used_detections.push_back(roi); 76 | } 77 | else 78 | { 79 | // Get detections from roi 80 | detections_ptrs = hailo_common::get_hailo_detections(roi); 81 | 82 | for (HailoDetectionPtr &detection : detections_ptrs) 83 | { 84 | auto matrix_objs = detection->get_objects_typed(HAILO_MATRIX); 85 | for (auto matrix : matrix_objs) 86 | { 87 | HailoMatrixPtr matrix_ptr = std::dynamic_pointer_cast(matrix); 88 | xt::xarray embeddings = get_xtensor(matrix_ptr); 89 | // if image_embedding is empty or 0-dimensional, initialize it with embeddings 90 | if (image_embedding.size() == 0 || image_embedding.dimension() == 0) 91 | { 92 | image_embedding = embeddings; 93 | } 94 | else 95 | { 96 | // if image_embedding is not empty and not 0-dimensional, concatenate it with embeddings 97 | image_embedding = xt::concatenate(xt::xtuple(image_embedding, embeddings), 0); 98 | } 99 | used_detections.push_back(detection); 100 | } 101 | } 102 | } 103 | // if image_embedding is empty, return 104 | if (image_embedding.size() == 0 || image_embedding.dimension() == 0) 105 | { 106 | return; 107 | } 108 | std::vector matches = matcher->match(image_embedding); 109 | for (auto &match : matches) 110 | { 111 | auto detection = used_detections[match.row_idx]; 112 | auto old_classifications = hailo_common::get_hailo_classifications(detection); 113 | for (auto old_classification : old_classifications) 114 | { 115 | if (old_classification->get_classification_type() == "clip") 116 | detection->remove_object(old_classification); 117 | } 118 | if (match.negative || !match.passed_threshold) 119 | { 120 | continue; 121 | } 122 | HailoClassificationPtr classification = std::make_shared(std::string("clip"), match.text, match.similarity); 123 | detection->add_object(classification); 124 | } 125 | } 126 | 127 | void run(HailoROIPtr roi) 128 | { 129 | filter(roi); 130 | } 131 | -------------------------------------------------------------------------------- /cpp/clip_matcher.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2021-2022 Hailo Technologies Ltd. All rights reserved. 3 | * Distributed under the LGPL license (https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt) 4 | **/ 5 | 6 | #pragma once 7 | #include "hailo_objects.hpp" 8 | #include "hailo_common.hpp" 9 | 10 | __BEGIN_DECLS 11 | void *init(std::string config_path, std::string func_name); 12 | void filter(HailoROIPtr roi); 13 | void run(HailoROIPtr roi); 14 | void update_config(std::string config_path); 15 | __END_DECLS 16 | -------------------------------------------------------------------------------- /cpp/meson.build: -------------------------------------------------------------------------------- 1 | 2 | ################################################ 3 | # CLIP SOURCES 4 | ################################################ 5 | clip_sources = [ 6 | 'clip.cpp', 7 | ] 8 | shared_library('clip_post', 9 | clip_sources, 10 | dependencies : postprocess_dep, 11 | gnu_symbol_visibility : 'default', 12 | install: true, 13 | install_dir: join_paths(meson.project_source_root(), 'resources'), 14 | ) 15 | 16 | ################################################ 17 | # clip_cropper SOURCES 18 | ################################################ 19 | clip_croppers_sources = [ 20 | 'clip_croppers.cpp', 21 | ] 22 | shared_library('clip_croppers', 23 | clip_croppers_sources, 24 | dependencies : postprocess_dep, 25 | gnu_symbol_visibility : 'default', 26 | install: true, 27 | install_dir: join_paths(meson.project_source_root(), 'resources'), 28 | ) 29 | 30 | ################################################ 31 | # clip_matcher SOURCES 32 | ################################################ 33 | 34 | # sudo apt-get install libblas-dev liblapack-dev 35 | # to find blas.pc 36 | # find /usr -name '*blas*.pc' 37 | 38 | cblas_dep = dependency('blas') 39 | 40 | clip_matcher_sources = [ 41 | 'clip_matcher.cpp','TextImageMatcher.cpp', 42 | ] 43 | shared_library('clip_matcher', 44 | clip_matcher_sources, 45 | dependencies : [postprocess_dep, cblas_dep], 46 | gnu_symbol_visibility : 'default', 47 | install: true, 48 | install_dir: join_paths(meson.project_source_root(), 'resources'), 49 | ) 50 | 51 | -------------------------------------------------------------------------------- /download_resources.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e # Exit immediately if a command exits with a non-zero status 4 | 5 | # Set the resource directory 6 | RESOURCE_DIR="./resources" 7 | mkdir -p "$RESOURCE_DIR" 8 | 9 | # Define download function with file existence check and retries 10 | download_model() { 11 | local url=$1 12 | local file_name=$(basename "$url") 13 | 14 | # Check if the file is for H8L and rename it accordingly 15 | if [[ ( "$url" == *"hailo8l"* || "$url" == *"h8l_rpi"* ) && ( "$url" != *"barcode"* ) ]]; then 16 | file_name="${file_name%.hef}_h8l.hef" 17 | fi 18 | 19 | local file_path="$RESOURCE_DIR/$file_name" 20 | 21 | if [ ! -f "$file_path" ]; then 22 | echo "Downloading $file_name..." 23 | wget -q --show-progress "$url" -O "$file_path" || { 24 | echo "Failed to download $file_name after multiple attempts." 25 | exit 1 26 | } 27 | else 28 | echo "File $file_name already exists in $RESOURCE_DIR. Skipping download." 29 | fi 30 | } 31 | 32 | # Define all URLs in arrays 33 | H8_HEFS=( 34 | "https://hailo-model-zoo.s3.eu-west-2.amazonaws.com/ModelZoo/Compiled/v2.9.0/clip_resnet_50x4.hef" 35 | "https://hailo-tappas.s3.eu-west-2.amazonaws.com/v3.26/general/hefs/yolov5s_personface.hef" 36 | ) 37 | 38 | H8L_HEFS=( 39 | "https://hailo-csdata.s3.eu-west-2.amazonaws.com/resources/hefs/h8l_rpi/clip_resnet_50x4_h8l.hef" 40 | "https://hailo-csdata.s3.eu-west-2.amazonaws.com/resources/hefs/h8l_rpi/yolov5s_personface_h8l_pi.hef" 41 | ) 42 | 43 | VIDEOS=( 44 | "https://hailo-csdata.s3.eu-west-2.amazonaws.com/resources/video/clip_example.mp4" 45 | ) 46 | 47 | # If --all flag is provided, download everything in parallel 48 | if [ "$1" == "--all" ]; then 49 | echo "Downloading all models and video resources..." 50 | for url in "${H8_HEFS[@]}" "${H8L_HEFS[@]}" "${VIDEOS[@]}"; do 51 | download_model "$url" & 52 | done 53 | else 54 | if [ "$DEVICE_ARCHITECTURE" == "HAILO8L" ]; then 55 | echo "Downloading HAILO8L models..." 56 | for url in "${H8L_HEFS[@]}"; do 57 | download_model "$url" & 58 | done 59 | elif [ "$DEVICE_ARCHITECTURE" == "HAILO8" ]; then 60 | echo "Downloading HAILO8 models..." 61 | for url in "${H8_HEFS[@]}"; do 62 | download_model "$url" & 63 | done 64 | fi 65 | fi 66 | 67 | # Download additional videos 68 | for url in "${VIDEOS[@]}"; do 69 | download_model "$url" & 70 | done 71 | 72 | # Wait for all background downloads to complete 73 | wait 74 | 75 | echo "All downloads completed successfully!" -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e # Exit immediately if a command exits with a non-zero status 4 | 5 | # Source environment variables and activate virtual environment 6 | echo "Sourcing environment variables and activating virtual environment..." 7 | source setup_env.sh 8 | 9 | # Install additional system dependencies (if needed) 10 | echo "Installing additional system dependencies..." 11 | sudo apt-get -y install libblas-dev nlohmann-json3-dev 12 | 13 | # Initialize variables 14 | DOWNLOAD_RESOURCES_FLAG="" 15 | PYHAILORT_WHL="" 16 | PYTAPPAS_WHL="" 17 | INSTALL_TEST_REQUIREMENTS=false 18 | TAG="25.3.1" 19 | 20 | # Parse command-line arguments 21 | while [[ "$#" -gt 0 ]]; do 22 | case $1 in 23 | --pyhailort) PYHAILORT_WHL="$2"; shift ;; 24 | --pytappas) PYTAPPAS_WHL="$2"; shift ;; 25 | --test) INSTALL_TEST_REQUIREMENTS=true ;; 26 | --all) DOWNLOAD_RESOURCES_FLAG="--all" ;; 27 | --tag) TAG="$2"; shift ;; # New parameter to specify tag for Hailo Apps Infra 28 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 29 | esac 30 | shift 31 | done 32 | 33 | # Install specified Python wheels 34 | if [[ -n "$PYHAILORT_WHL" ]]; then 35 | echo "Installing pyhailort wheel: $PYHAILORT_WHL" 36 | pip install "$PYHAILORT_WHL" 37 | fi 38 | 39 | if [[ -n "$PYTAPPAS_WHL" ]]; then 40 | echo "Installing pytappas wheel: $PYTAPPAS_WHL" 41 | pip install "$PYTAPPAS_WHL" 42 | fi 43 | 44 | # Install test requirements if needed 45 | if [[ "$INSTALL_TEST_REQUIREMENTS" == true ]]; then 46 | echo "Installing test requirements..." 47 | python3 -m pip install -r tests/test_resources/requirements.txt 48 | fi 49 | 50 | # Install the package using setup.py in editable mode 51 | echo "Installing the package using setup.py in editable mode..." 52 | python3 -m pip install -v -e . 53 | 54 | # Install Hailo Apps Infrastructure from specified tag/branch 55 | echo "Installing Hailo Apps Infrastructure from version: $TAG..." 56 | pip install "git+https://github.com/hailo-ai/hailo-apps-infra.git@$TAG" 57 | 58 | # Download resources needed for the pipelines 59 | echo "Downloading resources needed for the pipelines..." 60 | ./download_resources.sh $DOWNLOAD_RESOURCES_FLAG 61 | 62 | echo "Installation completed successfully." 63 | -------------------------------------------------------------------------------- /meson.build: -------------------------------------------------------------------------------- 1 | project('clip_app', 'c', 'cpp', 2 | version : '1.1.1', 3 | default_options : [ 'warning_level=1', 4 | 'buildtype=release', 5 | 'c_std=c11', 'cpp_std=c++17'] 6 | ) 7 | 8 | postprocess_dep = dependency('hailo-tappas-core', version : '>=3.30.0', required : false) 9 | 10 | if not postprocess_dep.found() 11 | postprocess_dep = dependency('hailo_tappas', version : '>=3.30.0', required : true) 12 | endif 13 | subdir('cpp') 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<2.0.0 2 | torch>=1.9.0 3 | torchvision>=0.10.0 4 | openai-clip 5 | Pillow -------------------------------------------------------------------------------- /resources/CLIP_UI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/resources/CLIP_UI.png -------------------------------------------------------------------------------- /resources/Hackathon-banner-2024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/resources/Hackathon-banner-2024.png -------------------------------------------------------------------------------- /resources/configs/yolov5_personface.json: -------------------------------------------------------------------------------- 1 | { 2 | "iou_threshold": 0.5, 3 | "detection_threshold": 0.5, 4 | "output_activation": "none", 5 | "label_offset":1, 6 | "max_boxes":10000, 7 | "anchors": [ 8 | [ 9 | 116, 10 | 90, 11 | 156, 12 | 198, 13 | 373, 14 | 326 15 | ], 16 | [ 17 | 30, 18 | 61, 19 | 62, 20 | 45, 21 | 59, 22 | 119 23 | ], 24 | [ 25 | 10, 26 | 13, 27 | 16, 28 | 30, 29 | 33, 30 | 23 31 | ] 32 | ], 33 | "labels": [ 34 | "unlabeled", 35 | "person", 36 | "face" 37 | ] 38 | } -------------------------------------------------------------------------------- /resources/github_clip_based_classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hailo-ai/hailo-CLIP/da90cf46d1cfb36f2b546226758c77c23c1d62b7/resources/github_clip_based_classification.png -------------------------------------------------------------------------------- /resources/texts_json_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "negative": [ 3 | "person" 4 | ], 5 | "positive": [ 6 | "keyboard", 7 | "mouse", 8 | "cat", 9 | "dog" 10 | ] 11 | } -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Source environment setup 4 | source setup_env.sh 5 | 6 | # Install pytest requirements if not already installed 7 | pip install -r tests/test_resources/requirements.txt 8 | 9 | # Run the tests 10 | echo "Running CLIP application tests..." 11 | pytest tests/test_clip_app.py -v --log-cli-level=INFO 12 | pytest tests/test_demo_clip.py -v --log-cli-level=INFO 13 | 14 | 15 | # Exit with the pytest return code 16 | exit $? -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | import sys 4 | import subprocess 5 | import logging 6 | 7 | # Configure logging 8 | logging.basicConfig(level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | def check_hailo_package(): 12 | try: 13 | import hailo 14 | except ImportError: 15 | logger.error("Hailo python package not found. Please make sure you're in the Hailo virtual environment. Run 'source setup_env.sh' and try again.") 16 | sys.exit(1) 17 | 18 | def read_requirements(): 19 | """Reads requirements from requirements.txt, converting any 'git+https://' lines to PEP 508 syntax.""" 20 | with open("requirements.txt", "r") as f: 21 | lines = f.read().splitlines() 22 | 23 | new_lines = [] 24 | for line in lines: 25 | # If the line starts with git+https, convert it to PEP 508 form: @ git+https://... 26 | if line.startswith("git+https://"): 27 | # Choose a name that matches or approximates the actual package 28 | package_name = "hailo-apps-infra" 29 | pep_508_line = f"{package_name} @ {line}" 30 | new_lines.append(pep_508_line) 31 | else: 32 | new_lines.append(line) 33 | 34 | return new_lines 35 | 36 | def run_shell_command(command, error_message): 37 | logger.info(f"Running command: {command}") 38 | result = subprocess.run(command, shell=True) 39 | if result.returncode != 0: 40 | logger.error(f"{error_message}. Exit code: {result.returncode}") 41 | sys.exit(result.returncode) 42 | 43 | def main(): 44 | check_hailo_package() 45 | 46 | requirements = read_requirements() 47 | 48 | logger.info("Compiling C++ code...") 49 | run_shell_command("./compile_postprocess.sh", "Failed to compile C++ code") 50 | 51 | logger.info("Downloading Resources...") 52 | run_shell_command("./download_resources.sh", "Failed to download resources") 53 | 54 | setup( 55 | name='clip-app', 56 | version='0.6', 57 | description='Real time CLIP zero shot classification and detection', 58 | long_description=open('README.md').read(), 59 | long_description_content_type='text/markdown', 60 | author='Hailo', 61 | author_email='support@hailo.ai', 62 | packages=find_packages(), 63 | install_requires=requirements, 64 | entry_points={ 65 | 'console_scripts': [ 66 | 'text_image_matcher=clip_app.text_image_matcher:main', 67 | ], 68 | }, 69 | package_data={ 70 | 'clip_app': ['*.json', '*.sh', '*.cpp', '*.hpp', '*.pc'], 71 | }, 72 | include_package_data=True, 73 | ) 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TAPPAS CORE Definitions 4 | CORE_VENV_NAME="hailo_clip_venv" 5 | CORE_REQUIRED_VERSION=("3.30.0" "3.31.0" "3.32.0") 6 | 7 | # TAPPAS Definitions 8 | TAPPAS_VENV_NAME="hailo_tappas_venv" 9 | TAPPAS_REQUIRED_VERSION=("3.30.0" "3.31.0" "3.32.0") 10 | 11 | # Function to check if the script is being sourced 12 | is_sourced() { 13 | if [ -n "$ZSH_VERSION" ]; then 14 | [[ -o sourced ]] 15 | elif [ -n "$BASH_VERSION" ]; then 16 | [[ "${BASH_SOURCE[0]}" != "$0" ]] 17 | else 18 | echo "Unsupported shell. Please use bash or zsh." 19 | return 1 20 | fi 21 | } 22 | 23 | # Only proceed if the script is being sourced 24 | if is_sourced; then 25 | echo "Setting up the environment..." 26 | 27 | # Check if we are working with hailo-tappas-core or hailo_tappas 28 | if pkg-config --exists hailo-tappas-core; then 29 | TAPPAS_CORE=1 30 | VENV_NAME=$CORE_VENV_NAME 31 | REQUIRED_VERSION=("${CORE_REQUIRED_VERSION[@]}") 32 | echo "Setting up the environment for hailo-tappas-core..." 33 | TAPPAS_VERSION=$(pkg-config --modversion hailo-tappas-core) 34 | else 35 | TAPPAS_CORE=0 36 | REQUIRED_VERSION=("${TAPPAS_REQUIRED_VERSION[@]}") 37 | echo "Setting up the environment for hailo_tappas..." 38 | TAPPAS_VERSION=$(pkg-config --modversion hailo_tappas) 39 | TAPPAS_WORKSPACE=$(pkg-config --variable=tappas_workspace hailo_tappas) 40 | export TAPPAS_WORKSPACE 41 | echo "TAPPAS_WORKSPACE set to $TAPPAS_WORKSPACE" 42 | if [[ "$TAPPAS_WORKSPACE" == "/local/workspace/tappas" ]]; then 43 | VENV_NAME="DOCKER" 44 | else 45 | VENV_NAME=$TAPPAS_VENV_NAME 46 | fi 47 | fi 48 | 49 | # Check if TAPPAS_VERSION is in REQUIRED_VERSION 50 | version_match=0 51 | for version in "${REQUIRED_VERSION[@]}"; do 52 | if [ "$TAPPAS_VERSION" = "$version" ]; then 53 | version_match=1 54 | break 55 | fi 56 | done 57 | 58 | if [ "$version_match" -eq 1 ]; then 59 | echo "TAPPAS_VERSION is ${TAPPAS_VERSION}. Proceeding..." 60 | else 61 | echo "TAPPAS_VERSION is ${TAPPAS_VERSION} not in the list of required versions ${REQUIRED_VERSION[*]}." 62 | return 1 63 | fi 64 | 65 | if [ $TAPPAS_CORE -eq 1 ]; then 66 | # Get the directory of the current script 67 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-${(%):-%N}}")" &> /dev/null && pwd)" 68 | # Check if we are in the defined virtual environment 69 | if [[ "$VIRTUAL_ENV" == *"$VENV_NAME"* ]]; then 70 | echo "You are in the $VENV_NAME virtual environment." 71 | else 72 | echo "You are not in the $VENV_NAME virtual environment." 73 | # Check if the virtual environment exists in the same directory as the script 74 | if [ -d "$SCRIPT_DIR/$VENV_NAME" ]; then 75 | echo "Virtual environment exists. Activating..." 76 | source "$SCRIPT_DIR/$VENV_NAME/bin/activate" 77 | else 78 | echo "Virtual environment does not exist. Creating and activating..." 79 | python3 -m venv --system-site-packages "$SCRIPT_DIR/$VENV_NAME" 80 | source "$SCRIPT_DIR/$VENV_NAME/bin/activate" 81 | fi 82 | fi 83 | TAPPAS_POST_PROC_DIR=$(pkg-config --variable=tappas_postproc_lib_dir hailo-tappas-core) 84 | else 85 | if [[ "$VENV_NAME" == "DOCKER" ]]; then 86 | echo "Running in DOCKER using default virtualenv" 87 | else 88 | # Check if we are in the defined virtual environment 89 | if [[ "$VIRTUAL_ENV" == *"$VENV_NAME"* ]]; then 90 | echo "You are in the $VENV_NAME virtual environment." 91 | else 92 | echo "You are not in the $VENV_NAME virtual environment." 93 | # Activate TAPPAS virtual environment 94 | VENV_PATH="${TAPPAS_WORKSPACE}/hailo_tappas_venv/bin/activate" 95 | if [ -f "$VENV_PATH" ]; then 96 | echo "Activating virtual environment..." 97 | source "$VENV_PATH" 98 | else 99 | echo "Error: Virtual environment not found at $VENV_PATH." 100 | return 1 101 | fi 102 | fi 103 | fi 104 | TAPPAS_POST_PROC_DIR="${TAPPAS_WORKSPACE}/apps/h8/gstreamer/libs/post_processes/" 105 | fi 106 | export TAPPAS_POST_PROC_DIR 107 | echo "TAPPAS_POST_PROC_DIR set to $TAPPAS_POST_PROC_DIR" 108 | 109 | # Get the Device Architecture 110 | output=$(hailortcli fw-control identify | tr -d '\0') 111 | # Extract the Device Architecture from the output 112 | device_arch=$(echo "$output" | grep "Device Architecture" | awk -F": " '{print $2}') 113 | # if the device architecture is not found, output the error message and return 114 | if [ -z "$device_arch" ]; then 115 | echo "Error: Device Architecture not found. Please check the connection to the device." 116 | return 1 117 | fi 118 | # Export the Device Architecture to an environment variable 119 | export DEVICE_ARCHITECTURE="$device_arch" 120 | # Print the environment variable to verify 121 | echo "DEVICE_ARCHITECTURE is set to: $DEVICE_ARCHITECTURE" 122 | else 123 | echo "This script needs to be sourced to correctly set up the environment. Please run '. $(basename "$0")' instead of executing it." 124 | fi 125 | -------------------------------------------------------------------------------- /tests/test_clip_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import subprocess 4 | import time 5 | import signal 6 | import json 7 | import cv2 8 | import numpy as np 9 | import gi 10 | gi.require_version('Gst', '1.0') 11 | from gi.repository import Gst, GLib 12 | 13 | # Import clip app modules 14 | from clip_app.text_image_matcher import TextImageMatcher, text_image_matcher 15 | import clip_application as callback_module 16 | 17 | class TestSanityCheck: 18 | """Basic sanity checks for the CLIP application.""" 19 | 20 | def test_hailo_runtime(self): 21 | """Test if Hailo runtime is installed and accessible.""" 22 | try: 23 | result = subprocess.run(['hailortcli', '--version'], 24 | check=True, capture_output=True, text=True) 25 | assert result.returncode == 0, "Hailo runtime check failed" 26 | print(f"Hailo runtime version: {result.stdout.strip()}") 27 | except subprocess.CalledProcessError as e: 28 | pytest.fail(f"Hailo runtime is not properly installed: {str(e)}") 29 | 30 | def test_required_files(self): 31 | """Test if all required files exist.""" 32 | required_files = [ 33 | 'setup_env.sh', 34 | 'download_resources.sh', 35 | 'compile_postprocess.sh', 36 | 'requirements.txt', 37 | 'resources/libclip_post.so', 38 | 'resources/libclip_matcher.so', 39 | 'resources/libclip_croppers.so' 40 | ] 41 | for file in required_files: 42 | assert os.path.exists(file), f"Required file missing: {file}" 43 | 44 | def test_environment_variables(self): 45 | """Test if required environment variables are set.""" 46 | required_vars = ['TAPPAS_POST_PROC_DIR', 'DEVICE_ARCHITECTURE'] 47 | for var in required_vars: 48 | assert os.environ.get(var) is not None, f"Environment variable {var} is not set" 49 | print(f"{var} is set to: {os.environ.get(var)}") 50 | 51 | def test_verify_device_architecture(self): 52 | """Verify device architecture is properly detected.""" 53 | device_arch = os.environ.get('DEVICE_ARCHITECTURE') 54 | assert device_arch in ['HAILO8', 'HAILO8L'], f"Invalid device architecture: {device_arch}" 55 | 56 | class TestTextImageMatcher: 57 | """Tests for the TextImageMatcher functionality.""" 58 | 59 | @pytest.fixture(scope="class") 60 | def matcher(self): 61 | """Fixture providing a TextImageMatcher instance.""" 62 | return text_image_matcher 63 | 64 | def test_singleton_pattern(self, matcher): 65 | """Verify TextImageMatcher implements singleton pattern correctly.""" 66 | matcher2 = TextImageMatcher() 67 | assert matcher is matcher2, "TextImageMatcher singleton pattern failed" 68 | 69 | def test_threshold_setting(self, matcher): 70 | """Test setting and getting threshold values.""" 71 | test_threshold = 0.75 72 | matcher.set_threshold(test_threshold) 73 | assert matcher.threshold == test_threshold 74 | 75 | # Test bounds 76 | matcher.set_threshold(0.0) 77 | assert matcher.threshold == 0.0 78 | matcher.set_threshold(1.0) 79 | assert matcher.threshold == 1.0 80 | 81 | def test_text_prefix(self, matcher): 82 | """Test setting and getting text prefix.""" 83 | test_prefix = "Testing: " 84 | original_prefix = matcher.text_prefix 85 | matcher.set_text_prefix(test_prefix) 86 | assert matcher.text_prefix == test_prefix 87 | # Restore original prefix 88 | matcher.set_text_prefix(original_prefix) 89 | 90 | def test_embeddings_save_load(self, matcher, tmp_path): 91 | """Test saving and loading embeddings.""" 92 | test_file = tmp_path / "test_embeddings.json" 93 | 94 | # Set test values 95 | test_threshold = 0.85 96 | test_prefix = "Test prefix: " 97 | matcher.threshold = test_threshold 98 | matcher.text_prefix = test_prefix 99 | 100 | # Save embeddings 101 | matcher.save_embeddings(str(test_file)) 102 | 103 | # Modify values 104 | matcher.threshold = 0.5 105 | matcher.text_prefix = "Changed: " 106 | 107 | # Load embeddings 108 | matcher.load_embeddings(str(test_file)) 109 | 110 | # Verify values were restored 111 | assert matcher.threshold == test_threshold 112 | assert matcher.text_prefix == test_prefix 113 | 114 | class TestGStreamerPipeline: 115 | """Tests for GStreamer pipeline functionality.""" 116 | 117 | @classmethod 118 | def setup_class(cls): 119 | """Initialize GStreamer for pipeline tests.""" 120 | Gst.init(None) 121 | 122 | def test_basic_pipeline(self): 123 | """Test basic GStreamer pipeline creation and state changes.""" 124 | pipeline_str = ( 125 | 'videotestsrc num-buffers=10 ! ' 126 | 'video/x-raw,format=RGB,width=640,height=480 ! ' 127 | 'fakesink' 128 | ) 129 | 130 | try: 131 | pipeline = Gst.parse_launch(pipeline_str) 132 | assert pipeline is not None 133 | 134 | # Test state changes 135 | ret = pipeline.set_state(Gst.State.PLAYING) 136 | assert ret != Gst.StateChangeReturn.FAILURE 137 | 138 | time.sleep(1) # Let it run briefly 139 | 140 | ret = pipeline.set_state(Gst.State.NULL) 141 | assert ret != Gst.StateChangeReturn.FAILURE 142 | 143 | except GLib.Error as e: 144 | pytest.fail(f"Pipeline creation failed: {e}") 145 | 146 | def test_plugin_availability(self): 147 | """Test if required GStreamer plugins are available.""" 148 | required_plugins = ['hailo', 'hailotools'] 149 | for plugin in required_plugins: 150 | registry = Gst.Registry.get() 151 | plugin_obj = registry.find_plugin(plugin) 152 | assert plugin_obj is not None, f"Required GStreamer plugin '{plugin}' not found" 153 | 154 | class TestCallbackFunctionality: 155 | """Tests for callback functionality.""" 156 | 157 | def test_callback_class(self): 158 | """Test the callback class functionality.""" 159 | callback_instance = callback_module.app_callback_class() 160 | 161 | # Test initial state 162 | assert callback_instance.frame_count == 0 163 | assert callback_instance.use_frame is False 164 | assert callback_instance.running is True 165 | 166 | # Test frame counter 167 | callback_instance.increment() 168 | assert callback_instance.frame_count == 1 169 | assert callback_instance.get_count() == 1 170 | 171 | class TestEdgeCases: 172 | """Tests for edge cases and error handling.""" 173 | 174 | 175 | 176 | def test_empty_embeddings(self): 177 | """Test matcher behavior with empty embeddings.""" 178 | matcher = text_image_matcher 179 | empty_embeddings = matcher.get_embeddings() 180 | assert isinstance(empty_embeddings, list), "get_embeddings should return a list" 181 | 182 | # Match with empty embeddings should return empty list 183 | result = matcher.match(np.array([0.1, 0.2, 0.3])) 184 | assert isinstance(result, list), "match should return a list" 185 | assert len(result) == 0, "match with empty embeddings should return empty list" 186 | 187 | def test_clean_shutdown(): 188 | """Test clean shutdown of pipeline and resources.""" 189 | Gst.init(None) 190 | pipeline_str = 'videotestsrc ! fakesink' 191 | pipeline = Gst.parse_launch(pipeline_str) 192 | 193 | # Start pipeline 194 | pipeline.set_state(Gst.State.PLAYING) 195 | 196 | # Simulate shutdown 197 | time.sleep(0.1) 198 | pipeline.send_event(Gst.Event.new_eos()) 199 | pipeline.set_state(Gst.State.NULL) 200 | 201 | # Verify pipeline is properly cleaned up 202 | state = pipeline.get_state(0)[1] 203 | assert state == Gst.State.NULL, "Pipeline not properly cleaned up" 204 | 205 | if __name__ == "__main__": 206 | pytest.main(["-v", __file__]) 207 | 208 | -------------------------------------------------------------------------------- /tests/test_demo_clip.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import subprocess 3 | import os 4 | import sys 5 | import time 6 | import signal 7 | import glob 8 | import logging 9 | import gi 10 | gi.require_version('Gst', '1.0') 11 | from gi.repository import Gst, GLib 12 | import json 13 | from pathlib import Path 14 | import numpy as np 15 | 16 | 17 | # Add path for clip app 18 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 19 | 20 | try: 21 | from picamera2 import Picamera2 22 | rpi_camera_available = True 23 | except ImportError: 24 | rpi_camera_available = False 25 | 26 | # Constants 27 | TEST_RUN_TIME = 5 # Same as Hailo RPi examples 28 | LOG_DIR = "logs" 29 | os.makedirs(LOG_DIR, exist_ok=True) 30 | 31 | def get_usb_video_devices(): 32 | """Get a list of video devices that are connected via USB and have video capture capability.""" 33 | video_devices = [f'/dev/{device}' for device in os.listdir('/dev') if device.startswith('video')] 34 | usb_video_devices = [] 35 | 36 | for device in video_devices: 37 | try: 38 | udevadm_cmd = ["udevadm", "info", "--query=all", "--name=" + device] 39 | result = subprocess.run(udevadm_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 40 | output = result.stdout.decode('utf-8') 41 | 42 | if "ID_BUS=usb" in output and ":capture:" in output: 43 | usb_video_devices.append(device) 44 | except Exception as e: 45 | print(f"Error checking device {device}: {e}") 46 | 47 | return usb_video_devices 48 | 49 | def check_rpi_camera_available(): 50 | """Check if RPi camera is available.""" 51 | try: 52 | result = subprocess.run( 53 | ['rpicam-hello', '-t', '1'], 54 | capture_output=True, 55 | text=True, 56 | timeout=2 57 | ) 58 | return "no cameras available" not in result.stderr.lower() 59 | except (subprocess.TimeoutExpired, FileNotFoundError): 60 | return False 61 | 62 | def test_rpi_camera_connection(): 63 | """Test if RPI camera is connected by running rpicam-hello.""" 64 | log_file_path = os.path.join(LOG_DIR, "rpi_camera_test.log") 65 | 66 | with open(log_file_path, "w") as log_file: 67 | process = subprocess.Popen( 68 | ['rpicam-hello', '-t', '0', '--post-process-file', '/usr/share/rpi-camera-assets/hailo_yolov6_inference.json'], 69 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 70 | 71 | try: 72 | time.sleep(TEST_RUN_TIME) 73 | process.send_signal(signal.SIGTERM) 74 | process.wait(timeout=5) 75 | except subprocess.TimeoutExpired: 76 | process.kill() 77 | pytest.fail(f"RPI camera connection test could not be terminated") 78 | 79 | stdout, stderr = process.communicate() 80 | log_file.write(f"rpi_camera stdout:\n{stdout.decode()}\n") 81 | log_file.write(f"rpi_camera stderr:\n{stderr.decode()}\n") 82 | 83 | if "ERROR: *** no cameras available ***" in stderr.decode(): 84 | pytest.skip("RPI camera is not connected") 85 | else: 86 | log_file.write("RPI camera is connected and working.\n") 87 | 88 | def test_demo_clip(): 89 | """Test CLIP application with demo video.""" 90 | log_file_path = os.path.join(LOG_DIR, "clip_demo_test.log") 91 | 92 | with open(log_file_path, "w") as log_file: 93 | process = subprocess.Popen( 94 | ['python', 'clip_application.py', '--input', 'demo', '--disable-runtime-prompts'], 95 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 96 | 97 | try: 98 | time.sleep(TEST_RUN_TIME) 99 | process.send_signal(signal.SIGTERM) 100 | process.wait(timeout=5) 101 | except subprocess.TimeoutExpired: 102 | process.kill() 103 | pytest.fail(f"Demo clip test could not be terminated") 104 | 105 | stdout, stderr = process.communicate() 106 | log_file.write(f"Demo clip stdout:\n{stdout.decode()}\n") 107 | log_file.write(f"Demo clip stderr:\n{stderr.decode()}\n") 108 | 109 | assert "Traceback" not in stderr.decode(), f"Exception occurred in demo test: {stderr.decode()}" 110 | assert "Error" not in stderr.decode(), f"Error occurred in demo test: {stderr.decode()}" 111 | log_file.write("Demo clip test completed successfully.\n") 112 | 113 | def test_all_detectors(): 114 | """Test CLIP application with different detectors.""" 115 | detectors = ['none', 'person', 'face'] 116 | 117 | for detector in detectors: 118 | log_file_path = os.path.join(LOG_DIR, f"clip_{detector}_test.log") 119 | with open(log_file_path, "w") as log_file: 120 | process = subprocess.Popen( 121 | ['python', 'clip_application.py', '--input', 'demo', '--detector', detector, '--disable-runtime-prompts'], 122 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 123 | 124 | try: 125 | time.sleep(TEST_RUN_TIME) 126 | process.send_signal(signal.SIGTERM) 127 | process.wait(timeout=5) 128 | except subprocess.TimeoutExpired: 129 | process.kill() 130 | pytest.fail(f"{detector} detector test could not be terminated") 131 | 132 | stdout, stderr = process.communicate() 133 | log_file.write(f"{detector} detector stdout:\n{stdout.decode()}\n") 134 | log_file.write(f"{detector} detector stderr:\n{stderr.decode()}\n") 135 | 136 | assert "Traceback" not in stderr.decode(), f"Exception occurred with detector {detector}: {stderr.decode()}" 137 | assert "Error" not in stderr.decode(), f"Error occurred with detector {detector}: {stderr.decode()}" 138 | log_file.write(f"{detector} detector test completed successfully.\n") 139 | 140 | @pytest.mark.camera 141 | def test_usb_camera(): 142 | """Test CLIP application with USB camera.""" 143 | usb_cameras = get_usb_video_devices() 144 | if not usb_cameras: 145 | pytest.skip("No USB cameras found") 146 | 147 | for camera in usb_cameras: 148 | device_name = os.path.basename(camera) 149 | log_file_path = os.path.join(LOG_DIR, f"clip_usb_camera_{device_name}_test.log") 150 | 151 | with open(log_file_path, "w") as log_file: 152 | process = subprocess.Popen( 153 | ['python', 'clip_application.py', '--input', camera, '--disable-runtime-prompts'], 154 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 155 | 156 | try: 157 | time.sleep(TEST_RUN_TIME) 158 | process.send_signal(signal.SIGTERM) 159 | process.wait(timeout=5) 160 | except subprocess.TimeoutExpired: 161 | process.kill() 162 | pytest.fail(f"USB camera test for {camera} could not be terminated") 163 | 164 | stdout, stderr = process.communicate() 165 | log_file.write(f"USB camera stdout:\n{stdout.decode()}\n") 166 | log_file.write(f"USB camera stderr:\n{stderr.decode()}\n") 167 | 168 | assert "Traceback" not in stderr.decode(), f"Exception occurred with USB camera: {stderr.decode()}" 169 | assert "Error" not in stderr.decode(), f"Error occurred with USB camera: {stderr.decode()}" 170 | log_file.write("USB camera test completed successfully.\n") 171 | 172 | @pytest.mark.camera 173 | def test_rpi_camera(): 174 | """Test CLIP application with RPi camera.""" 175 | if not check_rpi_camera_available(): 176 | pytest.skip("RPi camera not available") 177 | 178 | log_file_path = os.path.join(LOG_DIR, "clip_rpi_camera_test.log") 179 | 180 | with open(log_file_path, "w") as log_file: 181 | process = subprocess.Popen( 182 | ['python', 'clip_application.py', '--input', 'rpi', '--disable-runtime-prompts'], 183 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 184 | 185 | try: 186 | time.sleep(TEST_RUN_TIME) 187 | process.send_signal(signal.SIGTERM) 188 | process.wait(timeout=5) 189 | except subprocess.TimeoutExpired: 190 | process.kill() 191 | pytest.fail("RPi camera test could not be terminated") 192 | 193 | stdout, stderr = process.communicate() 194 | log_file.write(f"RPi camera stdout:\n{stdout.decode()}\n") 195 | log_file.write(f"RPi camera stderr:\n{stderr.decode()}\n") 196 | 197 | assert "Traceback" not in stderr.decode(), f"Exception occurred with RPi camera: {stderr.decode()}" 198 | assert "Error" not in stderr.decode(), f"Error occurred with RPi camera: {stderr.decode()}" 199 | log_file.write("RPi camera test completed successfully.\n") 200 | 201 | class TestRuntimePrompts: 202 | """Tests for runtime prompts functionality.""" 203 | 204 | def test_runtime_prompts_enabled(self): 205 | """Test with runtime prompts enabled (default).""" 206 | process = subprocess.Popen( 207 | ['python', 'clip_application.py', '--input', 'demo'], 208 | stdout=subprocess.PIPE, stderr=subprocess.PIPE, 209 | text=True) 210 | 211 | try: 212 | time.sleep(TEST_RUN_TIME) 213 | process.send_signal(signal.SIGTERM) 214 | process.wait(timeout=5) 215 | stdout, stderr = process.communicate() 216 | 217 | if stderr and ("Error:" in stderr or "Traceback" in stderr): 218 | pytest.fail(f"Runtime prompts test failed with error:\n{stderr}") 219 | 220 | assert process.returncode in [0, -15], f"Unexpected return code {process.returncode}" 221 | finally: 222 | if process.poll() is None: 223 | process.kill() 224 | 225 | if __name__ == "__main__": 226 | pytest.main(["-v", __file__]) 227 | -------------------------------------------------------------------------------- /tests/test_resources/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-timeout 3 | --------------------------------------------------------------------------------