├── .gitignore ├── LICENSE ├── README.md ├── gui.py ├── requirements.txt └── run_primates.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 damaggu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Idtracking GUI accompanying SIPEC 2 | 3 | ## usage 4 | 5 | start the gui by simply running 6 | ``` 7 | python gui.py 8 | ``` 9 | 10 | in bash current steps taken are displayed until the gui pops up. 11 | Now basically everything is controlled in the keyboard. 12 | 13 | The first image is the first mask, which is indicated as a green rectangle. 14 | So the ID needs to be determined for that individual. 15 | To go forward or backward in the video press 16 | ``` 17 | n - for next frame (hold to see video) 18 | p - for previous frame 19 | ``` 20 | 21 | Now, to label the id use the numbers 1 to 4, which should be tied to the identity 22 | of the primates (i.e. 1 - Bob, 2 - John, ...) 23 | So after identifying the animal press Numbers 1 to 4. 24 | If you are unsure follow the rule: in doubt - leave it out! and press 25 | ``` 26 | d - for doubt 27 | ``` 28 | Then the program will just skip to the next mask. 29 | 30 | 31 | ## installation 32 | 33 | create a fresh python environment (conda installed previously) with: 34 | ``` 35 | conda create -n idtracker python=3.6 36 | ``` 37 | afterwards install necessary packages from requirements file 38 | ``` 39 | pip install -r requirements.txt 40 | ``` 41 | and Done! 42 | 43 | 44 | ## Cite: 45 | 46 | If you use any part of this code for your work, please cite the following: 47 | 48 | ``` 49 | SIPEC: the deep-learning Swiss knife for behavioral data analysis 50 | Markus Marks, Jin Qiuhan, Oliver Sturman, Lukas von Ziegler, Sepp Kollmorgen, Wolfger von der Behrens, Valerio Mante, Johannes Bohacek, Mehmet Fatih Yanik 51 | bioRxiv 2020.10.26.355115; doi: https://doi.org/10.1101/2020.10.26.355115 52 | ``` 53 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from joblib import Parallel, delayed 3 | 4 | sys.path.extend(["./venv/lib/python3.7/site-packages"]) 5 | import os 6 | import time 7 | from argparse import ArgumentParser 8 | from time import sleep 9 | 10 | import cv2 11 | import numpy as np 12 | from skimage.transform import rescale 13 | import skvideo.io 14 | from tqdm import tqdm 15 | from glob import glob 16 | 17 | import skimage.color 18 | import skimage.io 19 | import skimage.transform 20 | 21 | import gc 22 | import joblib 23 | import time 24 | from multiprocessing import Process 25 | import concurrent.futures 26 | 27 | from distutils.version import LooseVersion 28 | 29 | import pickle 30 | 31 | crop = True 32 | # highres is 500 33 | 34 | crop = 500 35 | 36 | parser = ArgumentParser() 37 | 38 | parser.add_argument( 39 | "--filename", 40 | action="store", 41 | dest="filename", 42 | type=str, 43 | help="filename of the video to be processed (has to be a segmented one)", 44 | ) 45 | 46 | parser.add_argument( 47 | "--names", 48 | action="store", 49 | dest="names", 50 | type=str, 51 | help="Name of animals in order [1,2,3,4]", 52 | ) 53 | 54 | parser.add_argument( 55 | "--out_folder", 56 | action="store", 57 | dest="results_sink", 58 | type=str, 59 | default="./results/", 60 | help="folder where results should be saved", 61 | ) 62 | 63 | parser.add_argument( 64 | "--num_masks", 65 | action="store", 66 | dest="num_masks", 67 | type=int, 68 | default=40, 69 | help="number of masks to be labeled for this video", 70 | ) 71 | 72 | parser.add_argument( 73 | "--window_size", 74 | action="store", 75 | dest="window_size", 76 | type=int, 77 | default=1024, 78 | help="size of the GUI in pixels", 79 | ) 80 | 81 | parser.add_argument( 82 | "--species", 83 | action="store", 84 | dest="species", 85 | type=str, 86 | help="define the species to annotate primate/mouse/else", 87 | ) 88 | 89 | 90 | def resize( 91 | image, 92 | output_shape, 93 | order=1, 94 | mode="constant", 95 | cval=0, 96 | clip=True, 97 | preserve_range=False, 98 | anti_aliasing=False, 99 | anti_aliasing_sigma=None, 100 | ): 101 | """A wrapper for Scikit-Image resize(). 102 | 103 | Scikit-Image generates warnings on every call to resize() if it doesn't 104 | receive the right parameters. The right parameters depend on the version 105 | of skimage. This solves the problem by using different parameters per 106 | version. And it provides a central place to control resizing defaults. 107 | """ 108 | if LooseVersion(skimage.__version__) >= LooseVersion("0.14"): 109 | # New in 0.14: anti_aliasing. Default it to False for backward 110 | # compatibility with skimage 0.13. 111 | return skimage.transform.resize( 112 | image, 113 | output_shape, 114 | order=order, 115 | mode=mode, 116 | cval=cval, 117 | clip=clip, 118 | preserve_range=preserve_range, 119 | anti_aliasing=anti_aliasing, 120 | anti_aliasing_sigma=anti_aliasing_sigma, 121 | ) 122 | else: 123 | return skimage.transform.resize( 124 | image, 125 | output_shape, 126 | order=order, 127 | mode=mode, 128 | cval=cval, 129 | clip=clip, 130 | preserve_range=preserve_range, 131 | ) 132 | 133 | 134 | def resize_image(image, min_dim=None, max_dim=None, min_scale=None, mode="square"): 135 | # Keep track of image dtype and return results in the same dtype 136 | image_dtype = image.dtype 137 | # Default window (y1, x1, y2, x2) and default scale == 1. 138 | h, w = image.shape[:2] 139 | window = (0, 0, h, w) 140 | scale = 1 141 | padding = [(0, 0), (0, 0), (0, 0)] 142 | crop = None 143 | 144 | if mode == "none": 145 | return image, window, scale, padding, crop 146 | 147 | # Scale? 148 | if min_dim: 149 | # Scale up but not down 150 | scale = max(1, min_dim / min(h, w)) 151 | if min_scale and scale < min_scale: 152 | scale = min_scale 153 | 154 | # Does it exceed max dim? 155 | if max_dim and mode == "square": 156 | image_max = max(h, w) 157 | if round(image_max * scale) > max_dim: 158 | scale = max_dim / image_max 159 | 160 | # Resize image using bilinear interpolation 161 | if scale != 1: 162 | image = resize(image, (round(h * scale), round(w * scale)), preserve_range=True) 163 | 164 | # Need padding or cropping? 165 | if mode == "square": 166 | # Get new height and width 167 | h, w = image.shape[:2] 168 | top_pad = (max_dim - h) // 2 169 | bottom_pad = max_dim - h - top_pad 170 | left_pad = (max_dim - w) // 2 171 | right_pad = max_dim - w - left_pad 172 | padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)] 173 | image = np.pad(image, padding, mode="constant", constant_values=0) 174 | window = (top_pad, left_pad, h + top_pad, w + left_pad) 175 | elif mode == "pad64": 176 | h, w = image.shape[:2] 177 | # Both sides must be divisible by 64 178 | assert min_dim % 64 == 0, "Minimum dimension must be a multiple of 64" 179 | # Height 180 | if h % 64 > 0: 181 | max_h = h - (h % 64) + 64 182 | top_pad = (max_h - h) // 2 183 | bottom_pad = max_h - h - top_pad 184 | else: 185 | top_pad = bottom_pad = 0 186 | # Width 187 | if w % 64 > 0: 188 | max_w = w - (w % 64) + 64 189 | left_pad = (max_w - w) // 2 190 | right_pad = max_w - w - left_pad 191 | else: 192 | left_pad = right_pad = 0 193 | padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)] 194 | image = np.pad(image, padding, mode="constant", constant_values=0) 195 | window = (top_pad, left_pad, h + top_pad, w + left_pad) 196 | elif mode == "crop": 197 | # Pick a random crop 198 | h, w = image.shape[:2] 199 | y = random.randint(0, (h - min_dim)) 200 | x = random.randint(0, (w - min_dim)) 201 | crop = (y, x, min_dim, min_dim) 202 | image = image[y : y + min_dim, x : x + min_dim] 203 | window = (0, 0, min_dim, min_dim) 204 | else: 205 | raise Exception("Mode {} not supported".format(mode)) 206 | return image.astype(image_dtype), window, scale, padding, crop 207 | 208 | 209 | def mold_image(img): 210 | image, window, scale, padding, crop = resize_image( 211 | img[:, :, :], min_dim=2048, min_scale=2048, max_dim=2048, mode="square" 212 | ) 213 | return image 214 | 215 | 216 | class WindowHandler: 217 | frames = None 218 | current_frame = None 219 | 220 | def __init__( 221 | self, 222 | frames_path, 223 | name_indicators, 224 | filename, 225 | results_sink, 226 | masks, 227 | num_masks, 228 | stepsize, 229 | window_size, 230 | ): 231 | 232 | super().__init__() 233 | self.masks = masks 234 | self.frames_path = frames_path 235 | self.name_indicators = name_indicators 236 | self.filename = filename 237 | self.results_sink = results_sink 238 | self.current_mask_focus = 0 239 | self.zoom = False 240 | self.stepsize = stepsize 241 | self.break_status = False 242 | self.num_masks = num_masks 243 | 244 | # start timer for persistent saving 245 | self.start_time = time.time() 246 | 247 | # opencv params 248 | 249 | self.window_name = "output" 250 | 251 | cv2.namedWindow( 252 | self.window_name, cv2.WINDOW_NORMAL 253 | ) # Create window with freedom of dimensions 254 | cv2.resizeWindow(self.window_name, window_size, window_size) 255 | self.font = cv2.FONT_HERSHEY_SIMPLEX 256 | self.bottomLeftCornerOfText = (512, 512) 257 | self.fontScale = 500 258 | self.fontColor = (0, 255, 0) 259 | self.lineType = 20 260 | self.font_thickness = 1 261 | 262 | try: 263 | self.results = self.load_data() 264 | self.previous_frame_focus = 0 265 | print("Loading previous annotations") 266 | except FileNotFoundError: 267 | self.results = {} 268 | self.previous_frame_focus = None 269 | 270 | self.manual_mode = False 271 | self.current_mask = 0 272 | self.current_frame_focus = 1 273 | self.current_frame = self.current_frame_focus 274 | self.current_difficulty_flag = "easy" 275 | self.mask_focus = 0 276 | 277 | self.mask_color_default_focus_frame = (125, 125, 125) 278 | self.mask_color_default = (75, 75, 75) 279 | self.mask_color_focus = (0, 255, 0) 280 | self.mask_color_labeled = (0, 0, 255) 281 | 282 | self.local_slider_window = 20 283 | self.local_slider_lower_window = self.current_frame - self.local_slider_window 284 | self.local_slider_higher_window = self.current_frame + self.local_slider_window 285 | 286 | self.regular_focus_interval = 100 287 | 288 | # load frames --- old frames 289 | 290 | # self.overall_frames = len(glob(frames_path + '*.npy')) 291 | self.load_frames(0, 0) 292 | self.overall_frames = len(self.frames) 293 | self.frame_buffer = min(10000, self.overall_frames) 294 | # self.frame_batches = int(float(self.overall_frames)/float(self.frame_buffer)) 295 | self.frame_current_batch = 0 296 | 297 | self.local_slider = cv2.createTrackbar( 298 | "Local Slider", 299 | self.window_name, 300 | self.local_slider_lower_window, 301 | self.local_slider_higher_window, 302 | self.on_change_local, 303 | ) 304 | self.global_slider = cv2.createTrackbar( 305 | "Global Slider", 306 | self.window_name, 307 | self.current_frame, 308 | self.overall_frames - 1, 309 | self.on_change_global, 310 | ) 311 | 312 | def load_frames(self, start, end): 313 | print("loading frames") 314 | 315 | # TODO: optim code 316 | basepath = "" 317 | 318 | segment = int(self.filename.split("_")[-1][-1]) 319 | vidname = self.filename.split("1_")[0] + "1" 320 | 321 | vid = basepath + vidname + ".mp4" 322 | idx = int(self.filename.split("1_")[-1]) 323 | batch_size = 10000 324 | 325 | videodata = skvideo.io.vread(vid, as_grey=False) 326 | videodata = videodata[idx * batch_size : (idx + 1) * batch_size] 327 | results_list = Parallel( 328 | n_jobs=40, max_nbytes=None, backend="multiprocessing", verbose=40 329 | )(delayed(mold_image)(image) for image in videodata) 330 | results = {} 331 | for idx, el in enumerate(results_list): 332 | results[idx] = el 333 | self.frames = results 334 | print("frames loaded", str(len(results))) 335 | 336 | def on_change_local(self, int): 337 | self.current_frame = int 338 | cv2.setTrackbarPos("Local Slider", self.window_name, int) 339 | cv2.setTrackbarPos("Global Slider", self.window_name, int) 340 | 341 | def on_change_global(self, int): 342 | self.current_frame = int 343 | cv2.setTrackbarPos("Global Slider", self.window_name, int) 344 | if not self.local_slider_lower_window < int < self.local_slider_higher_window: 345 | self.local_slider_lower_window = int - self.local_slider_window 346 | self.local_slider_higher_window = int + self.local_slider_window 347 | cv2.setTrackbarMin( 348 | "Local Slider", 349 | winname=self.window_name, 350 | minval=self.local_slider_lower_window, 351 | ) 352 | cv2.setTrackbarMax( 353 | "Local Slider", 354 | winname=self.window_name, 355 | maxval=self.local_slider_higher_window, 356 | ) 357 | cv2.setTrackbarPos("Local Slider", self.window_name, int) 358 | 359 | def close(self): 360 | print("writing data, do not interrupt!") 361 | self.save_data() 362 | print("done writing data") 363 | cv2.destroyAllWindows() 364 | 365 | def save_data(self): 366 | np.save(self.results_sink + "IDresults_" + self.filename + ".npy", self.results) 367 | 368 | def load_data(self): 369 | return np.load( 370 | self.results_sink + "IDresults_" + self.filename + ".npy", allow_pickle=True 371 | ).item() 372 | 373 | def clocked_save(self): 374 | # save data every minute 375 | if (time.time() - self.start_time) % 60 < 0.055: 376 | self.save_data() 377 | 378 | def mask_to_opencv(self, frame, mask, color, animal_id=None, mask_id=None): 379 | cv2.rectangle(frame, (mask[1], mask[0]), (mask[3], mask[2]), color, 3) 380 | if animal_id: 381 | if mask_id == 0: 382 | cv2.putText( 383 | frame, 384 | animal_id, 385 | (mask[1], mask[0]), 386 | self.font, 387 | 0.5, 388 | self.mask_color_labeled, 389 | self.font_thickness, 390 | cv2.LINE_AA, 391 | ) 392 | if mask_id == 1: 393 | cv2.putText( 394 | frame, 395 | animal_id, 396 | (mask[3], mask[2]), 397 | self.font, 398 | 0.5, 399 | self.mask_color_labeled, 400 | self.font_thickness, 401 | cv2.LINE_AA, 402 | ) 403 | if mask_id == 2: 404 | cv2.putText( 405 | frame, 406 | animal_id, 407 | (mask[1], mask[2]), 408 | self.font, 409 | 0.5, 410 | self.mask_color_labeled, 411 | self.font_thickness, 412 | cv2.LINE_AA, 413 | ) 414 | if mask_id == 3: 415 | cv2.putText( 416 | frame, 417 | animal_id, 418 | (mask[3], mask[0]), 419 | self.font, 420 | 0.5, 421 | self.mask_color_labeled, 422 | self.font_thickness, 423 | cv2.LINE_AA, 424 | ) 425 | 426 | def display_mask(self, mask_id, current_mask): 427 | 428 | frame = self.frames[self.current_frame] 429 | is_labeled = None 430 | try: 431 | is_labeled = mask_id in self.results[self.current_frame]["results"].keys() 432 | except KeyError: 433 | pass 434 | # display focus mask in focus frame12 435 | if self.current_frame == self.current_frame_focus: 436 | if mask_id == self.current_mask_focus: 437 | self.mask_to_opencv(frame, current_mask, self.mask_color_focus) 438 | elif is_labeled: 439 | animal_id = self.results[self.current_frame]["results"][mask_id] 440 | self.mask_to_opencv( 441 | frame, 442 | current_mask, 443 | self.mask_color_labeled, 444 | animal_id=animal_id, 445 | mask_id=mask_id, 446 | ) 447 | else: 448 | self.mask_to_opencv( 449 | frame, current_mask, self.mask_color_default_focus_frame 450 | ) 451 | else: 452 | if is_labeled: 453 | animal_id = self.results[self.current_frame]["results"][mask_id] 454 | self.mask_to_opencv( 455 | frame, 456 | current_mask, 457 | self.mask_color_labeled, 458 | animal_id=animal_id, 459 | mask_id=mask_id, 460 | ) 461 | else: 462 | self.mask_to_opencv(frame, current_mask, self.mask_color_default) 463 | 464 | def draw_random_frame(self, window=50): 465 | draw = 0 466 | print(self.results.keys()) 467 | while True: 468 | draw = np.random.randint(0, len(self.masks), 1)[0] 469 | print(draw) 470 | if ( 471 | window < draw < len(self.masks) - window 472 | and draw not in self.results.keys() 473 | ): 474 | print("True") 475 | break 476 | print("False") 477 | return draw 478 | 479 | def adjust_trackbar(self): 480 | 481 | # check whether current frame outside focus 482 | if not self.frame_current_batch == self.check_batchnum(self.current_frame): 483 | print("reloading frames") 484 | self.frame_current_batch = self.check_batchnum(self.current_frame) 485 | self.load_frames( 486 | self.frame_buffer * self.frame_current_batch, 487 | self.frame_buffer * (self.frame_current_batch + 1), 488 | ) 489 | 490 | self.current_mask_focus = 0 491 | 492 | cv2.setTrackbarPos("Global Slider", self.window_name, self.current_frame) 493 | if ( 494 | not self.local_slider_lower_window 495 | < self.current_frame 496 | < self.local_slider_higher_window 497 | ): 498 | self.local_slider_lower_window = ( 499 | self.current_frame - self.local_slider_window 500 | ) 501 | self.local_slider_higher_window = ( 502 | self.current_frame + self.local_slider_window 503 | ) 504 | cv2.setTrackbarMin( 505 | "Local Slider", 506 | winname=self.window_name, 507 | minval=self.local_slider_lower_window, 508 | ) 509 | cv2.setTrackbarMax( 510 | "Local Slider", 511 | winname=self.window_name, 512 | maxval=self.local_slider_higher_window, 513 | ) 514 | cv2.setTrackbarPos("Local Slider", self.window_name, self.current_frame) 515 | 516 | def check_batchnum(self, frame): 517 | 518 | for i in range(0, self.frame_batches): 519 | if int(self.frame_buffer * i) < frame < int(self.frame_buffer * (i + 1)): 520 | return i 521 | return -1 522 | 523 | def set_new_regular_focus(self, interval=500): 524 | self.current_mask += 1 525 | self.current_frame_focus = ( 526 | self.current_frame_focus + self.regular_focus_interval 527 | ) 528 | self.current_frame = self.current_frame_focus 529 | 530 | # TODO: nicer 531 | breakval = True 532 | while breakval: 533 | try: 534 | self.masks[self.current_frame]["rois"][0] 535 | breakval = False 536 | except IndexError: 537 | self.current_frame_focus = ( 538 | self.current_frame_focus + self.regular_focus_interval 539 | ) 540 | self.current_frame = self.current_frame_focus 541 | 542 | if self.current_frame > self.overall_frames - 50: 543 | self.break_status = True 544 | else: 545 | self.adjust_trackbar() 546 | 547 | def set_new_random_focus(self): 548 | self.current_mask += 1 549 | self.current_frame_focus = self.draw_random_frame() 550 | self.current_frame = self.current_frame_focus 551 | self.current_mask_focus = 0 552 | 553 | self.adjust_trackbar() 554 | 555 | def set_focus(self, focus_frame): 556 | # self.current_mask += 1 557 | print(str(focus_frame)) 558 | self.current_frame = focus_frame 559 | # self.current_mask_focus = 0 560 | 561 | self.adjust_trackbar() 562 | 563 | def display_frame(self): 564 | if self.zoom: 565 | img = self.frames[self.current_frame] 566 | y1, x1, y2, x2 = self.masks[self.current_frame]["rois"][ 567 | self.current_mask_focus 568 | ] 569 | center_x = float(x2 + x1) / 2.0 570 | center_y = float(y2 + y1) / 2.0 571 | 572 | # TODO: make relative 573 | masked_img = img[ 574 | int(center_y - 200) : int(center_y + 200), 575 | int(center_x - 200) : int(center_x + 200), 576 | ] 577 | 578 | # TODO: determine value or find best fixed 579 | rescaled_img = rescale(masked_img, 1.75, multichannel=True)[crop:-crop, :] 580 | cv2.imshow( 581 | "output", 582 | cv2.cvtColor(rescaled_img.astype("float32"), cv2.COLOR_BGR2RGB), 583 | ) 584 | else: 585 | curr_img = self.frames[self.current_frame][crop:-crop, :] 586 | cv2.putText( 587 | curr_img, 588 | "Frame: " + str(self.current_frame), 589 | (1000, 800), 590 | self.font, 591 | 4, 592 | (255, 255, 255), 593 | self.font_thickness, 594 | cv2.LINE_AA, 595 | ) 596 | cv2.putText( 597 | curr_img, 598 | "Mask: " + str(len(self.results) + 1), 599 | (1000, 900), 600 | self.font, 601 | 4, 602 | (255, 255, 255), 603 | self.font_thickness, 604 | cv2.LINE_AA, 605 | ) 606 | cv2.imshow("output", cv2.cvtColor(curr_img, cv2.COLOR_BGR2RGB)) 607 | return 608 | 609 | def display_all_indicators(self): 610 | for idx, indicator in enumerate(self.name_indicators.keys()): 611 | cv2.putText( 612 | self.frames[self.current_frame], 613 | self.name_indicators[indicator] + " : " + str(indicator + 1), 614 | (10, 850 + 25 * idx), 615 | self.font, 616 | 0.5, 617 | (255, 255, 255), 618 | self.font_thickness, 619 | cv2.LINE_AA, 620 | ) 621 | 622 | def display_all_keys(self): 623 | # TODO: explain 'b','w','t' 624 | ## 625 | dist_1 = 200 626 | cv2.putText( 627 | self.frames[self.current_frame], 628 | "p -- display previous mask", 629 | (dist_1, 850 + 0), 630 | self.font, 631 | 0.5, 632 | (255, 255, 255), 633 | self.font_thickness, 634 | cv2.LINE_AA, 635 | ) 636 | cv2.putText( 637 | self.frames[self.current_frame], 638 | "b -- reset view to current mask", 639 | (dist_1, 850 + 25), 640 | self.font, 641 | 0.5, 642 | (255, 255, 255), 643 | self.font_thickness, 644 | cv2.LINE_AA, 645 | ) 646 | cv2.putText( 647 | self.frames[self.current_frame], 648 | "m -- trigger manual mode on current frame", 649 | (dist_1, 850 + 50), 650 | self.font, 651 | 0.5, 652 | (255, 255, 255), 653 | self.font_thickness, 654 | cv2.LINE_AA, 655 | ) 656 | cv2.putText( 657 | self.frames[self.current_frame], 658 | "w -- wrong mask (not an animal)", 659 | (dist_1, 850 + 75), 660 | self.font, 661 | 0.5, 662 | (255, 255, 255), 663 | self.font_thickness, 664 | cv2.LINE_AA, 665 | ) 666 | cv2.putText( 667 | self.frames[self.current_frame], 668 | "h -- difficult to annotate (hard to see from current frame)", 669 | (dist_1, 850 + 100), 670 | self.font, 671 | 0.5, 672 | (255, 255, 255), 673 | self.font_thickness, 674 | cv2.LINE_AA, 675 | ) 676 | cv2.putText( 677 | self.frames[self.current_frame], 678 | "t -- too difficult to annotate", 679 | (dist_1, 850 + 125), 680 | self.font, 681 | 0.5, 682 | (255, 255, 255), 683 | self.font_thickness, 684 | cv2.LINE_AA, 685 | ) 686 | 687 | ## 688 | dist_2 = 600 689 | cv2.putText( 690 | self.frames[self.current_frame], 691 | ". -- change mask to focus on forward", 692 | (dist_2, 850 + 0), 693 | self.font, 694 | 0.5, 695 | (255, 255, 255), 696 | self.font_thickness, 697 | cv2.LINE_AA, 698 | ) 699 | cv2.putText( 700 | self.frames[self.current_frame], 701 | ", -- change mask to focus on backward", 702 | (dist_2, 850 + 25), 703 | self.font, 704 | 0.5, 705 | (255, 255, 255), 706 | self.font_thickness, 707 | cv2.LINE_AA, 708 | ) 709 | cv2.putText( 710 | self.frames[self.current_frame], 711 | "= -- zoom in", 712 | (dist_2, 850 + 50), 713 | self.font, 714 | 0.5, 715 | (255, 255, 255), 716 | self.font_thickness, 717 | cv2.LINE_AA, 718 | ) 719 | cv2.putText( 720 | self.frames[self.current_frame], 721 | "- -- zoom out", 722 | (dist_2, 850 + 75), 723 | self.font, 724 | 0.5, 725 | (255, 255, 255), 726 | self.font_thickness, 727 | cv2.LINE_AA, 728 | ) 729 | 730 | pass 731 | 732 | def save_mask_result(self, result): 733 | if self.current_frame == self.current_frame_focus: 734 | try: 735 | self.results[self.current_frame]["frame"] 736 | except KeyError: 737 | self.results[self.current_frame] = { 738 | "frame": self.frames[self.current_frame], 739 | "masks": self.masks[self.current_frame], 740 | } 741 | try: 742 | results = self.results[self.current_frame]["results"] 743 | results[self.current_mask_focus] = ( 744 | result + "_" + self.current_difficulty_flag 745 | ) 746 | except KeyError: 747 | # first result indicates mask_id, second indicates animal id 748 | results = { 749 | self.current_mask_focus: result + "_" + self.current_difficulty_flag 750 | } 751 | self.results[self.current_frame]["results"] = results 752 | # change to next random frame if all masks labeled 753 | if ( 754 | self.current_mask_focus 755 | == len(self.masks[self.current_frame_focus]["rois"]) - 1 756 | ): 757 | print("setting new focus") 758 | self.set_new_regular_focus() 759 | self.previous_frame_focus = 0 760 | self.current_difficulty_flag = "easy" 761 | # otherwise next mask 762 | else: 763 | self.current_mask_focus += 1 764 | self.current_difficulty_flag = "easy" 765 | 766 | def display_current_frame(self): 767 | # if masks available, display them 768 | try: 769 | current_masks = self.masks[self.current_frame]["rois"] 770 | for mask_id, mask in enumerate(current_masks): 771 | self.display_mask(mask_id, mask) 772 | except IndexError: 773 | pass 774 | try: 775 | self.display_frame() 776 | except IndexError: 777 | self.break_status = True 778 | 779 | return 780 | 781 | def check_keys(self): 782 | frameclick = cv2.waitKey(1) & 0xFF 783 | # quit 784 | if frameclick == ord("q"): 785 | self.break_status = True 786 | if frameclick == ord("a"): 787 | while True: 788 | if cv2.waitKey(20) & 0xFF == ord("a"): 789 | self.break_status = True 790 | sleep(0.1) 791 | # next frame 792 | if frameclick == ord("k"): 793 | self.current_frame += self.stepsize 794 | cv2.putText( 795 | self.frames[self.current_frame], 796 | "Mask: " + str(self.current_mask), 797 | (10, 100), 798 | self.font, 799 | 4, 800 | (255, 255, 255), 801 | 2, 802 | cv2.LINE_AA, 803 | ) 804 | self.adjust_trackbar() 805 | 806 | # previous frame 807 | if frameclick == ord("j"): 808 | self.current_frame -= self.stepsize 809 | cv2.putText( 810 | self.frames[self.current_frame], 811 | "Mask: " + str(self.current_mask), 812 | (10, 100), 813 | self.font, 814 | 4, 815 | (255, 255, 255), 816 | 2, 817 | cv2.LINE_AA, 818 | ) 819 | self.adjust_trackbar() 820 | 821 | # next frame 822 | if frameclick == ord("i"): 823 | self.current_frame += int(self.stepsize * 5) 824 | cv2.putText( 825 | self.frames[self.current_frame], 826 | "Mask: " + str(self.current_mask), 827 | (10, 100), 828 | self.font, 829 | 4, 830 | (255, 255, 255), 831 | 2, 832 | cv2.LINE_AA, 833 | ) 834 | self.adjust_trackbar() 835 | 836 | # previous frame 837 | if frameclick == ord("u"): 838 | self.current_frame -= int(self.stepsize * 5) 839 | cv2.putText( 840 | self.frames[self.current_frame], 841 | "Mask: " + str(self.current_mask), 842 | (10, 100), 843 | self.font, 844 | 4, 845 | (255, 255, 255), 846 | 2, 847 | cv2.LINE_AA, 848 | ) 849 | self.adjust_trackbar() 850 | 851 | # TODO: finish easy/hard flag 852 | if frameclick == ord("h"): 853 | self.current_difficulty_flag = "hard" 854 | # back to previous focus 855 | if frameclick == ord("p") and self.previous_frame_focus is not None: 856 | if self.previous_frame_focus < len(list(self.results.keys())): 857 | self.current_frame = list(self.results.keys())[ 858 | self.previous_frame_focus 859 | ] 860 | self.previous_frame_focus += 1 861 | else: 862 | self.previous_frame_focus += 0 863 | # back to current focus 864 | if frameclick == ord("b"): 865 | self.set_focus(self.current_frame_focus) 866 | # zoom in 867 | if frameclick == ord("="): 868 | self.zoom = True 869 | # zoom out 870 | if frameclick == ord("-"): 871 | self.zoom = False 872 | # change mask focus increasing 873 | if frameclick == ord("."): 874 | if ( 875 | self.current_mask_focus 876 | < len(self.masks[self.current_frame]["rois"]) - 1 877 | ): 878 | self.current_mask_focus += 1 879 | # change mask focus decreasing 880 | if frameclick == ord(","): 881 | if self.current_mask_focus > 0: 882 | self.current_mask_focus -= 1 883 | # TODO: implement 884 | # manual mode trigger 885 | if frameclick == ord("m"): 886 | self.current_frame_focus = self.current_frame 887 | if self.current_frame_focus in self.results.keys(): 888 | del self.results[self.current_frame_focus] 889 | 890 | # skipping a mask 891 | # TODO: fix what is results 892 | if frameclick == ord("w"): 893 | self.save_mask_result("wrong_mask") 894 | if frameclick == ord("t"): 895 | self.save_mask_result("too_difficult") 896 | # labeling one of the animals 897 | for j in range(1, len(self.name_indicators) + 1): 898 | if frameclick == ord(str(j)): 899 | # TODO: multiple results are in same FOV 900 | self.save_mask_result(self.name_indicators[j - 1]) 901 | 902 | def check_num_results(self): 903 | if len(self.results) == self.num_masks: 904 | self.break_status = True 905 | print("labeled enough masks") 906 | return 907 | 908 | def update(self): 909 | try: 910 | self.display_current_frame() 911 | # self.display_all_indicators() 912 | # self.display_all_keys() 913 | self.check_keys() 914 | self.clocked_save() 915 | self.check_num_results() 916 | except (FileNotFoundError, IndexError): 917 | self.break_status = True 918 | return self.break_status 919 | 920 | 921 | def load_mask(video_path): 922 | gc.disable() 923 | with open(video_path + "SegResults.pkl", "rb") as handle: 924 | masks = pickle.load(handle) 925 | gc.enable() 926 | return masks 927 | 928 | 929 | def load_mask_parallel(video_path): 930 | gc.disable() 931 | with open(video_path + "SegResults.pkl", "rb") as handle: 932 | masks = pickle.load(handle) 933 | gc.enable() 934 | return masks 935 | 936 | 937 | def main(): 938 | args = parser.parse_args() 939 | species = args.species 940 | names = args.names 941 | results_sink = args.results_sink 942 | num_masks = 100 943 | window_size = args.window_size 944 | names = names.split(",") 945 | name_indicators = {} 946 | for idx, el in enumerate(names): 947 | name_indicators[idx] = el 948 | 949 | base_path = "" 950 | videos = [ 951 | "", 952 | ] 953 | if not os.path.exists(results_sink): 954 | os.makedirs(results_sink) 955 | 956 | future = None 957 | executor = None 958 | myhandler = None 959 | idx = 0 960 | 961 | for video_id, filename in enumerate(videos): 962 | 963 | print("LOADING NEXT VIDEO") 964 | 965 | if myhandler: 966 | del myhandler 967 | # check num masks 968 | start = time.time() 969 | video_path = base_path + filename + "/" 970 | 971 | if executor: 972 | preload_masks = future.result() 973 | masks = preload_masks 974 | executor.shutdown(wait=False) 975 | else: 976 | masks = load_mask(video_path) 977 | print("loading mask took", time.time() - start) 978 | print("len masks", str(len(masks))) 979 | idx = int(filename.split("1_")[-1]) 980 | batch_size = 10000 981 | masks = masks[idx * batch_size : (idx + 1) * batch_size] 982 | 983 | frames_path = video_path + "frames/" 984 | stepsize = 50 985 | 986 | print("initiating handler") 987 | # init handler 988 | myhandler = WindowHandler( 989 | frames_path, 990 | name_indicators, 991 | filename, 992 | results_sink, 993 | masks, 994 | num_masks, 995 | stepsize, 996 | window_size, 997 | ) 998 | 999 | breaking = True 1000 | while breaking: 1001 | break_status = myhandler.update() 1002 | if break_status: 1003 | print("saving data, dont interrupt") 1004 | myhandler.save_data() 1005 | print("data saved, good to interrupt") 1006 | cv2.destroyAllWindows() 1007 | breaking = False 1008 | break 1009 | sleep(0.05) 1010 | 1011 | 1012 | if __name__ == "__main__": 1013 | main() 1014 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | absl-py==0.7.1 3 | astor==0.8.0 4 | attrs==19.1.0 5 | backcall==0.1.0 6 | bleach==3.1.0 7 | certifi==2019.6.16 8 | cycler==0.10.0 9 | decorator==4.4.0 10 | defusedxml==0.6.0 11 | entrypoints==0.3 12 | freetype-py==2.1.0.post1 13 | gast==0.2.2 14 | google-pasta==0.1.7 15 | grpcio==1.22.0 16 | h5py==2.9.0 17 | ipykernel==5.1.1 18 | ipython==7.7.0 19 | ipython-genutils==0.2.0 20 | ipywidgets==7.5.1 21 | jedi==0.14.1 22 | Jinja2==2.10.1 23 | joblib==0.13.2 24 | jsonschema==3.0.2 25 | jupyter==1.0.0 26 | jupyter-client==5.3.1 27 | jupyter-console==6.0.0 28 | jupyter-core==4.5.0 29 | Keras==2.2.4 30 | Keras-Applications==1.0.8 31 | Keras-Preprocessing==1.1.0 32 | kiwisolver==1.1.0 33 | Markdown==3.1.1 34 | MarkupSafe==1.1.1 35 | matplotlib==3.1.1 36 | mistune==0.8.4 37 | nbconvert==5.5.0 38 | nbformat==4.4.0 39 | notebook==6.0.0 40 | numpy==1.17.0 41 | opencv-python==4.1.0.25 42 | pandocfilters==1.4.2 43 | parso==0.5.1 44 | pexpect==4.7.0 45 | pickleshare==0.7.5 46 | prometheus-client==0.7.1 47 | prompt-toolkit==2.0.9 48 | protobuf==3.9.1 49 | ptyprocess==0.6.0 50 | Pygments==2.4.2 51 | pyparsing==2.4.2 52 | pyrsistent==0.15.4 53 | python-dateutil==2.8.0 54 | PyYAML==5.1.2 55 | pyzmq==18.0.2 56 | qtconsole==4.5.2 57 | scikit-learn==0.21.3 58 | scipy==1.3.0 59 | Send2Trash==1.5.0 60 | six==1.12.0 61 | tensorboard==1.14.0 62 | tensorflow==1.14.0 63 | tensorflow-estimator==1.14.0 64 | termcolor==1.1.0 65 | terminado==0.8.2 66 | testpath==0.4.2 67 | tornado==6.0.3 68 | tqdm==4.33.0 69 | traitlets==4.3.2 70 | wcwidth==0.1.7 71 | webencodings==0.5.1 72 | Werkzeug==0.15.5 73 | widgetsnbextension==3.5.1 74 | wrapt==1.11.2 75 | -------------------------------------------------------------------------------- /run_primates.sh: -------------------------------------------------------------------------------- 1 | python3 gui.py \ 2 | --filename 'video1' \ 3 | --num_masks 120 \ 4 | --names 'Charles,Max,Paul,Alan' \ 5 | --species 'primate' 6 | --------------------------------------------------------------------------------