├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── dependabot.yml ├── .gitignore ├── CNAME ├── DeepLabStream.py ├── LICENSE ├── Readme.md ├── __init__.py ├── _config.yml ├── app.py ├── convert_classifier.py ├── design_experiment.py ├── docs ├── DLSSTim_example.gif ├── GraphAbstract.png ├── design_experiment_gif.gif └── flowchart2.png ├── experiments ├── __init__.py ├── base │ ├── __init__.py │ ├── experiments.py │ ├── stimulation.py │ ├── stimulus_process.py │ └── triggers.py ├── configs │ ├── BaseConditionalExperiment_example.ini │ ├── BaseOptogeneticExperiment_example.ini │ ├── BaseTrialExperiment_example.ini │ └── default_config.ini ├── custom │ ├── __init__.py │ ├── classifier.py │ ├── experiments.py │ ├── featureextraction.py │ ├── stimulation.py │ ├── stimulus_process.py │ └── triggers.py ├── src │ ├── bluebar_whiteback_1920_1080.png │ ├── greenbar_whiteback_1920_1080.png │ ├── stuckinaloop.jpg │ └── whiteback_1920_1080.png └── utils │ ├── DAQ_output.py │ ├── __init__.py │ ├── exp_setup.py │ └── gpio_control.py ├── misc ├── DLStream_Logo_small.png ├── StartAnalysis2.png ├── StartExperiment2.png ├── StartRecording2.png ├── StartStream2.png ├── StopAnalysis2.png ├── StopExperiment2.png ├── StopRecording2.png └── StopStream2.png ├── requirements.txt ├── settings.ini └── utils ├── __init__.py ├── advanced_settings.ini ├── analysis.py ├── configloader.py ├── generic.py ├── gui_image.py ├── plotter.py ├── poser.py ├── pylon.py ├── realsense.py └── webcam.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: JensBlack 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: JensBlack 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: monthly 7 | time: "04:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Pycharm idea files 10 | .idea/ -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | www.dlstream.io -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # DeepLabStream 4 | 5 | ![GraphAbstract](docs/GraphAbstract.png) 6 | 7 | [![GitHub stars](https://img.shields.io/github/stars/SchwarzNeuroconLab/DeepLabStream.svg?style=social&label=Star)](https://github.com/SchwarzNeuroconLab/DeepLabStream) 8 | [![GitHub forks](https://img.shields.io/github/forks/SchwarzNeuroconLab/DeepLabStream.svg?style=social&label=Fork)](https://github.com/SchwarzNeuroconLab/DeepLabStream) 9 | [![License: GPL v3](https://img.shields.io/badge/License-GPL%20v3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) 10 | [![Twitter Follow](https://img.shields.io/twitter/follow/SNeuroconnect.svg?label=SNeuroconnect&style=social)](https://twitter.com/SNeuroconnect) 11 | 12 | [DeepLabStream](https://www.nature.com/articles/s42003-021-01654-9) is a python based multi-purpose tool that enables the realtime tracking and manipulation of animals during ongoing experiments. 13 | Our toolbox was orginally adapted from the previously published [DeepLabCut](https://github.com/AlexEMG/DeepLabCut) ([Mathis et al., 2018](https://www.nature.com/articles/s41593-018-0209-y)) and expanded on its core capabilities, but is now able to utilize a variety of different network architectures for online pose estimation 14 | ([SLEAP](https://github.com/murthylab/sleap), [DLC-Live](https://github.com/DeepLabCut/DeepLabCut-live), [DeepPosekit's](https://github.com/jgraving/DeepPoseKit) StackedDenseNet, StackedHourGlass and [LEAP](https://github.com/murthylab/sleap)). 15 | 16 | DeepLabStreams core feature is the utilization of real-time tracking to orchestrate closed-loop experiments. This can be achieved using any type of camera-based video stream (incl. multiple streams). It enables running experimental protocols that are dependent on a constant stream of bodypart positions and feedback activation of several input/output devices. It's capabilities range from simple region of interest (ROI) based triggers to headdirection or behavior dependent stimulation, including online classification ([SiMBA](https://www.biorxiv.org/content/10.1101/2020.04.19.049452v2), [B-SOID](https://www.biorxiv.org/content/10.1101/770271v2)). 17 | 18 | ![DLS_Stim](docs/DLSSTim_example.gif) 19 | 20 | ## Read the news: 21 | 22 | - Watch an introductory talk about DLStream on [Youtube](https://www.youtube.com/watch?v=ZspLDZb_kMI) 23 | - Real-time behavioral analysis using artificial intelligence @ [Federal Ministry of Education and Research](http://www.research-in-germany.org/news/2021/2/2021-02-19_Real-time_behavioural_analysis_using_artificial_intelligence.html), [Uni Bonn](https://www.uni-bonn.de/news/real-time-behavioral-analysis-using-artificial-intelligence) and [Phys.org](https://phys.org/news/2021-02-real-time-behavioral-analysis-artificial-intelligence.html) 24 | - We are featured on [OpenBehavior](https://edspace.american.edu/openbehavior/project/dlstream/) 25 | - We are also featured on [open-neuroscience.com](https://open-neuroscience.com/post/deeplabstream/)! 26 | 27 | ## New features: 28 | 29 | #### 03/2021: Online Behavior Classification using SiMBA and B-SOID: 30 | 31 | - full integration of online classification of user-defined behavior using [SiMBA](https://github.com/sgoldenlab/simba) and [B-SOID](https://github.com/YttriLab/B-SOID). 32 | - SOCIAL CLASSIFICATION with SiMBA 14bp two animal classification (more to come!) 33 | - Unsupervised Classification with B-SOID 34 | - New wiki guide and example experiment to get started with online classification: [Advanced Behavior Classification](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Advanced-Behavior-Classification) 35 | - this version has new requirements (numba, pure, scikit-learn), so be sure to install them (e.g. `pip install -r requirements.txt`). 36 | 37 | #### 02/2021: Multiple Animal Experiments (Pre-release): Full [SLEAP](https://github.com/murthylab/sleap) integration (Full release coming soon!) 38 | 39 | - Updated [Installation](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Installation-&-Testing) (for SLEAP support) 40 | - Single Instance and Multiple Instance models (TopDown & BottomUp) integration 41 | - [New Multiple Animal Experiment Tutorial](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Multiple-Animal-Experiments) 42 | 43 | #### 01/2021: DLStream was published in [Communications Biology](https://www.nature.com/articles/s42003-021-01654-9) 44 | 45 | #### 12/2021: New pose estimation model integration 46 | - ([DLC-Live](https://github.com/DeepLabCut/DeepLabCut-live)) and pre-release of further integration ([DeepPosekit's](https://github.com/jgraving/DeepPoseKit) StackedDenseNet, StackedHourGlass and [LEAP](https://github.com/murthylab/sleap)) 47 | 48 | ## Quick Reference: 49 | 50 | #### Check out or wiki: [DLStream Wiki](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki) 51 | 52 | #### Read the paper: [Schweihoff, et al. 2021](https://www.nature.com/articles/s42003-021-01654-9) 53 | 54 | #### Contributing 55 | 56 | If you have feature requests or questions regarding the design of experiments join our [slack group](https://join.slack.com/t/dlstream/shared_invite/zt-pdh6twf5-xHbpLvJtrTx32t12w8RDVQ). 57 | 58 | We are constantly working to update and increase the capabilities of DLStream. 59 | We welcome all feedback and input from your side. 60 | 61 | 62 | ### 1. [Updated Installation & Testing](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Installation-&-Testing) 63 | 64 | ### 2. [How to use DLStream GUI](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/How-to-use-DLStream) 65 | 66 | ### 3. Check out our [Out-of-the-Box](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Out-Of-The-Box:-Overview) 67 | 68 | ### 4. [Design an Out-of-the-Box Experiment](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Out-Of-The-Box:-Design-Experiments) 69 | 70 | ### What's underneath?: 71 | 72 | ### 5. [Introduction to experiments](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Introduction) 73 | 74 | ### For advanced users: 75 | 76 | ### 6. [Design your first experiment](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/My-first-experiment) 77 | 78 | ### 7. [Adapting an existing experiment to your own needs](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Adapting-an-existing-experiment-to-your-own-needs) 79 | 80 | 81 | 82 | ### How to use DeepLabStream 83 | 84 | Just run 85 | ``` 86 | cd DeepLabStream 87 | python app.py 88 | ``` 89 | 90 | You will see the main control panel of a GUI app. 91 | 92 | ![Main](https://user-images.githubusercontent.com/44863941/91172971-59faf000-e6dd-11ea-8b68-3c36db0ff22f.png) 93 | 94 | To start working with DeepLabStream, press the `Start Stream` button. It will activate the camera manager and show you the current view from the connected cameras. 95 | 96 | ![Stream](https://user-images.githubusercontent.com/44863941/91173024-7008b080-e6dd-11ea-84b0-b05ac408d9a2.png) 97 | 98 | After that you can `Start Analysis` to start DeepLabCut and receive a pose estimations for each frame, or, additionally, you can `Start Recording` to record a 99 | video of the current feed (visible in the stream window). You will see your current video timestamp (counted in frames) and FPS after you pressed the `Start Analysis` button. 100 | 101 | ![Analysis](https://user-images.githubusercontent.com/44863941/91173049-7ac34580-e6dd-11ea-80b6-ad56cb9cf22c.png) 102 | 103 | As you can see, we track three points that represent three body parts of the mouse - nose, neck and tail root. 104 | Every single frame where the animal was tracked is outputted to the dataframe, which would create a .csv file after the analysis is finished. 105 | 106 | After you finish with tracking and/or recording the video, you can stop either function by specifically pressing on corresponding "stop" button 107 | (so, `Stop Analysis` or `Stop Recording`) or you can stop the app and refresh all the timing at once, by pressing `Stop Streaming` button. 108 | 109 | #### Experiments 110 | 111 | DeepLabStream was build specifically for closed-loop experiments, so with a properly implemented experiment protocol, running experiments on this system is as easy as 112 | pressing the `Start Experiment` button. Depending on your protocol and experimental goals, experiments could run and finish without any further engagement from the user. 113 | 114 | ![Start](https://user-images.githubusercontent.com/44863941/91173075-857dda80-e6dd-11ea-90a4-1e768cab41ad.png) 115 | 116 | In the provided `ExampleExperiment` two regions of interest (ROIs) are created inside an arena. The experiment is designed to count the number of times the mouse enters a ROI and trigger a corresponding visual stimulus on a screen. 117 | The high contrast stimuli (image files) are located within the `experiments/src` folder and specified within the `experiments.py` `ExampleExperiments` Class. 118 | 119 | ![Experiment](https://user-images.githubusercontent.com/44863941/91173098-90d10600-e6dd-11ea-94be-63e99f88df0a.png) 120 | 121 | As a visual representation of this event, the border of the ROI will turn green. 122 | 123 | All experimental output will be stored to a .csv file for easy postprocessing. Check out [Working with DLStream output](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Working-with-DLStream-output) for further details. 124 | 125 | Look at the [Introduction to experiments](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Introduction) to get an idea how to design your own experiment in DeepLabStream or learn how to adapt one of the already published experiments at [Adapting an existing experiment](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Adapting-an-existing-experiment-to-your-own-needs). 126 | 127 | ## How does this work 128 | 129 | DeepLabStream uses the camera's video stream to simultaneously record a raw (read as unmodified) video of the ongoing experiment, 130 | send frames one-by-one to the neuronal network for analysis, and use returned analysed data to plot and show a video stream for the experimenter to observe and control the experiment. 131 | Analysed data will also be utilized to enable closed-loop experiments without any human interference, using triggers to operate equipment on predefined conditions 132 | and to end, prolong or modify parts of experimental protocol. 133 | 134 | ![Flowchart](docs/flowchart2.png) 135 | 136 | ### Known issues 137 | 138 | If you encounter any issues or errors, you can check out the wiki article ([Help there is an error!](https://github.com/SchwarzNeuroconLab/DeepLabStream/wiki/Help-there-is-an-error!)). If your issue is not listed yet, please refer to the issues and either submit a new issue or find a reported issue (which might be already solved) there. Thank you! 139 | 140 | ## References: 141 | 142 | If you use this code or data please cite: 143 | 144 | Schweihoff, J.F., Loshakov, M., Pavlova, I. et al. DeepLabStream enables closed-loop behavioral experiments using deep learning-based markerless, real-time posture detection. 145 | 146 | Commun Biol 4, 130 (2021). https://doi.org/10.1038/s42003-021-01654-9 147 | 148 | ## License 149 | This project is licensed under the GNU General Public License v3.0. Note that the software is provided "as is", without warranty of any kind, expressed or implied. 150 | 151 | ## Authors 152 | 153 | Developed by: 154 | - Jens Schweihoff, jens.schweihoff@ukbonn.de 155 | 156 | - Matvey Loshakov, matveyloshakov@gmail.com 157 | 158 | Corresponding Author: Martin Schwarz, Martin.Schwarz@ukbonn.de 159 | 160 | ## Other References 161 | 162 | If you are using any of the following open-source code please cite them accordingly: 163 | 164 | > Simple Behavioral Analysis (SimBA) – an open source toolkit for computer classification of complex social behaviors in experimental animals; 165 | Simon RO Nilsson, Nastacia L. Goodwin, Jia Jie Choong, Sophia Hwang, Hayden R Wright, Zane C Norville, Xiaoyu Tong, Dayu Lin, Brandon S. Bentzley, Neir Eshel, Ryan J McLaughlin, Sam A. Golden 166 | bioRxiv 2020.04.19.049452; doi: https://doi.org/10.1101/2020.04.19.049452 167 | 168 | > B-SOiD: An Open Source Unsupervised Algorithm for Discovery of Spontaneous Behaviors; 169 | Alexander I. Hsu, Eric A. Yttri 170 | bioRxiv 770271; doi: https://doi.org/10.1101/770271 171 | 172 | > SLEAP: Multi-animal pose tracking; 173 | Talmo D. Pereira, Nathaniel Tabris, Junyu Li, Shruthi Ravindranath, Eleni S. Papadoyannis, Z. Yan Wang, David M. Turner, Grace McKenzie-Smith, Sarah D. Kocher, Annegret L. Falkner, Joshua W. Shaevitz, Mala Murthy 174 | bioRxiv 2020.08.31.276246; doi: https://doi.org/10.1101/2020.08.31.276246 175 | 176 | >Real-time, low-latency closed-loop feedback using markerless posture tracking; 177 | Gary A Kane, Gonçalo Lopes, Jonny L Saunders, Alexander Mathis, Mackenzie W Mathis; 178 | eLife 2020;9:e61909 doi: 10.7554/eLife.61909 179 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import sys 10 | import os 11 | import cv2 12 | 13 | from DeepLabStream import DeepLabStream, show_stream 14 | from utils.generic import MissingFrameError 15 | from utils.configloader import MULTI_CAM, STREAMS, RECORD_EXP 16 | from utils.gui_image import QFrame, ImageWindow, emit_qframes 17 | 18 | from PySide2.QtCore import QThread 19 | from PySide2.QtWidgets import QPushButton, QApplication, QWidget, QGridLayout 20 | from PySide2.QtGui import QIcon 21 | 22 | # creating a complete thread process to work in the background 23 | class AThread(QThread): 24 | """ 25 | QThread is just one of the many PyQt ways to do multitasking 26 | This is, for most intents and purposes, identical to Python multithreading 27 | """ 28 | 29 | def start(self, **kwargs): 30 | """ 31 | Setting thread to active, creating a QFrames dictionary 32 | Then just starting parent method 33 | """ 34 | self.threadactive = True 35 | self.qframes = {} 36 | # changePixmap = pyqtSignal(QImage) 37 | for camera in stream_manager.enabled_cameras: 38 | self.qframes[camera] = QFrame() 39 | super().start(**kwargs) 40 | 41 | def run(self): 42 | """ 43 | Infinite loop with all the streaming, analysis and recording logic 44 | """ 45 | while self.threadactive: 46 | try: 47 | all_frames = stream_manager.get_frames() 48 | except MissingFrameError as e: 49 | """catch missing frame, stop Thread and save what can be saved""" 50 | print(*e.args, "\nShutting down DLStream and saving data...") 51 | stream_manager.finish_streaming() 52 | stream_manager.stop_cameras() 53 | self.stop() 54 | break 55 | 56 | color_frames, depth_maps, infra_frames = all_frames 57 | 58 | # writing the video 59 | if stream_manager.recording_status(): 60 | stream_manager.write_video(color_frames, stream_manager.frame_index) 61 | 62 | if stream_manager.dlc_status(): 63 | # outputting the frames 64 | res_frames, res_time = stream_manager.get_analysed_frames() 65 | # inputting the frames 66 | stream_manager.input_frames_for_analysis( 67 | all_frames, stream_manager.frame_index 68 | ) 69 | # streaming the stream 70 | if res_frames: 71 | self._stream_frames(res_frames) 72 | else: 73 | self._stream_frames(color_frames) 74 | 75 | stream_manager.frame_index += 1 76 | 77 | def stop(self): 78 | """ 79 | Setting thread to active, thus stopping the infinite loop 80 | """ 81 | self.threadactive = False 82 | 83 | def _stream_frames(self, frames): 84 | """ 85 | Shows some number of stream frames, depending on cameras quantity 86 | Method of streaming depends on platform 87 | Windows -> through openCV with their window objects 88 | Unix -> thought PyQt with some widget window 89 | :param frames: dictionary of frames in format of {camera:frame} 90 | """ 91 | if os.name == "nt": 92 | show_stream(frames) 93 | # very important line for openCV to work correctly 94 | # actually does nothing, but do NOT delete 95 | cv2.waitKey(1) 96 | else: 97 | emit_qframes(frames, self.qframes) 98 | 99 | 100 | class ButtonWindow(QWidget): 101 | def __init__(self): 102 | super().__init__() 103 | # setting the icon for window 104 | self.setWindowIcon(QIcon("misc/DLStream_Logo_small.png")) 105 | self.setWindowTitle("DeepLabStream") 106 | self.title = "ButtonWindow" 107 | # next is the complete buttons dictionary with buttons, icons, functions and layouts 108 | self._buttons_dict = { 109 | "Start_Stream": { 110 | "Button": QPushButton("Start Stream"), 111 | "Icon": QIcon("misc/StartStream2.png"), 112 | "Function": self.start_stream, 113 | "Layout": (0, 0, 2, 2), 114 | "State": True, 115 | }, 116 | "Start_Analysis": { 117 | "Button": QPushButton("Start Analysis"), 118 | "Icon": QIcon("misc/StartAnalysis2.png"), 119 | "Function": self.start_analysis, 120 | "Layout": (2, 0, 2, 1), 121 | "State": False, 122 | }, 123 | "Start_Experiment": { 124 | "Button": QPushButton("Start Experiment"), 125 | "Icon": QIcon("misc/StartExperiment2.png"), 126 | "Function": self.start_experiment, 127 | "Layout": (4, 0, 2, 1), 128 | "State": False, 129 | }, 130 | "Start_Recording": { 131 | "Button": QPushButton("Start Recording"), 132 | "Icon": QIcon("misc/StartRecording2.png"), 133 | "Function": self.start_recording, 134 | "Layout": (6, 0, 2, 1), 135 | "State": False, 136 | }, 137 | "Stop_Stream": { 138 | "Button": QPushButton("Stop Stream"), 139 | "Icon": QIcon("misc/StopStream2.png"), 140 | "Function": self.stop_stream, 141 | "Layout": (8, 0, 2, 2), 142 | "State": False, 143 | }, 144 | "Stop_Analysis": { 145 | "Button": QPushButton("Stop Analysis"), 146 | "Icon": QIcon("misc/StopAnalysis2.png"), 147 | "Function": self.stop_analysis, 148 | "Layout": (2, 1, 2, 1), 149 | "State": False, 150 | }, 151 | "Stop_Experiment": { 152 | "Button": QPushButton("Stop Experiment"), 153 | "Icon": QIcon("misc/StopExperiment2.png"), 154 | "Function": self.stop_experiment, 155 | "Layout": (4, 1, 2, 1), 156 | "State": False, 157 | }, 158 | "Stop_Recording": { 159 | "Button": QPushButton("Stop Recording"), 160 | "Icon": QIcon("misc/StopRecording2.png"), 161 | "Function": self.stop_recording, 162 | "Layout": (6, 1, 2, 1), 163 | "State": False, 164 | }, 165 | } 166 | 167 | # creating button layout with icons and functionality 168 | self.initialize_buttons() 169 | self._thread = None 170 | self.image_windows = {} 171 | 172 | def start(self): 173 | self._thread.start() 174 | 175 | def stop(self): 176 | self._thread.stop() 177 | 178 | def initialize_buttons(self): 179 | """ 180 | Function to make button window great again 181 | Sets all buttons with an icon, function and position 182 | """ 183 | layout = QGridLayout() 184 | for func in self._buttons_dict: 185 | # setting icon 186 | self._buttons_dict[func]["Button"].setIcon(self._buttons_dict[func]["Icon"]) 187 | # setting function 188 | self._buttons_dict[func]["Button"].clicked.connect( 189 | self._buttons_dict[func]["Function"] 190 | ) 191 | # setting position 192 | layout.addWidget( 193 | self._buttons_dict[func]["Button"], *self._buttons_dict[func]["Layout"] 194 | ) 195 | # setting default state 196 | self._buttons_dict[func]["Button"].setEnabled( 197 | self._buttons_dict[func]["State"] 198 | ) 199 | # setting button size 200 | self._buttons_dict[func]["Button"].setMinimumHeight(100) 201 | # setting window layout for all buttons 202 | self.setLayout(layout) 203 | 204 | def buttons_toggle(self, *buttons): 205 | for button in buttons: 206 | self._buttons_dict[button]["Button"].setEnabled( 207 | not self._buttons_dict[button]["Button"].isEnabled() 208 | ) 209 | 210 | """ Button functions""" 211 | 212 | def start_stream(self): 213 | # initializing the stream manager cameras 214 | stream_manager.start_cameras(STREAMS, MULTI_CAM) 215 | 216 | # initializing background thread 217 | self._thread = AThread(self) 218 | self._thread.start() 219 | print("Streaming started") 220 | 221 | # flipping the state of the buttons 222 | self.buttons_toggle( 223 | "Start_Analysis", "Start_Recording", "Start_Stream", "Stop_Stream" 224 | ) 225 | 226 | # initializing image windows for Unix systems via PyQt 227 | if os.name != "nt": 228 | for camera in stream_manager.enabled_cameras: 229 | self.image_windows[camera] = ImageWindow(camera) 230 | self._thread.qframes[camera].signal.connect( 231 | self.image_windows[camera].set_image 232 | ) 233 | self.image_windows[camera].show() 234 | else: 235 | # for Windows it is taken care by openCV 236 | pass 237 | 238 | def stop_stream(self): 239 | # stopping background thread 240 | self._thread.stop() 241 | 242 | # flipping the state of the buttons 243 | for func in self._buttons_dict: 244 | self._buttons_dict[func]["Button"].setEnabled( 245 | self._buttons_dict[func]["State"] 246 | ) 247 | 248 | if os.name != "nt": 249 | for camera in self.image_windows: 250 | self.image_windows[camera].hide() 251 | else: 252 | pass 253 | 254 | print("Streaming stopped") 255 | stream_manager.finish_streaming() 256 | stream_manager.stop_cameras() 257 | 258 | def start_analysis(self): 259 | print("Analysis starting") 260 | self.buttons_toggle("Stop_Analysis", "Start_Analysis", "Start_Experiment") 261 | stream_manager.set_up_multiprocessing() 262 | stream_manager.start_dlc() 263 | stream_manager.create_output() 264 | 265 | def stop_analysis(self): 266 | print("Analysis stopped") 267 | self.buttons_toggle("Stop_Analysis", "Start_Analysis", "Start_Experiment") 268 | stream_manager.stop_dlc() 269 | 270 | def start_experiment(self): 271 | print("Experiment started") 272 | self.buttons_toggle("Stop_Experiment", "Start_Experiment") 273 | stream_manager.set_up_experiment() 274 | stream_manager.start_experiment() 275 | if RECORD_EXP: 276 | self.start_recording() 277 | 278 | def stop_experiment(self): 279 | print("Experiment stopped") 280 | self.buttons_toggle("Stop_Experiment", "Start_Experiment") 281 | stream_manager.stop_experiment() 282 | if RECORD_EXP: 283 | self.stop_recording() 284 | 285 | def start_recording(self): 286 | print("Recording started") 287 | self.buttons_toggle("Stop_Recording", "Start_Recording") 288 | stream_manager.start_recording() 289 | 290 | def stop_recording(self): 291 | print("Recording stopped") 292 | self.buttons_toggle("Stop_Recording", "Start_Recording") 293 | stream_manager.stop_recording() 294 | 295 | 296 | if __name__ == "__main__": 297 | stream_manager = DeepLabStream() 298 | app = QApplication([]) 299 | bt = ButtonWindow() 300 | bt.show() 301 | sys.exit(app.exec_()) 302 | -------------------------------------------------------------------------------- /convert_classifier.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import joblib 3 | import os 4 | from pure_sklearn.map import convert_estimator 5 | 6 | 7 | def load_classifier_SIMBA(path_to_sav): 8 | """Load saved classifier""" 9 | file = open(path_to_sav, "rb") 10 | classifier = pickle.load(file) 11 | file.close() 12 | return classifier 13 | 14 | def load_classifier_BSOID(path_to_sav): 15 | """Load saved classifier""" 16 | file = open(path_to_sav, "rb") 17 | clf = joblib.load(file) 18 | file.close() 19 | return clf 20 | 21 | 22 | def convert_classifier(path, origin: str): 23 | # convert to pure python estimator 24 | dir_path = os.path.dirname(path) 25 | filename = os.path.basename(path) 26 | filename, _ = filename.split(".") 27 | 28 | print("Loading classifier...") 29 | if origin.lower() == 'simba': 30 | clf = load_classifier_SIMBA(path) 31 | clf_pure_predict = convert_estimator(clf) 32 | with open(dir_path + "/" + filename + "_pure.sav", "wb") as f: 33 | pickle.dump(clf_pure_predict, f) 34 | 35 | elif origin.lower() == 'bsoid': 36 | clf_pack = load_classifier_BSOID(path) 37 | # bsoid exported classfier has format [a, b, c, clf, d, e] 38 | clf_pure_predict = convert_estimator(clf_pack[3]) 39 | clf_pack[3] =clf_pure_predict 40 | with open(dir_path + "/" + filename + "_pure.sav", "wb") as f: 41 | joblib.dump(clf_pack, f) 42 | else: 43 | raise ValueError(f'{origin} is not a valid classifier origin.') 44 | 45 | print(f"Converted Classifier {filename}") 46 | 47 | 48 | if __name__ == "__main__": 49 | 50 | """Converted BSOID Classifiers are not integrated yet, although you can already convert them here""" 51 | path_to_classifier = "PATH_TO_CLASSIFIER" 52 | convert_classifier(path_to_classifier, origin= 'SIMBA') 53 | -------------------------------------------------------------------------------- /design_experiment.py: -------------------------------------------------------------------------------- 1 | from experiments.utils.exp_setup import DlStreamConfigWriter 2 | import click 3 | 4 | 5 | @click.command() 6 | @click.option("--default", "default", is_flag=True) 7 | def design_experiment(default): 8 | config = DlStreamConfigWriter() 9 | input_dict = dict(EXPERIMENT=None, TRIGGER=None, PROCESS=None, STIMULATION=None) 10 | 11 | def get_input(input_name): 12 | click.echo( 13 | f"Choosing {input_name}... \n Available {input_name}s are: " 14 | + ", ".join(config.get_available_module_names(input_name)) 15 | ) 16 | input_value = click.prompt(f"Enter a base {input_name} module", type=str) 17 | while not config.check_if_default_exists(input_value, input_name): 18 | click.echo(f"{input_name} {input_value} does not exists.") 19 | input_value = click.prompt(f"Enter a base {input_name} module", type=str) 20 | return input_value 21 | 22 | """Experiment""" 23 | 24 | input_value = get_input("EXPERIMENT") 25 | input_dict["EXPERIMENT"] = input_value 26 | if input_value == "BaseOptogeneticExperiment": 27 | input_dict["STIMULATION"] = "BaseStimulation" 28 | input_dict["PROCESS"] = "BaseProtocolProcess" 29 | 30 | elif input_value == "BaseTrialExperiment": 31 | click.echo( 32 | f"Available Triggers are: " 33 | + ", ".join(config.get_available_module_names("TRIGGER")) 34 | ) 35 | click.echo( 36 | "Note, that you cannot select the same Trigger as selected in TRIGGER." 37 | ) 38 | 39 | input_value = click.prompt( 40 | f"Enter TRIAL_TRIGGER for BaseTrialExperiment", type=str 41 | ) 42 | while not config.check_if_default_exists(input_value, "TRIGGER"): 43 | click.echo(f"TRIGGER {input_value} does not exists.") 44 | input_value = click.prompt(f"Enter a base TRIGGER module", type=str) 45 | 46 | input_dict["TRIAL_TRIGGER"] = input_value 47 | click.echo(f"TRIAL_TRIGGER for BaseTrialExperiment set to {input_value}.") 48 | 49 | """TRIGGER""" 50 | 51 | input_value = get_input("TRIGGER") 52 | input_dict["TRIGGER"] = input_value 53 | 54 | """Process""" 55 | 56 | if input_dict["PROCESS"] is None: 57 | input_value = get_input("PROCESS") 58 | input_dict["PROCESS"] = input_value 59 | 60 | """STIMULATION""" 61 | if input_dict["STIMULATION"] is None: 62 | input_value = get_input("STIMULATION") 63 | input_dict["STIMULATION"] = input_value 64 | 65 | """Setting Process Type""" 66 | 67 | if input_dict["EXPERIMENT"] == "BaseTrialExperiment": 68 | input_dict["PROCESS_TYPE"] = "trial" 69 | elif input_dict["STIMULATION"] == "BaseStimulation": 70 | input_dict["PROCESS_TYPE"] = "switch" 71 | elif ( 72 | input_dict["STIMULATION"] == "ScreenStimulation" 73 | or input_dict["STIMULATION"] == "RewardDispenser" 74 | ): 75 | input_dict["PROCESS_TYPE"] = "supply" 76 | 77 | if input_dict["EXPERIMENT"] == "BaseTrialExperiment": 78 | config.import_default( 79 | experiment_name=input_dict["EXPERIMENT"], 80 | trigger_name=input_dict["TRIGGER"], 81 | process_name=input_dict["PROCESS"], 82 | stimulation_name=input_dict["STIMULATION"], 83 | trial_trigger_name=input_dict["TRIAL_TRIGGER"], 84 | ) 85 | 86 | else: 87 | config.import_default( 88 | experiment_name=input_dict["EXPERIMENT"], 89 | trigger_name=input_dict["TRIGGER"], 90 | process_name=input_dict["PROCESS"], 91 | stimulation_name=input_dict["STIMULATION"], 92 | ) 93 | 94 | if "PROCESS_TYPE" in input_dict.keys(): 95 | config._change_parameter( 96 | module_name=input_dict["PROCESS"], 97 | parameter_name="TYPE", 98 | parameter_value=input_dict["PROCESS_TYPE"], 99 | ) 100 | 101 | if click.confirm( 102 | "Do you want to set parameters as well (Not recommended)? \n Note, that you can change them in the created file later." 103 | ): 104 | current_config = config.get_current_config() 105 | ignore_list = ["EXPERIMENT", "BaseProtocolProcess"] 106 | inner_ignore_list = [ 107 | "EXPERIMENTER", 108 | "PROCESS", 109 | "STIMULATION", 110 | "TRIGGER", 111 | "DEBUG", 112 | ] 113 | try: 114 | for module in current_config.keys(): 115 | parameter_dict = config.get_parameters(module) 116 | if module not in ignore_list: 117 | for input_key in parameter_dict.keys(): 118 | if input_key not in inner_ignore_list: 119 | click.echo( 120 | f"Default {input_key} is: " 121 | + str(parameter_dict[input_key]) 122 | ) 123 | input_value = click.prompt(f"Enter new value: ", type=str) 124 | config._change_parameter( 125 | module_name=module, 126 | parameter_name=input_key, 127 | parameter_value=input_value, 128 | ) 129 | except: 130 | click.echo( 131 | "Failed to set individual parameters. Please change them later in the config file..." 132 | ) 133 | else: 134 | click.echo( 135 | "Skipping parameters. Experiment config will be created with default values..." 136 | ) 137 | """Finish up""" 138 | # Name of experimentor 139 | experimenter = click.prompt("Enter an experimenter name", type=str) 140 | click.echo(f"Experimenter set to {experimenter}.") 141 | config.set_experimenter(experimenter) 142 | 143 | click.echo( 144 | "Current modules are:\n BaseExperiment: {}\n Trigger: {}\n Process: {} \n Stimulation: {}".format( 145 | input_dict["EXPERIMENT"], 146 | input_dict["TRIGGER"], 147 | input_dict["PROCESS"], 148 | input_dict["STIMULATION"], 149 | ) 150 | ) 151 | 152 | if click.confirm("Do you want to continue?"): 153 | config.write_ini() 154 | click.echo("Config was created. It can be found in experiments/configs") 155 | 156 | 157 | if __name__ == "__main__": 158 | design_experiment() 159 | -------------------------------------------------------------------------------- /docs/DLSSTim_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/docs/DLSSTim_example.gif -------------------------------------------------------------------------------- /docs/GraphAbstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/docs/GraphAbstract.png -------------------------------------------------------------------------------- /docs/design_experiment_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/docs/design_experiment_gif.gif -------------------------------------------------------------------------------- /docs/flowchart2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/docs/flowchart2.png -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | -------------------------------------------------------------------------------- /experiments/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/base/__init__.py -------------------------------------------------------------------------------- /experiments/base/experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import time 10 | 11 | from experiments.base.stimulation import BaseStimulation 12 | from experiments.base.stimulus_process import Timer 13 | from experiments.utils.exp_setup import ( 14 | get_experiment_settings, 15 | setup_trigger, 16 | setup_process, 17 | ) 18 | from utils.plotter import plot_triggers_response 19 | import random 20 | 21 | 22 | class BaseExperiment: 23 | """ 24 | Base class for standard experiments""" 25 | 26 | def __init__(self): 27 | self._name = "BaseExperiment" 28 | self._settings_dict = {} 29 | self.experiment_finished = False 30 | self._process = None 31 | self._event = None 32 | self._current_trial = None 33 | self._exp_timer = Timer(10) 34 | 35 | def check_skeleton(self, frame, skeleton): 36 | """ 37 | Checking each passed animal skeleton for a pre-defined set of conditions 38 | Outputting the visual representation, if exist 39 | Advancing trials according to inherent logic of an experiment 40 | :param frame: frame, on which animal skeleton was found 41 | :param skeleton: skeleton, consisting of multiple joints of an animal 42 | """ 43 | pass 44 | 45 | def check_exp_timer(self): 46 | """ 47 | Checking the experiment timer 48 | """ 49 | if not self._exp_timer.check_timer(): 50 | print("Experiment is finished") 51 | print("Time ran out.") 52 | self.stop_experiment() 53 | 54 | def start_experiment(self): 55 | """ 56 | Start the experiment 57 | """ 58 | if self._process is not None: 59 | self._process.start() 60 | if not self.experiment_finished: 61 | self._exp_timer.start() 62 | 63 | def stop_experiment(self): 64 | """ 65 | Stop the experiment and reset the timer 66 | """ 67 | self.experiment_finished = True 68 | print("Experiment completed!") 69 | self._exp_timer.reset() 70 | # don't forget to end the process! 71 | if self._process is not None: 72 | self._process.end() 73 | 74 | def get_settings(self): 75 | 76 | return self._settings_dict 77 | 78 | def get_name(self): 79 | 80 | return self._name 81 | 82 | 83 | class BaseTrialExperiment(BaseExperiment): 84 | def __init__(self): 85 | super().__init__() 86 | self._name = "BaseTrialExperiment" 87 | self.experiment_finished = False 88 | self._event = None 89 | self._print_check = False 90 | self._current_trial = None 91 | self._result_list = [] 92 | self._success_count = 0 93 | 94 | self._parameter_dict = dict( 95 | TRIGGER="str", 96 | PROCESS="str", 97 | INTERTRIAL_TIME="int", 98 | TRIAL_NAME="str", 99 | TRIAL_TRIGGER="str", 100 | TRIAL_TIME="int", 101 | STIMULUS_TIME="int", 102 | RESULT_FUNC="str", 103 | EXP_LENGTH="int", 104 | EXP_COMPLETION="int", 105 | EXP_TIME="int", 106 | ) 107 | 108 | self._settings_dict = get_experiment_settings(self._name, self._parameter_dict) 109 | self._process = setup_process(self._settings_dict["PROCESS"]) 110 | self._init_trigger = setup_trigger(self._settings_dict["TRIGGER"]) 111 | self._trials_list = self.generate_trials_list( 112 | self._trials, self._settings_dict["EXP_LENGTH"] 113 | ) 114 | self._trial_timer = Timer(self._settings_dict["TRIAL_TIME"]) 115 | self._exp_timer = Timer(self._settings_dict["EXP_TIME"]) 116 | self._intertrial_timer = Timer(self._settings_dict["INTERTRIAL_TIME"]) 117 | 118 | def check_skeleton(self, frame, skeleton): 119 | status, trial = self._process.get_status() 120 | if status: 121 | current_trial = self._trials[trial] 122 | condition, response = current_trial["trigger"].check_skeleton(skeleton) 123 | self._process.put(condition) 124 | result = self._process.get_result() 125 | if result is not None: 126 | self.process_result(result, trial) 127 | self._current_trial = None 128 | # check if all trials were successful until completion 129 | if self._success_count >= self._settings_dict["EXP_COMPLETION"]: 130 | print("Experiment is finished") 131 | print("Trial reached required amount of successes") 132 | self.stop_experiment() 133 | 134 | # if not continue 135 | print(" Going into Intertrial time.") 136 | self._intertrial_timer.reset() 137 | self._intertrial_timer.start() 138 | result = None 139 | plot_triggers_response(frame, response) 140 | 141 | elif not self._intertrial_timer.check_timer(): 142 | if self._current_trial is None: 143 | self._current_trial = next(self._trials_list, False) 144 | elif not self._current_trial: 145 | print("Experiment is finished due to max. trial repetition.") 146 | print(self._result_list) 147 | self.stop_experiment() 148 | else: 149 | init_result, response_body = self._init_trigger.check_skeleton(skeleton) 150 | if init_result: 151 | # check trial start triggers 152 | self._process.put_trial( 153 | self._trials[self._current_trial], self._current_trial 154 | ) 155 | self._print_check = False 156 | elif not self._print_check: 157 | print( 158 | "Next trial: #" 159 | + str(len(self._result_list) + 1) 160 | + " " 161 | + self._current_trial 162 | ) 163 | print( 164 | "Animal is not meeting trial start criteria, the start of trial is delayed." 165 | ) 166 | self._print_check = True 167 | # if experimental time ran out, finish experiments 168 | super().check_exp_timer() 169 | 170 | def process_result(self, result, trial): 171 | """ 172 | Will add result if TRUE or reset comp_counter if FALSE 173 | :param result: bool if trial was successful 174 | :param trial: str name of the trial 175 | :return: 176 | """ 177 | self._result_list.append((trial, result)) 178 | if result is True: 179 | self._success_count += 1 180 | print("Trial successful!") 181 | else: 182 | print("Trial failed.") 183 | # 184 | 185 | @staticmethod 186 | def generate_trials_list(trials: dict, length: int): 187 | trials_list = [] 188 | for trial in range(length): 189 | trials_list.append(random.choice(list(trials.keys()))) 190 | return iter(trials_list) 191 | 192 | @property 193 | def _trials(self): 194 | 195 | trigger = setup_trigger(self._settings_dict["TRIAL_TRIGGER"]) 196 | if self._settings_dict["RESULT_FUNC"] == "all": 197 | result_func = all 198 | elif self._settings_dict["RESULT_FUNC"] == "any": 199 | result_func = any 200 | else: 201 | raise ValueError( 202 | f'Result function can only be "all" or "any", not {self._settings_dict["RESULT_FUNC"]}.' 203 | ) 204 | trials = { 205 | self._settings_dict["TRIAL_NAME"]: dict( 206 | stimulus_timer=Timer(self._settings_dict["STIMULUS_TIME"]), 207 | success_timer=Timer(self._settings_dict["TRIAL_TIME"]), 208 | trigger=trigger, 209 | result_func=result_func, 210 | ) 211 | } 212 | 213 | return trials 214 | 215 | 216 | class BaseConditionalExperiment(BaseExperiment): 217 | """ 218 | Simple class to contain all of the experiment properties 219 | Uses multiprocess to ensure the best possible performance and 220 | to showcase that it is possible to work with any type of equipment, even timer-dependent 221 | """ 222 | 223 | def __init__(self): 224 | super().__init__() 225 | self._name = "BaseConditionalExperiment" 226 | self._parameter_dict = dict( 227 | TRIGGER="str", 228 | PROCESS="str", 229 | INTERTRIAL_TIME="int", 230 | EXP_LENGTH="int", 231 | EXP_TIME="int", 232 | ) 233 | self._settings_dict = get_experiment_settings(self._name, self._parameter_dict) 234 | self.experiment_finished = False 235 | self._process = setup_process(self._settings_dict["PROCESS"]) 236 | self._event = None 237 | self._event_count = 0 238 | self._current_trial = None 239 | 240 | self._exp_timer = Timer(self._settings_dict["EXP_TIME"]) 241 | self._intertrial_timer = Timer(self._settings_dict["INTERTRIAL_TIME"]) 242 | 243 | self._trigger = setup_trigger(self._settings_dict["TRIGGER"]) 244 | 245 | def check_skeleton(self, frame, skeleton): 246 | """ 247 | Checking each passed animal skeleton for a pre-defined set of conditions 248 | Outputting the visual representation, if exist 249 | Advancing trials according to inherent logic of an experiment 250 | :param frame: frame, on which animal skeleton was found 251 | :param skeleton: skeleton, consisting of multiple joints of an animal 252 | """ 253 | self.check_exp_timer() # checking if experiment is still on 254 | 255 | if self._event_count >= self._settings_dict["EXP_LENGTH"]: 256 | self.stop_experiment() 257 | 258 | elif not self.experiment_finished: 259 | if not self._intertrial_timer.check_timer(): 260 | # check if condition is met 261 | result, response = self._trigger.check_skeleton(skeleton=skeleton) 262 | #set event to result to export to dataframe 263 | self._event = result 264 | if result: 265 | self._event_count += 1 266 | print(f"Stimulation #{self._event_count}") 267 | self._intertrial_timer.reset() 268 | self._intertrial_timer.start() 269 | 270 | plot_triggers_response(frame, response) 271 | self._process.put(result) 272 | 273 | def check_exp_timer(self): 274 | """ 275 | Checking the experiment timer 276 | """ 277 | if not self._exp_timer.check_timer(): 278 | print("Time ran out.") 279 | self.stop_experiment() 280 | 281 | def start_experiment(self): 282 | """ 283 | Start the experiment 284 | """ 285 | self._process.start() 286 | if not self.experiment_finished: 287 | self._exp_timer.start() 288 | 289 | def stop_experiment(self): 290 | """ 291 | Stop the experiment and reset the timer 292 | """ 293 | self.experiment_finished = True 294 | print("Experiment completed!") 295 | self._exp_timer.reset() 296 | # don't forget to end the process! 297 | self._process.end() 298 | 299 | def get_trial(self): 300 | """ 301 | Check if event is going on right now 302 | return: bool 303 | """ 304 | return self._event 305 | 306 | 307 | """Standardexperiments that can be setup by using the experiment config""" 308 | 309 | 310 | class BaseOptogeneticExperiment(BaseExperiment): 311 | """Standard implementation of an optogenetic experiment""" 312 | 313 | def __init__(self): 314 | super().__init__() 315 | self.experiment_finished = False 316 | self._name = "BaseOptogeneticExperiment" 317 | 318 | # loading settings 319 | self._exp_parameter_dict = dict( 320 | TRIGGER="str", 321 | INTERTRIAL_TIME="int", 322 | MAX_STIM_TIME="int", 323 | MIN_STIM_TIME="int", 324 | MAX_TOTAL_STIM_TIME="int", 325 | EXP_TIME="int", 326 | PROCESS="str", 327 | ) 328 | self._settings_dict = get_experiment_settings( 329 | self._name, self._exp_parameter_dict 330 | ) 331 | self._process = setup_process(self._settings_dict["PROCESS"]) 332 | self._intertrial_timer = Timer(self._settings_dict["INTERTRIAL_TIME"]) 333 | self._exp_timer = Timer(self._settings_dict["EXP_TIME"]) 334 | self._event = False 335 | self._event_start = None 336 | 337 | # setting limits 338 | self._max_trial_time = self._settings_dict["MAX_STIM_TIME"] 339 | self._min_trial_time = self._settings_dict["MIN_STIM_TIME"] 340 | self._max_total_time = ( 341 | self._settings_dict["MAX_TOTAL_STIM_TIME"] 342 | if self._settings_dict["MAX_TOTAL_STIM_TIME"] is not None 343 | else self._settings_dict["EXP_TIME"] + 1 344 | ) 345 | 346 | # keeping count 347 | self._results = [] 348 | self._total_time = 0 349 | self._trial_time = 0 350 | # trigger 351 | self._trigger = setup_trigger(self._settings_dict["TRIGGER"]) 352 | 353 | def check_skeleton(self, frame, skeleton): 354 | 355 | if self._exp_timer.check_timer(): 356 | if self._total_time >= self._max_total_time: 357 | # check if total time to stimulate per experiment is reached 358 | print("Ending experiment, total event time ran out") 359 | self.stop_experiment() 360 | else: 361 | # if not continue 362 | if not self._intertrial_timer.check_timer(): 363 | # check if there is an intertrial time running right now, if not continue 364 | # check if the trigger is true 365 | result, _ = self._trigger.check_skeleton(skeleton) 366 | if result: 367 | if not self._event: 368 | # if a stimulation event wasn't started already, start one 369 | print("Starting Stimulation") 370 | self._event = True 371 | # and activate the laser, start the timer and reset the intertrial timer 372 | self._event_start = time.time() 373 | self._intertrial_timer.reset() 374 | else: 375 | if time.time() - self._event_start <= self._max_trial_time: 376 | # if the total event time has not reached the maximum time per event 377 | pass 378 | else: 379 | # if the maximum event time was reached, reset the event, 380 | # turn off the laser and start inter-trial time 381 | print("Ending Stimulation, Stimulation time ran out") 382 | self._event = False 383 | trial_time = time.time() - self._event_start 384 | self._total_time += trial_time 385 | self._results.append(trial_time) 386 | print("Stimulation duration", trial_time) 387 | self._intertrial_timer.start() 388 | else: 389 | # if the trigger is false 390 | if self._event: 391 | # but the stimulation is still going 392 | if time.time() - self._event_start < self._min_trial_time: 393 | # check if the minimum event time was not reached, then pass 394 | pass 395 | else: 396 | # if minumum event time has been reached, reset the event, 397 | # turn of the laser and start intertrial time 398 | print("Ending Stimulation, Trigger is False") 399 | self._event = False 400 | trial_time = time.time() - self._event_start 401 | self._total_time += trial_time 402 | self._results.append(trial_time) 403 | print("Stimulation duration", trial_time) 404 | self._intertrial_timer.start() 405 | self._process.put(self._event) 406 | 407 | else: 408 | # if maximum experiment time was reached, stop experiment 409 | print("Ending experiment, timer ran out") 410 | self.stop_experiment() 411 | 412 | def start_experiment(self): 413 | self._exp_timer.start() 414 | 415 | def stop_experiment(self): 416 | self.experiment_finished = True 417 | print("Experiment completed!") 418 | print("Total event duration", sum(self._results)) 419 | print(self._results) 420 | 421 | def get_trial(self): 422 | return self._event 423 | -------------------------------------------------------------------------------- /experiments/base/stimulation.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import time 10 | import cv2 11 | import numpy as np 12 | 13 | from experiments.utils.exp_setup import get_stimulation_settings 14 | 15 | 16 | class BaseStimulation: 17 | def __init__(self): 18 | self._name = "BaseStimulation" 19 | self._parameter_dict = dict(TYPE="str", PORT="str", IP="str", STIM_TIME="float") 20 | self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict) 21 | self._running = False 22 | self._stim_device = self._setup_device( 23 | self._settings_dict["TYPE"], 24 | self._settings_dict["PORT"], 25 | self._settings_dict["IP"], 26 | ) 27 | 28 | @staticmethod 29 | def _setup_device(type, port, ip): 30 | device = None 31 | if type == "NI": 32 | from experiments.utils.DAQ_output import DigitalModDevice 33 | 34 | device = DigitalModDevice(port) 35 | 36 | if type == "RASPBERRY": 37 | from experiments.utils.gpio_control import DigitalPiDevice 38 | 39 | device = DigitalPiDevice(port) 40 | 41 | if type == "RASP_NETWORK": 42 | from experiments.utils.gpio_control import DigitalPiDevice 43 | 44 | if ip is not None: 45 | device = DigitalPiDevice(port, ip) 46 | else: 47 | raise ValueError("IP required for remote GPIO control.") 48 | 49 | if type == "ARDUINO": 50 | from experiments.utils.gpio_control import DigitalArduinoDevice 51 | 52 | device = DigitalArduinoDevice(port) 53 | 54 | return device 55 | 56 | def stimulate(self): 57 | """Run stimulation and stop after being done""" 58 | if self._settings_dict["STIM_TIME"] is not None: 59 | print( 60 | "Stimulation: {} for {}.".format( 61 | self._name, self._settings_dict["STIM_TIME"] 62 | ) 63 | ) 64 | self._stim_device.turn_on() 65 | self._running = True 66 | time.sleep(self._settings_dict["STIM_TIME"]) 67 | self._stim_device.turn_off() 68 | self._running = False 69 | else: 70 | print("Stimulation: {} does not support stimulate().".format(self._name)) 71 | 72 | def remove(self): 73 | """remove stimulation (e.g. reward) and stop after being done""" 74 | print("Stimulation: {} does not support remove().".format(self._name)) 75 | 76 | def start(self): 77 | 78 | if not self._running: 79 | print("Stimulation: {} ON.".format(self._name)) 80 | self._stim_device.turn_on() 81 | self._running = True 82 | else: 83 | print("Stimulation was already ON.") 84 | 85 | def stop(self): 86 | if self._running: 87 | print("Stimulation: {} OFF.".format(self._name)) 88 | self._stim_device.turn_off() 89 | self._running = False 90 | else: 91 | print("Stimulation was already OFF.") 92 | 93 | 94 | class RewardDispenser(BaseStimulation): 95 | def __init__(self): 96 | self._name = "RewardDispenser" 97 | self._parameter_dict = dict( 98 | TYPE="str", 99 | IP="str", 100 | STIM_PORT="str", 101 | REMOVAL_PORT="str", 102 | STIM_TIME="float", 103 | REMOVAL_TIME="float", 104 | ) 105 | self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict) 106 | self._running = False 107 | self._stim_device = self._setup_device( 108 | self._settings_dict["TYPE"], 109 | self._settings_dict["STIM_PORT"], 110 | self._settings_dict["IP"], 111 | ) 112 | self._removal_device = self._setup_device( 113 | self._settings_dict["TYPE"], 114 | self._settings_dict["REMOVAL_PORT"], 115 | self._settings_dict["IP"], 116 | ) 117 | 118 | @staticmethod 119 | def _setup_device(type, port, ip): 120 | device = None 121 | if type == "NI": 122 | from experiments.utils.DAQ_output import DigitalModDevice 123 | 124 | device = DigitalModDevice(port) 125 | 126 | if type == "RASPBERRY": 127 | from experiments.utils.gpio_control import DigitialPiBoardDevice 128 | 129 | device = DigitialPiBoardDevice(port) 130 | 131 | if type == "RASP_NETWORK": 132 | from experiments.utils.gpio_control import DigitialPiBoardDevice 133 | 134 | if ip is not None: 135 | device = DigitialPiBoardDevice(port, ip) 136 | else: 137 | raise ValueError("IP required for remote GPIO control.") 138 | 139 | return device 140 | 141 | def stimulate(self): 142 | """Run stimulation and stop after being done""" 143 | print( 144 | "Stimulation: {} for {}.".format( 145 | self._name, self._settings_dict["STIM_TIME"] 146 | ) 147 | ) 148 | self._stim_device.turn_on() 149 | self._running = True 150 | time.sleep(self._settings_dict["STIM_TIME"]) 151 | self._stim_device.turn_off() 152 | self._running = False 153 | 154 | def remove(self): 155 | """remove stimulation (e.g. reward) and stop after being done""" 156 | print( 157 | "Stimulation: {} for {}.".format( 158 | self._name, self._settings_dict["REMOVAL_TIME"] 159 | ) 160 | ) 161 | self._removal_device.turn_on() 162 | self._running = True 163 | time.sleep(self._settings_dict["REMOVAL_TIME"]) 164 | self._removal_device.turn_off() 165 | self._running = False 166 | 167 | def start(self): 168 | print( 169 | "Stimulation: {} does not support start(). Did you mean stimulate()?".format( 170 | self._name 171 | ) 172 | ) 173 | 174 | def stop(self): 175 | print( 176 | "Stimulation: {} does not support stop(). Did you mean remove()?".format( 177 | self._name 178 | ) 179 | ) 180 | 181 | 182 | class ScreenStimulation(BaseStimulation): 183 | def __init__(self): 184 | self._name = "ScreenStimulation" 185 | self._parameter_dict = dict(TYPE="str", STIM_PATH="str", BACKGROUND_PATH="str") 186 | self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict) 187 | self._running = False 188 | self._stim_device = None 189 | self.height, self.width = None, None 190 | self.framerate = None # only for video 191 | 192 | self._stimulus = self._setup_stimulus( 193 | self._settings_dict["STIM_PATH"], type=self._settings_dict["TYPE"] 194 | ) 195 | 196 | if self._settings_dict["BACKGROUND_PATH"] is not None: 197 | self._background = self._setup_stimulus(self._settings_dict["BACKGROUND_PATH"], type="image") 198 | else: 199 | self._background = np.zeros((self.height, self.width, 3), np.uint8) 200 | 201 | self._window = None 202 | 203 | def _setup_stimulus(self, path, type="image"): 204 | if type == "image": 205 | img = cv2.imread(path, -1) 206 | stimulus = np.uint8(img) 207 | if self.width is None or self.height is None: 208 | self.height, self.width = img.shape[:2] 209 | 210 | elif type == "video": 211 | stimulus = cv2.VideoCapture(path) 212 | if self.width is None or self.height is None: 213 | self.height, self.width = int(stimulus.get(cv2.CAP_PROP_FRAME_HEIGHT)), int( 214 | stimulus.get(cv2.CAP_PROP_FRAME_WIDTH)) 215 | # get framerate of video 216 | self.framerate = int(stimulus.get(cv2.CAP_PROP_FPS)) 217 | 218 | return stimulus 219 | 220 | def _setup_window(self): 221 | cv2.namedWindow(self._name, cv2.WINDOW_NORMAL) 222 | 223 | def stimulate(self): 224 | """Run stimulation and stop after being done""" 225 | if self._window is None: 226 | self._setup_window() 227 | if self._settings_dict["TYPE"] == "image": 228 | cv2.imshow(self._name, self._stimulus) 229 | cv2.waitKey(1) 230 | 231 | elif self._settings_dict["TYPE"] == "video": 232 | while self._stimulus.isOpened(): 233 | self._running = True 234 | ret, frame = self._stimulus.read() 235 | last_frame_time = time.time() 236 | if ret is True: 237 | cv2.imshow(self._name, frame) 238 | running_time = time.time() - last_frame_time 239 | # if the video is faster than the framerate, wait until the next frame should be shown 240 | if running_time <= 1 / self.framerate: 241 | sleepy_time = int(np.ceil(1000 / self.framerate - running_time / 1000)) 242 | cv2.waitKey(sleepy_time) 243 | else: 244 | break 245 | if cv2.waitKey(1) & 0xFF == ord("q"): 246 | # if user presses q, the video restarts from the beginning 247 | break 248 | self._running = False 249 | #reset video for next stimulus 250 | self._stimulus.set(cv2.CAP_PROP_POS_FRAMES, 0) 251 | 252 | def remove(self): 253 | """remove stimulation (e.g. reward) and stop after being done""" 254 | if self._window is None: 255 | self._setup_window() 256 | 257 | cv2.imshow(self._name, self._background) 258 | # add wait key. window waits until user presses a key (because cv2 crashes otherwise) 259 | cv2.waitKey(1) 260 | 261 | def start(self): 262 | print( 263 | "Stimulation: {} does not support start(). Did you mean stimulate()?".format( 264 | self._name 265 | ) 266 | ) 267 | 268 | def stop(self): 269 | print( 270 | "Stimulation: {} does not support stop(). Did you mean remove()?".format( 271 | self._name 272 | ) 273 | ) 274 | -------------------------------------------------------------------------------- /experiments/base/stimulus_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import time 10 | import multiprocessing as mp 11 | from experiments.utils.exp_setup import get_process_settings, setup_stimulation 12 | 13 | 14 | class Timer: 15 | """ 16 | Very simple timer 17 | """ 18 | 19 | def __init__(self, seconds): 20 | """ 21 | Setting the time the timer needs to run 22 | :param seconds: time in seconds 23 | """ 24 | self._seconds = seconds 25 | self._start_time = None 26 | 27 | def start(self): 28 | """ 29 | Starting the timer 30 | If already started does nothing 31 | """ 32 | if not self._start_time: 33 | self._start_time = time.time() 34 | 35 | def check_timer(self): 36 | """ 37 | Check if the time has run out or not 38 | Returns False if timer is not started 39 | Returns True if timer has run less then _seconds (still runs) 40 | """ 41 | if self._start_time: 42 | current_time = time.time() 43 | return current_time - self._start_time <= self._seconds 44 | else: 45 | return False 46 | 47 | def return_time(self): 48 | 49 | if self._start_time: 50 | current_time = time.time() 51 | return current_time - self._start_time 52 | else: 53 | pass 54 | 55 | def reset(self): 56 | """ 57 | Resets the timer 58 | """ 59 | self._start_time = None 60 | 61 | def get_start_time(self): 62 | """ 63 | Returns the start time of the timer 64 | """ 65 | return self._start_time 66 | 67 | 68 | def base_conditional_switch_protocol_run(condition_q: mp.Queue, stimulus_name): 69 | condition = False 70 | stimulation = setup_stimulation(stimulus_name) 71 | while True: 72 | if condition_q.full(): 73 | condition = condition_q.get() 74 | if condition: 75 | stimulation.start() 76 | else: 77 | stimulation.stop() 78 | 79 | 80 | def base_conditional_supply_protocol_run(condition_q: mp.Queue, stimulus_name): 81 | condition = False 82 | stimulation = setup_stimulation(stimulus_name) 83 | while True: 84 | if condition_q.full(): 85 | condition = condition_q.get() 86 | if condition: 87 | stimulation.stimulate() 88 | else: 89 | stimulation.remove() 90 | 91 | 92 | def base_trial_protocol_run( 93 | trial_q: mp.Queue, condition_q: mp.Queue, success_q: mp.Queue, stimulation_name 94 | ): 95 | """ 96 | The function to use in ProtocolProcess class 97 | Designed to be run continuously alongside the main loop 98 | Three parameters are three mp.Queue classes, each passes corresponding values 99 | :param trial_q: the protocol name (inwards); dict of trial from respective experiment 100 | :param success_q: the result of each protocol (outwards) 101 | :param condition_q: collects trigger results from trial trigger 102 | :param stimulus_name: exact name of stimulus function in base.stimulation.py 103 | """ 104 | current_trial = None 105 | stimulation = setup_stimulation(stimulation_name) 106 | # starting the main loop without any protocol running 107 | while True: 108 | if trial_q.empty() and current_trial is None: 109 | pass 110 | elif trial_q.full(): 111 | current_trial = trial_q.get() 112 | finished_trial = False 113 | # starting timers 114 | current_trial["stimulus_timer"].start() 115 | current_trial["success_timer"].start() 116 | print("Starting protocol {}".format(current_trial)) 117 | condition_list = [] 118 | # this branch is for already running protocol 119 | elif current_trial is not None: 120 | # checking for stimulus timer and outputting correct image 121 | if current_trial["stimulus_timer"].check_timer(): 122 | # if stimulus timer is running, show stimulus 123 | stimulation.start() 124 | else: 125 | # if the timer runs out, finish protocol and reset timer 126 | stimulation.stop() 127 | current_trial["stimulus_timer"].reset() 128 | current_trial = None 129 | 130 | # checking if any condition was passed 131 | if condition_q.full(): 132 | stimulus_condition = condition_q.get() 133 | # checking if timer for condition is running and condition=True 134 | if current_trial["success_timer"].check_timer(): 135 | condition_list.append(stimulus_condition) 136 | 137 | # checking if the timer for condition has run out 138 | if not current_trial["success_timer"].check_timer() and not finished_trial: 139 | # resetting the timer 140 | print("Timer for condition run out") 141 | finished_trial = True 142 | # outputting the result, whatever it is 143 | success = current_trial["result_func"](condition_list) 144 | success_q.put(success) 145 | current_trial["success_timer"].reset() 146 | 147 | 148 | class BaseProtocolProcess: 149 | """ 150 | Class to help work with protocol function in multiprocessing 151 | """ 152 | 153 | def __init__(self): 154 | """ 155 | Setting up the three queues and the process itself 156 | """ 157 | self._name = "BaseProtocolProcess" 158 | self._parameter_dict = dict(TYPE="str", STIMULATION="str") 159 | self._settings_dict = get_process_settings(self._name, self._parameter_dict) 160 | 161 | if self._settings_dict["TYPE"] == "trial": 162 | self._trial_queue = mp.Queue(1) 163 | self._success_queue = mp.Queue(1) 164 | self._condition_queue = mp.Queue(1) 165 | self._protocol_process = mp.Process( 166 | target=base_trial_protocol_run, 167 | args=( 168 | self._trial_queue, 169 | self._condition_queue, 170 | self._success_queue, 171 | self._settings_dict["STIMULATION"], 172 | ), 173 | ) 174 | elif self._settings_dict["TYPE"] == "switch": 175 | self._condition_queue = mp.Queue(1) 176 | self._protocol_process = mp.Process( 177 | target=base_conditional_switch_protocol_run, 178 | args=(self._condition_queue, self._settings_dict["STIMULATION"]), 179 | ) 180 | 181 | elif self._settings_dict["TYPE"] == "supply": 182 | self._condition_queue = mp.Queue(1) 183 | self._protocol_process = mp.Process( 184 | target=base_conditional_supply_protocol_run, 185 | args=(self._condition_queue, self._settings_dict["STIMULATION"]), 186 | ) 187 | 188 | self._running = False 189 | self._current_trial = None 190 | 191 | def start(self): 192 | """ 193 | Starting the process 194 | """ 195 | self._protocol_process.start() 196 | 197 | def end(self): 198 | """ 199 | Ending the process 200 | """ 201 | if ( 202 | self._settings_dict["TYPE"] == "switch" 203 | or self._settings_dict["TYPE"] == "supply" 204 | ): 205 | self._condition_queue.close() 206 | elif self._settings_dict["TYPE"] == "trial": 207 | self._trial_queue.close() 208 | self._success_queue.close() 209 | 210 | self._protocol_process.terminate() 211 | 212 | def get_status(self): 213 | """ 214 | Getting current status of the running protocol 215 | """ 216 | return self._running, self._current_trial 217 | 218 | def put(self, input_p): 219 | """ 220 | Passing the trial name to the process 221 | """ 222 | 223 | if self._condition_queue.empty(): 224 | self._condition_queue.put(input_p) 225 | 226 | def put_trial(self, trial: dict, trial_name): 227 | """ 228 | Passing the condition to the process 229 | """ 230 | if self._settings_dict["TYPE"] == "trial": 231 | if self._trial_queue.empty() and self._success_queue.empty(): 232 | self._trial_queue.put(trial) 233 | self._running = True 234 | self._current_trial = trial_name 235 | 236 | def get_result(self) -> bool: 237 | """ 238 | Getting result from the process 239 | """ 240 | if self._settings_dict["TYPE"] == "trial": 241 | if self._success_queue.full(): 242 | self._running = False 243 | return self._success_queue.get() 244 | else: 245 | return None 246 | -------------------------------------------------------------------------------- /experiments/configs/BaseConditionalExperiment_example.ini: -------------------------------------------------------------------------------- 1 | [EXPERIMENT] 2 | BASE = BaseConditionalExperiment 3 | EXPERIMENTER = Example 4 | 5 | [BaseConditionalExperiment] 6 | TRIGGER = BaseRegionTrigger 7 | PROCESS = BaseProtocolProcess 8 | INTERTRIAL_TIME = 40 9 | EXP_LENGTH = 40 10 | EXP_TIME = 3600 11 | 12 | [BaseRegionTrigger] 13 | TYPE = circle 14 | CENTER = 550, 63 15 | RADIUS = 30 16 | BODYPARTS = neck 17 | DEBUG = False 18 | 19 | [BaseProtocolProcess] 20 | TYPE = switch 21 | STIMULATION = BaseStimulation 22 | 23 | [BaseStimulation] 24 | TYPE = NI 25 | PORT = Dev1/PFI6 26 | STIM_TIME = 3.5 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/BaseOptogeneticExperiment_example.ini: -------------------------------------------------------------------------------- 1 | [EXPERIMENT] 2 | BASE = BaseOptogeneticExperiment 3 | EXPERIMENTER = Example 4 | 5 | [BaseOptogeneticExperiment] 6 | TRIGGER = BaseRegionTrigger 7 | INTERTRIAL_TIME = 40 8 | MAX_STIM_TIME = 10 9 | MIN_STIM_TIME = 3 10 | MAX_TOTAL_STIM_TIME = 500 11 | EXP_TIME = 3600 12 | PROCESS = BaseProtocolProcess 13 | 14 | [BaseRegionTrigger] 15 | TYPE = circle 16 | CENTER = 550, 63 17 | RADIUS = 30 18 | BODYPARTS = neck 19 | DEBUG = False 20 | 21 | [BaseProtocolProcess] 22 | TYPE = switch 23 | STIMULATION = BaseStimulation 24 | 25 | [BaseStimulation] 26 | TYPE = NI 27 | PORT = Dev1/PFI6 28 | STIM_TIME = 3.5 29 | 30 | -------------------------------------------------------------------------------- /experiments/configs/BaseTrialExperiment_example.ini: -------------------------------------------------------------------------------- 1 | [EXPERIMENT] 2 | BASE = BaseTrialExperiment 3 | EXPERIMENTER = Example 4 | 5 | [BaseTrialExperiment] 6 | TRIGGER = BaseScreenTrigger 7 | PROCESS = BaseProtocolProcess 8 | INTERTRIAL_TIME = 40 9 | TRIAL_NAME = Trial 10 | TRIAL_TRIGGER = BaseRegionTrigger 11 | TRIAL_TIME = 10 12 | STIMULUS_TIME = 10 13 | RESULT_FUNC = any 14 | EXP_LENGTH = 40 15 | EXP_COMPLETION = 20 16 | EXP_TIME = 6000 17 | 18 | [BaseScreenTrigger] 19 | ANGLE = 60 20 | DIRECTION = North 21 | BODYPARTS = nose, neck 22 | DEBUG = False 23 | 24 | [BaseProtocolProcess] 25 | TYPE = trial 26 | STIMULATION = ScreenStimulation 27 | 28 | [ScreenStimulation] 29 | TYPE = image 30 | STIM_PATH = PATH_TO_IMAGE 31 | BACKGROUND_PATH = PATH_TO_BACKGROUND 32 | 33 | [BaseRegionTrigger] 34 | TYPE = circle 35 | CENTER = 550, 63 36 | RADIUS = 30 37 | BODYPARTS = neck 38 | DEBUG = False 39 | 40 | -------------------------------------------------------------------------------- /experiments/configs/default_config.ini: -------------------------------------------------------------------------------- 1 | ;DO NOT REMOVE FILE 2 | ;[EXPERIMENTS] 3 | 4 | [BaseConditionalExperiment] 5 | TRIGGER = BaseRegionTrigger 6 | PROCESS = BaseProtocolProcess 7 | INTERTRIAL_TIME = 40 8 | EXP_LENGTH = 40 9 | EXP_TIME = 3600 10 | 11 | 12 | [BaseTrialExperiment] 13 | TRIGGER = BaseRegionTrigger 14 | PROCESS = BaseProtocolProcess 15 | INTERTRIAL_TIME = 40 16 | TRIAL_NAME = Trial 17 | TRIAL_TRIGGER = BaseRegionTrigger 18 | TRIAL_TIME = 10 19 | STIMULUS_TIME = 10 20 | RESULT_FUNC = any 21 | EXP_LENGTH = 40 22 | EXP_COMPLETION = 20 23 | EXP_TIME = 6000 24 | 25 | 26 | [BaseOptogeneticExperiment] 27 | TRIGGER = BaseRegionTrigger 28 | INTERTRIAL_TIME = 40 29 | MAX_STIM_TIME = 10 30 | MIN_STIM_TIME = 3 31 | MAX_TOTAL_STIM_TIME = 500 32 | EXP_TIME = 3600 33 | PROCESS = BaseProtocolProcess 34 | 35 | 36 | ;[TRIGGER] 37 | 38 | [BaseRegionTrigger] 39 | TYPE = circle 40 | CENTER= 550, 63 41 | RADIUS = 30 42 | BODYPARTS = neck 43 | DEBUG = False 44 | 45 | [BaseOutsideRegionTrigger] 46 | TYPE = circle 47 | CENTER= 550, 63 48 | RADIUS = 30 49 | BODYPARTS = neck 50 | DEBUG = False 51 | 52 | [BaseHeaddirectionTrigger] 53 | POINT= 550, 63 54 | ANGLE = 30 55 | BODYPARTS = nose, neck 56 | DEBUG = False 57 | 58 | [BaseEgoHeaddirectionTrigger] 59 | ANGLE = 30 60 | HEADDIRECTION = both 61 | BODYPARTS = nose, neck, tailroot 62 | DEBUG = False 63 | 64 | [BaseScreenTrigger] 65 | ANGLE = 30 66 | DIRECTION = North 67 | BODYPARTS = nose, neck 68 | DEBUG = False 69 | 70 | [BaseHeaddirectionROITrigger] 71 | TYPE = circle 72 | CENTER= 550, 63 73 | RADIUS = 30 74 | ROI_BODYPARTS = neck 75 | DEBUG = False 76 | POINT= 550, 63 77 | ANGLE = 30 78 | ANGLE_BODYPARTS = nose, neck 79 | 80 | 81 | [BaseSpeedTrigger] 82 | THRESHOLD = 2.6 83 | BODYPARTS = any 84 | DEBUG = False 85 | 86 | [BaseFreezeTrigger] 87 | THRESHOLD = 2.6 88 | BODYPARTS = any 89 | DEBUG = False 90 | 91 | ;[ProtocolProcess] 92 | 93 | [BaseProtocolProcess] 94 | TYPE= switch 95 | STIMULATION = BaseStimulation 96 | 97 | 98 | ;[STIMULATION] 99 | 100 | [BaseStimulation] 101 | ; can be NI, RASPBERRY or RASP_NETWORK 102 | TYPE = NI 103 | ;only used in RASP_NETWORK 104 | IP = None 105 | ;PORT parameter is used for all (Port from DAQ, PIN from Raspberry, or serial port from Arduino) 106 | PORT = Dev1/PFI6 107 | STIM_TIME = 3.5 108 | 109 | [RewardDispenser] 110 | ; can be NI, RASPBERRY, RASP_NETWORK or ARDUINO 111 | TYPE = NI 112 | ;only used in RASP_NETWORK 113 | IP = None 114 | ;PORT parameter is used for all (Port from DAQ, PIN from Raspberry, or serial port from Arduino) 115 | STIM_PORT = Dev1/PFI6 116 | REMOVAL_PORT = Dev1/PFI5 117 | STIM_TIME = 3.5 118 | REMOVAL_TIME = 3.5 119 | 120 | [ScreenStimulation] 121 | TYPE = image 122 | STIM_PATH = PATH_TO_IMAGE 123 | BACKGROUND_PATH = PATH_TO_BACKGROUND 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /experiments/custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/custom/__init__.py -------------------------------------------------------------------------------- /experiments/custom/stimulation.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import time 10 | import os 11 | import cv2 12 | import numpy as np 13 | from experiments.utils.DAQ_output import DigitalModDevice 14 | 15 | 16 | def show_visual_stim_img(img_type="background", name="vistim", img_dict=None): 17 | """ 18 | Shows image in newly created or named window 19 | 20 | :param img_type: defines image through visual dictionary to be displayed 21 | :param name: name of window that is created or used by OpenCV to display image 22 | :param img_dict: optional custom image paths dictionary 23 | """ 24 | # Show image when called 25 | img_path = os.path.join(os.path.dirname(__file__), "src") 26 | if img_dict is None: 27 | visual = { 28 | "background": r"whiteback_1920_1080.png", 29 | "Greenbar_whiteback": r"greenbar_whiteback_1920_1080.png", 30 | "Bluebar_whiteback": r"bluebar_whiteback_1920_1080.png", 31 | } 32 | else: 33 | visual = img_dict 34 | # load image unchanged (-1), greyscale (0) or color (1) 35 | img = cv2.imread(os.path.join(img_path, visual[img_type]), -1) 36 | converted_image = np.uint8(img) 37 | cv2.namedWindow(name, cv2.WINDOW_NORMAL) 38 | cv2.imshow(name, converted_image) 39 | 40 | 41 | def toggle_device(): 42 | """Controls micro peristaltic pump via digital trigger signal.""" 43 | device = DigitalModDevice("Dev1/PFI2") 44 | device.toggle() 45 | 46 | 47 | def show_visual_stim_img(type="background", name="vistim"): 48 | """ 49 | Shows image in newly created or named window 50 | 51 | :param type: defines image through visual dictionary to be displayed 52 | :param name: name of window that is created or used by OpenCV to display image 53 | """ 54 | # Show image when called 55 | visual = { 56 | "background": dict(path=r"./experiments/src/whiteback_1920_1080.png"), 57 | "Greenbar_whiteback": dict( 58 | path=r"./experiments/src/greenbar_whiteback_1920_1080.png" 59 | ), 60 | "Bluebar_whiteback": dict( 61 | path=r"./experiments/src/bluebar_whiteback_1920_1080.png" 62 | ), 63 | "DLStream_test": dict(path=r"./experiments/src/stuckinaloop.jpg"), 64 | } 65 | # load image unchanged (-1), greyscale (0) or color (1) 66 | img = cv2.imread(visual[type]["path"], -1) 67 | converted_image = np.uint8(img) 68 | cv2.namedWindow(name, cv2.WINDOW_NORMAL) 69 | cv2.imshow(name, converted_image) 70 | 71 | 72 | def show_visual_stim_vid(type, name="vistim"): 73 | """ 74 | Shows video in newly created or named window 75 | WARNING: LONG FILES WILL HOLD THE PROCESS NOTICEABLY 76 | :param type: defines video through visual dictionary to be displayed 77 | :param name: name of window that is created or used by OpenCV to display image 78 | """ 79 | # Show image when called 80 | visual = { 81 | "Vid1": dict(path=r"./experiments/src/video1.mp4"), 82 | "Vid2": dict(path=r"./experiments/src/video2.mp4"), 83 | } 84 | cap = cv2.VideoCapture(visual[type]["path"]) 85 | cv2.namedWindow(name, cv2.WINDOW_NORMAL) 86 | while cap.isOpened(): 87 | ret, frame = cap.read() 88 | 89 | # gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 90 | 91 | if ret is True: 92 | cv2.imshow(name, frame) 93 | 94 | else: 95 | break 96 | 97 | if cv2.waitKey(1) & 0xFF == ord("q"): 98 | break 99 | 100 | cap.release() 101 | 102 | 103 | def toggle_device(): 104 | """Controls micro peristaltic pump via digital trigger signal.""" 105 | device = DigitalModDevice("Dev1/PFI2") 106 | device.toggle() 107 | 108 | 109 | """The following is the original stimulation we used for our experiments! If you are interested in using this, 110 | you will need to adapt the stimulation to your system! Otherwise I recommend looking at them for ideas how to incorporate 111 | your own experiment into DLStream!""" 112 | 113 | 114 | def laser_toggle(): 115 | """Toggle laser on or off 116 | Laser needs to be connected to DAQ_PORT and switched to "Digital modulation" 117 | If you use additional safety measurements to control the laser, make sure to undo them before starting the protocol!""" 118 | 119 | laser = DigitalModDevice(LSR_DAQ_PORT) 120 | laser.toggle() 121 | print("Laser was toggled") 122 | 123 | 124 | def laser_switch(switch: bool = False): 125 | """Toggle laser on or off 126 | Laser needs to be connected to DAQ_PORT and switched to "Digital modulation" 127 | If you use additional safety measurements to control the laser, make sure to undo them before starting the protocol!""" 128 | 129 | laser = DigitalModDevice(LSR_DAQ_PORT) 130 | if switch: 131 | laser.turn_on() 132 | print("Laser is switched on") 133 | 134 | else: 135 | laser.turn_off() 136 | print("Laser is switched off") 137 | 138 | 139 | def deliver_tone_shock(): 140 | """ 141 | Activates tone signal via digital trigger. Cycle is optional 142 | :param rep: Number of repetitions for signal [int] 143 | :param duration: Duration in seconds of signal for each rep [float] 144 | :param inter_time: Time in seconds between reps [float] 145 | """ 146 | 147 | tone_gen = DigitalModDevice("Dev1/PFI5") 148 | tone_gen.toggle() 149 | 150 | 151 | def deliver_airpuff(rep: int = 1, duration: float = 0.1, inter_time: float = 0.1): 152 | """Controls pressure micro-injector via digital trigger signal. 153 | All other parameters need to be manually changed on the Device, this only triggers it!""" 154 | 155 | pump = DigitalModDevice(AP_DAQ_PORT) 156 | if rep > 1: 157 | pump.cycle(rep, duration, inter_time) 158 | else: 159 | pump.trigger() 160 | 161 | 162 | def deliver_liqreward(): 163 | """Controls micro peristaltic pump via digital trigger signal.""" 164 | pump_delivery = DigitalModDevice("Dev1/PFI2") 165 | pump_delivery.toggle() 166 | 167 | 168 | def withdraw_liqreward(): 169 | """activates micro peristaltic pump""" 170 | pump_withdraw = DigitalModDevice("Dev1/PFI6") 171 | pump_withdraw.timed_on(3.5) 172 | 173 | 174 | if __name__ == "__main__": 175 | toggle_device() 176 | time.sleep(3.5) 177 | toggle_device() 178 | -------------------------------------------------------------------------------- /experiments/custom/stimulus_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import time 10 | import cv2 11 | import multiprocessing as mp 12 | from experiments.custom.stimulation import ( 13 | show_visual_stim_img, 14 | deliver_liqreward, 15 | deliver_tone_shock, 16 | withdraw_liqreward, 17 | DigitalModDevice, 18 | ) 19 | from experiments.utils.gpio_control import DigitalArduinoDevice 20 | import random 21 | 22 | 23 | class Timer: 24 | """ 25 | Very simple timer 26 | """ 27 | 28 | def __init__(self, seconds): 29 | """ 30 | Setting the time the timer needs to run 31 | :param seconds: time in seconds 32 | """ 33 | self._seconds = seconds 34 | self._start_time = None 35 | 36 | def start(self): 37 | """ 38 | Starting the timer 39 | If already started does nothing 40 | """ 41 | if not self._start_time: 42 | self._start_time = time.time() 43 | 44 | def check_timer(self): 45 | """ 46 | Check if the time has run out or not 47 | Returns False if timer is not started 48 | Returns True if timer has run less then _seconds (still runs) 49 | """ 50 | if self._start_time: 51 | current_time = time.time() 52 | return current_time - self._start_time <= self._seconds 53 | else: 54 | return False 55 | 56 | def return_time(self): 57 | 58 | if self._start_time: 59 | current_time = time.time() 60 | return current_time - self._start_time 61 | else: 62 | pass 63 | 64 | def reset(self): 65 | """ 66 | Resets the timer 67 | """ 68 | self._start_time = None 69 | 70 | def get_start_time(self): 71 | """ 72 | Returns the start time of the timer 73 | """ 74 | return self._start_time 75 | 76 | 77 | def example_protocol_run(condition_q: mp.Queue): 78 | current_trial = None 79 | # dmod_device = DigitalModDevice('Dev1/PFI0') 80 | # led_machine = DigitalArduinoDevice("COM5") 81 | while True: 82 | # if no protocol is selected, running default picture (background) 83 | if condition_q.full(): 84 | current_trial = condition_q.get() 85 | if current_trial is not None: 86 | show_visual_stim_img(type=current_trial, name="DlStream") 87 | # dmod_device.toggle() 88 | # led_machine.turn_on() 89 | else: 90 | show_visual_stim_img(name="DlStream") 91 | # dmod_device.turn_off() 92 | # led_machine.turn_off() 93 | 94 | if cv2.waitKey(1) & 0xFF == ord("q"): 95 | break 96 | 97 | 98 | class ProtocolProcess: 99 | """ 100 | Class to help work with protocol function in multiprocessing 101 | """ 102 | 103 | def __init__(self): 104 | """ 105 | Setting up the three queues and the process itself 106 | """ 107 | self._trial_queue = mp.Queue(1) 108 | self._success_queue = mp.Queue(1) 109 | self._condition_queue = mp.Queue(1) 110 | self._protocol_process = None 111 | self._running = False 112 | self._current_trial = None 113 | 114 | def start(self): 115 | """ 116 | Starting the process 117 | """ 118 | self._protocol_process.start() 119 | 120 | def end(self): 121 | """ 122 | Ending the process 123 | """ 124 | self._trial_queue.close() 125 | self._success_queue.close() 126 | self._condition_queue.close() 127 | self._protocol_process.terminate() 128 | 129 | def get_status(self): 130 | """ 131 | Getting current status of the running protocol 132 | """ 133 | return self._running, self._current_trial 134 | 135 | def set_trial(self, trial: str): 136 | """ 137 | Passing the trial name to the process 138 | """ 139 | if self._trial_queue.empty() and self._success_queue.empty(): 140 | self._trial_queue.put(trial) 141 | self._running = True 142 | self._current_trial = trial 143 | 144 | def pass_condition(self, condition: bool): 145 | """ 146 | Passing the condition to the process 147 | """ 148 | if self._condition_queue.empty(): 149 | self._condition_queue.put(condition) 150 | 151 | def get_result(self) -> bool: 152 | """ 153 | Getting result from the process 154 | """ 155 | if self._success_queue.full(): 156 | self._running = False 157 | return self._success_queue.get() 158 | 159 | 160 | class ExampleProtocolProcess(ProtocolProcess): 161 | """ 162 | Class to help work with protocol function in multiprocessing with simple stimulation 163 | """ 164 | 165 | def __init__(self): 166 | """ 167 | Setting up the three queues and the process itself 168 | """ 169 | super().__init__() 170 | self._protocol_process = mp.Process( 171 | target=example_protocol_run, args=(self._trial_queue,) 172 | ) 173 | 174 | 175 | """The following is the original protocols we used for our experiments! If you are interested in using this, 176 | you will need to adapt the stimulation to your system! Otherwise I recommend looking at them for ideas how to incorporate 177 | your own experiment into DLStream!""" 178 | 179 | 180 | def start_unconditional(protocol): 181 | print("Running some stuff, water or sound for {} protocol".format(protocol)) 182 | 183 | 184 | def classic_protocol_run_old( 185 | trial_q: mp.Queue, condition_q: mp.Queue, success_q: mp.Queue, trials: dict 186 | ): 187 | """ 188 | The function to use in ProtocolProcess class 189 | Designed to be run continuously alongside the main loop 190 | Three parameters are three mp.Queue classes, each passes corresponding values 191 | :param trial_q: the protocol name (inwards) 192 | :param condition_q: the condition (inwards) 193 | :param success_q: the result of each protocol (outwards) 194 | :param trials: dict of possible trials 195 | """ 196 | # setting up different trials 197 | current_trial = None 198 | # starting the main loop without any protocol running 199 | while True: 200 | # if no protocol is selected, running default picture (background) 201 | if trial_q.empty() and current_trial is None: 202 | # print('No protocol running') 203 | show_visual_stim_img(name="inside") 204 | # if some protocol is passed, set up protocol timers and variables 205 | elif trial_q.full(): 206 | current_trial = trial_q.get() 207 | finished_trial = False 208 | # starting timers 209 | stimulus_timer = trials[current_trial]["stimulus_timer"] 210 | success_timer = trials[current_trial]["success_timer"] 211 | print("Starting protocol {}".format(current_trial)) 212 | stimulus_timer.start() 213 | success_timer.start() 214 | condition_list = [] 215 | # this branch is for already running protocol 216 | elif current_trial is not None: 217 | # checking for stimulus timer and outputting correct image 218 | if stimulus_timer.check_timer(): 219 | # if stimulus timer is running, show stimulus 220 | show_visual_stim_img(current_trial, name="inside") 221 | else: 222 | # if the timer runs out, finish protocol and reset timer 223 | trials[current_trial]["stimulus_timer"].reset() 224 | current_trial = None 225 | 226 | # checking if any condition was passed 227 | if condition_q.full(): 228 | stimulus_condition = condition_q.get() 229 | # checking if timer for condition is running and condition=True 230 | if success_timer.check_timer(): 231 | # print('That was a success!') 232 | condition_list.append(stimulus_condition) 233 | # elif success_timer.check_timer() and not stimulus_condition: 234 | # # print('That was not a success') 235 | # condition_list.append(False) 236 | 237 | # checking if the timer for condition has run out 238 | if not success_timer.check_timer() and not finished_trial: 239 | if CTRL: 240 | # start a random time interval 241 | # TODO: working ctrl timer that does not set new time each frame... 242 | ctrl_time = random.randint(0, INTERTRIAL_TIME + 1) 243 | ctrl_timer = Timer(ctrl_time) 244 | ctrl_timer.start() 245 | print("Waiting for extra" + str(ctrl_time) + " sec") 246 | if not ctrl_timer.check_timer(): 247 | # in ctrl just randomly decide between the two 248 | print("Random choice between both stimuli") 249 | if random.random() >= 0.5: 250 | # very fast random choice between TRUE and FALSE 251 | deliver_liqreward() 252 | print("Delivered Reward") 253 | 254 | else: 255 | deliver_tone_shock() 256 | print("Delivered Aversive") 257 | 258 | ctrl_timer.reset() 259 | finished_trial = True 260 | # outputting the result, whatever it is 261 | success = trials[current_trial]["result_func"](condition_list) 262 | success_q.put(success) 263 | trials[current_trial]["success_timer"].reset() 264 | 265 | else: 266 | if current_trial == "Bluebar_whiteback": 267 | deliver_tone_shock() 268 | print("Delivered Aversive") 269 | elif current_trial == "Greenbar_whiteback": 270 | if trials[current_trial]["random_reward"]: 271 | if random.random() >= 0.5: 272 | # very fast random choice between TRUE and FALSE 273 | deliver_liqreward() 274 | print("Delivered Reward") 275 | else: 276 | print("No Reward") 277 | else: 278 | deliver_liqreward() 279 | # resetting the timer 280 | print("Timer for condition run out") 281 | finished_trial = True 282 | # outputting the result, whatever it is 283 | success = trials[current_trial]["result_func"](condition_list) 284 | success_q.put(success) 285 | trials[current_trial]["success_timer"].reset() 286 | 287 | # don't delete that 288 | if cv2.waitKey(1) & 0xFF == ord("q"): 289 | break 290 | 291 | 292 | def classic_protocol_run( 293 | trial_q: mp.Queue, condition_q: mp.Queue, success_q: mp.Queue, trials: dict 294 | ): 295 | """ 296 | The function to use in ProtocolProcess class 297 | Designed to be run continuously alongside the main loop 298 | Three parameters are three mp.Queue classes, each passes corresponding values 299 | :param trial_q: the protocol name (inwards) 300 | :param condition_q: the condition (inwards) 301 | :param success_q: the result of each protocol (outwards) 302 | :param trials: dict of possible trials 303 | """ 304 | # setting up different trials 305 | current_trial = None 306 | # starting the main loop without any protocol running 307 | while True: 308 | # if no protocol is selected, running default picture (background) 309 | if trial_q.empty() and current_trial is None: 310 | # print('No protocol running') 311 | show_visual_stim_img(name="inside") 312 | # if some protocol is passed, set up protocol timers and variables 313 | elif trial_q.full(): 314 | current_trial = trial_q.get() 315 | finished_trial = False 316 | delivery = False 317 | reward_del = False 318 | # starting timers 319 | stimulus_timer = trials[current_trial]["stimulus_timer"] 320 | collection_timer = trials[current_trial]["collection_timer"] 321 | success_timer = trials[current_trial]["success_timer"] 322 | delivery_timer = Timer(3.5) 323 | shock_timer = Timer(3.5) 324 | # withdraw_timer = Timer(3.5) 325 | print("Starting protocol {}".format(current_trial)) 326 | stimulus_timer.start() 327 | success_timer.start() 328 | condition_list = [] 329 | collection_list = [] 330 | # this branch is for already running protocol 331 | elif current_trial is not None: 332 | # checking for stimulus timer and outputting correct image 333 | if stimulus_timer.check_timer(): 334 | # if stimulus timer is running, show stimulus 335 | show_visual_stim_img(current_trial, name="inside") 336 | else: 337 | # if the timer runs out, finish protocol and reset timer 338 | trials[current_trial]["stimulus_timer"].reset() 339 | show_visual_stim_img(name="inside") 340 | # checking if any condition was passed 341 | if condition_q.full(): 342 | stimulus_condition = condition_q.get() 343 | # checking if timer for condition is running and condition=True 344 | if success_timer.check_timer(): 345 | condition_list.append(stimulus_condition) 346 | elif not success_timer.check_timer() and collection_timer.check_timer(): 347 | collection_list.append(stimulus_condition) 348 | 349 | # checking if the timer for condition has run out 350 | if not success_timer.check_timer() and not finished_trial: 351 | 352 | if not delivery: 353 | if current_trial is not None: 354 | print("Timer for condition ran out") 355 | print_check = True 356 | # check wether animal collected within success timer 357 | success = trials[current_trial]["result_func"](condition_list) 358 | trials[current_trial]["success_timer"].reset() 359 | 360 | print("Stimulation.") 361 | 362 | if current_trial == "Bluebar_whiteback": 363 | deliver_tone_shock() 364 | print("Aversive") 365 | shock_timer.start() 366 | elif current_trial == "Greenbar_whiteback": 367 | deliver_liqreward() 368 | delivery_timer.start() 369 | reward_del = True 370 | print("Reward") 371 | delivery = True 372 | collection_timer.start() 373 | elif delivery: 374 | # resetting the timer 375 | if not collection_timer.check_timer(): 376 | finished_trial = True 377 | # check whether animal collected at all 378 | collect = any(collection_list) 379 | if not collect and reward_del: 380 | # if the animal didnt go to collect reward, withdraw reward again. 381 | withdraw_liqreward() 382 | # withdraw_timer.start() 383 | trials[current_trial]["collection_timer"].reset() 384 | current_trial = None 385 | # put success in queue and finish trial 386 | success_q.put(success) 387 | 388 | if ( 389 | not delivery_timer.check_timer() 390 | and delivery_timer.get_start_time() is not None 391 | ): 392 | deliver_liqreward() 393 | delivery_timer.reset() 394 | if ( 395 | not shock_timer.check_timer() 396 | and shock_timer.get_start_time() is not None 397 | ): 398 | deliver_tone_shock() 399 | shock_timer.reset() 400 | 401 | # if not withdraw_timer.check_timer() and withdraw_timer.get_start_time() is not None: 402 | # withdraw_liqreward(False) 403 | # withdraw_timer.reset() 404 | # delivery = False 405 | 406 | # don't delete that 407 | if cv2.waitKey(1) & 0xFF == ord("q"): 408 | break 409 | 410 | 411 | def simple_protocol_run(trial_q: mp.Queue, success_q: mp.Queue, trials: dict): 412 | """ 413 | The function to use in ProtocolProcess class 414 | Designed to be run continuously alongside the main loop 415 | Three parameters are three mp.Queue classes, each passes corresponding values 416 | :param trial_q: the protocol name (inwards) 417 | :param success_q: the result of each protocol (outwards) 418 | :param trials: dict of possible trials 419 | """ 420 | current_trial = None 421 | # starting the main loop without any protocol running 422 | while True: 423 | if trial_q.empty() and current_trial is None: 424 | pass 425 | elif trial_q.full(): 426 | current_trial = trial_q.get() 427 | print(current_trial) 428 | # this branch is for already running protocol 429 | elif current_trial is not None: 430 | print("Stimulating...") 431 | current_trial = None 432 | success_q.put(True) 433 | deliver_liqreward() 434 | time.sleep(3.5) 435 | deliver_liqreward() 436 | 437 | 438 | class ClassicProtocolProcess: 439 | """ 440 | Class to help work with protocol function in multiprocessing 441 | """ 442 | 443 | def __init__(self, trials): 444 | """ 445 | Setting up the three queues and the process itself 446 | """ 447 | self._trial_queue = mp.Queue(1) 448 | self._success_queue = mp.Queue(1) 449 | self._condition_queue = mp.Queue(1) 450 | self._protocol_process = mp.Process( 451 | target=classic_protocol_run, 452 | args=( 453 | self._trial_queue, 454 | self._condition_queue, 455 | self._success_queue, 456 | trials, 457 | ), 458 | ) 459 | self._running = False 460 | self._current_trial = None 461 | 462 | def start(self): 463 | """ 464 | Starting the process 465 | """ 466 | self._protocol_process.start() 467 | 468 | def end(self): 469 | """ 470 | Ending the process 471 | """ 472 | self._trial_queue.close() 473 | self._success_queue.close() 474 | self._condition_queue.close() 475 | self._protocol_process.terminate() 476 | 477 | def get_status(self): 478 | """ 479 | Getting current status of the running protocol 480 | """ 481 | return self._running, self._current_trial 482 | 483 | def set_trial(self, trial: str): 484 | """ 485 | Passing the trial name to the process 486 | """ 487 | if self._trial_queue.empty() and self._success_queue.empty(): 488 | self._trial_queue.put(trial) 489 | self._running = True 490 | self._current_trial = trial 491 | 492 | def pass_condition(self, condition: bool): 493 | """ 494 | Passing the condition to the process 495 | """ 496 | if self._condition_queue.empty(): 497 | self._condition_queue.put(condition) 498 | 499 | def get_result(self) -> bool: 500 | """ 501 | Getting result from the process 502 | """ 503 | if self._success_queue.full(): 504 | self._running = False 505 | return self._success_queue.get() 506 | 507 | 508 | class SimpleProtocolProcess(ClassicProtocolProcess): 509 | """ 510 | Class to help work with protocol function in multiprocessing with simple stimulation 511 | """ 512 | 513 | def __init__(self, trials): 514 | """ 515 | Setting up the three queues and the process itself 516 | """ 517 | super().__init__(trials) 518 | self._protocol_process = mp.Process( 519 | target=simple_protocol_run, 520 | args=(self._trial_queue, self._success_queue, trials), 521 | ) 522 | -------------------------------------------------------------------------------- /experiments/src/bluebar_whiteback_1920_1080.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/src/bluebar_whiteback_1920_1080.png -------------------------------------------------------------------------------- /experiments/src/greenbar_whiteback_1920_1080.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/src/greenbar_whiteback_1920_1080.png -------------------------------------------------------------------------------- /experiments/src/stuckinaloop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/src/stuckinaloop.jpg -------------------------------------------------------------------------------- /experiments/src/whiteback_1920_1080.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/src/whiteback_1920_1080.png -------------------------------------------------------------------------------- /experiments/utils/DAQ_output.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import nidaqmx 10 | import time 11 | 12 | 13 | class NotFoundException(nidaqmx._lib.DaqNotFoundError): 14 | pass 15 | 16 | 17 | class Device: 18 | """ 19 | Modulated devices connected to the DAQ Board 20 | """ 21 | 22 | def __init__(self, port): 23 | """ 24 | :param port: output port on the DAQ board connected to the Device 25 | """ 26 | self.INPUT_PORT = port 27 | self._status = False 28 | 29 | def get_port(self): 30 | return self.INPUT_PORT 31 | 32 | def get_status(self): 33 | return self._status 34 | 35 | 36 | class DigitalModDevice(Device): 37 | """ 38 | Digital modulated devices 39 | """ 40 | 41 | def __init__(self, digital_output_port): 42 | """ 43 | :param digital_output_port: the digital output port on the DAQ board connected to the Device 44 | """ 45 | super().__init__(digital_output_port) 46 | self._t_switch = False 47 | 48 | def trigger(self): 49 | """ 50 | triggers devices via Digital output of NIDAQ board 51 | """ 52 | TRIGGER = [True, False] 53 | try: 54 | with nidaqmx.Task() as task: 55 | task.do_channels.add_do_chan(self.INPUT_PORT) 56 | task.write(TRIGGER, auto_start=True) 57 | except NotFoundException: 58 | print("DAQ device not found") 59 | 60 | def turn_on(self): 61 | try: 62 | with nidaqmx.Task() as task: 63 | task.do_channels.add_do_chan(self.INPUT_PORT) 64 | task.write(True, auto_start=True) 65 | except NotFoundException: 66 | print("DAQ device not found") 67 | 68 | def turn_off(self): 69 | try: 70 | with nidaqmx.Task() as task: 71 | task.do_channels.add_do_chan(self.INPUT_PORT) 72 | task.write(False, auto_start=True) 73 | except NotFoundException: 74 | print("DAQ device not found") 75 | 76 | def toggle(self): 77 | """ 78 | Digital modulation of Device via Digital output of NIDAQ board 79 | Toggles Device on if off and vice versa 80 | """ 81 | self._t_switch = not self._t_switch 82 | try: 83 | with nidaqmx.Task() as task: 84 | task.do_channels.add_do_chan(self.INPUT_PORT) 85 | task.write(self._t_switch, auto_start=True) 86 | except NotFoundException: 87 | print("DAQ device not found") 88 | 89 | def timed_on(self, on_time): 90 | """ 91 | Digital modulation of Device via Digital output of NIDAQ board 92 | :param on_time: the amount of time that the Device should stay turned ON in seconds 93 | """ 94 | if on_time > 0: 95 | self.toggle() 96 | time.sleep(on_time) 97 | self.toggle() 98 | 99 | def cycle(self, repeats, on_time, off_time): 100 | """ 101 | Digital modulation of Device via Digital output of NIDAQ board 102 | :param repeats: the number of repeats the ON/OFF cycle should run 103 | :param on_time: the amount of time that the Device should stay turned ON in seconds 104 | :param off_time: the amount of time that the Device should stay turned OFF between cycles in seconds 105 | """ 106 | for i in range(repeats): 107 | self.timed_on(on_time) 108 | time.sleep(off_time) 109 | 110 | 111 | class AnalogModDevice(Device): 112 | """Analog modulated devices""" 113 | 114 | def __init__(self, AO_DAQ_PORT): 115 | """ 116 | :param AO_DAQ_PORT: the analog output port on the DAQ board connected to the Device 117 | """ 118 | super().__init__(AO_DAQ_PORT) 119 | 120 | def amod_decive(self, V): 121 | """Changes Output current of Analog Output Port defined amount in Volt 122 | 123 | :param V: Amount of Current as float if turned ON 124 | """ 125 | with nidaqmx.Task() as task: 126 | task.ao_channels.add_ao_voltage_chan(self.INPUT_PORT) 127 | task.write([V], auto_start=True) 128 | -------------------------------------------------------------------------------- /experiments/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/experiments/utils/__init__.py -------------------------------------------------------------------------------- /experiments/utils/exp_setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | 10 | import os 11 | import configparser as cfg 12 | from datetime import date 13 | from utils.configloader import EXP_NAME, EXP_ORIGIN 14 | 15 | 16 | def get_config_settings(name, parameter_dict, config_file_name): 17 | 18 | config = cfg.ConfigParser() 19 | path = os.path.join(os.path.dirname(__file__), "..", "configs", config_file_name) 20 | with open(path) as file: 21 | config.read_file(file) 22 | 23 | config_dict = {} 24 | for parameter in list(parameter_dict.keys()): 25 | if parameter_dict[parameter] == "int": 26 | try: 27 | config_dict[parameter] = config[name].getint(parameter) 28 | except: 29 | config_dict[parameter] = None 30 | print( 31 | "Did not find valid {} entry for {} in {}. Setting to None.".format( 32 | parameter, name, config_file_name 33 | ) 34 | ) 35 | if parameter_dict[parameter] == "float": 36 | try: 37 | config_dict[parameter] = config[name].getfloat(parameter) 38 | except: 39 | config_dict[parameter] = None 40 | print( 41 | "Did not find valid {} entry for {} in {}. Setting to None.".format( 42 | parameter, name, config_file_name 43 | ) 44 | ) 45 | 46 | elif parameter_dict[parameter] == "tuple": 47 | try: 48 | config_dict[parameter] = tuple( 49 | int(entry) for entry in config[name].get(parameter).split(",") 50 | ) 51 | except: 52 | config_dict[parameter] = None 53 | print( 54 | "Did not find valid {} entry for {} in {}. Setting to None.".format( 55 | parameter, name, config_file_name 56 | ) 57 | ) 58 | 59 | elif parameter_dict[parameter] == "list": 60 | try: 61 | config_dict[parameter] = list( 62 | str(entry) for entry in config[name].get(parameter).split(",") 63 | ) 64 | except: 65 | config_dict[parameter] = None 66 | print( 67 | "Did not find valid {} entry for {} in {}. Setting to None.".format( 68 | parameter, name, config_file_name 69 | ) 70 | ) 71 | 72 | elif parameter_dict[parameter] == "boolean": 73 | try: 74 | config_dict[parameter] = config[name].getboolean(parameter) 75 | except: 76 | config_dict[parameter] = None 77 | print( 78 | "Did not find valid {} entry for {} in {}. Setting to None.".format( 79 | parameter, name, config_file_name 80 | ) 81 | ) 82 | 83 | elif parameter_dict[parameter] == "str": 84 | try: 85 | config_dict[parameter] = config[name].get(parameter) 86 | except: 87 | config_dict[parameter] = None 88 | print( 89 | "Did not find valid {} entry for {} in {}. Setting to None.".format( 90 | parameter, name, config_file_name 91 | ) 92 | ) 93 | 94 | return config_dict 95 | 96 | 97 | def get_experiment_settings(experiment_name, parameter_dict): 98 | 99 | experiment_config = get_config_settings( 100 | experiment_name, parameter_dict, f"{EXP_NAME}.ini" 101 | ) 102 | 103 | return experiment_config 104 | 105 | 106 | def get_stimulation_settings(stimulation_name, parameter_dict): 107 | experiment_config = get_config_settings( 108 | stimulation_name, parameter_dict, f"{EXP_NAME}.ini" 109 | ) 110 | 111 | return experiment_config 112 | 113 | 114 | def get_trigger_settings(trigger_name, parameter_dict): 115 | experiment_config = get_config_settings( 116 | trigger_name, parameter_dict, f"{EXP_NAME}.ini" 117 | ) 118 | 119 | return experiment_config 120 | 121 | 122 | def get_process_settings(process_name, parameter_dict): 123 | experiment_config = get_config_settings( 124 | process_name, parameter_dict, f"{EXP_NAME}.ini" 125 | ) 126 | 127 | return experiment_config 128 | 129 | 130 | def setup_experiment(): 131 | if EXP_ORIGIN.upper() == "BASE": 132 | config = cfg.ConfigParser() 133 | path = os.path.join( 134 | os.path.dirname(__file__), "..", "configs", f"{EXP_NAME}.ini" 135 | ) 136 | try: 137 | with open(path) as file: 138 | config.read_file(file) 139 | except FileNotFoundError: 140 | raise FileNotFoundError( 141 | f"{EXP_NAME}.ini was not found. Make sure it exists." 142 | ) 143 | 144 | experiment_name = config["EXPERIMENT"]["BASE"] 145 | import importlib 146 | 147 | mod = importlib.import_module("experiments.base.experiments") 148 | try: 149 | experiment_class = getattr(mod, experiment_name) 150 | experiment = experiment_class() 151 | except Exception: 152 | raise ValueError( 153 | f"Experiment: {experiment_name} not in base.experiments.py." 154 | ) 155 | 156 | elif EXP_ORIGIN.upper() == "CUSTOM": 157 | 158 | experiment_name = EXP_NAME 159 | import importlib 160 | 161 | mod = importlib.import_module("experiments.custom.experiments") 162 | try: 163 | experiment_class = getattr(mod, experiment_name) 164 | experiment = experiment_class() 165 | except Exception: 166 | raise ValueError( 167 | f"Experiment: {experiment_name} not in custom.experiments.py." 168 | ) 169 | 170 | else: 171 | raise ValueError( 172 | f'Experiment Origin "{EXP_ORIGIN}" not valid. Pick CUSTOM or BASE.' 173 | ) 174 | 175 | return experiment 176 | 177 | 178 | def setup_trigger(trigger_name): 179 | import importlib 180 | 181 | mod = importlib.import_module("experiments.base.triggers") 182 | try: 183 | trigger_class = getattr(mod, trigger_name) 184 | trigger = trigger_class() 185 | except Exception: 186 | raise ValueError(f"Trigger: {trigger_name} not in base.triggers.py.") 187 | 188 | return trigger 189 | 190 | 191 | def setup_process(process_name): 192 | import importlib 193 | 194 | mod = importlib.import_module("experiments.base.stimulus_process") 195 | try: 196 | process_class = getattr(mod, process_name) 197 | process = process_class() 198 | except Exception: 199 | raise ValueError(f"Process: {process_name} not in base.stimulus_process.py.") 200 | 201 | return process 202 | 203 | 204 | def setup_stimulation(stimulus_name): 205 | import importlib 206 | 207 | mod = importlib.import_module("experiments.base.stimulation") 208 | try: 209 | stimulation_class = getattr(mod, stimulus_name) 210 | stimulation = stimulation_class() 211 | except Exception: 212 | raise ValueError(f"Stimulus: {stimulus_name} not in stimulation.py.") 213 | 214 | return stimulation 215 | 216 | 217 | class DlStreamConfigWriter: 218 | def __init__(self): 219 | self._config = self._init_configparser() 220 | self._default_config = self._init_configparser() 221 | self._init_configparser() 222 | self._filename = None 223 | self._default_path = os.path.join(os.path.dirname(__file__), "..", "configs") 224 | self._dlstream_dict = dict( 225 | EXPERIMENT=dict(BASE="DEFAULT", EXPERIMENTER="DEFAULT") 226 | ) 227 | self._date = date.today().strftime("%d%m%Y") 228 | # TODO: Make this adaptive! 229 | self._available_modules = dict( 230 | EXPERIMENT=[ 231 | "BaseConditionalExperiment", 232 | "BaseOptogeneticExperiment", 233 | "BaseTrialExperiment", 234 | ], 235 | TRIGGER=[ 236 | "BaseRegionTrigger", 237 | "BaseOutsideRegionTrigger", 238 | "BaseHeaddirectionTrigger", 239 | "BaseEgoHeaddirectionTrigger", 240 | "BaseScreenTrigger", 241 | "BaseSpeedTrigger", 242 | "BaseFreezeTrigger", 243 | "BaseHeaddirectionROITrigger", 244 | ], 245 | PROCESS=["BaseProtocolProcess"], 246 | STIMULATION=["BaseStimulation", "RewardDispenser", "ScreenStimulation"], 247 | ) 248 | 249 | @staticmethod 250 | def _init_configparser(): 251 | config = cfg.ConfigParser() 252 | config.optionxform = str 253 | return config 254 | 255 | def _read_default_config(self): 256 | try: 257 | self._default_config.read( 258 | os.path.join(self._default_path, "default_config.ini") 259 | ) 260 | except FileNotFoundError: 261 | raise FileNotFoundError( 262 | "The default_config.ini was not found. Make sure it exists." 263 | ) 264 | 265 | def _read_config(self, config_path): 266 | 267 | try: 268 | self._config.read(config_path) 269 | except FileNotFoundError: 270 | raise FileNotFoundError("Config file does not exist at this location.") 271 | 272 | def set_experimenter(self, name): 273 | self._dlstream_dict["EXPERIMENT"]["EXPERIMENTER"] = name 274 | 275 | def set_experiment(self, experiment_name): 276 | self._dlstream_dict["EXPERIMENT"]["BASE"] = experiment_name 277 | 278 | def import_default( 279 | self, 280 | experiment_name, 281 | trigger_name=None, 282 | process_name=None, 283 | stimulation_name=None, 284 | trial_trigger_name=None, 285 | ): 286 | 287 | self._read_default_config() 288 | 289 | self.set_experiment(experiment_name) 290 | 291 | try: 292 | self._dlstream_dict[experiment_name] = self._default_config[experiment_name] 293 | except Exception: 294 | raise ValueError(f"Unknown Experiment: {experiment_name}.") 295 | 296 | if trigger_name is not None: 297 | self._dlstream_dict[trigger_name] = self._default_config[trigger_name] 298 | self._dlstream_dict[experiment_name]["TRIGGER"] = trigger_name 299 | else: 300 | trigger_name = self._dlstream_dict[experiment_name]["TRIGGER"] 301 | if trigger_name is not None: 302 | self._dlstream_dict[trigger_name] = self._default_config[trigger_name] 303 | 304 | if process_name is not None: 305 | self._dlstream_dict[process_name] = self._default_config[process_name] 306 | self._dlstream_dict[experiment_name]["PROCESS"] = process_name 307 | else: 308 | process_name = self._dlstream_dict[experiment_name]["PROCESS"] 309 | if process_name is None: 310 | self._dlstream_dict[process_name] = self._default_config[process_name] 311 | 312 | if stimulation_name is not None: 313 | self._dlstream_dict[stimulation_name] = self._default_config[ 314 | stimulation_name 315 | ] 316 | self._dlstream_dict[process_name]["STIMULATION"] = stimulation_name 317 | else: 318 | stimulation_name = self._dlstream_dict[process_name]["STIMULATION"] 319 | if stimulation_name is not None: 320 | self._dlstream_dict[stimulation_name] = self._default_config[ 321 | stimulation_name 322 | ] 323 | 324 | # TODO: Make this adaptive 325 | if experiment_name == "BaseTrialExperiment" and trial_trigger_name is not None: 326 | # TODO: ADD option to use the same trigger as trigger_name 327 | if not trigger_name == trial_trigger_name: 328 | self._dlstream_dict[trial_trigger_name] = self._default_config[ 329 | trial_trigger_name 330 | ] 331 | else: 332 | raise ValueError( 333 | f"Trial Trigger can currently not be the same as Trigger." 334 | ) 335 | 336 | def import_custom(self, config_path): 337 | 338 | self._read_config(config_path) 339 | self._dlstream_dict = self._config._sections 340 | 341 | def _set_path(self): 342 | if self._filename is None: 343 | experiment_name = self._dlstream_dict["EXPERIMENT"]["BASE"] 344 | self._filename = f"{experiment_name}_{self._date}.ini" 345 | 346 | file = open(os.path.join(self._default_path, self._filename), "w") 347 | return file 348 | 349 | def _change_module(self, module_type: str, module_name: str): 350 | """Changes module in a given settings.ini (resets parameters on that module) 351 | :param module_type str: Module type (TRIGGER, PROCESS, STIMULATION, EXPERIMENT etc.) 352 | :param module_name str: Exact name of new module (with Camelcase) 353 | :param config_path: path to config that needs changing""" 354 | 355 | module_type = module_type.upper() 356 | self._read_default_config() 357 | # self.import_custom(config_path) 358 | 359 | for key in self._dlstream_dict.keys(): 360 | if module_type in self._dlstream_dict[key].keys(): 361 | old_module = self._dlstream_dict[key][module_type] 362 | self._dlstream_dict[key][module_type] = module_name 363 | self._dlstream_dict.pop(old_module, None) 364 | self._dlstream_dict[module_name] = self._default_config[module_name] 365 | print(f"Changed {old_module} to {module_name}.") 366 | 367 | def _change_parameter( 368 | self, module_name: str, parameter_name: str, parameter_value: str 369 | ): 370 | 371 | parameter_name = parameter_name.upper() 372 | # self.import_custom(config_path) 373 | if module_name in self._dlstream_dict.keys(): 374 | if parameter_name in self._dlstream_dict[module_name].keys(): 375 | old_value = self._dlstream_dict[module_name][parameter_name] 376 | if not isinstance(parameter_value, str): 377 | parameter_value = str(parameter_value) 378 | self._dlstream_dict[module_name][parameter_name] = parameter_value 379 | print( 380 | f"Changed {parameter_name} in {module_name} from {old_value} to {parameter_value}." 381 | ) 382 | else: 383 | raise ValueError( 384 | f"Parameter {parameter_name} does not exist in given config." 385 | ) 386 | 387 | else: 388 | raise ValueError(f"Module {module_name} does not exist in given config.") 389 | 390 | def check_if_default_exists(self, module_name, module_type): 391 | # TODO: adjust to make adaptive 392 | # self._read_default_config() 393 | if module_name in self._available_modules[module_type]: 394 | return True 395 | else: 396 | return False 397 | 398 | def get_available_module_names(self, module_type): 399 | return self._available_modules[module_type] 400 | 401 | def change_modules(self, config_path, module_dict: dict, overwrite: bool = False): 402 | """Changes multiple modules at once, 403 | :param module_dict: dictionary in style:(module_type = module_name)""" 404 | 405 | self.import_custom(config_path) 406 | 407 | for key, value in module_dict.items(): 408 | self._change_module(module_type=str(key).upper(), module_name=value) 409 | 410 | if overwrite: 411 | self.write_ini(path=config_path) 412 | 413 | def change_parameters( 414 | self, config_path, parameter_dict: dict, overwrite: bool = False 415 | ): 416 | """Changes multiple modules at once, 417 | :param paramter_dict: nested dictionary in style:{module_name: dict(parameter_name = value)}""" 418 | 419 | self.import_custom(config_path) 420 | for key in parameter_dict.keys(): 421 | for inner_key, value in parameter_dict[key].items(): 422 | self._change_parameter( 423 | module_name=str(key).upper(), 424 | parameter_name=inner_key, 425 | parameter_value=value, 426 | ) 427 | 428 | if overwrite: 429 | self.write_ini(path=config_path) 430 | 431 | def change_config(self, config_path, config_dict: dict, overwrite: bool = False): 432 | """Changes multiple modules and parameters at once, 433 | :param config_dict: nested dictionary in style:dict(module_type ={module_name: dict(parameter_name = value)})""" 434 | 435 | self.import_custom(config_path) 436 | 437 | for key, inner_dict in config_dict.items(): 438 | if isinstance(inner_dict, dict): 439 | for inner_key, most_inner_dict in inner_dict.items(): 440 | self._change_module(module_type=key, module_name=inner_key) 441 | if isinstance(most_inner_dict, dict): 442 | for most_inner_key in most_inner_dict.keys(): 443 | 444 | self._change_parameter( 445 | module_name=inner_key, 446 | parameter_name=str(most_inner_key).upper(), 447 | parameter_value=most_inner_dict[most_inner_key], 448 | ) 449 | else: 450 | self._change_module(module_type=key, module_name=inner_dict) 451 | 452 | if overwrite: 453 | self.write_ini(path=config_path) 454 | 455 | def write_ini(self, path: str = None): 456 | if path is None: 457 | file = self._set_path() 458 | else: 459 | self._filename = os.path.basename(path) 460 | file = open(path, "w") 461 | self._config = self._init_configparser() 462 | for key in self._dlstream_dict.keys(): 463 | self._config.add_section(key) 464 | for parameter, value in self._dlstream_dict[key].items(): 465 | self._config.set(key, parameter, str(value)) 466 | self._config.write(file) 467 | file.close() 468 | print(f"Created {self._filename}.") 469 | 470 | def set_filename(self, filename): 471 | self._filename = filename + ".ini" 472 | 473 | def get_current_config(self): 474 | return self._dlstream_dict 475 | 476 | def get_parameters(self, module_name): 477 | if module_name in self._dlstream_dict.keys(): 478 | print(self._dlstream_dict.keys()) 479 | return self._dlstream_dict[module_name] 480 | else: 481 | raise ValueError(f"{module_name} is not valid.") 482 | 483 | 484 | if __name__ == "__main__": 485 | 486 | exp = setup_experiment() 487 | print(exp) 488 | -------------------------------------------------------------------------------- /experiments/utils/gpio_control.py: -------------------------------------------------------------------------------- 1 | from gpiozero import DigitalOutputDevice 2 | from gpiozero.pins.pigpio import PiGPIOFactory 3 | 4 | import serial 5 | 6 | 7 | class DigitalPiDevice: 8 | """ 9 | Digital modulated devices in combination with Raspberry Pi GPIO 10 | Setup: https://gpiozero.readthedocs.io/en/stable/remote_gpio.html 11 | """ 12 | 13 | def __init__(self, PIN, BOARD_IP: str = None): 14 | 15 | """ 16 | :param BOARD_IP: IP adress of board connected to the Device 17 | """ 18 | if BOARD_IP is not None: 19 | self._factory = PiGPIOFactory(host=BOARD_IP) 20 | self._device = DigitalOutputDevice(PIN, pin_factory=self._factory) 21 | else: 22 | self._factory = None 23 | self._device = DigitalOutputDevice(PIN) 24 | self._running = False 25 | 26 | def turn_on(self): 27 | self._device.on() 28 | self._running = True 29 | 30 | def turn_off(self): 31 | self._device.off() 32 | self._running = False 33 | 34 | def toggle(self): 35 | self._device.toggle() 36 | self._running = self._device.is_active 37 | 38 | 39 | class DigitalArduinoDevice: 40 | """ 41 | Digital modulated devices in combination with Arduino boards connected via USB 42 | setup: https://pythonforundergradengineers.com/python-arduino-LED.html 43 | 44 | """ 45 | 46 | def __init__(self, PORT): 47 | """ 48 | :param PORT: USB PORT of the arduino board 49 | """ 50 | self._device = serial.Serial(PORT, baudrate=9600) 51 | self._running = False 52 | 53 | def turn_on(self): 54 | self._device.write(b"H") 55 | self._running = True 56 | 57 | def turn_off(self): 58 | self._device.write(b"L") 59 | self._running = False 60 | 61 | def toggle(self): 62 | if self._running: 63 | self.turn_off() 64 | else: 65 | self.turn_on() 66 | -------------------------------------------------------------------------------- /misc/DLStream_Logo_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/DLStream_Logo_small.png -------------------------------------------------------------------------------- /misc/StartAnalysis2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StartAnalysis2.png -------------------------------------------------------------------------------- /misc/StartExperiment2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StartExperiment2.png -------------------------------------------------------------------------------- /misc/StartRecording2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StartRecording2.png -------------------------------------------------------------------------------- /misc/StartStream2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StartStream2.png -------------------------------------------------------------------------------- /misc/StopAnalysis2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StopAnalysis2.png -------------------------------------------------------------------------------- /misc/StopExperiment2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StopExperiment2.png -------------------------------------------------------------------------------- /misc/StopRecording2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StopRecording2.png -------------------------------------------------------------------------------- /misc/StopStream2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SchwarzNeuroconLab/DeepLabStream/5736b2b3ecb16c47e7ba48121717813bddacc020/misc/StopStream2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | PySide2==5.15.2.1 3 | gpiozero==1.6.2 4 | pigpio==1.78 5 | pyserial==3.5 6 | nidaqmx==0.6.4 7 | click==8.1.3 8 | opencv-python==4.5.2.54 9 | opencv-contrib-python==4.5.2.54 10 | numpy>=1.14.5 11 | pandas==1.5.2 12 | scikit-learn== 0.24.1 13 | scikit-image==0.17.2 14 | scipy==1.4.1 15 | pure-predict==0.0.4 16 | numba==0.56.0 17 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [Streaming] 2 | RESOLUTION = 960, 540 3 | FRAMERATE = 30 4 | #put in the path to your outputfolder 5 | OUTPUT_DIRECTORY = C:/Output 6 | #if you have connected multiple cameras (USB), you will need to select the number OpenCV has given them. 7 | #Default is "0", which takes the first available camera. 8 | CAMERA_SOURCE = 0 9 | #you can use "camera", "ipwebcam" or "video" to select your input source 10 | STREAMING_SOURCE = camera 11 | 12 | [Pose Estimation] 13 | #possible origins are: SLEAP, DLC, DLC-LIVE,MADLC, DEEPPOSEKIT 14 | MODEL_ORIGIN = MODEL_ORIGIN 15 | #takes path to model or models (in case of SLEAP topdown, bottom up) in style "string" or "string , string", without "" 16 | # E.g.: MODEL_PATH = D:\SLEAP\models\baseline_model.centroids , D:\SLEAP\models\baseline_model.topdown 17 | MODEL_PATH = PATH_TO_MODEL 18 | MODEL_NAME = NAME_OF_MODEL 19 | ; only used in DLC-LIVE and DeepPoseKit for now; if left empty or to short, auto-naming will be enabled in style bp1, bp2 ... 20 | ALL_BODYPARTS = bp1, bp2, bp3, bp4 21 | 22 | [Experiment] 23 | #Available parameters are "CUSTOM" and "BASE" 24 | EXP_ORIGIN = CUSTOM 25 | #Name of the experiment config in /experiments/configs or name of the custom experiment in /experiments/custom/experiments.py 26 | EXP_NAME = ExampleExperiment 27 | #if you want the experiment to be recorded as a raw video set this to "True". 28 | RECORD_EXP = True 29 | 30 | [Classification] 31 | PATH_TO_CLASSIFIER = PATH_TO_CLASSIFIER 32 | #time window used for feature extraction (currently only works with 15) 33 | TIME_WINDOW = 15 34 | #number of parallel classifiers to run, this is dependent on your performance time. You need at least 1 more classifier then your average classification time. 35 | POOL_SIZE = 1 36 | #threshold to accept a classification probability as positive detection (SIMBA + ) 37 | THRESHOLD = 0.9 38 | # class/category of identified behavior to use as trigger (only used for B-SOID) 39 | TRIGGER = 0 40 | #feature extraction currently works with millimeter not px, so be sure to enter the factor (as in simba). 41 | PIXPERMM = 1 42 | 43 | [Video] 44 | #Full path to video that you want to use as input. Needs "STREAMING_SOURCE" set to "video"! 45 | VIDEO_SOURCE = PATH_TO_VIDEO 46 | 47 | [IPWEBCAM] 48 | #Standard Port is 5555 if you followed the SmoothStream setup 49 | PORT = 5555 50 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | -------------------------------------------------------------------------------- /utils/advanced_settings.ini: -------------------------------------------------------------------------------- 1 | #advanced settings only change them if you know what you do! 2 | 3 | [Streaming] 4 | STREAMS = color, depth, infrared 5 | MULTIPLE_DEVICES = False 6 | STACK_FRAMES = False 7 | PASS_SEPARATE = False 8 | ANIMALS_NUMBER = 1 9 | 10 | CROP = False 11 | CROP_X = 0, 1280 12 | CROP_Y = 0, 500 13 | 14 | [Pose Estimation] 15 | FLATTEN_MA = FALSE 16 | SPLIT_MA = FALSE 17 | #handle missing will handle missing/NaN values. E.g., when using DLC pose estimation with likelihood filter. 18 | #Handle with care: NaN values might result in unexpected behavior during experiments and when triggers are calculated! 19 | #default is "skip", for SIMBA use "null", for BSOID use "pass", "reset" will set all values to NaN for the entire skeleton 20 | HANDLE_MISSING = skip 21 | 22 | # These settings work in synergy with HANDLE_MISSING (currently only for DLC) 23 | FILTER_LIKELIHOOD = True 24 | # Likelihood value to filter pose estimation by used in DLC. Filtered values will be set to NaN, NaN. 25 | LIKELIHOOD_THRESHOLD = 0.9 26 | # this is a legacy option for original DLSTREAM & DLC interaction and will soon be deprecated. 27 | USE_DLSTREAM_POSTURE_DETECTION = FALSE 28 | 29 | [Video] 30 | REPEAT_VIDEO = True 31 | -------------------------------------------------------------------------------- /utils/analysis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | DeepLabStream 4 | © J.Schweihoff, M. Loshakov 5 | University Bonn Medical Faculty, Germany 6 | https://github.com/SchwarzNeuroconLab/DeepLabStream 7 | Licensed under GNU General Public License v3.0 8 | """ 9 | 10 | 11 | import os 12 | from typing import Union, List, Tuple 13 | import numpy as np 14 | import pandas as pd 15 | import math 16 | 17 | 18 | class ROI: 19 | """ 20 | Creating a ROI with given parameters 21 | center - a tuple with (x,y) coordinates 22 | h - radius by Y-axis 23 | k - radius by X-axis 24 | name (optional) - name for ROI 25 | """ 26 | 27 | def __init__(self, center: tuple, h: int, k: int, name: str = "ROI"): 28 | self._name = name 29 | self._x_center, self._y_center = center 30 | self._x_radius = k 31 | self._y_radius = h 32 | 33 | # creating a coordinates box 34 | self._box = [ 35 | self._x_center - self._x_radius, 36 | self._y_center - self._y_radius, 37 | self._x_center + self._x_radius, 38 | self._y_center + self._y_radius, 39 | ] 40 | 41 | def get_box(self): 42 | """ 43 | Returns box coordinates as [x1, y1, x2, y2] 44 | """ 45 | return self._box 46 | 47 | def get_x_radius(self): 48 | """ 49 | Returns x_radius 50 | """ 51 | return self._x_radius 52 | 53 | def get_y_radius(self): 54 | """ 55 | Returns y_radius 56 | """ 57 | return self._y_radius 58 | 59 | def get_center(self): 60 | """ 61 | Returns center coordinates as x, y 62 | """ 63 | return self._x_center, self._y_center 64 | 65 | def get_name(self): 66 | """ 67 | Returns ROI name 68 | """ 69 | return self._name 70 | 71 | def set_name(self, name: str): 72 | """ 73 | Returns ROI name 74 | """ 75 | self._name = name 76 | 77 | 78 | class RectangleROI(ROI): 79 | """ 80 | Creating a rectangle ROI with given parameters 81 | center - a tuple with (x,y) coordinates 82 | h - radius by Y-axis 83 | k - radius by X-axis 84 | name (optional) - name for ROI 85 | """ 86 | 87 | def __init__(self, center: tuple, h: int, k: int, name: str = "RectangleROI"): 88 | super().__init__(center, h, k, name) 89 | 90 | def check_point(self, x: int, y: int): 91 | """ 92 | Checking if point with given coordinates x,y is inside ROI 93 | Returns True or False 94 | """ 95 | check = (-self._x_radius <= x - self._x_center <= self._x_radius) and ( 96 | -self._y_radius <= y - self._y_center <= self._y_radius 97 | ) 98 | return check 99 | 100 | 101 | class EllipseROI(ROI): 102 | """ 103 | Creating a ROI ellipse with given parameters 104 | center - a tuple with (x,y) coordinates 105 | h - radius by Y-axis 106 | k - radius by X-axis 107 | name (optional) - name for ROI 108 | """ 109 | 110 | def __init__(self, center: tuple, h: int, k: int, name: str = "EllipseROI"): 111 | super().__init__(center, h, k, name) 112 | 113 | def check_point(self, x: int, y: int): 114 | """ 115 | Checking if point with given coordinates x,y is inside ROI 116 | Returns True or False 117 | """ 118 | check = ((x - self._x_center) ** 2 / self._x_radius ** 2) + ( 119 | (y - self._y_center) ** 2 / self._y_radius ** 2 120 | ) 121 | return check <= 1 122 | 123 | 124 | def calculate_distance(point1: tuple, point2: tuple) -> float: 125 | """ 126 | Calculates distance between two points (x1,y1) and (x2,y2) 127 | """ 128 | if np.isnan(point1).any() or np.isnan(point2).any(): 129 | return np.nan 130 | x1, y1 = point1 131 | x2, y2 = point2 132 | distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 133 | return distance 134 | 135 | 136 | def calculate_distance_for_bodyparts( 137 | dataframe: pd.DataFrame, body_parts: Union[List[str], str] 138 | ) -> List[pd.Series]: 139 | """ 140 | Calculating distances traveled for each frame for desired body parts 141 | :param dataframe DataFrame: dataframe to calculate distances on 142 | Should have columns with X and Y coordinates of desired body_parts 143 | :param body_parts str or list of str: part or parts to calculate distances for 144 | Can be either string or list of strings 145 | :return list: returns list of pd.Series with distances for each bodypart 146 | """ 147 | df = dataframe 148 | # creating temporary dataframe for calculations 149 | temp_df = pd.DataFrame() 150 | # creating empty list for results 151 | results = [] 152 | 153 | def distance_func(row, bodypart): 154 | """ 155 | Function to actually calculate distance 156 | """ 157 | return math.sqrt( 158 | row["{}_travel_X".format(bodypart)] ** 2 159 | + row["{}_travel_Y".format(bodypart)] ** 2 160 | ) 161 | 162 | def calc_distance(bodypart): 163 | """ 164 | Function to create temporary dataframe columns and do calculations on them 165 | Then append the resulting series to results list 166 | """ 167 | temp_df["{}_travel_X".format(bodypart)] = ( 168 | df["{}_X".format(bodypart)].diff().astype(float) 169 | ) 170 | temp_df["{}_travel_Y".format(bodypart)] = ( 171 | df["{}_Y".format(bodypart)].diff().astype(float) 172 | ) 173 | results.append(temp_df.apply(distance_func, axis=1, args=(bodypart,))) 174 | 175 | # checking if provided body_parts is list or not 176 | if isinstance(body_parts, list): 177 | # if list, calculate for every body part in it 178 | for part in body_parts: 179 | calc_distance(part) 180 | else: 181 | # if not, calculate for one provided part 182 | calc_distance(body_parts) 183 | return results 184 | 185 | 186 | def calculate_speed_for_bodyparts( 187 | dataframe: pd.DataFrame, body_parts: Union[List[str], str] 188 | ) -> List[pd.Series]: 189 | """ 190 | Calculating speed in pixels per seconds for each frame for desired body parts 191 | :param dataframe DataFrame: dataframe to calculate speeds on 192 | Should have columns distances travelled for each desired body part 193 | :param body_parts str or list of str: part or parts to calculate distances for 194 | Can be either string or list of strings 195 | :return list: returns list of pd.Series with speeds for each bodypart 196 | """ 197 | df = dataframe 198 | # creating temporary dataframe for calculations 199 | temp_df = pd.DataFrame() 200 | # calculating time differences between each frame 201 | temp_df["Time_diff"] = df["Time"].diff().astype(float) 202 | # creating empty list for results 203 | results = [] 204 | 205 | def speed_func(row, bodypart): 206 | """ 207 | Function to actually calculate speed 208 | """ 209 | if row["Time_diff"] != 0: 210 | return row["distance_{}".format(bodypart)] / row["Time_diff"] 211 | else: 212 | return np.nan 213 | 214 | def check_for_distance(bodypart): 215 | """ 216 | Check if column with distance for desired body part exists in provided dataframe 217 | If true, copy it to temp_df 218 | Otherwise, raise ValueError exception 219 | """ 220 | if "distance_{}".format(bodypart) in df.columns: 221 | temp_df["distance_{}".format(bodypart)] = df["distance_{}".format(bodypart)] 222 | else: 223 | raise ValueError( 224 | "Distances travelled should be calculated beforehand for each bodypart" 225 | ) 226 | 227 | # checking if provided body_parts is list or not 228 | if isinstance(body_parts, list): 229 | # if list, calculate for every body part in it 230 | for part in body_parts: 231 | check_for_distance(part) 232 | results.append(temp_df.apply(speed_func, axis=1, args=(part,))) 233 | else: 234 | # if not, calculate for one provided part 235 | check_for_distance(body_parts) 236 | results.append(temp_df.apply(speed_func, axis=1, args=(body_parts,))) 237 | return results 238 | 239 | 240 | def angle_between_vectors( 241 | xa: int, ya: int, xb: int, yb: int, xc: int, yc: int 242 | ) -> Tuple[str, float]: 243 | """ 244 | Calculating angle between vectors, defined by coordinates 245 | Returns angle and direction (left, right, forward or backward) 246 | *ISSUE* - if y axis is reversed, directions would also be reversed 247 | """ 248 | # using atan2() formula for both vectors 249 | dir_ab = math.atan2(ya - yb, xa - xb) 250 | dir_bc = math.atan2(yb - yc, xb - xc) 251 | 252 | # angle between vectors in radians 253 | rad_angle = dir_ab - dir_bc 254 | pi = math.pi 255 | 256 | # converting to degrees 257 | angle = rad_angle 258 | if pi < angle: 259 | angle -= 2 * pi 260 | elif -pi > angle: 261 | angle += 2 * pi 262 | angle = math.degrees(angle) 263 | 264 | # defining the direction 265 | if 180 > angle > 0: 266 | direction = "left" 267 | elif -180 < angle < 0: 268 | direction = "right" 269 | elif abs(angle) == 180: 270 | direction = "backwards" 271 | else: 272 | direction = "forward" 273 | 274 | return direction, angle 275 | 276 | 277 | ## miscellaneous ## 278 | def cls(): 279 | os.system("cls" if os.name == "nt" else "clear") 280 | -------------------------------------------------------------------------------- /utils/configloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import time 10 | import os 11 | import configparser as cfg 12 | 13 | # loading DeepLabStream configuration 14 | # remember when it was called DSC? 15 | dsc_config = cfg.ConfigParser() 16 | adv_dsc_config = cfg.ConfigParser() 17 | 18 | 19 | def get_script_path(): 20 | return os.path.dirname(os.path.join(os.path.dirname(__file__), "..")) 21 | 22 | 23 | cfg_path = os.path.join(os.path.dirname(__file__), "..", "settings.ini") 24 | with open(cfg_path) as cfg_file: 25 | dsc_config.read_file(cfg_file) 26 | 27 | adv_cfg_path = os.path.join(os.path.dirname(__file__), "advanced_settings.ini") 28 | with open(adv_cfg_path) as adv_cfg_file: 29 | adv_dsc_config.read_file(adv_cfg_file) 30 | # DeepLabCut 31 | # deeplabcut_config = dict(dsc_config.items('DeepLabCut')) 32 | 33 | # poseestimation 34 | MODEL_ORIGIN = dsc_config["Pose Estimation"].get("MODEL_ORIGIN") 35 | model_path_string = [ 36 | str(part).strip() 37 | for part in dsc_config["Pose Estimation"].get("MODEL_PATH").split(",") 38 | ] 39 | MODEL_PATH = model_path_string[0] if len(model_path_string) <= 1 else model_path_string 40 | MODEL_NAME = dsc_config["Pose Estimation"].get("MODEL_NAME") 41 | ALL_BODYPARTS = tuple( 42 | part.strip() 43 | for part in dsc_config["Pose Estimation"].get("ALL_BODYPARTS").split(",") 44 | ) 45 | 46 | 47 | # Streaming items 48 | 49 | try: 50 | RESOLUTION = tuple( 51 | int(part) for part in dsc_config["Streaming"].get("RESOLUTION").split(",") 52 | ) 53 | except ValueError: 54 | print( 55 | "Incorrect resolution in config!\n" 56 | 'Using default value "RESOLUTION = 848, 480"' 57 | ) 58 | RESOLUTION = (848, 480) 59 | 60 | FRAMERATE = dsc_config["Streaming"].getint("FRAMERATE") 61 | OUT_DIR = dsc_config["Streaming"].get("OUTPUT_DIRECTORY") 62 | CAMERA_SOURCE = dsc_config["Streaming"].get("CAMERA_SOURCE") 63 | STREAMING_SOURCE = dsc_config["Streaming"].get("STREAMING_SOURCE") 64 | # Video 65 | VIDEO_SOURCE = dsc_config["Video"].get("VIDEO_SOURCE") 66 | 67 | # IPWEBCAM 68 | PORT = dsc_config["IPWEBCAM"].get("PORT") 69 | 70 | 71 | # experiment 72 | EXP_ORIGIN = dsc_config["Experiment"].get("EXP_ORIGIN") 73 | EXP_NAME = dsc_config["Experiment"].get("EXP_NAME") 74 | RECORD_EXP = dsc_config["Experiment"].getboolean("RECORD_EXP") 75 | 76 | START_TIME = time.time() 77 | 78 | # Classification 79 | PATH_TO_CLASSIFIER = dsc_config["Classification"].get("PATH_TO_CLASSIFIER") 80 | POOL_SIZE = dsc_config["Classification"].getint("POOL_SIZE") 81 | 82 | # SIMBA 83 | PIXPERMM = dsc_config["Classification"].getfloat("PIXPERMM") 84 | THRESHOLD = dsc_config["Classification"].getfloat("THRESHOLD") 85 | TRIGGER = dsc_config["Classification"].getfloat("TRIGGER") 86 | 87 | # BSOID 88 | TIME_WINDOW = dsc_config["Classification"].getint("TIME_WINDOW") 89 | 90 | 91 | """advanced settings""" 92 | STREAMS = [ 93 | str(part).strip() for part in adv_dsc_config["Streaming"].get("STREAMS").split(",") 94 | ] 95 | MULTI_CAM = adv_dsc_config["Streaming"].getboolean("MULTIPLE_DEVICES") 96 | STACK_FRAMES = ( 97 | adv_dsc_config["Streaming"].getboolean("STACK_FRAMES") 98 | if adv_dsc_config["Streaming"].getboolean("STACK_FRAMES") is not None 99 | else False 100 | ) 101 | ANIMALS_NUMBER = ( 102 | adv_dsc_config["Streaming"].getint("ANIMALS_NUMBER") 103 | if adv_dsc_config["Streaming"].getint("ANIMALS_NUMBER") is not None 104 | else 1 105 | ) 106 | PASS_SEPARATE = adv_dsc_config["Streaming"].getboolean("PASS_SEPARATE") 107 | 108 | REPEAT_VIDEO = adv_dsc_config["Video"].getboolean("REPEAT_VIDEO") 109 | CROP = adv_dsc_config["Streaming"].getboolean("CROP") 110 | CROP_X = [ 111 | int(str(part).strip()) 112 | for part in adv_dsc_config["Streaming"].get("CROP_X").split(",") 113 | ] 114 | CROP_Y = [ 115 | int(str(part).strip()) 116 | for part in adv_dsc_config["Streaming"].get("CROP_Y").split(",") 117 | ] 118 | 119 | USE_DLSTREAM_POSTURE_DETECTION = adv_dsc_config["Pose Estimation"].getboolean("USE_DLSTREAM_POSTURE_DETECTION") 120 | FLATTEN_MA = adv_dsc_config["Pose Estimation"].getboolean("FLATTEN_MA") 121 | SPLIT_MA = adv_dsc_config["Pose Estimation"].getboolean("SPLIT_MA") 122 | HANDLE_MISSING = adv_dsc_config["Pose Estimation"].get("HANDLE_MISSING") 123 | FILTER_LIKELIHOOD = adv_dsc_config["Pose Estimation"].getboolean("FILTER_LIKELIHOOD") 124 | LIKELIHOOD_THRESHOLD = adv_dsc_config["Pose Estimation"].getfloat("LIKELIHOOD_THRESHOLD") 125 | -------------------------------------------------------------------------------- /utils/generic.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | import time 9 | import cv2 10 | import numpy as np 11 | 12 | from utils.configloader import ( 13 | CAMERA_SOURCE, 14 | VIDEO_SOURCE, 15 | RESOLUTION, 16 | FRAMERATE, 17 | REPEAT_VIDEO, 18 | ) 19 | 20 | 21 | class MissingFrameError(Exception): 22 | """Custom expection to be raised when frame is not received. Should be caught in app.py and deeplabstream.py 23 | to stop dlstream gracefully""" 24 | 25 | 26 | class GenericManager: 27 | """ 28 | Camera manager class for generic (not specified) cameras 29 | """ 30 | 31 | def __init__(self): 32 | """ 33 | Generic camera manager from video source 34 | Uses pure opencv 35 | """ 36 | self._source = CAMERA_SOURCE if CAMERA_SOURCE is not None else 0 37 | self._manager_name = "generic" 38 | self._enabled_devices = {} 39 | self._camera = None 40 | # Will be called when enabling stream! Important for restart of stream 41 | # self._camera = cv2.VideoCapture(int(self._source)) 42 | self._camera_name = "Camera {}".format(self._source) 43 | 44 | def get_connected_devices(self) -> list: 45 | """ 46 | Getter for stored connected devices list 47 | """ 48 | return [self._camera_name] 49 | 50 | def get_enabled_devices(self) -> dict: 51 | """ 52 | Getter for enabled devices dictionary 53 | """ 54 | return self._enabled_devices 55 | 56 | def enable_stream(self, resolution, framerate, *args): 57 | """ 58 | Enable one stream with given parameters 59 | (hopefully) 60 | """ 61 | width, height = resolution 62 | self._camera = cv2.VideoCapture(int(self._source)) 63 | self._camera.set(cv2.CAP_PROP_FRAME_WIDTH, width) 64 | self._camera.set(cv2.CAP_PROP_FRAME_HEIGHT, height) 65 | self._camera.set(cv2.CAP_PROP_FPS, framerate) 66 | 67 | def enable_device(self, *args): 68 | """ 69 | Redirects to enable_all_devices() 70 | """ 71 | self.enable_all_devices() 72 | 73 | def enable_all_devices(self): 74 | """ 75 | We don't need to enable anything with opencv 76 | """ 77 | self._enabled_devices = {self._camera_name: self._camera} 78 | 79 | def get_frames(self) -> tuple: 80 | """ 81 | Collect frames for camera and outputs it in 'color' dictionary 82 | ***depth and infrared are not used here*** 83 | :return: tuple of three dictionaries: color, depth, infrared 84 | """ 85 | color_frames = {} 86 | depth_maps = {} 87 | infra_frames = {} 88 | ret, image = self._camera.read() 89 | if ret: 90 | color_frames[self._camera_name] = image 91 | else: 92 | raise MissingFrameError( 93 | "No frame was received from the camera. Make sure that the camera is connected " 94 | "and that the camera source is set correctly." 95 | ) 96 | 97 | return color_frames, depth_maps, infra_frames 98 | 99 | def stop(self): 100 | """ 101 | Stops camera 102 | """ 103 | self._camera.release() 104 | self._enabled_devices = {} 105 | 106 | def get_name(self) -> str: 107 | return self._manager_name 108 | 109 | 110 | class VideoManager(GenericManager): 111 | 112 | """ 113 | Camera manager class for analyzing videos 114 | """ 115 | 116 | def __init__(self): 117 | """ 118 | Generic video manager from video files 119 | Uses pure opencv 120 | """ 121 | super().__init__() 122 | # will be defined in enable_stream 123 | self._camera = None 124 | self._camera_name = "Video" 125 | self.initial_wait = False 126 | self.last_frame_time = time.time() 127 | 128 | def enable_stream(self, resolution, framerate, *args): 129 | """ 130 | Enable one stream with given parameters 131 | (hopefully) 132 | """ 133 | # set video to first frame 134 | self._camera = cv2.VideoCapture(VIDEO_SOURCE) 135 | self._camera.set(cv2.CAP_PROP_POS_FRAMES, 0) 136 | 137 | def get_frames(self) -> tuple: 138 | """ 139 | Collect frames for camera and outputs it in 'color' dictionary 140 | ***depth and infrared are not used here*** 141 | :return: tuple of three dictionaries: color, depth, infrared 142 | """ 143 | 144 | color_frames = {} 145 | depth_maps = {} 146 | infra_frames = {} 147 | ret, image = self._camera.read() 148 | self.last_frame_time = time.time() 149 | if ret: 150 | if not self.initial_wait: 151 | cv2.waitKey(1000) 152 | self.initial_wait = True 153 | image = cv2.resize(image, RESOLUTION) 154 | color_frames[self._camera_name] = image 155 | running_time = time.time() - self.last_frame_time 156 | if running_time <= 1 / FRAMERATE: 157 | sleepy_time = int(np.ceil(1000 / FRAMERATE - running_time / 1000)) 158 | cv2.waitKey(sleepy_time) 159 | elif REPEAT_VIDEO: 160 | # cycle the video for testing purposes 161 | self._camera.set(cv2.CAP_PROP_POS_FRAMES, 0) 162 | return self.get_frames() 163 | else: 164 | raise MissingFrameError( 165 | "The video reached the end or is damaged. Use REPEAT_VIDEO in the advanced_settings to repeat videos." 166 | ) 167 | 168 | return color_frames, depth_maps, infra_frames 169 | -------------------------------------------------------------------------------- /utils/gui_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | from utils.configloader import RESOLUTION 10 | 11 | import cv2 12 | from PySide2.QtCore import QObject 13 | from PySide2.QtCore import Signal as pyqtSignal 14 | from PySide2.QtCore import Slot as pyqtSlot 15 | from PySide2.QtWidgets import QWidget, QLabel 16 | from PySide2.QtGui import QImage, QPixmap 17 | 18 | width, height = RESOLUTION 19 | 20 | 21 | class ImageWindow(QWidget): 22 | """ 23 | Image window is used to show an image within PyQt GUI 24 | This example is hardcoded for width and height of currently defined resolution 25 | """ 26 | 27 | def __init__(self, name): 28 | super().__init__() 29 | self.title = name 30 | self.left = 0 31 | self.top = 0 32 | self.width, self.height = width, height 33 | self.label = QLabel(self) 34 | self.init_ui() 35 | 36 | @pyqtSlot(QImage) 37 | def set_image(self, image): 38 | """ 39 | This is a slot for QImage widget 40 | We use this to "catch" an emitted image 41 | :param image: QImage 42 | """ 43 | # using label to output the image 44 | self.label.setPixmap(QPixmap.fromImage(image)) 45 | 46 | def init_ui(self): 47 | """ 48 | Creating a UI itself 49 | """ 50 | # title 51 | self.setWindowTitle(self.title) 52 | # geometric parameters 53 | self.setGeometry(self.left, self.top, self.width, self.height) 54 | self.resize(self.width, self.height) 55 | # create a label 56 | self.label.resize(self.width, self.height) 57 | 58 | def closeEvent(self, event): 59 | # keeps window from closing on accident 60 | event.ignore() 61 | 62 | 63 | class QFrame(QObject): 64 | """ 65 | This is a dummy class to workaround PyQt5 restrictions on declaring signals 66 | Basically, it does not allow to be declared in an instance of a class, or it will become bounded 67 | You need to define it as a class variable, but that does not work every time, especially for dynamic tasks 68 | In short - 69 | Okay: 70 | class A: 71 | smth = pyqtSignal() 72 | Not okay: 73 | class A: 74 | def __init__() 75 | self.smth = pyqtSignal() 76 | 77 | But both of the examples provide the same functionality, A.smth is a pyqtSignal 78 | So we workaround it, making a dummy class and going with the "okay" route 79 | This is essential for multicam support 80 | 81 | ! Probably have a tax on the performance 82 | """ 83 | 84 | signal = pyqtSignal(QImage) 85 | 86 | 87 | def emit_qframes(frames, qframes): 88 | """ 89 | Emit some number of stream frames, depending on cameras quantity 90 | Should do it through QFrame objects, defined above 91 | Don't forget to connect QFrame objects to PyQt slots in widgets! 92 | 93 | :param frames: dictionary of frames in format of {camera:frame} 94 | :param qframes: dictionary of qframes in the same format as frames 95 | """ 96 | for camera in frames: 97 | # converting to RGB 98 | rgb_image = cv2.cvtColor(frames[camera], cv2.COLOR_BGR2RGB) 99 | h, w, ch = rgb_image.shape 100 | # converting to QImage 101 | qpicture = QImage(rgb_image.data, w, h, ch * w, QImage.Format_RGB888) 102 | # scaling QImage to resolution 103 | scaled_qpicture = qpicture.scaled(width, height, Qt.KeepAspectRatio) 104 | # emitting the picture 105 | qframes[camera].signal.emit(scaled_qpicture) 106 | -------------------------------------------------------------------------------- /utils/plotter.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | 13 | def plot_dots(image, coordinates, color, cond=False): 14 | """ 15 | Takes the image and positional arguments from pose to plot corresponding dot 16 | Returns the resulting image 17 | """ 18 | cv2.circle(image, coordinates, 3, color, -1) 19 | if cond: 20 | cv2.circle(image, (10, 10), 10, (0, 255, 0), -1) 21 | return image 22 | 23 | 24 | def plot_bodyparts(image, skeletons): 25 | """ 26 | Takes the image and skeletons list to plot them 27 | :return: resulting image 28 | """ 29 | res_image = image.copy() 30 | # predefined colors list 31 | colors_list = [ 32 | (0, 0, 255), 33 | (0, 255, 0), 34 | (0, 255, 255), 35 | (255, 0, 0), 36 | (255, 0, 255), 37 | (255, 255, 0), 38 | (255, 255, 128), 39 | (0, 0, 128), 40 | (0, 128, 0), 41 | (0, 128, 128), 42 | (0, 128, 255), 43 | (0, 255, 128), 44 | (128, 0, 0), 45 | (128, 0, 128), 46 | (128, 0, 255), 47 | (128, 128, 0), 48 | (128, 128, 128), 49 | (128, 128, 255), 50 | (128, 255, 0), 51 | (128, 255, 128), 52 | (128, 255, 255), 53 | (255, 0, 128), 54 | (255, 128, 0), 55 | (255, 128, 128), 56 | (255, 128, 255), 57 | ] 58 | # color = (255, 0, 0) 59 | 60 | for num, animal in enumerate(skeletons): 61 | bodyparts = animal.keys() 62 | bp_count = len(bodyparts) 63 | # colors = dict(zip(bodyparts, colors_list[:bp_count])) 64 | for part in animal: 65 | # check for NaNs and skip 66 | if not any(np.isnan(animal[part])): 67 | plot_dots(res_image, tuple(map(int, animal[part])), colors_list[num]) 68 | # plot_dots(res_image, tuple(animal[part]), colors[part]) 69 | else: 70 | pass 71 | return res_image 72 | 73 | 74 | def plot_metadata_frame( 75 | image, frame_width, frame_height, current_fps, current_elapsed_time 76 | ): 77 | """ 78 | Takes the image and plots metadata 79 | :return: resulting image 80 | """ 81 | res_image = image.copy() 82 | font = cv2.FONT_HERSHEY_PLAIN 83 | 84 | cv2.putText( 85 | res_image, 86 | "Time: " + str(round(current_elapsed_time, 2)), 87 | (int(frame_width * 0.8), int(frame_height * 0.9)), 88 | font, 89 | 1, 90 | (255, 255, 0), 91 | ) 92 | cv2.putText( 93 | res_image, 94 | "FPS: " + str(round(current_fps, 1)), 95 | (int(frame_width * 0.8), int(frame_height * 0.94)), 96 | font, 97 | 1, 98 | (255, 255, 0), 99 | ) 100 | return res_image 101 | 102 | 103 | def plot_dlc_bodyparts(image, bodyparts): 104 | """ 105 | Plots dlc bodyparts on given image 106 | adapted from plotter 107 | """ 108 | 109 | for bp in bodyparts: 110 | center = tuple(bp.astype(int)) 111 | cv2.circle(image, center=center, radius=3, color=(255, 0, 0), thickness=2) 112 | return image 113 | 114 | 115 | def plot_triggers_response(image, response): 116 | """ 117 | Plots trigger response on given image 118 | """ 119 | if "plot" in response: 120 | plot = response["plot"] 121 | if "line" in plot: 122 | #make sure they are int for openCV. No half pixels there... 123 | 124 | plot['line']["pt1"] = tuple([int(i) for i in plot['line']["pt1"]]) 125 | plot['line']["pt2"] = tuple([int(i) for i in plot['line']["pt2"]]) 126 | cv2.line(image, **plot["line"], thickness=4) 127 | if "text" in plot: 128 | #make sure they are int for openCV. No half pixels there... 129 | plot['text']["org"] = tuple([int(i) for i in plot['text']["org"]]) 130 | font = cv2.FONT_HERSHEY_PLAIN 131 | cv2.putText(image, **plot["text"], fontFace=font, fontScale=1) 132 | if "circle" in plot: 133 | #make sure they are int for openCV. No half pixels there... 134 | plot['circle']["center"] = tuple([int(i) for i in plot['circle']["center"]]) 135 | plot['circle']["radius"] = int(plot['circle']["radius"]) 136 | 137 | cv2.circle(image, **plot["circle"], thickness=2) 138 | if "square" in plot: 139 | #make sure they are int for openCV. No half pixels there... 140 | plot['square']["pt1"] = tuple([int(i) for i in plot['square']["pt1"]]) 141 | plot['square']["pt2"] = tuple([int(i) for i in plot['square']["pt2"]]) 142 | 143 | cv2.rectangle(image, **plot["square"], thickness=2) 144 | -------------------------------------------------------------------------------- /utils/poser.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import sys 10 | import os 11 | import importlib.util 12 | from itertools import product, combinations 13 | 14 | import numpy as np 15 | from skimage.feature import peak_local_max 16 | from scipy.ndimage.measurements import label, maximum_position 17 | from scipy.ndimage.morphology import generate_binary_structure, binary_erosion 18 | from scipy.ndimage.filters import maximum_filter 19 | 20 | from utils.analysis import calculate_distance 21 | from utils.configloader import ( 22 | MODEL_ORIGIN, 23 | MODEL_NAME, 24 | MODEL_PATH, 25 | ALL_BODYPARTS, 26 | FLATTEN_MA, 27 | SPLIT_MA, 28 | HANDLE_MISSING, 29 | ANIMALS_NUMBER, 30 | FILTER_LIKELIHOOD, 31 | LIKELIHOOD_THRESHOLD, 32 | USE_DLSTREAM_POSTURE_DETECTION, 33 | ) 34 | 35 | # suppressing unnecessary warnings 36 | import warnings 37 | 38 | warnings.simplefilter(action="ignore", category=FutureWarning) 39 | warnings.simplefilter(action="ignore", category=DeprecationWarning) 40 | 41 | # trying importing functions using deeplabcut module, if DLC 2 is installed correctly 42 | if MODEL_ORIGIN in ("DLC", "MADLC"): 43 | try: 44 | # checking for DLC-core 45 | if importlib.util.find_spec("deeplabcutcore") is not None: 46 | import deeplabcutcore.pose_estimation_tensorflow.nnet.predict as predict 47 | from deeplabcutcore.pose_estimation_tensorflow.config import load_config 48 | # trying to import "classic" DLC2 49 | else: 50 | import deeplabcut.pose_estimation_tensorflow.nnet.predict as predict 51 | from deeplabcut.pose_estimation_tensorflow.config import load_config 52 | 53 | if MODEL_ORIGIN == "MADLC": 54 | from deeplabcut.pose_estimation_tensorflow.nnet import ( 55 | predict_multianimal, 56 | ) 57 | 58 | models_folder = "pose_estimation_tensorflow/models/" 59 | # if not DLC 2 is not installed, try import from DLC 1 the old way 60 | except ImportError: 61 | # adding DLC posing path and loading modules from it 62 | sys.path.insert(0, MODEL_PATH + "/pose-tensorflow") 63 | from config import load_config 64 | from nnet import predict 65 | 66 | models_folder = "pose-tensorflow/models/" 67 | 68 | elif MODEL_ORIGIN == "DEEPPOSEKIT": 69 | from deepposekit.models import load_model 70 | 71 | elif MODEL_ORIGIN == "DLC-LIVE": 72 | from dlclive import DLCLive 73 | from utils.configloader import MODEL_PATH 74 | 75 | elif MODEL_ORIGIN == "SLEAP": 76 | from sleap import load_model 77 | from utils.configloader import MODEL_PATH 78 | 79 | 80 | class SkeletonError(Exception): 81 | """Custom expection to be raised when issues with the skeleton is not received""" 82 | 83 | 84 | def load_deeplabcut(): 85 | """ 86 | Loads TensorFlow with predefined in config DeepLabCut model 87 | 88 | :return: tuple of DeepLabCut config, TensorFlow session, inputs and outputs 89 | """ 90 | model = os.path.join(MODEL_PATH, models_folder, MODEL_NAME) 91 | cfg = load_config(os.path.join(model, "test/pose_cfg.yaml")) 92 | snapshots = sorted( 93 | [sn.split(".")[0] for sn in os.listdir(model + "/train/") if "index" in sn] 94 | ) 95 | cfg["init_weights"] = model + "/train/" + snapshots[-1] 96 | 97 | sess, inputs, outputs = predict.setup_pose_prediction(cfg) 98 | return cfg, sess, inputs, outputs 99 | 100 | 101 | # pure DLC 102 | def get_pose(image, config, session, inputs, outputs): 103 | """ 104 | Gets scoremap, local reference and pose from DeepLabCut using given image 105 | Pose is most probable points for each joint, and not really used later 106 | Scoremap and local reference is essential to extract skeletons 107 | :param image: frame which would be analyzed 108 | :param config, session, inputs, outputs: DeepLabCut configuration and TensorFlow variables from load_deeplabcut() 109 | 110 | :return: tuple of scoremap, local reference and pose 111 | """ 112 | scmap, locref, pose = predict.getpose(image, config, session, inputs, outputs, True) 113 | return scmap, locref, pose 114 | 115 | 116 | def find_local_peaks_new( 117 | scoremap: np.ndarray, local_reference: np.ndarray, animal_number: int, config: dict 118 | ) -> dict: 119 | """ 120 | Function for finding local peaks for each joint on provided scoremap 121 | :param scoremap: scmap from get_pose function 122 | :param local_reference: locref from get_pose function 123 | :param animal_number: number of animals for which we need to find peaks, also used for critical joints 124 | Critical joint are used to define skeleton of an animal 125 | There can not be more than animal_number point for each critical joint 126 | :param config: DeepLabCut config from load_deeplabcut() 127 | 128 | :returns all_joints dictionary with coordinates as list of tuples for each joint 129 | """ 130 | 131 | # loading animal joints from config 132 | all_joints_names = config["all_joints_names"] 133 | # critical_joints = ['neck', 'tailroot'] 134 | all_peaks = {} 135 | # loading stride from config 136 | stride = config["stride"] 137 | # filtering scoremap 138 | scoremap[scoremap < 0.1] = 0 139 | for joint_num, joint in enumerate(all_joints_names): 140 | all_peaks[joint] = [] 141 | # selecting the joint in scoremap and locref 142 | lr_joint = local_reference[:, :, joint_num] 143 | sm_joint = scoremap[:, :, joint_num] 144 | 145 | # applying maximum filter with footprint 146 | neighborhood = generate_binary_structure(2, 1) 147 | sm_max_filter = maximum_filter(sm_joint, footprint=neighborhood) 148 | # eroding filtered scoremap 149 | erosion_structure = generate_binary_structure(2, 3) 150 | sm_max_filter_eroded = binary_erosion( 151 | sm_max_filter, structure=erosion_structure 152 | ).astype(sm_max_filter.dtype) 153 | # labeling eroded filtered scoremap 154 | labeled_sm_eroded, num = label(sm_max_filter_eroded) 155 | # if joint is 'critical' and we have too few labels then we try a workaround to ensure maximum found peaks 156 | # for all other joints - normal procedure with cutoff point at animal_number 157 | peaks = maximum_position( 158 | sm_joint, labels=labeled_sm_eroded, index=range(1, num + 1) 159 | ) 160 | if num != animal_number: 161 | peaks = [ 162 | tuple(peak) 163 | for peak in peak_local_max( 164 | sm_joint, min_distance=4, num_peaks=animal_number 165 | ) 166 | ] 167 | 168 | if len(peaks) > animal_number: 169 | peaks = peaks[: animal_number + 1] 170 | 171 | # using scoremap peaks to get the coordinates on original image 172 | for peak in peaks: 173 | offset = lr_joint[peak] 174 | prob = sm_joint[peak] # not used 175 | # some weird DLC magic with stride and offsets 176 | coordinates = np.floor( 177 | np.array(peak)[::-1] * stride + 0.5 * stride + offset 178 | ) 179 | all_peaks[joint].append([tuple(coordinates.astype(int)), joint]) 180 | return all_peaks 181 | 182 | 183 | def filter_pose_by_likelihood(pose, threshold: float = 0.1): 184 | """ 185 | Filters pose estimation by likelihood threshold. Estimates below threshold are set to NaN and handled downstream 186 | of this function in calculate skeletons. 187 | :param pose: pose estimation (e.g., from DLC) 188 | :param threshold: likelihood threshold to filter by 189 | :return filtered_pose: pose estimation filtered by likelihood (may contain NaN) in shape of pose, 190 | likelihood will be set to 2 in case of filtered bodyparts 191 | """ 192 | filtered_pose = pose.copy() 193 | if MODEL_ORIGIN == 'DLC': 194 | """ DLC pose output is an np.array with [bp*[X,Y, Likelihood]]""" 195 | for num, bp in enumerate(filtered_pose): 196 | if bp[2] < threshold: 197 | # set new threshold to "2" (number outside of normal range to signify filter) 198 | # if just past to "calculate_skeleton" it will be ignored 199 | filtered_pose[num] = np.array([np.NaN, np.NaN, 2]) 200 | 201 | return filtered_pose 202 | 203 | 204 | def calculate_dlstream_skeletons(peaks: dict, animals_number: int) -> list: 205 | """ 206 | Creating skeletons from given peaks 207 | There could be no more skeletons than animals_number 208 | Only unique skeletons output 209 | """ 210 | # creating a cartesian product out of all joints 211 | # this product contains all possible variations (dots clusters) of all joints groups 212 | cartesian_p = product(*peaks.values(), repeat=1) 213 | 214 | def calculate_closest_distances(dots_cluster: list) -> float: 215 | """ 216 | Calculating a sum of all distances between all dots in a cluster 217 | """ 218 | # extracting dots coordinates from given list 219 | dots_coordinates = (dot[0] for dot in dots_cluster) 220 | # calculating sum of each dots cluster 221 | product_sum = sum( 222 | calculate_distance(*c) for c in combinations(dots_coordinates, 2) 223 | ) 224 | return product_sum 225 | 226 | # sorting groups by their sum 227 | sorted_product = sorted( 228 | cartesian_p, key=lambda c: calculate_closest_distances(c), reverse=False 229 | ) 230 | 231 | # creating skeletons from top dots cluster 232 | def compare_clusters(unique_clusters: list, new_cluster: tuple) -> bool: 233 | """ 234 | Compare some new cluster against every existing unique cluster to find if it is unique 235 | :param unique_clusters: list of existing unique cluster 236 | :param new_cluster: cluster with same dots 237 | :return: if new cluster is unique 238 | """ 239 | # compare each element of tuple for uniqueness 240 | # finding unique combinations of joints within all possible combinations 241 | compare = lambda cl1, cl2: not any([(s1 == s2) for s1, s2 in zip(cl1, cl2)]) 242 | # create a uniqueness check list 243 | # so if a list consists at least one False then new_cluster is not unique 244 | comparison = [compare(u_cluster, new_cluster) for u_cluster in unique_clusters] 245 | return all(comparison) 246 | 247 | def create_animal_skeleton(dots_cluster: tuple) -> dict: 248 | """ 249 | Creating a easy to read skeleton from dots cluster 250 | Format for each joint: 251 | {'joint_name': (x,y)} 252 | """ 253 | skeleton = {} 254 | for dot in dots_cluster: 255 | skeleton[dot[-1]] = dot[0] 256 | return skeleton 257 | 258 | top_unique_clusters = [] 259 | animal_skeletons = [] 260 | 261 | if sorted_product: 262 | # add first cluster in our sorted list 263 | top_unique_clusters.append(sorted_product[0]) 264 | for cluster in sorted_product[1:]: 265 | # check if cluster is unique and we have a room for it 266 | if ( 267 | compare_clusters(top_unique_clusters, cluster) 268 | and len(top_unique_clusters) < animals_number 269 | ): 270 | top_unique_clusters.append(cluster) 271 | # there couldn't be more clusters then animal_number limit 272 | elif len(top_unique_clusters) == animals_number: 273 | break 274 | 275 | # creating a skeleton out of each cluster in our top clusters list 276 | for unique_cluster in top_unique_clusters: 277 | animal_skeletons.append(create_animal_skeleton(unique_cluster)) 278 | 279 | return animal_skeletons 280 | 281 | 282 | # maDLC 283 | def get_ma_pose(image, config, session, inputs, outputs): 284 | """ 285 | Gets scoremap, local reference and pose from DeepLabCut using given image 286 | Pose is most probable points for each joint, and not really used later 287 | Scoremap and local reference is essential to extract skeletons 288 | :param image: frame which would be analyzed 289 | :param config, session, inputs, outputs: DeepLabCut configuration and TensorFlow variables from load_deeplabcut() 290 | 291 | :return: tuple of scoremap, local reference and pose 292 | """ 293 | scmap, locref, paf, pose = predict_multianimal.get_detectionswithcosts( 294 | image, 295 | config, 296 | session, 297 | inputs, 298 | outputs, 299 | outall=True, 300 | nms_radius=5.0, 301 | det_min_score=0.1, 302 | c_engine=False, 303 | ) 304 | 305 | return pose 306 | 307 | 308 | def calculate_ma_skeletons( 309 | pose: dict, animals_number: int, threshold: float = 0.1 310 | ) -> list: 311 | """ 312 | Creating skeletons from given pose in maDLC 313 | There could be no more skeletons than animals_number 314 | Only unique skeletons output 315 | """ 316 | 317 | def filter_mapredictions(pose): 318 | detection = [] 319 | conf = np.array(pose["confidence"]) 320 | coords = np.array(pose["coordinates"]) 321 | for num, bp in enumerate(pose["coordinates"][0]): 322 | if len(bp) > 0: 323 | conf_bp = conf[num].flatten() 324 | fltred_bp = bp[conf_bp >= threshold, :] 325 | # todo: add function to only take top k-highest poses with k = animal number 326 | detection.append(fltred_bp) 327 | else: 328 | detection.append(np.array([])) 329 | return detection 330 | 331 | def extract_to_animal_skeleton(coords): 332 | """ 333 | Creating a easy to read skeleton from dots cluster 334 | Format for each joint: 335 | {'joint_name': (x,y)} 336 | """ 337 | bodyparts = np.array(coords) 338 | skeletons = {} 339 | for bp in range(len(bodyparts)): 340 | for animal_num in range(animals_number): 341 | if "Mouse" + str(animal_num + 1) not in skeletons.keys(): 342 | skeletons["Mouse" + str(animal_num + 1)] = {} 343 | if len(bodyparts[bp]) >= animals_number: 344 | skeletons["Mouse" + str(animal_num + 1)][ 345 | "bp" + str(bp + 1) 346 | ] = bodyparts[bp][animal_num].astype(float) 347 | else: 348 | if animal_num < len(bodyparts[bp]): 349 | skeletons["Mouse" + str(animal_num + 1)][ 350 | "bp" + str(bp + 1) 351 | ] = bodyparts[bp][animal_num].astype(float) 352 | else: 353 | skeletons["Mouse" + str(animal_num + 1)][ 354 | "bp" + str(bp + 1) 355 | ] = np.array([np.NaN, np.NaN]) 356 | 357 | return skeletons 358 | 359 | detections = filter_mapredictions(pose) 360 | animal_skeletons = extract_to_animal_skeleton(detections) 361 | animal_skeletons = list(animal_skeletons.values()) 362 | 363 | return animal_skeletons 364 | 365 | 366 | # DLC LIVE & DeepPoseKit 367 | def load_dpk(): 368 | model = load_model(MODEL_PATH) 369 | return model.predict_model 370 | 371 | 372 | def load_dlc_live(): 373 | return DLCLive(MODEL_PATH) 374 | 375 | 376 | def load_sleap(): 377 | model = load_model(MODEL_PATH, batch_size=1) 378 | model.inference_model 379 | return model.inference_model 380 | 381 | 382 | def flatten_maDLC_skeletons(skeletons): 383 | """Flattens maDLC multi skeletons into one skeleton to simulate dlc output 384 | where animals are not identical e.g. for animals with different fur colors (SIMBA)""" 385 | flat_skeletons = dict() 386 | for num, skeleton in enumerate(skeletons): 387 | for bp, value in skeleton.items(): 388 | flat_skeletons[f"{num}_{bp}"] = value 389 | 390 | return [flat_skeletons] 391 | 392 | 393 | def split_flat_skeleton(skeletons): 394 | """Splits flat multi skeletons (e.g. from flatten_maDLCskeleton) into seperate skeleton to simulate output 395 | where animals are identity tracked (e.g. SLEAP)""" 396 | flat_skeletons = skeletons[0] 397 | split_skeletons = [] 398 | bp_per_animal, remainder = divmod(len(flat_skeletons), ANIMALS_NUMBER) 399 | if remainder > 0: 400 | raise SkeletonError( 401 | f"The number of body parts ({len(flat_skeletons)}) cannot be split equally into {ANIMALS_NUMBER} animals." 402 | ) 403 | else: 404 | for animal in range(ANIMALS_NUMBER): 405 | single_skeleton = list(flat_skeletons.keys())[ 406 | bp_per_animal * animal : bp_per_animal * animal + bp_per_animal 407 | ] 408 | split_skeletons.append( 409 | {x: flat_skeletons[x] for x in flat_skeletons if x in single_skeleton} 410 | ) 411 | 412 | return split_skeletons 413 | 414 | 415 | def transform_2skeleton(pose): 416 | """ 417 | Transforms pose estimation into DLStream style "skeleton" posture. 418 | If ALL_BODYPARTS is not sufficient, it will autoname the bodyparts in style bp0, bp1 ... 419 | """ 420 | try: 421 | skeleton = dict() 422 | counter = 0 423 | for bp in pose: 424 | skeleton[ALL_BODYPARTS[counter]] = tuple(np.array(bp[0:2], dtype=float)) 425 | counter += 1 426 | except IndexError: 427 | skeleton = dict() 428 | counter = 0 429 | for bp in pose: 430 | skeleton[f"bp{counter}"] = tuple(np.array(bp[0:2], dtype=float)) 431 | counter += 1 432 | 433 | return skeleton 434 | 435 | 436 | def transform_2pose(skeleton): 437 | pose = np.array([*skeleton.values()]) 438 | return pose 439 | 440 | 441 | def handle_missing_bp(animal_skeletons: list): 442 | """handles missing bodyparts (NaN values) by selected method in advanced_settings.ini 443 | If HANDLE_MISSING is skip: the complete skeleton is removed (default); 444 | If HANDLE_MISSING is null: the missing coordinate is set to 0.0, not recommended for experiments 445 | where continuous monitoring of parameters is necessary. 446 | If HANDLE_MISSING is pass: the missing coordinate is left NaN. This is useful to keep identities, but can yield 447 | unexpected results down the line if NaN values are not caught. 448 | If HANDLE_MISSING is reset: the whole skeleton is set to NaN. This is useful to keep identities, but can yield 449 | unexpected results down the line if NaN values are not caught. 450 | 451 | Missing skeletons will not be passed to the trigger, while resetting coordinates might lead to false results returned 452 | by triggers. 453 | 454 | 455 | :param: animal_skeletons: list of skeletons returned by calculate skeleton 456 | :return animal_skeleton with handled missing values""" 457 | 458 | for skeleton in animal_skeletons: 459 | for bodypart, coordinates in skeleton.items(): 460 | np_coords = np.array((coordinates)) 461 | if any(np.isnan(np_coords)): 462 | if HANDLE_MISSING == "pass": 463 | # do nothing 464 | pass 465 | elif HANDLE_MISSING == "skip": 466 | # remove the whole skeleton 467 | animal_skeletons.remove(skeleton) 468 | break 469 | elif HANDLE_MISSING == "null": 470 | # remove replace coordinates with 0,0 471 | new_coordinates = np.nan_to_num(np_coords, copy=True) 472 | skeleton[bodypart] = tuple(new_coordinates) 473 | elif HANDLE_MISSING == "reset": 474 | # reset complete skeleton to NaN, NaN 475 | reset_skeleton = {bp: (np.NaN, np.NaN) for bp in skeleton} 476 | animal_skeletons = [ 477 | reset_skeleton if i == skeleton else i for i in animal_skeletons 478 | ] 479 | break 480 | else: 481 | animal_skeletons.remove(skeleton) 482 | break 483 | 484 | return animal_skeletons 485 | 486 | 487 | def arrange_flatskeleton(skeleton, n_animals, n_bp_animal, switch_dict): 488 | """changes sequence of bodypart sets (skeletons) in multi animal tracking with flat skeleton output (multiple animals in single skeleton) by switching position of pairs. 489 | E.g. in pose estimation with different fur colors. Note: When switching muliple animals the new position of the previous switches will be used. 490 | :param skeleton: flat skeleton of pose estimation in style {bp1: (x,y), bp2: (x2,y2) ...} 491 | :param n_animals: number of animals in total represented by skeleton 492 | :param n_bp_animal: number of bodyparts per animal in skeleton 493 | :param switch_dict: dictionary containing position of bodypart set (animal) in flat skeleton as key and bp set to exchange with as value. 494 | e.g. switch_dict = dict(1 = 2, 3 = 4) 495 | :return: skeleton with new order 496 | """ 497 | flat_pose = transform_2pose(skeleton) 498 | ra_dict = {} 499 | # slicing the animals out 500 | for num_animal in range(n_animals): 501 | ra_dict[num_animal] = flat_pose[ 502 | num_animal * n_bp_animal : num_animal * n_bp_animal + n_bp_animal 503 | ] 504 | # switching positions 505 | for orig_pos, switch_pos in switch_dict.items(): 506 | # extract old 507 | orig = ra_dict[orig_pos] 508 | switch = ra_dict[switch_pos] 509 | # set to new position 510 | ra_dict[orig_pos] = switch 511 | ra_dict[switch_pos] = orig 512 | # extracting pose 513 | arranged_pose = np.array([*ra_dict.values()]).reshape(flat_pose.shape) 514 | # transforming it to skeleton 515 | flat_skeleton = transform_2skeleton(arranged_pose) 516 | return flat_skeleton 517 | 518 | 519 | def calculate_skeletons_dlc_live(pose) -> list: 520 | """ 521 | Creating skeletons from given pose 522 | There could be no more skeletons than animals_number 523 | Only unique skeletons output 524 | """ 525 | skeletons = [transform_2skeleton(pose)] 526 | return skeletons 527 | 528 | 529 | def calculate_sleap_skeletons(pose) -> list: 530 | """ 531 | Creating skeleton from sleap output 532 | """ 533 | skeletons = [] 534 | for animal in range(pose.shape[0]): 535 | skeleton = transform_2skeleton(pose[animal]) 536 | skeletons.append(skeleton) 537 | return skeletons 538 | 539 | 540 | def calculate_skeletons(peaks: dict, animals_number: int) -> list: 541 | """ 542 | Creating skeletons from given peaks 543 | There could be no more skeletons than animals_number 544 | Only unique skeletons output 545 | adaptive to chosen model origin 546 | """ 547 | if MODEL_ORIGIN == "DLC": 548 | if USE_DLSTREAM_POSTURE_DETECTION: 549 | animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number) 550 | else: 551 | if FILTER_LIKELIHOOD: 552 | peaks = filter_pose_by_likelihood(peaks, LIKELIHOOD_THRESHOLD) 553 | animal_skeletons = calculate_skeletons_dlc_live(peaks) 554 | 555 | if animals_number != 1 and SPLIT_MA: 556 | animal_skeletons = split_flat_skeleton(animal_skeletons) 557 | else: 558 | pass 559 | 560 | elif MODEL_ORIGIN == "MADLC": 561 | animal_skeletons = calculate_ma_skeletons(peaks, animals_number) 562 | if FLATTEN_MA: 563 | animal_skeletons = flatten_maDLC_skeletons(animal_skeletons) 564 | else: 565 | pass 566 | 567 | elif MODEL_ORIGIN == "DLC-LIVE" or MODEL_ORIGIN == "DEEPPOSEKIT": 568 | animal_skeletons = calculate_skeletons_dlc_live(peaks) 569 | if animals_number != 1 and not SPLIT_MA: 570 | raise SkeletonError( 571 | "Multiple animals are currently not supported by DLC-LIVE." 572 | " If you are using differently colored animals, please refer to the bodyparts directly (as a flattened skeleton) or use SPLIT_MA in the advanced settings." 573 | ) 574 | elif SPLIT_MA: 575 | animal_skeletons = split_flat_skeleton(animal_skeletons) 576 | else: 577 | pass 578 | 579 | elif MODEL_ORIGIN == "SLEAP": 580 | animal_skeletons = calculate_sleap_skeletons(peaks) 581 | if FLATTEN_MA: 582 | animal_skeletons = flatten_maDLC_skeletons(animal_skeletons) 583 | elif animals_number != 1 and SPLIT_MA: 584 | animal_skeletons = split_flat_skeleton(animal_skeletons) 585 | else: 586 | pass 587 | 588 | animal_skeletons = handle_missing_bp(animal_skeletons) 589 | 590 | return animal_skeletons 591 | -------------------------------------------------------------------------------- /utils/pylon.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | from pypylon import pylon 10 | import cv2 11 | 12 | 13 | class PylonManager: 14 | """ 15 | Basler cameras manager class 16 | """ 17 | 18 | def __init__(self): 19 | self._manager_name = "Basler Pylon" 20 | self._factory = pylon.TlFactory.GetInstance() 21 | self._enabled_devices = {} 22 | self._resolution = None 23 | self._converter = pylon.ImageFormatConverter() 24 | self._converter.OutputPixelFormat = pylon.PixelType_BGR8packed 25 | self._converter.OutputBitAlignment = pylon.OutputBitAlignment_MsbAligned 26 | 27 | @property 28 | def _connected_devices(self) -> dict: 29 | """ 30 | Create a dict with all connected devices from self._factory 31 | """ 32 | return { 33 | device.GetSerialNumber(): device 34 | for device in self._factory.EnumerateDevices() 35 | } 36 | 37 | def get_connected_devices(self) -> list: 38 | """ 39 | Getter for stored connected devices serials list 40 | """ 41 | return list(self._connected_devices.keys()) 42 | 43 | def enable_stream(self, resolution, *args): 44 | """ 45 | Enable stream with given parameters 46 | Pretty meaningless for pylon manager, just sets the desired resolution 47 | """ 48 | self._resolution = resolution 49 | 50 | def enable_device(self, device_serial: str, *args): 51 | """ 52 | Camera starter 53 | """ 54 | camera = pylon.InstantCamera( 55 | self._factory.CreateDevice(self._connected_devices[device_serial]) 56 | ) 57 | self._enabled_devices[camera.DeviceInfo.GetSerialNumber()] = camera 58 | # grabbing continuously (video) with minimal delay 59 | camera.StartGrabbing(pylon.GrabStrategy_LatestImageOnly) 60 | 61 | def enable_all_devices(self): 62 | """ 63 | Starts the cameras with minimal delay 64 | """ 65 | for device in self._connected_devices: 66 | self.enable_device(device) 67 | 68 | def get_enabled_devices(self) -> dict: 69 | """ 70 | Getter for enabled devices dictionary 71 | """ 72 | return self._enabled_devices 73 | 74 | def get_frames(self) -> tuple: 75 | """ 76 | Collect frames for cameras and outputs it in 'color' dictionary 77 | ***depth and infrared are not used in pylon*** 78 | :return: tuple of three dictionaries: color, depth, infrared 79 | """ 80 | color_frames = {} 81 | depth_maps = {} 82 | infra_frames = {} 83 | for camera_name, camera in self._enabled_devices.items(): 84 | grabbed_frame = camera.RetrieveResult( 85 | 5000, pylon.TimeoutHandling_ThrowException 86 | ) 87 | # converting to opencv bgr format 88 | image = self._converter.Convert(grabbed_frame) 89 | img = image.GetArray() 90 | color_frames[camera_name] = cv2.resize(img, self._resolution) 91 | grabbed_frame.Release() 92 | return color_frames, depth_maps, infra_frames 93 | 94 | def stop(self): 95 | """ 96 | Stops cameras 97 | """ 98 | for camera_name, camera in self._enabled_devices.items(): 99 | camera.StopGrabbing() 100 | self._enabled_devices = {} 101 | 102 | def get_name(self) -> str: 103 | return self._manager_name 104 | -------------------------------------------------------------------------------- /utils/realsense.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | import warnings 10 | import numpy as np 11 | 12 | warnings.filterwarnings( 13 | category=FutureWarning, action="ignore" 14 | ) # filter unwanted warnings 15 | import pyrealsense2 as prs2 16 | 17 | 18 | class RealSenseManager: 19 | """ 20 | RealSense cameras manager class 21 | """ 22 | 23 | def __init__(self): 24 | """ 25 | Everything needed to initialize we get from the environment 26 | List of connected devices is created once and is not updated later 27 | """ 28 | self._manager_name = "Intel RealSense" 29 | self._config, self._context = self.realsense_environment() 30 | self._connected_devices = self.find_connected_devices() 31 | self._enabled_devices = {} 32 | # initializing colorizer, that is not really used afterwards 33 | self._colorizer = prs2.colorizer() 34 | # Create alignment primitive with color as its target stream: 35 | self._align = prs2.align(prs2.stream.color) 36 | 37 | @staticmethod 38 | def realsense_environment() -> tuple: 39 | """ 40 | Getting config and context to find devices and enable them 41 | """ 42 | config = prs2.config() 43 | context = prs2.context() 44 | return config, context 45 | 46 | def get_name(self): 47 | return self._manager_name 48 | 49 | def find_connected_devices(self) -> list: 50 | """ 51 | Create a list with all connected devices from self._context 52 | """ 53 | connected_devices = [] 54 | for device in self._context.devices: 55 | connected_devices.append(device.get_info(prs2.camera_info.serial_number)) 56 | return connected_devices 57 | 58 | def get_connected_devices(self) -> list: 59 | """ 60 | Getter for stored connected devices list 61 | """ 62 | return self._connected_devices 63 | 64 | def enable_stream(self, resolution: tuple, framerate: int, stream_type: str): 65 | """ 66 | Enable one stream with given parameters 67 | :param resolution: resolution of the stream in format of (width, height) 68 | :param stream_type: type of stream to enable, supported ones: 'color', 'depth', 'infrared' 69 | :param framerate: maximum stream framerate 70 | """ 71 | width, heigth = resolution 72 | if stream_type == "color": 73 | self._config.enable_stream( 74 | stream_type=prs2.stream.color, 75 | width=width, 76 | height=heigth, 77 | format=prs2.format.bgr8, 78 | framerate=framerate, 79 | ) 80 | elif stream_type == "depth": 81 | self._config.enable_stream( 82 | stream_type=prs2.stream.depth, 83 | width=width, 84 | height=heigth, 85 | format=prs2.format.z16, 86 | framerate=framerate, 87 | ) 88 | elif stream_type == "infrared": 89 | self._config.enable_stream( 90 | stream_type=prs2.stream.infrared, 91 | stream_index=1, # we are using only the first camera here 92 | width=width, 93 | height=heigth, 94 | format=prs2.format.y8, 95 | framerate=framerate, 96 | ) 97 | 98 | def enable_device(self, device_serial: str): 99 | """ 100 | Enable one device with given serial 101 | """ 102 | pipeline = prs2.pipeline(self._context) 103 | self._config.enable_device(device_serial) 104 | pipeline_profile = pipeline.start(self._config) 105 | # enabling the emitter 106 | sensor = pipeline_profile.get_device().first_depth_sensor() 107 | sensor.set_option(prs2.option.emitter_enabled, 0) 108 | # storing the enabled device in enabled devices dictionary 109 | self._enabled_devices[device_serial] = (pipeline, pipeline_profile) 110 | 111 | def enable_all_devices(self): 112 | """ 113 | Cycles through all connected devices and enables them 114 | """ 115 | for device_serial in self._connected_devices: 116 | self.enable_device(device_serial) 117 | 118 | def get_enabled_devices(self) -> dict: 119 | """ 120 | Getter for enabled devices dictionary 121 | """ 122 | return self._enabled_devices 123 | 124 | def get_frames(self) -> tuple: 125 | """ 126 | Collect frames for each enabled stream from each enabled device and output them in the corresponding dictionary 127 | :return: tuple of three dictionaries: color, depth, infrared 128 | """ 129 | color_frames = {} 130 | depth_maps = {} 131 | infra_frames = {} 132 | for serial, device in self._enabled_devices.items(): 133 | device_pipeline, device_profile = device 134 | streams = device_profile.get_streams() 135 | frameset = device_pipeline.wait_for_frames() 136 | 137 | # Alignment for depth stream 138 | # currently not used 139 | # to enable it, uncomment following line 140 | # frameset = self._align.process(frameset) 141 | 142 | for stream in streams: 143 | if stream.stream_type() == prs2.stream.color: 144 | color_frames[serial] = frameset.get_color_frame().get_data() 145 | elif stream.stream_type() == prs2.stream.depth: 146 | depth_maps[serial] = frameset.get_depth_frame() 147 | elif stream.stream_type() == prs2.stream.infrared: 148 | infra_frames[serial] = frameset.get_infrared_frame(1).get_data() 149 | 150 | return color_frames, depth_maps, infra_frames 151 | 152 | def colorize_depth_frame(self, depth_frame: prs2.depth_frame) -> prs2.depth_frame: 153 | """ 154 | Colorizes the depth frame 155 | Not currently used due to very heavy CPU load 156 | OCV works with this better 157 | """ 158 | colorized_frame = np.asanyarray( 159 | self._colorizer.colorize(depth_frame).get_data() 160 | ) 161 | return colorized_frame 162 | 163 | def stop(self): 164 | """ 165 | Stops every device and stream 166 | """ 167 | self._config.disable_all_streams() 168 | for serial, device in self._enabled_devices.items(): 169 | device_pipeline, device_profile = device 170 | device_pipeline.stop() 171 | self._enabled_devices = {} 172 | -------------------------------------------------------------------------------- /utils/webcam.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepLabStream 3 | © J.Schweihoff, M. Loshakov 4 | University Bonn Medical Faculty, Germany 5 | https://github.com/SchwarzNeuroconLab/DeepLabStream 6 | Licensed under GNU General Public License v3.0 7 | """ 8 | 9 | 10 | from utils.generic import GenericManager, MissingFrameError 11 | import time 12 | import base64 13 | 14 | import cv2 15 | import numpy as np 16 | import zmq 17 | 18 | from utils.configloader import RESOLUTION, FRAMERATE, PORT 19 | 20 | 21 | class WebCamManager(GenericManager): 22 | def __init__(self): 23 | """ 24 | Binds the computer to a ip address and starts listening for incoming streams. 25 | Adapted from StreamViewer.py https://github.com/CT83/SmoothStream 26 | """ 27 | super().__init__() 28 | self._context = zmq.Context() 29 | self._footage_socket = self._context.socket(zmq.SUB) 30 | self._footage_socket.bind("tcp://*:" + PORT) 31 | self._footage_socket.setsockopt_string(zmq.SUBSCRIBE, np.unicode("")) 32 | 33 | self._camera = None 34 | self._camera_name = "webcam" 35 | self.initial_wait = False 36 | self.last_frame_time = time.time() 37 | 38 | @staticmethod 39 | def string_to_image(string): 40 | """ 41 | Taken from https://github.com/CT83/SmoothStream 42 | """ 43 | 44 | img = base64.b64decode(string) 45 | npimg = np.fromstring(img, dtype=np.uint8) 46 | return cv2.imdecode(npimg, 1) 47 | 48 | def get_frames(self) -> tuple: 49 | """ 50 | Collect frames for camera and outputs it in 'color' dictionary 51 | ***depth and infrared are not used here*** 52 | :return: tuple of three dictionaries: color, depth, infrared 53 | """ 54 | 55 | color_frames = {} 56 | depth_maps = {} 57 | infra_frames = {} 58 | 59 | if self._footage_socket: 60 | ret = True 61 | else: 62 | ret = False 63 | self.last_frame_time = time.time() 64 | if ret: 65 | # if not self.initial_wait: 66 | # cv2.waitKey(1000) 67 | # self.initial_wait = True 68 | # receives frame from stream 69 | image = self._footage_socket.recv_string() 70 | # converts image from str to image format that cv can handle 71 | image = self.string_to_image(image) 72 | image = cv2.resize(image, RESOLUTION) 73 | color_frames[self._camera_name] = image 74 | running_time = time.time() - self.last_frame_time 75 | if running_time <= 1 / FRAMERATE: 76 | sleepy_time = int(np.ceil(1000 / FRAMERATE - running_time / 1000)) 77 | cv2.waitKey(sleepy_time) 78 | 79 | else: 80 | raise MissingFrameError( 81 | "No frame was received from the webcam stream. Make sure that you started streaming on the host machine." 82 | ) 83 | 84 | return color_frames, depth_maps, infra_frames 85 | 86 | def enable_stream(self, resolution, framerate, *args): 87 | """ 88 | Not used for webcam streaming over network 89 | """ 90 | pass 91 | --------------------------------------------------------------------------------