├── LICENSE ├── README.md ├── example.jpg ├── requirements.txt. └── run.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Roi Levy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hair-Detection 2 | Hair Mask RCNN using matterport model 3 | ![example](example.jpg) 4 | 5 | ## Set-Up 6 | ### Download requirments:
7 | ```zsh 8 | $ pip install -r requirements.txt 9 | ``` 10 | ### Download and extract dataset and weights directly to repository folder:
11 | [Dataset](https://drive.google.com/file/d/1C-0foSYsKBh1bxp9XRIMXKUO6er4OqZc/view?usp=sharing)
12 | [Weights](https://drive.google.com/file/d/1ZbWTqWLi7w-lVvf7TQ59Gqil_SJnofbE/view?usp=sharing)

13 | 14 | #### Folder path should look like 15 | .
16 | |
17 | |── dataset
18 | │               |── train
19 | │               |── val
20 | |── mask_rcnn_hair_0200.h5
21 | |── run.py
22 |
23 | ## Usage 24 | ### Training:
25 | To train run
26 | ```zsh 27 | $ python3 run.py train --dataset=path/to/dataset --weights=path/to/weights 28 | ``` 29 | For example to start training from the coco model
30 | ```zsh 31 | $ python3 run.py train --dataset=./data/dataset --weights=coco 32 | ``` 33 | Or to continue training from custom weights(for example project weights)
34 | ```zsh 35 | $ python3 run.py train --dataset=./data/dataset --weights=~/proj/Hair-Detection/data/weights/mask_rcnn_hair_0200.h5 36 | ``` 37 | ### Run:
38 | Run this command on an image
39 | ```zsh 40 | $ python3 run.py mask --image=path/to/image --weights=path/to/weights 41 | ``` 42 | The image is saved into the project directory 43 | ## Dataset 44 | [Figaro-1k](https://drive.google.com/file/d/1G7VWeIy2t0yM7bdOeFrf6Eqf6Z_aF0f-/view?usp=sharing): It contains 1050 unconstrained view images with persons, subdivided into seven different hairstyles classes (straight, wavy, curly, kinky, braids, dreadlocks, short), where each image is provided with the related manually segmented hair mask. 45 | -------------------------------------------------------------------------------- /example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Papich23691/Hair-Detection/0efb9ccd17502eeb11d7dda5e92f45407a972357/example.jpg -------------------------------------------------------------------------------- /requirements.txt.: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | alabaster==0.7.12 3 | appnope==0.1.0 4 | astor==0.8.0 5 | astroid==2.2.5 6 | attrs==19.1.0 7 | autopep8==1.4.4 8 | Babel==2.7.0 9 | backcall==0.1.0 10 | backports.weakref==1.0.post1 11 | beautifulsoup4==4.8.0 12 | bleach==3.1.0 13 | blis==0.2.4 14 | bs4==0.0.1 15 | certifi==2019.3.9 16 | chardet==3.0.4 17 | Click==7.0 18 | contextlib2==0.6.0 19 | cycler==0.10.0 20 | cymem==2.0.2 21 | Cython==0.29.13 22 | decorator==4.4.0 23 | defusedxml==0.6.0 24 | dlib==19.17.0 25 | docopt==0.6.2 26 | docutils==0.14 27 | en-core-web-lg==2.1.0 28 | en-core-web-sm==2.1.0 29 | entrypoints==0.3 30 | Flask==1.1.1 31 | gast==0.2.2 32 | google-pasta==0.1.7 33 | grpcio==1.21.1 34 | h5py==2.9.0 35 | idna==2.8 36 | imageio==2.5.0 37 | imagesize==1.1.0 38 | imutils==0.5.3 39 | ipykernel==5.1.2 40 | ipython==7.7.0 41 | ipython-genutils==0.2.0 42 | ipywidgets==7.5.1 43 | isort==4.3.20 44 | itsdangerous==1.1.0 45 | jedi==0.14.1 46 | Jinja2==2.10.1 47 | joblib==0.13.2 48 | jsonschema==3.0.1 49 | jupyter-client==5.3.3 50 | jupyter-console==6.0.0 51 | jupyter-core==4.5.0 52 | Keras==2.2.4 53 | Keras-Applications==1.0.7 54 | Keras-Preprocessing==1.0.9 55 | kiwisolver==1.1.0 56 | lazy-object-proxy==1.4.1 57 | lxml==4.4.1 58 | Markdown==3.1.1 59 | MarkupSafe==1.1.1 60 | matplotlib==3.1.1 61 | mccabe==0.6.1 62 | mistune==0.8.4 63 | mock==3.0.5 64 | mrcnn==0.2 65 | murmurhash==1.0.2 66 | nbconvert==5.6.0 67 | nbformat==4.4.0 68 | networkx==2.3 69 | nose==1.3.7 70 | notebook==6.0.1 71 | numpy==1.16.3 72 | opencv-python==4.1.0.25 73 | packaging==19.0 74 | pandas==0.25.0 75 | pandocfilters==1.4.2 76 | parso==0.5.1 77 | pexpect==4.7.0 78 | pickleshare==0.7.5 79 | Pillow==6.1.0 80 | pipreqs==0.4.9 81 | plac==0.9.6 82 | preshed==2.0.1 83 | prometheus-client==0.7.1 84 | prompt-toolkit==2.0.9 85 | protobuf==3.7.1 86 | ptyprocess==0.6.0 87 | pycocotools==2.1 88 | pycodestyle==2.5.0 89 | Pygments==2.4.1 90 | pylint==2.3.1 91 | pyparsing==2.4.0 92 | pyrsistent==0.15.2 93 | python-dateutil==2.8.0 94 | pytz==2019.1 95 | PyWavelets==1.0.3 96 | PyYAML==5.1 97 | pyzmq==18.1.0 98 | qtconsole==4.5.5 99 | requests==2.21.0 100 | scikit-image==0.15.0 101 | scikit-learn==0.21.2 102 | scipy==1.3.0 103 | Send2Trash==1.5.0 104 | six==1.12.0 105 | sklearn==0.0 106 | snowballstemmer==1.2.1 107 | soupsieve==1.9.2 108 | spacy==2.1.4 109 | Sphinx==2.0.1 110 | sphinxcontrib-applehelp==1.0.1 111 | sphinxcontrib-devhelp==1.0.1 112 | sphinxcontrib-htmlhelp==1.0.2 113 | sphinxcontrib-jsmath==1.0.1 114 | sphinxcontrib-qthelp==1.0.2 115 | sphinxcontrib-serializinghtml==1.1.3 116 | srsly==0.0.5 117 | TBB==0.1 118 | tensorboard==1.14.0 119 | tensorflow==1.14.0 120 | tensorflow-cpu==0.0.0 121 | tensorflow-estimator==1.14.0 122 | termcolor==1.1.0 123 | terminado==0.8.2 124 | testpath==0.4.2 125 | tflearn==0.3.2 126 | thinc==7.0.4 127 | tornado==6.0.3 128 | tqdm==4.32.0 129 | traitlets==4.3.2 130 | typed-ast==1.3.5 131 | urllib3==1.24.3 132 | virtualenv==16.6.1 133 | wasabi==0.2.2 134 | wcwidth==0.1.7 135 | webencodings==0.5.1 136 | Werkzeug==0.15.4 137 | widgetsnbextension==3.5.1 138 | wrapt==1.11.1 139 | xmltodict==0.12.0 140 | yarg==0.1.9 141 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2017 Matterport, Inc. 3 | Licensed under the MIT License (see LICENSE for details) 4 | Written by Waleed Abdulla 5 | 6 | # Train a new model starting from pre-trained COCO weights 7 | python3 run.py train --dataset=/path/to/dataset --weights=coco 8 | # Resume training a model that you had trained earlier 9 | python3 run.py train --dataset=/path/to/dataset --weights=last 10 | # Train a new model starting from ImageNet weights 11 | python3 run.py train --dataset=/path/to/dataset --weights=imagenet 12 | # Apply color mask to an image 13 | python3 run.py mask --weights=/path/to/weights/file.h5 --image= 14 | # Apply color mask to video using the last weights you trained 15 | python3 run.py mask --weights=last --video= 16 | """ 17 | 18 | import os 19 | import sys 20 | import json 21 | import datetime 22 | import numpy as np 23 | import skimage.draw 24 | import cv2 25 | from mrcnn.visualize import display_instances 26 | import matplotlib.pyplot as plt 27 | 28 | # Root directory of the project 29 | ROOT_DIR = os.path.abspath("./") 30 | 31 | # Import Mask RCNN 32 | sys.path.append(ROOT_DIR) # To find local version of the library 33 | from mrcnn.config import Config 34 | from mrcnn import model as modellib, utils 35 | 36 | # Path to trained weights file 37 | COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, "./mask_rcnn_coco.h5") 38 | 39 | # Directory to save logs and model checkpoints, if not provided 40 | # through the command line argument --logs 41 | DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs") 42 | 43 | ############################################################ 44 | # Configurations 45 | ############################################################ 46 | 47 | 48 | class CustomConfig(Config): 49 | """Configuration for training on the Hair dataset. 50 | Derives from the base Config class and overrides some values. 51 | """ 52 | # Give the configuration a recognizable name 53 | NAME = "Hair" 54 | 55 | # Running on CPU 56 | IMAGES_PER_GPU = 1 57 | 58 | # Number of classes (including background) 59 | NUM_CLASSES = 1 + 1 # Background + Hair 60 | 61 | # Number of training steps per epoch 62 | STEPS_PER_EPOCH = 10 63 | 64 | # Skip detections with < 90% confidence 65 | DETECTION_MIN_CONFIDENCE = 0.9 66 | 67 | 68 | ############################################################ 69 | # Dataset 70 | ############################################################ 71 | 72 | class CustomDataset(utils.Dataset): 73 | def load_custom(self, dataset_dir, subset): 74 | # Add classes. We have only one class to add. 75 | self.add_class("Hair", 1, "Hair") 76 | dataset_dir = os.path.join(dataset_dir, subset) 77 | for filename in os.listdir(os.path.join(dataset_dir, 'photos')): 78 | if not filename.endswith('jpg'): #Only jpg photos from dataset 79 | continue 80 | input_path = os.path.join(dataset_dir, 'photos',filename) 81 | img = cv2.imread(input_path) 82 | height, width = img.shape[:2] 83 | 84 | self.add_image( 85 | "Hair", # for a single class just add the name here 86 | image_id=filename, # use file name as a unique image id 87 | path= input_path, 88 | width=width, height=height) 89 | 90 | def load_mask(self,image_id): 91 | """Generate instance masks for an image from database. 92 | Returns: 93 | masks: A bool array of shape [height, width, instance count] with 94 | one mask per instance. 95 | """ 96 | image_info = self.image_info[image_id] 97 | if image_info["source"] != "Hair": 98 | return super(self.__class__, self).load_mask(image_id) 99 | 100 | info = self.image_info[image_id] 101 | if info['id'].startswith('v'): # If validation mask or training 102 | dataset_dir = os.path.join(args.dataset, 'val') 103 | else: 104 | dataset_dir = os.path.join(args.dataset, 'train') 105 | for maskf in os.listdir(os.path.join(dataset_dir, 'masks')): 106 | mname,png = os.path.splitext(maskf) 107 | iname,jpg = os.path.splitext(info['id']) 108 | if mname == iname: 109 | mask=cv2.imread(os.path.join(dataset_dir, 'masks',maskf)) # Reading mask data from dataset 110 | 111 | # Return mask, and array of class IDs of each instance. Since we have 112 | # one class ID only, we return an array of 1s 113 | return mask.astype(np.bool), np.ones([mask.shape[-1]], dtype=np.int32) 114 | 115 | def image_reference(self, image_id): 116 | """Return the path of the image.""" 117 | info = self.image_info[image_id] 118 | if info["source"] == "Hair": 119 | return info["path"] 120 | else: 121 | super(self.__class__, self).image_reference(image_id) 122 | 123 | 124 | def train(model): 125 | """Train the model.""" 126 | # Training dataset. 127 | dataset_train = CustomDataset() 128 | dataset_train.load_custom(args.dataset, "train") 129 | dataset_train.prepare() 130 | 131 | # Validation dataset 132 | dataset_val = CustomDataset() 133 | dataset_val.load_custom(args.dataset, "val") 134 | dataset_val.prepare() 135 | 136 | # *** This training schedule is an example. Update to your needs *** 137 | 138 | # Training - Stage 1 139 | print("Training network heads") 140 | model.train(dataset_train, dataset_val, 141 | learning_rate=config.LEARNING_RATE, 142 | epochs=60, 143 | layers='heads') 144 | 145 | # Training - Stage 2 146 | # Finetune layers from ResNet stage 4 and up 147 | print("Fine tune Resnet stage 4 and up") 148 | model.train(dataset_train, dataset_val, 149 | learning_rate=config.LEARNING_RATE, 150 | epochs=140, 151 | layers='4+') 152 | 153 | # Training - Stage 3 154 | # Fine tune all layers 155 | print("Fine tune all layers") 156 | model.train(dataset_train, dataset_val, 157 | learning_rate=config.LEARNING_RATE / 10, 158 | epochs=220, 159 | layers='all') 160 | 161 | 162 | def apply_mask(image, mask): 163 | """Apply color splash effect. 164 | image: RGB image [height, width, 3] 165 | mask: instance segmentation mask [height, width, instance count] 166 | Returns result image. 167 | """ 168 | # Make a grayscale copy of the image. The grayscale copy still 169 | # has 3 RGB channels, though. 170 | blank = np.zeros(image.shape, dtype=np.uint8) 171 | # Copy color pixels from the original color image where mask is set 172 | if mask.shape[-1] > 0: 173 | # We're treating all instances as one, so collapse the mask into one layer 174 | mask = (np.sum(mask, -1, keepdims=True) >= 1) 175 | crop = np.where(mask, image, blank).astype(np.uint8) 176 | else: 177 | crop = blank.astype(np.uint8) 178 | return crop 179 | 180 | 181 | def detect_and_mask(model, image_path=None, video_path=None): 182 | assert image_path or video_path 183 | 184 | if image_path: 185 | import cv2 186 | # Run model detection and generate the mask 187 | print("Running on {}".format(args.image)) 188 | # Read image 189 | image = cv2.imread(args.image) 190 | # Detect objects 191 | image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) 192 | r = model.detect([image], verbose=1)[0] 193 | # Mask 194 | crop = apply_mask(image, r['masks']) 195 | # Save output 196 | crop = cv2.cvtColor(crop,cv2.COLOR_RGB2BGR) 197 | file_name = "crop_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now()) 198 | cv2.imwrite(file_name,crop) 199 | elif video_path: 200 | import cv2 201 | # Video capture 202 | vcapture = cv2.VideoCapture(video_path) 203 | width = int(vcapture.get(cv2.CAP_PROP_FRAME_WIDTH)) 204 | height = int(vcapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 205 | fps = vcapture.get(cv2.CAP_PROP_FPS) 206 | 207 | # Define codec and create video writer 208 | file_name = "crop_{:%Y%m%dT%H%M%S}.avi".format(datetime.datetime.now()) 209 | vwriter = cv2.VideoWriter(file_name, 210 | cv2.VideoWriter_fourcc(*'MJPG'), 211 | fps, (width, height)) 212 | 213 | count = 0 214 | success = True 215 | while success: 216 | print("frame: ", count) 217 | # Read next image 218 | success, image = vcapture.read() 219 | if success: 220 | # OpenCV returns images as BGR, convert to RGB 221 | image = image[..., ::-1] 222 | # Detect objects 223 | r = model.detect([image], verbose=0)[0] 224 | # Color splash 225 | crop = apply_mask(image, r['masks']) 226 | # RGB -> BGR to save image to video 227 | crop = crop[..., ::-1] 228 | # Add image to video writer 229 | vwriter.write(crop) 230 | count += 1 231 | vwriter.release() 232 | print("Saved to ", file_name) 233 | 234 | ############################################################ 235 | # Training 236 | ############################################################ 237 | 238 | if __name__ == '__main__': 239 | import argparse 240 | 241 | # Parse command line arguments 242 | parser = argparse.ArgumentParser( 243 | description='Train Mask R-CNN to detect custom class.') 244 | parser.add_argument("command", 245 | metavar="", 246 | help="'train' or 'mask'") 247 | parser.add_argument('--dataset', required=False, 248 | metavar="/path/to/custom/dataset/", 249 | help='Directory of the custom dataset') 250 | parser.add_argument('--weights', required=True, 251 | metavar="/path/to/weights.h5", 252 | help="Path to weights .h5 file or 'coco'") 253 | parser.add_argument('--logs', required=False, 254 | default=DEFAULT_LOGS_DIR, 255 | metavar="/path/to/logs/", 256 | help='Logs and checkpoints directory (default=logs/)') 257 | parser.add_argument('--image', required=False, 258 | metavar="path or URL to image", 259 | help='Image to apply the color mask effect on') 260 | parser.add_argument('--video', required=False, 261 | metavar="path or URL to video", 262 | help='Video to apply the color mask effect on') 263 | args = parser.parse_args() 264 | 265 | # Validate arguments 266 | if args.command == "train": 267 | assert args.dataset, "Argument --dataset is required for training" 268 | elif args.command == "mask": 269 | assert args.image or args.video,\ 270 | "Provide --image or --video to apply color mask" 271 | 272 | print("Weights: ", args.weights) 273 | print("Dataset: ", args.dataset) 274 | print("Logs: ", args.logs) 275 | 276 | # Configurations 277 | if args.command == "train": 278 | config = CustomConfig() 279 | else: 280 | class InferenceConfig(CustomConfig): 281 | # Set batch size to 1 since we'll be running inference on 282 | # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU 283 | GPU_COUNT = 1 284 | IMAGES_PER_GPU = 1 285 | config = InferenceConfig() 286 | config.display() 287 | 288 | # Create model 289 | if args.command == "train": 290 | model = modellib.MaskRCNN(mode="training", config=config, 291 | model_dir=args.logs) 292 | else: 293 | model = modellib.MaskRCNN(mode="inference", config=config, 294 | model_dir=args.logs) 295 | 296 | # Select weights file to load 297 | if args.weights.lower() == "coco": 298 | weights_path = COCO_WEIGHTS_PATH 299 | # Download weights file 300 | if not os.path.exists(weights_path): 301 | utils.download_trained_weights(weights_path) 302 | elif args.weights.lower() == "last": 303 | # Find last trained weights 304 | weights_path = model.find_last()[1] 305 | elif args.weights.lower() == "imagenet": 306 | # Start from ImageNet trained weights 307 | weights_path = model.get_imagenet_weights() 308 | else: 309 | weights_path = args.weights 310 | 311 | # Load weights 312 | print("Loading weights ", weights_path) 313 | if args.weights.lower() == "coco": 314 | # Exclude the last layers because they require a matching 315 | # number of classes 316 | model.load_weights(weights_path, by_name=True, exclude=[ 317 | "mrcnn_class_logits", "mrcnn_bbox_fc", 318 | "mrcnn_bbox", "mrcnn_mask"]) 319 | else: 320 | model.load_weights(weights_path, by_name=True) 321 | 322 | # Train or evaluate 323 | if args.command == "train": 324 | train(model) 325 | elif args.command == "mask": 326 | detect_and_mask(model, image_path=args.image, 327 | video_path=args.video) 328 | else: 329 | print("'{}' is not recognized. " 330 | "Use 'train' or 'mask'".format(args.command)) 331 | --------------------------------------------------------------------------------